linoyts HF Staff commited on
Commit
8106715
·
verified ·
1 Parent(s): ac6b97a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -32
app.py CHANGED
@@ -97,63 +97,79 @@ Please strictly follow the rewriting rules below:
97
  "Rewritten": "..."
98
  }
99
  '''
100
- # --- Prompt Enhancement using Hugging Face InferenceClient ---
101
  def polish_prompt_hf(prompt, img_list):
102
  """
103
  Rewrites the prompt using a Hugging Face InferenceClient.
 
104
  """
105
  # Ensure HF_TOKEN is set
106
  api_key = os.environ.get("HF_TOKEN")
107
  if not api_key:
108
  print("Warning: HF_TOKEN not set. Falling back to original prompt.")
109
  return prompt
110
-
 
111
  try:
112
  # Initialize the client
113
- prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
114
  client = InferenceClient(
115
  provider="nebius",
116
  api_key=api_key,
117
  )
118
 
 
 
 
 
 
 
 
 
119
  image_url = None
120
- if img is not None:
121
- # If img is a PIL Image
122
- if hasattr(img, 'save'): # Check if it's a PIL Image
123
- buffered = BytesIO()
124
- img.save(buffered, format="PNG")
125
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
126
- image_url = f"data:image/png;base64,{img_base64}"
127
- # If img is already a file path (string)
128
- elif isinstance(img, str):
129
- with open(img, "rb") as image_file:
130
- img_base64 = base64.b64encode(image_file.read()).decode('utf-8')
131
- image_url = f"data:image/png;base64,{img_base64}"
132
- else:
133
- print(f"Warning: Unexpected image type: {type(img)}")
134
- return original_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # Format the messages for the chat completions API
137
  messages = [
138
  {"role": "system", "content": system_prompt},
139
  {
140
  "role": "user",
141
- "content": [
142
- {
143
- "type": "text",
144
- "text": original_prompt
145
- },
146
- {
147
- "type": "image_url",
148
- "image_url": {
149
- "url": image_url
150
- }
151
- }
152
- ]
153
  }
154
  ]
155
 
156
-
157
  # Call the API
158
  completion = client.chat.completions.create(
159
  model="Qwen/Qwen2.5-VL-72B-Instruct",
@@ -181,7 +197,9 @@ def polish_prompt_hf(prompt, img_list):
181
  except Exception as e:
182
  print(f"Error during API call to Hugging Face: {e}")
183
  # Fallback to original prompt if enhancement fails
184
- return original_prompt
 
 
185
 
186
  def encode_image(pil_image):
187
  import io
 
97
  "Rewritten": "..."
98
  }
99
  '''
100
+
101
  def polish_prompt_hf(prompt, img_list):
102
  """
103
  Rewrites the prompt using a Hugging Face InferenceClient.
104
+ Supports multiple images via img_list.
105
  """
106
  # Ensure HF_TOKEN is set
107
  api_key = os.environ.get("HF_TOKEN")
108
  if not api_key:
109
  print("Warning: HF_TOKEN not set. Falling back to original prompt.")
110
  return prompt
111
+ prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
112
+ system_prompt = "you are a helpful assistant, you should provide useful answers to users."
113
  try:
114
  # Initialize the client
 
115
  client = InferenceClient(
116
  provider="nebius",
117
  api_key=api_key,
118
  )
119
 
120
+ # Convert list of images to base64 data URLs
121
+ image_urls = []
122
+ if img_list is not None:
123
+ # Ensure img_list is actually a list
124
+ if not isinstance(img_list, list):
125
+ img_list = [img_list]
126
+
127
+ for img in img_list:
128
  image_url = None
129
+ # If img is a PIL Image
130
+ if hasattr(img, 'save'): # Check if it's a PIL Image
131
+ buffered = BytesIO()
132
+ img.save(buffered, format="PNG")
133
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
134
+ image_url = f"data:image/png;base64,{img_base64}"
135
+ # If img is already a file path (string)
136
+ elif isinstance(img, str):
137
+ with open(img, "rb") as image_file:
138
+ img_base64 = base64.b64encode(image_file.read()).decode('utf-8')
139
+ image_url = f"data:image/png;base64,{img_base64}"
140
+ else:
141
+ print(f"Warning: Unexpected image type: {type(img)}, skipping...")
142
+ continue
143
+
144
+ if image_url:
145
+ image_urls.append(image_url)
146
+
147
+ # Build the content array with text first, then all images
148
+ content = [
149
+ {
150
+ "type": "text",
151
+ "text": prompt
152
+ }
153
+ ]
154
+
155
+ # Add all images to the content
156
+ for image_url in image_urls:
157
+ content.append({
158
+ "type": "image_url",
159
+ "image_url": {
160
+ "url": image_url
161
+ }
162
+ })
163
 
164
  # Format the messages for the chat completions API
165
  messages = [
166
  {"role": "system", "content": system_prompt},
167
  {
168
  "role": "user",
169
+ "content": content
 
 
 
 
 
 
 
 
 
 
 
170
  }
171
  ]
172
 
 
173
  # Call the API
174
  completion = client.chat.completions.create(
175
  model="Qwen/Qwen2.5-VL-72B-Instruct",
 
197
  except Exception as e:
198
  print(f"Error during API call to Hugging Face: {e}")
199
  # Fallback to original prompt if enhancement fails
200
+ return prompt
201
+
202
+
203
 
204
  def encode_image(pil_image):
205
  import io