Hug0endob commited on
Commit
c48daf3
·
verified ·
1 Parent(s): 73fc44c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -89
app.py CHANGED
@@ -219,62 +219,87 @@ def upload_file_to_mistral(client, path, filename=None, purpose="batch"):
219
 
220
 
221
  def build_messages_for_image(prompt: str, b64_jpg: str = None, image_url: str = None):
222
- user_content = [{"type": "text", "text": prompt}]
223
  if image_url:
224
- user_content.append({"type": "image_url", "image_url": image_url})
225
  elif b64_jpg:
226
- user_content.append({"type": "image_base64", "image_base64": b64_jpg})
 
227
  else:
228
  raise ValueError("Either image_url or b64_jpg required")
229
- return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": user_content}]
230
 
231
 
232
  def build_messages_for_text(prompt: str, extra_text: str):
233
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"{prompt}\n\n{extra_text}"}]
234
 
235
 
236
- # New helper: normalize messages so content is always a plain string
237
- def normalize_messages(messages):
238
- out = []
239
- for m in messages:
240
- if not isinstance(m, dict):
241
- out.append(m)
242
- continue
243
- c = m.get("content")
244
- if isinstance(c, list):
245
- parts = []
246
- for item in c:
247
- if isinstance(item, str):
248
- parts.append(item)
249
- elif isinstance(item, dict):
250
- typ = item.get("type")
251
- if typ == "text" and item.get("text"):
252
- parts.append(item["text"])
253
- elif typ == "image_url" and item.get("image_url"):
254
- parts.append(item["image_url"])
255
- elif typ == "image_base64" and item.get("image_base64"):
256
- # convert to data URL to satisfy string requirement
257
- parts.append("data:image/jpeg;base64," + item["image_base64"])
258
- else:
259
- parts.append(item.get("text") or item.get("image_url") or item.get("image_base64") or "")
260
  else:
261
- parts.append(str(item))
262
- newc = "\n\n".join(p for p in parts if p).strip()
263
- nm = m.copy()
264
- nm["content"] = newc
265
- out.append(nm)
266
- elif not isinstance(c, str):
267
- nm = m.copy()
268
- nm["content"] = str(c or "")
269
- out.append(nm)
270
- else:
271
- out.append(m)
272
- return out
273
 
274
 
275
  def stream_and_collect(client, model, messages, parts: list):
276
  try:
277
- norm_msgs = normalize_messages(messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  stream_gen = None
279
  try:
280
  stream_gen = client.chat.stream(model=model, messages=norm_msgs)
@@ -289,34 +314,9 @@ def stream_and_collect(client, model, messages, parts: list):
289
  continue
290
  parts.append(d)
291
  return
 
292
  res = client.chat.complete(model=model, messages=norm_msgs, stream=False)
293
- try:
294
- choices = getattr(res, "choices", None) or res.get("choices", [])
295
- except Exception:
296
- choices = []
297
- if choices:
298
- try:
299
- msg = choices[0].message
300
- if isinstance(msg, dict):
301
- content = msg.get("content")
302
- else:
303
- content = getattr(msg, "content", None)
304
- if content:
305
- if isinstance(content, str):
306
- parts.append(content)
307
- else:
308
- if isinstance(content, list):
309
- for c in content:
310
- if isinstance(c, dict) and c.get("type") == "text":
311
- parts.append(c.get("text", ""))
312
- elif isinstance(content, dict):
313
- text = content.get("text") or content.get("content")
314
- if text:
315
- parts.append(text)
316
- except Exception:
317
- parts.append(str(res))
318
- else:
319
- parts.append(str(res))
320
  except Exception as e:
321
  parts.append(f"[Model error: {e}]")
322
 
@@ -328,6 +328,7 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
328
  is_image = ext in IMAGE_EXTS or (not is_remote(src) and os.path.isfile(src) and ext in IMAGE_EXTS)
329
  parts = []
330
 
 
331
  if is_image:
332
  try:
333
  if is_remote(src):
@@ -342,6 +343,7 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
342
  stream_and_collect(client, DEFAULT_IMAGE_MODEL, msgs, parts)
343
  return "".join(parts).strip()
344
 
 
345
  if is_remote(src):
346
  try:
347
  media_bytes = fetch_bytes(src, timeout=120)
@@ -350,12 +352,21 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
350
  ext = ext_from_src(src) or ".mp4"
351
  tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
352
  try:
 
353
  try:
354
  file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src.split("?")[0]))
355
- except Exception as e:
 
 
 
 
 
 
 
 
356
  frame_bytes = extract_best_frame_bytes(tmp_media)
357
  if not frame_bytes:
358
- return f"Error uploading to Mistral and no frame fallback available: {e}"
359
  try:
360
  jpg = convert_to_jpeg_bytes(frame_bytes, base_h=480)
361
  except UnidentifiedImageError:
@@ -364,11 +375,6 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
364
  msgs = build_messages_for_image(prompt, b64_jpg=b64)
365
  stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
