Hug0endob commited on
Commit
b140dcf
·
verified ·
1 Parent(s): cf25146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -135
app.py CHANGED
@@ -3,16 +3,17 @@ import os
3
  import subprocess
4
  import tempfile
5
  import shutil
6
- from io import BytesIO
7
  import base64
8
  import requests
9
- from PIL import Image, UnidentifiedImageError, ImageFile
 
10
  import gradio as gr
11
  from mistralai import Mistral
12
 
 
13
  DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
14
- DEFAULT_IMAGE_MODEL = "pixtral-12b-2409"
15
- DEFAULT_VIDEO_MODEL = "voxtral-mini-latest"
16
  STREAM_THRESHOLD = 20 * 1024 * 1024
17
  FFMPEG_BIN = shutil.which("ffmpeg")
18
 
@@ -23,27 +24,23 @@ SYSTEM_INSTRUCTION = (
23
  "Do not invent sensory information not present in the media."
24
  )
25
 
26
- IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
27
- VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
28
-
29
  ImageFile.LOAD_TRUNCATED_IMAGES = True
30
  Image.MAX_IMAGE_PIXELS = 10000 * 10000
31
 
 
 
32
 
33
  def get_client(key: str = None):
34
  api_key = (key or "").strip() or DEFAULT_KEY
35
  return Mistral(api_key=api_key)
36
 
37
-
38
  def is_remote(src: str) -> bool:
39
  return bool(src) and src.startswith(("http://", "https://"))
40
 
41
-
42
  def ext_from_src(src: str) -> str:
43
  _, ext = os.path.splitext((src or "").split("?")[0])
44
  return ext.lower()
45
 
46
-
47
  def fetch_bytes(src: str, stream_threshold=STREAM_THRESHOLD, timeout=60) -> bytes:
48
  if is_remote(src):
49
  with requests.get(src, timeout=timeout, stream=True) as r:
@@ -58,20 +55,23 @@ def fetch_bytes(src: str, stream_threshold=STREAM_THRESHOLD, timeout=60) -> byte
58
  if chunk:
59
  f.write(chunk)
60
  with open(path, "rb") as f:
61
- data = f.read()
62
  finally:
63
- try:
64
- os.remove(path)
65
- except Exception:
66
- pass
67
- return data
68
  return r.content
69
  with open(src, "rb") as f:
70
  return f.read()
71
 
 
 
 
 
 
 
72
 
73
- def convert_to_jpeg_bytes(media_bytes: bytes, base_h=480) -> bytes:
74
- img = Image.open(BytesIO(media_bytes))
75
  try:
76
  if getattr(img, "is_animated", False):
77
  img.seek(0)
@@ -86,75 +86,54 @@ def convert_to_jpeg_bytes(media_bytes: bytes, base_h=480) -> bytes:
86
  img.save(buf, format="JPEG", quality=85)
87
  return buf.getvalue()
88
 
89
-
90
  def b64_jpeg(img_bytes: bytes) -> str:
91
  return base64.b64encode(img_bytes).decode("utf-8")
92
 
93
-
94
- def save_bytes_to_temp(b: bytes, suffix: str):
95
- fd, path = tempfile.mkstemp(suffix=suffix)
96
- os.close(fd)
97
- with open(path, "wb") as f:
98
- f.write(b)
99
- return path
100
-
101
-
102
  def extract_best_frames_bytes(media_path: str, sample_count: int = 5, timeout_probe: int = 10, timeout_extract: int = 15):
103
  if not FFMPEG_BIN or not os.path.exists(media_path):
104
  return []
105
- tmp_frames = []
 
 
 
 
 
 
 
 
 
106
  try:
