Spaces:
Sleeping
Sleeping
Commit ·
89c5a3f
1
Parent(s): 8c3c2ec
{commit_message}
Browse files- .official_space.py +0 -3
- app.py +16 -9
.official_space.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:51292da76658340750a198cba125f12668ff88f79248ffc920e168b645698b43
|
| 3 |
-
size 24588
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -59,9 +59,6 @@ pipe.to("cuda", torch.bfloat16)
|
|
| 59 |
|
| 60 |
print("Model loaded successfully!")
|
| 61 |
|
| 62 |
-
# Initialize LLM for prompt enhancement
|
| 63 |
-
llm_client = InferenceClient()
|
| 64 |
-
|
| 65 |
# Vision-Language model for prompt enhancement
|
| 66 |
VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
| 67 |
|
|
@@ -93,9 +90,16 @@ def image_to_base64(image) -> str:
|
|
| 93 |
image.save(buffered, format="JPEG", quality=85)
|
| 94 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 95 |
|
| 96 |
-
def enhance_prompt(prompt: str, reference_image=None) -> str:
|
| 97 |
"""Enhance the prompt using a VL model, optionally with a reference image."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
try:
|
|
|
|
|
|
|
|
|
|
| 99 |
# Build user content based on whether image is provided
|
| 100 |
if reference_image is not None:
|
| 101 |
# Convert image to base64 for the API
|
|
@@ -118,7 +122,7 @@ def enhance_prompt(prompt: str, reference_image=None) -> str:
|
|
| 118 |
{"role": "user", "content": user_content}
|
| 119 |
]
|
| 120 |
|
| 121 |
-
response =
|
| 122 |
messages=messages,
|
| 123 |
model=VL_MODEL,
|
| 124 |
max_tokens=250,
|
|
@@ -130,7 +134,7 @@ def enhance_prompt(prompt: str, reference_image=None) -> str:
|
|
| 130 |
# Remove any thinking tags if present
|
| 131 |
if "<think>" in enhanced:
|
| 132 |
enhanced = re.sub(r'<think>.*?</think>', '', enhanced, flags=re.DOTALL).strip()
|
| 133 |
-
print(f"[Prompt Enhancement] Model: {
|
| 134 |
print(f"[Prompt Enhancement] Original: {prompt}")
|
| 135 |
print(f"[Prompt Enhancement] Enhanced: {enhanced}")
|
| 136 |
return enhanced
|
|
@@ -151,11 +155,13 @@ def infer(
|
|
| 151 |
num_inference_steps,
|
| 152 |
use_prompt_enhancement,
|
| 153 |
reference_image,
|
|
|
|
| 154 |
progress=gr.Progress(track_tqdm=True),
|
| 155 |
):
|
| 156 |
# Enhance prompt if requested
|
| 157 |
if use_prompt_enhancement:
|
| 158 |
-
|
|
|
|
| 159 |
|
| 160 |
if randomize_seed:
|
| 161 |
seed = random.randint(0, MAX_SEED)
|
|
@@ -286,8 +292,9 @@ footer { display: none !important; }
|
|
| 286 |
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
|
| 287 |
with gr.Column(elem_id="col-container"):
|
| 288 |
gr.Markdown("# Z-Image", elem_classes="title")
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
|
| 292 |
# Prompt
|
| 293 |
prompt = gr.Textbox(
|
|
|
|
| 59 |
|
| 60 |
print("Model loaded successfully!")
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
# Vision-Language model for prompt enhancement
|
| 63 |
VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
| 64 |
|
|
|
|
| 90 |
image.save(buffered, format="JPEG", quality=85)
|
| 91 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 92 |
|
| 93 |
+
def enhance_prompt(prompt: str, reference_image=None, oauth_token: str = None) -> str:
|
| 94 |
"""Enhance the prompt using a VL model, optionally with a reference image."""
|
| 95 |
+
if not oauth_token:
|
| 96 |
+
print("[Prompt Enhancement] No auth token provided")
|
| 97 |
+
return prompt
|
| 98 |
+
|
| 99 |
try:
|
| 100 |
+
# Create client with user's token
|
| 101 |
+
client = InferenceClient(token=oauth_token)
|
| 102 |
+
|
| 103 |
# Build user content based on whether image is provided
|
| 104 |
if reference_image is not None:
|
| 105 |
# Convert image to base64 for the API
|
|
|
|
| 122 |
{"role": "user", "content": user_content}
|
| 123 |
]
|
| 124 |
|
| 125 |
+
response = client.chat_completion(
|
| 126 |
messages=messages,
|
| 127 |
model=VL_MODEL,
|
| 128 |
max_tokens=250,
|
|
|
|
| 134 |
# Remove any thinking tags if present
|
| 135 |
if "<think>" in enhanced:
|
| 136 |
enhanced = re.sub(r'<think>.*?</think>', '', enhanced, flags=re.DOTALL).strip()
|
| 137 |
+
print(f"[Prompt Enhancement] Model: {VL_MODEL}")
|
| 138 |
print(f"[Prompt Enhancement] Original: {prompt}")
|
| 139 |
print(f"[Prompt Enhancement] Enhanced: {enhanced}")
|
| 140 |
return enhanced
|
|
|
|
| 155 |
num_inference_steps,
|
| 156 |
use_prompt_enhancement,
|
| 157 |
reference_image,
|
| 158 |
+
oauth_token: gr.OAuthToken | None,
|
| 159 |
progress=gr.Progress(track_tqdm=True),
|
| 160 |
):
|
| 161 |
# Enhance prompt if requested
|
| 162 |
if use_prompt_enhancement:
|
| 163 |
+
token = oauth_token.token if oauth_token else None
|
| 164 |
+
prompt = enhance_prompt(prompt, reference_image, token)
|
| 165 |
|
| 166 |
if randomize_seed:
|
| 167 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
| 292 |
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
|
| 293 |
with gr.Column(elem_id="col-container"):
|
| 294 |
gr.Markdown("# Z-Image", elem_classes="title")
|
| 295 |
+
|
| 296 |
+
# Login button for HF authentication
|
| 297 |
+
login_btn = gr.LoginButton(value="Sign in with Hugging Face")
|
| 298 |
|
| 299 |
# Prompt
|
| 300 |
prompt = gr.Textbox(
|