Hug0endob commited on
Commit
b5d1da5
·
verified ·
1 Parent(s): aab85a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -1
app.py CHANGED
@@ -30,7 +30,7 @@ except Exception: # pragma: no cover
30
  DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
31
  PIXTRAL_MODEL = "pixtral-12b-2409"
32
  VIDEO_MODEL = "voxtral-mini-latest"
33
- STREAM_THRESHOLD = 20 * 1024 * 1024 # 20 MiB
34
  FFMPEG_BIN = shutil.which("ffmpeg")
35
  IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
36
  VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
@@ -56,6 +56,7 @@ def get_client(key: str | None = None):
56
  if Mistral is None:
57
  class Dummy:
58
  def __init__(self, k): self.api_key = k
 
59
  return Dummy(api_key)
60
  return Mistral(api_key=api_key)
61
 
@@ -262,3 +263,251 @@ def upload_file_to_mistral(
262
  fid = res["data"][0]["id"]
263
  return fid
264
  except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
31
  PIXTRAL_MODEL = "pixtral-12b-2409"
32
  VIDEO_MODEL = "voxtral-mini-latest"
33
+ STREAM_THRESHOLD = 20 * 1024 * 1024 # 20 MiB
34
  FFMPEG_BIN = shutil.which("ffmpeg")
35
  IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
36
  VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
 
56
  if Mistral is None:
57
  class Dummy:
58
  def __init__(self, k): self.api_key = k
59
+
60
  return Dummy(api_key)
61
  return Mistral(api_key=api_key)
62
 
 
263
  fid = res["data"][0]["id"]
264
  return fid
265
  except Exception:
266
+ pass
267
+ # Raw‑HTTP fallback ---------------------------------------
268
+ api_key = getattr(client, "api_key", "") or DEFAULT_KEY
269
+ url = "https://api.mistral.ai/v1/files"
270
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
271
+ with open(path, "rb") as fh:
272
+ files = {"file": (fname, fh)}
273
+ data = {"purpose": purpose}
274
+ r = requests.post(url, headers=headers, files=files, data=data, timeout=timeout)
275
+ r.raise_for_status()
276
+ jr = r.json()
277
+ return jr.get("id") or jr.get("data", [{}])[0].get("id")
278
+
279
+
280
+ def analyze_image_structured(client, img_bytes: bytes, prompt: str) -> str:
281
+ """Resize, encode, and send an image to Pixtral."""
282
+ jpeg = convert_to_jpeg_bytes(img_bytes, base_h=1024)
283
+ data_url = b64_bytes(jpeg, mime="image/jpeg")
284
+ messages = [
285
+ {"role": "system", "content": SYSTEM_INSTRUCTION},
286
+ {
287
+ "role": "user",
288
+ "content": [
289
+ {"type": "text", "text": prompt},
290
+ {"type": "image_url", "image_url": data_url},
291
+ ],
292
+ },
293
+ ]
294
+ return chat_complete(client, PIXTRAL_MODEL, messages)
295
+
296
+
297
+ def analyze_video_cohesive(client, video_path: str, prompt: str) -> str:
298
+ """Upload video; if that fails, fall back to frame extraction."""
299
+ try:
300
+ file_id = upload_file_to_mistral(client, video_path, filename=os.path.basename(video_path))
301
+ extra_msg = (
302
+ f"Uploaded video file id: {file_id}\n\n"
303
+ "Instruction: Analyze the entire video and produce a single cohesive narrative describing consistent observations."
304
+ )
305
+ messages = [
306
+ {"role": "system", "content": SYSTEM_INSTRUCTION},
307
+ {"role": "user", "content": extra_msg + "\n\n" + prompt},
308
+ ]
309
+ return chat_complete(client, VIDEO_MODEL, messages)
310
+ except Exception:
311
+ # Fallback: extract a few representative frames
312
+ frames = extract_best_frames_bytes(video_path, sample_count=6)
313
+ if not frames:
314
+ return "Error: could not upload video and no frames could be extracted."
315
+ image_entries = []
316
+ for i, fb in enumerate(frames, start=1):
317
+ try:
318
+ j = convert_to_jpeg_bytes(fb, base_h=720)
319
+ image_entries.append(
320
+ {
321
+ "type": "image_url",
322
+ "image_url": b64_bytes(j, mime="image/jpeg"),
323
+ "meta": {"frame_index": i},
324
+ }
325
+ )
326
+ except Exception:
327
+ continue
328
+ content = [
329
+ {"type": "text", "text": prompt + "\n\nPlease consolidate observations across these frames into a single cohesive narrative."}
330
+ ] + image_entries
331
+ messages = [
332
+ {"role": "system", "content": SYSTEM_INSTRUCTION},
333
+ {"role": "user", "content": content},
334
+ ]
335
+ return chat_complete(client, PIXTRAL_MODEL, messages)
336
+
337
+
338
+ def determine_media_type(src: str) -> Tuple[bool, bool]:
339
+ """Return (is_image, is_video)."""
340
+ is_image = False
341
+ is_video = False
342
+ ext = ext_from_src(src)
343
+ if ext in IMAGE_EXTS:
344
+ is_image = True
345
+ if ext in VIDEO_EXTS:
346
+ is_video = True
347
+ if is_remote(src):
348
+ head = safe_head(src)
349
+ if head:
350
+ ctype = (head.headers.get("content-type") or "").lower()
351
+ if ctype.startswith("image/"):
352
+ is_image, is_video = True, False
353
+ elif ctype.startswith("video/"):
354
+ is_video, is_image = True, False
355
+ return is_image, is_video
356
+
357
+
358
+ def process_media(src: str, custom_prompt: str, api_key: str, progress=gr.Progress()) -> str:
359
+ client = get_client(api_key)
360
+ prompt = (custom_prompt or "").strip() or "Please provide a detailed visual review."
361
+ if not src:
362
+ return "No URL or path provided."
363
+ progress(0.05, desc="Determining media type")
364
+ is_image, is_video = determine_media_type(src)
365
+
366
+ if is_image:
367
+ try:
368
+ raw = fetch_bytes(src)
369
+ except Exception as e:
370
+ return f"Error fetching image: {e}"
371
+ progress(0.2, desc="Analyzing image")
372
+ try:
373
+ return analyze_image_structured(client, raw, prompt)
374
+ except UnidentifiedImageError:
375
+ return "Error: provided file is not a valid image."
376
+ except Exception as e:
377
+ return f"Error analyzing image: {e}"
378
+
379
+ if is_video:
380
+ try:
381
+ raw = fetch_bytes(src, timeout=120)
382
+ except Exception as e:
383
+ return f"Error fetching video: {e}"
384
+ tmp_path = save_bytes_to_temp(raw, suffix=ext_from_src(src) or ".mp4")
385
+ try:
386
+ progress(0.2, desc="Analyzing video")
387
+ return analyze_video_cohesive(client, tmp_path, prompt)
388
+ finally:
389
+ try:
390
+ os.remove(tmp_path)
391
+ except Exception:
392
+ pass
393
+
394
+ # Fallback: treat as image
395
+ try:
396
+ raw = fetch_bytes(src)
397
+ progress(0.2, desc="Treating as image")
398
+ return analyze_image_structured(client, raw, prompt)
399
+ except Exception as e:
400
+ return f"Unable to determine media type or fetch file: {e}"
401
+
402
+
403
+ # ----------------------------------------------------------------------
404
+ # Gradio UI helpers
405
+ # ----------------------------------------------------------------------
406
+ css = ".preview_media img, .preview_media video { max-width: 100%; height: auto; }"
407
+
408
+
409
+ def load_preview(url: str):
410
+ """Return (image_component, video_component) updates."""
411
+ empty_img = gr.update(value=None, visible=False)
412
+ empty_vid = gr.update(value=None, visible=False)
413
+
414
+ if not url:
415
+ return empty_img, empty_vid
416
+
417
+ # Local file handling
418
+ if not is_remote(url) and os.path.exists(url):
419
+ ext = ext_from_src(url)
420
+ if ext in VIDEO_EXTS:
421
+ return empty_img, gr.update(value=os.path.abspath(url), visible=True)
422
+ if ext in IMAGE_EXTS:
423
+ try:
424
+ img = Image.open(url)
425
+ if getattr(img, "is_animated", False):
426
+ img.seek(0)
427
+ return gr.update(value=img.convert("RGB"), visible=True), empty_vid
428
+ except Exception:
429
+ return empty_img, empty_vid
430
+
431
+ # Remote handling – try to infer from headers
432
+ head = safe_head(url)
433
+ if head:
434
+ ctype = (head.headers.get("content-type") or "").lower()
435
+ if ctype.startswith("video/") or any(url.lower().endswith(ext) for ext in VIDEO_EXTS):
436
+ return empty_img, gr.update(value=url, visible=True)
437
+
438
+ # Try to load as image
439
+ try:
440
+ r = safe_get(url, timeout=15)
441
+ img = Image.open(BytesIO(r.content))
442
+ if getattr(img, "is_animated", False):
443
+ img.seek(0)
444
+ return gr.update(value=img.convert("RGB"), visible=True), empty_vid
445
+ except Exception:
446
+ return empty_img, empty_vid
447
+
448
+
449
+ def _btn_label_for_status(status: str) -> str:
450
+ return {
451
+ "idle": "Submit",
452
+ "busy": "Processing…",
453
+ "done": "Submit",
454
+ "error": "Retry",
455
+ }.get(status or "idle", "Submit")
456
+
457
+
458
+ # ----------------------------------------------------------------------
459
+ # Build Gradio demo
460
+ # ----------------------------------------------------------------------
461
+ def create_demo():
462
+ with gr.Blocks(title="Flux Multimodal (Pixtral / Voxtral)", css=css) as demo:
463
+ with gr.Row():
464
+ with gr.Column(scale=1):
465
+ url_input = gr.Textbox(
466
+ label="Image / Video URL or local path",
467
+ placeholder="https://... or /path/to/file",
468
+ lines=1,
469
+ )
470
+ custom_prompt = gr.Textbox(label="Prompt (optional)", lines=2, value="")
471
+ with gr.Accordion("Mistral API Key (optional)", open=False):
472
+ api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
473
+ submit_btn = gr.Button("Submit")
474
+ clear_btn = gr.Button("Clear")
475
+ preview_image = gr.Image(
476
+ label="Preview Image",
477
+ type="pil",
478
+ elem_classes="preview_media",
479
+ visible=False,
480
+ )
481
+ preview_video = gr.Video(
482
+ label="Preview Video",
483
+ elem_classes="preview_media",
484
+ visible=False,
485
+ )
486
+ with gr.Column(scale=2):
487
+ final_md = gr.Markdown(value="")
488
+
489
+ # Live preview
490
+ url_input.change(fn=load_preview, inputs=[url_input], outputs=[preview_image, preview_video])
491
+
492
+ # Clear button
493
+ clear_btn.click(
494
+ fn=lambda: (
495
+ "", # clear textbox
496
+ gr.update(value=None, visible=False), # hide image
497
+ gr.update(value=None, visible=False), # hide video
498
+ ),
499
+ inputs=[],
500
+ outputs=[url_input, preview_image, preview_video],
501
+ )
502
+
503
+ # State to track button status
504
+ status = gr.State("idle")
505
+
506
+ def start_busy() -> str:
507
+ return "busy"
508
+
509
+ def worker(url: str, prompt: str, key: str, progress=gr.Progress()):
510
+ return process_media(url or "", prompt or "", key or "", progress=progress)
511
+
512
+ def finish(result: str) -> tuple[str, str]:
513
+ if not result or result.lower().startswith(("error", "unhandled"))