107
- probe_cmd = [FFMPEG_BIN, "-v", "error", "-show_entries", "format=duration",
108
- "-of", "default=noprint_wrappers=1:nokey=1", media_path]
109
- proc = subprocess.Popen(probe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  try:
111
- out, _ = proc.communicate(timeout=timeout_probe)
112
  except subprocess.TimeoutExpired:
113
- proc.kill()
114
- out, _ = proc.communicate()
115
- duration = None
 
 
 
116
  try:
117
- duration = float(out.strip().split(b"\n")[0]) if out else None
118
  except Exception:
119
- duration = None
120
-
121
- if duration and duration > 0:
122
- timestamps = [(duration * i) / (sample_count + 1) for i in range(1, sample_count + 1)]
123
- else:
124
- timestamps = [0.5, 1.0, 2.0][:sample_count]
125
-
126
- for i, t in enumerate(timestamps):
127
- fd, tmp_frame = tempfile.mkstemp(suffix=f"_{i}.jpg")
128
- os.close(fd)
129
- cmd = [
130
- FFMPEG_BIN, "-nostdin", "-y", "-i", media_path,
131
- "-ss", str(t),
132
- "-frames:v", "1",
133
- "-q:v", "2",
134
- tmp_frame
135
- ]
136
- proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
137
- try:
138
- proc.communicate(timeout=timeout_extract)
139
- except subprocess.TimeoutExpired:
140
- try:
141
- proc.kill()
142
- except Exception:
143
- pass
144
- proc.communicate()
145
- if proc.returncode == 0 and os.path.exists(tmp_frame) and os.path.getsize(tmp_frame) > 0:
146
- with open(tmp_frame, "rb") as f:
147
- tmp_frames.append(f.read())
148
- try:
149
- if os.path.exists(tmp_frame):
150
- os.remove(tmp_frame)
151
- except Exception:
152
- pass
153
-
154
- return tmp_frames
155
- finally:
156
- pass
157
-
158
 
159
  def upload_file_to_mistral(client, path, filename=None, purpose="batch"):
160
  fname = filename or os.path.basename(path)
@@ -185,7 +164,6 @@ def upload_file_to_mistral(client, path, filename=None, purpose="batch"):
185
  raise RuntimeError(f"Upload failed to return id: {jr}")
186
  return fid
187
 
188
-
189
  def build_messages_for_image(prompt: str, b64_jpg: str = None, image_url: str = None):
190
  if image_url:
191
  content = f"{prompt}\n\nImage: {image_url}"
@@ -195,11 +173,9 @@ def build_messages_for_image(prompt: str, b64_jpg: str = None, image_url: str =
195
  raise ValueError("Either image_url or b64_jpg required")
196
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": content}]
197
 
198
-
199
  def build_messages_for_text(prompt: str, extra_text: str):
200
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"{prompt}\n\n{extra_text}"}]
201
 
202
-
203
  def extract_delta(chunk):
204
  if not chunk:
205
  return None
@@ -229,7 +205,6 @@ def extract_delta(chunk):
229
  except Exception:
230
  return None
231
 
232
-
233
  def extract_text_from_response(res, parts: list):
234
  try:
235
  choices = getattr(res, "choices", None) or res.get("choices", [])
@@ -259,7 +234,6 @@ def extract_text_from_response(res, parts: list):
259
  else:
260
  parts.append(str(res))
261
 
262
-
263
  def stream_and_collect(client, model, messages, parts: list):
264
  norm_msgs = []
265
  for m in messages:
@@ -309,8 +283,7 @@ def stream_and_collect(client, model, messages, parts: list):
309
  res = client.chat.complete(model=model, messages=norm_msgs, stream=False)
310
  extract_text_from_response(res, parts)
311
 
312
-
313
- def analyze_image_bytes(client, img_bytes: bytes, prompt: str, model=DEFAULT_IMAGE_MODEL):
314
  jpg = convert_to_jpeg_bytes(img_bytes, base_h=480)
315
  b64 = b64_jpeg(jpg)
316
  msgs = build_messages_for_image(prompt, b64_jpg=b64)
