GheeButter commited on
Commit
89c5a3f
·
1 Parent(s): 8c3c2ec

{commit_message}

Browse files
Files changed (2) hide show
  1. .official_space.py +0 -3
  2. app.py +16 -9
.official_space.py DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:51292da76658340750a198cba125f12668ff88f79248ffc920e168b645698b43
3
- size 24588
 
 
 
 
app.py CHANGED
@@ -59,9 +59,6 @@ pipe.to("cuda", torch.bfloat16)
59
 
60
  print("Model loaded successfully!")
61
 
62
- # Initialize LLM for prompt enhancement
63
- llm_client = InferenceClient()
64
-
65
  # Vision-Language model for prompt enhancement
66
  VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
67
 
@@ -93,9 +90,16 @@ def image_to_base64(image) -> str:
93
  image.save(buffered, format="JPEG", quality=85)
94
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
95
 
96
- def enhance_prompt(prompt: str, reference_image=None) -> str:
97
  """Enhance the prompt using a VL model, optionally with a reference image."""
 
 
 
 
98
  try:
 
 
 
99
  # Build user content based on whether image is provided
100
  if reference_image is not None:
101
  # Convert image to base64 for the API
@@ -118,7 +122,7 @@ def enhance_prompt(prompt: str, reference_image=None) -> str:
118
  {"role": "user", "content": user_content}
119
  ]
120
 
121
- response = llm_client.chat_completion(
122
  messages=messages,
123
  model=VL_MODEL,
124
  max_tokens=250,
@@ -130,7 +134,7 @@ def enhance_prompt(prompt: str, reference_image=None) -> str:
130
  # Remove any thinking tags if present
131
  if "<think>" in enhanced:
132
  enhanced = re.sub(r'<think>.*?</think>', '', enhanced, flags=re.DOTALL).strip()
133
- print(f"[Prompt Enhancement] Model: {model}")
134
  print(f"[Prompt Enhancement] Original: {prompt}")
135
  print(f"[Prompt Enhancement] Enhanced: {enhanced}")
136
  return enhanced
@@ -151,11 +155,13 @@ def infer(
151
  num_inference_steps,
152
  use_prompt_enhancement,
153
  reference_image,
 
154
  progress=gr.Progress(track_tqdm=True),
155
  ):
156
  # Enhance prompt if requested
157
  if use_prompt_enhancement:
158
- prompt = enhance_prompt(prompt, reference_image)
 
159
 
160
  if randomize_seed:
161
  seed = random.randint(0, MAX_SEED)
@@ -286,8 +292,9 @@ footer { display: none !important; }
286
  with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
287
  with gr.Column(elem_id="col-container"):
288
  gr.Markdown("# Z-Image", elem_classes="title")
289
-
290
- gr.Markdown("", elem_classes="spacer")
 
291
 
292
  # Prompt
293
  prompt = gr.Textbox(
 
59
 
60
  print("Model loaded successfully!")
61
 
 
 
 
62
  # Vision-Language model for prompt enhancement
63
  VL_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
64
 
 
90
  image.save(buffered, format="JPEG", quality=85)
91
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
92
 
93
+ def enhance_prompt(prompt: str, reference_image=None, oauth_token: str = None) -> str:
94
  """Enhance the prompt using a VL model, optionally with a reference image."""
95
+ if not oauth_token:
96
+ print("[Prompt Enhancement] No auth token provided")
97
+ return prompt
98
+
99
  try:
100
+ # Create client with user's token
101
+ client = InferenceClient(token=oauth_token)
102
+
103
  # Build user content based on whether image is provided
104
  if reference_image is not None:
105
  # Convert image to base64 for the API
 
122
  {"role": "user", "content": user_content}
123
  ]
124
 
125
+ response = client.chat_completion(
126
  messages=messages,
127
  model=VL_MODEL,
128
  max_tokens=250,
 
134
  # Remove any thinking tags if present
135
  if "<think>" in enhanced:
136
  enhanced = re.sub(r'<think>.*?</think>', '', enhanced, flags=re.DOTALL).strip()
137
+ print(f"[Prompt Enhancement] Model: {VL_MODEL}")
138
  print(f"[Prompt Enhancement] Original: {prompt}")
139
  print(f"[Prompt Enhancement] Enhanced: {enhanced}")
140
  return enhanced
 
155
  num_inference_steps,
156
  use_prompt_enhancement,
157
  reference_image,
158
+ oauth_token: gr.OAuthToken | None,
159
  progress=gr.Progress(track_tqdm=True),
160
  ):
161
  # Enhance prompt if requested
162
  if use_prompt_enhancement:
163
+ token = oauth_token.token if oauth_token else None
164
+ prompt = enhance_prompt(prompt, reference_image, token)
165
 
166
  if randomize_seed:
167
  seed = random.randint(0, MAX_SEED)
 
292
  with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
293
  with gr.Column(elem_id="col-container"):
294
  gr.Markdown("# Z-Image", elem_classes="title")
295
+
296
+ # Login button for HF authentication
297
+ login_btn = gr.LoginButton(value="Sign in with Hugging Face")
298
 
299
  # Prompt
300
  prompt = gr.Textbox(