366
  return "".join(parts).strip()
367
-
368
- extra = f"Remote video uploaded to Mistral Files with id: {file_id}\n\nInstruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
369
- msgs = build_messages_for_text(prompt, extra)
370
- stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
371
- return "".join(parts).strip()
372
  finally:
373
  try:
374
  if tmp_media and os.path.exists(tmp_media):
@@ -376,6 +382,7 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
376
  except Exception:
377
  pass
378
 
 
379
  tmp_media = None
380
  try:
381
  media_bytes = fetch_bytes(src)
@@ -383,12 +390,17 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
383
  ext = ext or ".mp4"
384
  tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
385
  try:
 
386
  file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src))
387
- extra = f"Local video uploaded to Mistral Files with id: {file_id}\n\nInstruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
 
 
 
388
  msgs = build_messages_for_text(prompt, extra)
389
  stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
390
  return "".join(parts).strip()
391
  except Exception:
 
392
  frame_bytes = extract_best_frame_bytes(tmp_media)
393
  if not frame_bytes:
394
  return "Unable to process the provided file. Provide a direct image/frame URL or a remote video URL."
@@ -409,12 +421,14 @@ css = ".preview_media img, .preview_media video { max-width: 100%; height: auto;
409
 
410
 
411
  def load_preview(url: str):
 
412
  if not url:
413
  return None, None, ""
414
  try:
415
  r = requests.get(url, timeout=30, stream=True)
416
  r.raise_for_status()
417
  ctype = (r.headers.get("content-type") or "").lower()
 
418
  if (ctype and ctype.startswith("video/")) or any(url.lower().split("?")[0].endswith(ext) for ext in VIDEO_EXTS):
419
  return None, url, "Video"
420
  data = r.content
@@ -438,16 +452,4 @@ with gr.Blocks(title="Flux", css=css) as demo:
438
  with gr.Accordion("Mistral API Key (optional)", open=False):
439
  api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
440
  submit = gr.Button("Submit")
441
- preview_image = gr.Image(label="Preview", type="pil", elem_classes="preview_media")
442
- preview_video = gr.Video(label="Preview", elem_classes="preview_media")
443
-
444
- with gr.Column(scale=2):
445
- final_text = gr.Markdown(value="")
446
-
447
- # Ensure preview outputs get None when not applicable so Gradio hides them
448
- url_input.change(fn=load_preview, inputs=[url_input], outputs=[preview_image, preview_video, gr.Textbox(visible=False)])
449
- submit.click(fn=generate_final_text, inputs=[url_input, custom_prompt, api_key], outputs=[final_text])
450
- demo.queue()
451
-
452
- if __name__ == "__main__":
453
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
219
 
220
 
221
  def build_messages_for_image(prompt: str, b64_jpg: str = None, image_url: str = None):
222
+ # Keep user-visible prompt and include a single string content containing a data URL or image URL.
223
  if image_url:
224
+ content = f"{prompt}\n\nImage: {image_url}"
225
  elif b64_jpg:
226
+ # Use explicit data URL so Mistral chat content is a single string (fixes Pydantic validation)
227
+ content = f"{prompt}\n\nImage (base64): data:image/jpeg;base64,{b64_jpg}"
228
  else:
229
  raise ValueError("Either image_url or b64_jpg required")
230
+ return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": content}]
231
 
232
 
233
  def build_messages_for_text(prompt: str, extra_text: str):
234
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"{prompt}\n\n{extra_text}"}]
235
 
236
 
237
+ def extract_text_from_response(res, parts: list):
238
+ try:
239
+ choices = getattr(res, "choices", None) or res.get("choices", [])
240
+ except Exception:
241
+ choices = []
242
+ if choices:
243
+ try:
244
+ msg = choices[0].message
245
+ if isinstance(msg, dict):
246
+ content = msg.get("content")
247
+ else:
248
+ content = getattr(msg, "content", None)
249
+ if content:
250
+ if isinstance(content, str):
251
+ parts.append(content)
 
 
 
 
 
 
 
 
 
252
  else:
253
+ if isinstance(content, list):
254
+ for c in content:
255
+ if isinstance(c, dict) and c.get("type") == "text":
256
+ parts.append(c.get("text", ""))
257
+ elif isinstance(content, dict):
258
+ text = content.get("text") or content.get("content")
259
+ if text:
260
+ parts.append(text)
261
+ except Exception:
262
+ parts.append(str(res))
263
+ else:
264
+ parts.append(str(res))
265
 
266
 
267
  def stream_and_collect(client, model, messages, parts: list):
268
  try:
269
+ # We assume messages are plain strings already for image/text functions above.
270
+ # For safety, if content is a list/dict (old codepath), coerce minimally here:
271
+ norm_msgs = []
272
+ for m in messages:
273
+ if not isinstance(m, dict):
274
+ norm_msgs.append(m)
275
+ continue
276
+ c = m.get("content")
277
+ if isinstance(c, list):
278
+ # prefer image_url or image_base64 if present, else join text chunks
279
+ picked = []
280
+ for item in c:
281
+ if isinstance(item, dict):
282
+ if item.get("type") == "image_url" and item.get("image_url"):
283
+ picked.append(item["image_url"])
284
+ elif item.get("type") == "image_base64" and item.get("image_base64"):
285
+ picked.append("data:image/jpeg;base64," + item["image_base64"])
286
+ elif item.get("type") == "text" and item.get("text"):
287
+ picked.append(item["text"])
288
+ elif isinstance(item, str):
289
+ picked.append(item)
290
+ newc = "\n\n".join(p for p in picked if p).strip()
291
+ nm = m.copy()
292
+ nm["content"] = newc
293
+ norm_msgs.append(nm)
294
+ else:
295
+ if not isinstance(c, str):
296
+ nm = m.copy()
297
+ nm["content"] = str(c or "")
298
+ norm_msgs.append(nm)
299
+ else:
300
+ norm_msgs.append(m)
301
+
302
+ # Try streaming first (client may expose .chat.stream); fall back to non-streaming complete
303
  stream_gen = None
304
  try:
305
  stream_gen = client.chat.stream(model=model, messages=norm_msgs)
 
314
  continue
315
  parts.append(d)
316
  return
317
+
318
  res = client.chat.complete(model=model, messages=norm_msgs, stream=False)
319
+ extract_text_from_response(res, parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  except Exception as e:
321
  parts.append(f"[Model error: {e}]")
322
 
 
328
  is_image = ext in IMAGE_EXTS or (not is_remote(src) and os.path.isfile(src) and ext in IMAGE_EXTS)
329
  parts = []
330
 
331
+ # IMAGE path: keep exactly the previous behavior but send content as a single string data URL
332
  if is_image:
333
  try:
334
  if is_remote(src):
 
343
  stream_and_collect(client, DEFAULT_IMAGE_MODEL, msgs, parts)
344
  return "".join(parts).strip()
345
 
346
+ # REMOTE VIDEO path
347
  if is_remote(src):
348
  try:
349
  media_bytes = fetch_bytes(src, timeout=120)
 
352
  ext = ext_from_src(src) or ".mp4"
353
  tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
354
  try:
355
+ # Try uploading full video to Mistral files first
356
  try:
357
  file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src.split("?")[0]))
358
+ extra = (
359
+ f"Remote video uploaded to Mistral Files with id: {file_id}\n\n"
360
+ "Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
361
+ )
362
+ msgs = build_messages_for_text(prompt, extra)
363
+ stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
364
+ return "".join(parts).strip()
365
+ except Exception as e_upload:
366
+ # Fallback: extract a representative frame and analyze as image
367
  frame_bytes = extract_best_frame_bytes(tmp_media)
368
  if not frame_bytes:
369
+ return f"Error uploading to Mistral and no frame fallback available: {e_upload}"
370
  try:
371
  jpg = convert_to_jpeg_bytes(frame_bytes, base_h=480)
372
  except UnidentifiedImageError:
 
375
  msgs = build_messages_for_image(prompt, b64_jpg=b64)
376
  stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
377
  return "".join(parts).strip()
 
 
 
 
 
378
  finally:
379
  try:
380
  if tmp_media and os.path.exists(tmp_media):
 
382
  except Exception:
383
  pass
384
 
385
+ # LOCAL VIDEO path
386
  tmp_media = None
387
  try:
388
  media_bytes = fetch_bytes(src)
 
390
  ext = ext or ".mp4"
391
  tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
392
  try:
393
+ # Upload local video file to Mistral files and analyze by file id
394
  file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src))
395
+ extra = (
396
+ f"Local video uploaded to Mistral Files with id: {file_id}\n\n"
397
+ "Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
398
+ )
399
  msgs = build_messages_for_text(prompt, extra)
400
  stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
401
  return "".join(parts).strip()
402
  except Exception:
403
+ # fallback to frame extraction + image analysis
404
  frame_bytes = extract_best_frame_bytes(tmp_media)
405
  if not frame_bytes:
406
  return "Unable to process the provided file. Provide a direct image/frame URL or a remote video URL."
 
421
 
422
 
423
  def load_preview(url: str):
424
+ # Return (image_or_None, video_or_None, status_text)
425
  if not url:
426
  return None, None, ""
427
  try:
428
  r = requests.get(url, timeout=30, stream=True)
429
  r.raise_for_status()
430
  ctype = (r.headers.get("content-type") or "").lower()
431
+ # If content-type indicates video or URL ends with a video ext, return video preview
432
  if (ctype and ctype.startswith("video/")) or any(url.lower().split("?")[0].endswith(ext) for ext in VIDEO_EXTS):
433
  return None, url, "Video"
434
  data = r.content
 
452
  with gr.Accordion("Mistral API Key (optional)", open=False):
453
  api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
454
  submit = gr.Button("Submit")
455
+ preview_image = gr.Image(label="Preview image", type="pil", visible