@@ -318,17 +291,14 @@ def analyze_image_bytes(client, img_bytes: bytes, prompt: str, model=DEFAULT_IMA
318
  stream_and_collect(client, model, msgs, parts)
319
  return "".join(parts).strip()
320
 
321
-
322
- def analyze_multiple_frames(client, frames_bytes_list, prompt: str, model=DEFAULT_IMAGE_MODEL):
323
  results = []
324
  for i, fb in enumerate(frames_bytes_list):
325
  res = analyze_image_bytes(client, fb, f"{prompt}\n\nFrame index: {i+1}", model=model)
326
  results.append((i, res))
327
-
328
  merged = []
329
  for i, text in results:
330
  merged.append(f"Frame {i+1} analysis:\n{text}")
331
-
332
  consolidation_prompt = (
333
  prompt
334
  + "\n\nConsolidate the key consistent observations across the provided frame analyses below. "
@@ -337,53 +307,67 @@ def analyze_multiple_frames(client, frames_bytes_list, prompt: str, model=DEFAUL
337
  )
338
  parts = []
339
  msgs = build_messages_for_text(consolidation_prompt, "")
340
- stream_and_collect(client, DEFAULT_IMAGE_MODEL, msgs, parts)
341
  consolidated = "".join(parts).strip()
342
  if consolidated:
343
  merged.append("Consolidated summary:\n" + consolidated)
344
  return "\n\n".join(merged)
345
 
346
-
347
  def generate_final_text(src: str, custom_prompt: str, api_key: str):
348
  client = get_client(api_key)
349
  prompt = (custom_prompt.strip() if custom_prompt and custom_prompt.strip() else "Please provide a detailed visual review.")
 
350
  ext = ext_from_src(src)
351
  is_image = ext in IMAGE_EXTS or (not is_remote(src) and os.path.isfile(src) and ext in IMAGE_EXTS)
352
- parts = []
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  if is_image:
355
  try:
356
- if is_remote(src):
357
- raw = fetch_bytes(src)
358
- return analyze_image_bytes(client, raw, prompt, model=DEFAULT_IMAGE_MODEL)
359
- else:
360
- raw = fetch_bytes(src)
361
- return analyze_image_bytes(client, raw, prompt, model=DEFAULT_IMAGE_MODEL)
362
  except UnidentifiedImageError:
363
  return "Error: provided file is not a valid image."
364
  except Exception as e:
365
  return f"Error processing image: {e}"
366
 
367
- if is_remote(src):
368
  tmp_media = None
369
  try:
370
- media_bytes = fetch_bytes(src, timeout=120)
 
 
 
371
  ext = ext_from_src(src) or ".mp4"
372
  tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
373
  try:
374
  file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src.split("?")[0]))
375
  extra = (
376
- f"Remote video uploaded to Mistral Files with id: {file_id}\n\n"
377
  "Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
378
  )
379
  msgs = build_messages_for_text(prompt, extra)
380
- stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
381
  return "".join(parts).strip()
382
  except Exception:
383
  frames = extract_best_frames_bytes(tmp_media, sample_count=5)
384
  if not frames:
385
  return "Error: could not upload remote video and no frames extracted."
386
- return analyze_multiple_frames(client, frames, prompt, model=DEFAULT_IMAGE_MODEL)
387
  finally:
388
  try:
389
  if tmp_media and os.path.exists(tmp_media):
@@ -391,36 +375,10 @@ def generate_final_text(src: str, custom_prompt: str, api_key: str):
391
  except Exception:
392
  pass
393
 
394
- tmp_media = None
395
- try:
396
- media_bytes = fetch_bytes(src)
397
- _, ext = os.path.splitext(src) if src else ("", ".mp4")
398
- ext = ext or ".mp4"
399
- tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
400
- try:
401
- file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src))
402
- extra = (
403
- f"Local video uploaded to Mistral Files with id: {file_id}\n\n"
404
- "Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
405
- )
406
- msgs = build_messages_for_text(prompt, extra)
407
- stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
408
- return "".join(parts).strip()
409
- except Exception:
410
- frames = extract_best_frames_bytes(tmp_media, sample_count=5)
411
- if not frames:
412
- return "Unable to process the provided file. Provide a direct image/frame URL or a remote video URL."
413
- return analyze_multiple_frames(client, frames, prompt, model=DEFAULT_IMAGE_MODEL)
414
- finally:
415
- try:
416
- if tmp_media and os.path.exists(tmp_media):
417
- os.remove(tmp_media)
418
- except Exception:
419
- pass
420
-
421
 
 
422
  css = ".preview_media img, .preview_media video { max-width: 100%; height: auto; }"
423
-
424
  def load_preview(url: str):
425
  if not url:
426
  return None, None, ""
@@ -442,24 +400,28 @@ def load_preview(url: str):
442
  except Exception:
443
  return None, None, "Preview failed"
444
 
445
- with gr.Blocks(title="Flux", css=css) as demo:
 
446
  with gr.Row():
447
  with gr.Column(scale=1):
448
- url_input = gr.Textbox(label="Image or Video URL", placeholder="https://...", lines=1)
449
  custom_prompt = gr.Textbox(label="Prompt (optional)", lines=2, value="")
450
  with gr.Accordion("Mistral API Key (optional)", open=False):
451
  api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
452
  submit = gr.Button("Submit")
453
- preview_image = gr.Image(label="Preview", type="pil", elem_classes="preview_media", visible=False)
454
- preview_video = gr.Video(label="Preview", elem_classes="preview_media", visible=False)
455
 
456
  with gr.Column(scale=2):
457
  final_text = gr.Markdown(value="")
458
 
459
- url_input.change(fn=load_preview, inputs=[url_input], outputs=[preview_image, preview_video, gr.Textbox(visible=False)])
 
 
 
 
460
  submit.click(fn=generate_final_text, inputs=[url_input, custom_prompt, api_key], outputs=[final_text])
461
  demo.queue()
462
 
463
  if __name__ == "__main__":
464
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
465
-
 
3
  import subprocess
4
  import tempfile
5
  import shutil
 
6
  import base64
7
  import requests
8
+ from io import BytesIO
9
+ from PIL import Image, ImageFile, UnidentifiedImageError
10
  import gradio as gr
11
  from mistralai import Mistral
12
 
13
+ # Config
14
  DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
15
+ PIXTRAL_MODEL = "pixtral-12b-2409" # image-capable multimodal model
16
+ VIDEO_MODEL = "voxtral-mini-latest" # replace with your preferred video model
17
  STREAM_THRESHOLD = 20 * 1024 * 1024
18
  FFMPEG_BIN = shutil.which("ffmpeg")
19
 
 
24
  "Do not invent sensory information not present in the media."
25
  )
26
 
 
 
 
27
  ImageFile.LOAD_TRUNCATED_IMAGES = True
28
  Image.MAX_IMAGE_PIXELS = 10000 * 10000
29
 
30
+ IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
31
+ VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
32
 
33
  def get_client(key: str = None):
34
  api_key = (key or "").strip() or DEFAULT_KEY
35
  return Mistral(api_key=api_key)
36
 
 
37
  def is_remote(src: str) -> bool:
38
  return bool(src) and src.startswith(("http://", "https://"))
39
 
 
40
  def ext_from_src(src: str) -> str:
41
  _, ext = os.path.splitext((src or "").split("?")[0])
42
  return ext.lower()
43
 
 
44
  def fetch_bytes(src: str, stream_threshold=STREAM_THRESHOLD, timeout=60) -> bytes:
45
  if is_remote(src):
46
  with requests.get(src, timeout=timeout, stream=True) as r:
 
55
  if chunk:
56
  f.write(chunk)
57
  with open(path, "rb") as f:
58
+ return f.read()
59
  finally:
60
+ try: os.remove(path)
61
+ except Exception: pass
 
 
 
62
  return r.content
63
  with open(src, "rb") as f:
64
  return f.read()
65
 
66
+ def save_bytes_to_temp(b: bytes, suffix: str):
67
+ fd, path = tempfile.mkstemp(suffix=suffix)
68
+ os.close(fd)
69
+ with open(path, "wb") as f:
70
+ f.write(b)
71
+ return path
72
 
73
+ def convert_to_jpeg_bytes(img_bytes: bytes, base_h=480) -> bytes:
74
+ img = Image.open(BytesIO(img_bytes))
75
  try:
76
  if getattr(img, "is_animated", False):
77
  img.seek(0)
 
86
  img.save(buf, format="JPEG", quality=85)
87
  return buf.getvalue()
88
 
 
89
  def b64_jpeg(img_bytes: bytes) -> str:
90
  return base64.b64encode(img_bytes).decode("utf-8")
91
 
 
 
 
 
 
 
 
 
 
92
  def extract_best_frames_bytes(media_path: str, sample_count: int = 5, timeout_probe: int = 10, timeout_extract: int = 15):
93
  if not FFMPEG_BIN or not os.path.exists(media_path):
94
  return []
95
+ frames = []
96
+ probe_cmd = [FFMPEG_BIN, "-v", "error", "-show_entries", "format=duration",
97
+ "-of", "default=noprint_wrappers=1:nokey=1", media_path]
98
+ proc = subprocess.Popen(probe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
99
+ try:
100
+ out, _ = proc.communicate(timeout=timeout_probe)
101
+ except subprocess.TimeoutExpired:
102
+ proc.kill()
103
+ out, _ = proc.communicate()
104
+ duration = None
105
  try:
106
+ duration = float(out.strip().split(b"\n")[0]) if out else None
107
+ except Exception:
108
+ duration = None
109
+
110
+ if duration and duration > 0:
111
+ timestamps = [(duration * i) / (sample_count + 1) for i in range(1, sample_count + 1)]
112
+ else:
113
+ timestamps = [0.5, 1.0, 2.0][:sample_count]
114
+
115
+ for i, t in enumerate(timestamps):
116
+ fd, tmp_frame = tempfile.mkstemp(suffix=f"_{i}.jpg")
117
+ os.close(fd)
118
+ cmd = [
119
+ FFMPEG_BIN, "-nostdin", "-y", "-i", media_path,
120
+ "-ss", str(t), "-frames:v", "1", "-q:v", "2", tmp_frame
121
+ ]
122
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
123
  try:
124
+ proc.communicate(timeout=timeout_extract)
125
  except subprocess.TimeoutExpired:
126
+ try: proc.kill()
127
+ except Exception: pass
128
+ proc.communicate()
129
+ if proc.returncode == 0 and os.path.exists(tmp_frame) and os.path.getsize(tmp_frame) > 0:
130
+ with open(tmp_frame, "rb") as f:
131
+ frames.append(f.read())
132
  try:
133
+ if os.path.exists(tmp_frame): os.remove(tmp_frame)
134
  except Exception:
135
+ pass
136
+ return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def upload_file_to_mistral(client, path, filename=None, purpose="batch"):
139
  fname = filename or os.path.basename(path)
 
164
  raise RuntimeError(f"Upload failed to return id: {jr}")
165
  return fid
166
 
 
167
  def build_messages_for_image(prompt: str, b64_jpg: str = None, image_url: str = None):
168
  if image_url:
169
  content = f"{prompt}\n\nImage: {image_url}"
 
173
  raise ValueError("Either image_url or b64_jpg required")
174
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": content}]
175
 
 
176
  def build_messages_for_text(prompt: str, extra_text: str):
177
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"{prompt}\n\n{extra_text}"}]
178
 
 
179
  def extract_delta(chunk):
180
  if not chunk:
181
  return None
 
205
  except Exception:
206
  return None
207
 
 
208
  def extract_text_from_response(res, parts: list):
209
  try:
210
  choices = getattr(res, "choices", None) or res.get("choices", [])
 
234
  else:
235
  parts.append(str(res))
236
 
 
237
  def stream_and_collect(client, model, messages, parts: list):
238
  norm_msgs = []
239
  for m in messages:
 
283
  res = client.chat.complete(model=model, messages=norm_msgs, stream=False)
284
  extract_text_from_response(res, parts)
285
 
286
+ def analyze_image_bytes(client, img_bytes: bytes, prompt: str, model=PIXTRAL_MODEL):
 
287
  jpg = convert_to_jpeg_bytes(img_bytes, base_h=480)
288
  b64 = b64_jpeg(jpg)
289
  msgs = build_messages_for_image(prompt, b64_jpg=b64)
 
291
  stream_and_collect(client, model, msgs, parts)
292
  return "".join(parts).strip()
293
 
294
+ def analyze_multiple_frames(client, frames_bytes_list, prompt: str, model=PIXTRAL_MODEL):
 
295
  results = []
296
  for i, fb in enumerate(frames_bytes_list):
297
  res = analyze_image_bytes(client, fb, f"{prompt}\n\nFrame index: {i+1}", model=model)
298
  results.append((i, res))
 
299
  merged = []
300
  for i, text in results:
301
  merged.append(f"Frame {i+1} analysis:\n{text}")
 
302
  consolidation_prompt = (
303
  prompt
304
  + "\n\nConsolidate the key consistent observations across the provided frame analyses below. "
 
307
  )
308
  parts = []
309
  msgs = build_messages_for_text(consolidation_prompt, "")
310
+ stream_and_collect(client, PIXTRAL_MODEL, msgs, parts)
311
  consolidated = "".join(parts).strip()
312
  if consolidated:
313
  merged.append("Consolidated summary:\n" + consolidated)
314
  return "\n\n".join(merged)
315
 
 
316
  def generate_final_text(src: str, custom_prompt: str, api_key: str):
317
  client = get_client(api_key)
318
  prompt = (custom_prompt.strip() if custom_prompt and custom_prompt.strip() else "Please provide a detailed visual review.")
319
+ parts = []
320
  ext = ext_from_src(src)
321
  is_image = ext in IMAGE_EXTS or (not is_remote(src) and os.path.isfile(src) and ext in IMAGE_EXTS)
322
+ is_video = ext in VIDEO_EXTS or (not is_remote(src) and os.path.isfile(src) and ext in VIDEO_EXTS)
323
+
324
+ # If remote and content-type suggests video, treat as video
325
+ if is_remote(src):
326
+ try:
327
+ r = requests.head(src, timeout=10, allow_redirects=True)
328
+ ctype = (r.headers.get("content-type") or "").lower()
329
+ if ctype.startswith("video/"):
330
+ is_video = True
331
+ elif ctype.startswith("image/"):
332
+ is_image = True
333
+ except Exception:
334
+ pass
335
 
336
  if is_image:
337
  try:
338
+ raw = fetch_bytes(src)
339
+ except Exception as e:
340
+ return f"Error fetching image: {e}"
341
+ try:
342
+ return analyze_image_bytes(client, raw, prompt, model=PIXTRAL_MODEL)
 
343
  except UnidentifiedImageError:
344
  return "Error: provided file is not a valid image."
345
  except Exception as e:
346
  return f"Error processing image: {e}"
347
 
348
+ if is_video:
349
  tmp_media = None
350
  try:
351
+ try:
352
+ media_bytes = fetch_bytes(src, timeout=120)
353
+ except Exception as e:
354
+ return f"Error fetching video: {e}"
355
  ext = ext_from_src(src) or ".mp4"
356
  tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
357
  try:
358
  file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src.split("?")[0]))
359
  extra = (
360
+ f"Uploaded video to Mistral Files with id: {file_id}\n\n"
361
  "Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
362
  )
363
  msgs = build_messages_for_text(prompt, extra)
364
+ stream_and_collect(client, VIDEO_MODEL, msgs, parts)
365
  return "".join(parts).strip()
366
  except Exception:
367
  frames = extract_best_frames_bytes(tmp_media, sample_count=5)
368
  if not frames:
369
  return "Error: could not upload remote video and no frames extracted."
370
+ return analyze_multiple_frames(client, frames, prompt, model=PIXTRAL_MODEL)
371
  finally:
372
  try:
373
  if tmp_media and os.path.exists(tmp_media):
 
375
  except Exception:
376
  pass
377
 
378
+ return "Unable to determine media type from the provided URL or file extension."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ # UI helpers
381
  css = ".preview_media img, .preview_media video { max-width: 100%; height: auto; }"
 
382
  def load_preview(url: str):
383
  if not url:
384
  return None, None, ""
 
400
  except Exception:
401
  return None, None, "Preview failed"
402
 
403
+ # Gradio app
404
+ with gr.Blocks(title="Flux Multimodal", css=css) as demo:
405
  with gr.Row():
406
  with gr.Column(scale=1):
407
+ url_input = gr.Textbox(label="Image or Video URL or local path", placeholder="https://... or /path/to/file", lines=1)
408
  custom_prompt = gr.Textbox(label="Prompt (optional)", lines=2, value="")
409
  with gr.Accordion("Mistral API Key (optional)", open=False):
410
  api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
411
  submit = gr.Button("Submit")
412
+ preview_image = gr.Image(label="Preview Image", type="pil", elem_classes="preview_media", visible=False)
413
+ preview_video = gr.Video(label="Preview Video", elem_classes="preview_media", visible=False)
414
 
415
  with gr.Column(scale=2):
416
  final_text = gr.Markdown(value="")
417
 
418
+ def _preview_wrapper(url):
419
+ img, vid, label = load_preview(url)
420
+ return img, vid, label
421
+
422
+ url_input.change(fn=_preview_wrapper, inputs=[url_input], outputs=[preview_image, preview_video, gr.Textbox(visible=False)])
423
  submit.click(fn=generate_final_text, inputs=[url_input, custom_prompt, api_key], outputs=[final_text])
424
  demo.queue()
425
 
426
  if __name__ == "__main__":
427
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))