Hug0endob commited on
Commit
3c2126a
·
verified ·
1 Parent(s): e6159f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -110
app.py CHANGED
@@ -1,20 +1,38 @@
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
 
2
  import os
 
3
  import subprocess
4
  import tempfile
5
- import shutil
6
  import base64
7
  import requests
8
  from io import BytesIO
 
9
  from PIL import Image, ImageFile, UnidentifiedImageError
10
  import gradio as gr
11
- from mistralai import Mistral
12
 
 
 
 
 
 
 
 
 
13
  DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
14
  PIXTRAL_MODEL = "pixtral-12b-2409"
15
  VIDEO_MODEL = "voxtral-mini-latest"
16
  STREAM_THRESHOLD = 20 * 1024 * 1024
17
  FFMPEG_BIN = shutil.which("ffmpeg")
 
 
18
 
19
  SYSTEM_INSTRUCTION = (
20
  "You are a clinical visual analyst. Only analyze media actually provided (image data or extracted frames). "
@@ -23,51 +41,75 @@ SYSTEM_INSTRUCTION = (
23
  "Do not invent sensory information not present in the media."
24
  )
25
 
 
26
  ImageFile.LOAD_TRUNCATED_IMAGES = True
27
  Image.MAX_IMAGE_PIXELS = 10000 * 10000
28
 
29
- IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
30
- VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
31
-
32
- def get_client(key: str = None) -> Mistral:
33
  api_key = (key or "").strip() or DEFAULT_KEY
 
 
 
 
 
34
  return Mistral(api_key=api_key)
35
 
36
  def is_remote(src: str) -> bool:
37
  return bool(src) and src.startswith(("http://", "https://"))
38
 
39
  def ext_from_src(src: str) -> str:
 
 
40
  _, ext = os.path.splitext((src or "").split("?")[0])
41
  return ext.lower()
42
 
43
- def fetch_bytes(src: str, stream_threshold=STREAM_THRESHOLD, timeout=60) -> bytes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if is_remote(src):
45
- try:
46
- h = requests.head(src, timeout=6, allow_redirects=True)
47
- h.raise_for_status()
48
- cl = h.headers.get("content-length")
49
- if cl and int(cl) > stream_threshold:
50
- with requests.get(src, timeout=timeout, stream=True) as r:
51
- r.raise_for_status()
52
- fd, path = tempfile.mkstemp()
53
- os.close(fd)
54
- try:
55
- with open(path, "wb") as f:
56
- for chunk in r.iter_content(8192):
57
- if chunk:
58
- f.write(chunk)
59
- with open(path, "rb") as f:
60
- return f.read()
61
- finally:
62
- try: os.remove(path)
63
- except Exception: pass
64
- except Exception:
65
- pass
66
- with requests.get(src, timeout=timeout) as r:
67
- r.raise_for_status()
68
- return r.content
69
- with open(src, "rb") as f:
70
- return f.read()
 
 
 
 
71
 
72
  def save_bytes_to_temp(b: bytes, suffix: str) -> str:
73
  fd, path = tempfile.mkstemp(suffix=suffix)
@@ -86,6 +128,7 @@ def convert_to_jpeg_bytes(img_bytes: bytes, base_h: int = 480) -> bytes:
86
  if img.mode != "RGB":
87
  img = img.convert("RGB")
88
  h = base_h
 
89
  w = max(1, int(img.width * (h / img.height)))
90
  img = img.resize((w, h), Image.LANCZOS)
91
  buf = BytesIO()
@@ -95,11 +138,11 @@ def convert_to_jpeg_bytes(img_bytes: bytes, base_h: int = 480) -> bytes:
95
  def b64_jpeg(img_bytes: bytes) -> str:
96
  return base64.b64encode(img_bytes).decode("utf-8")
97
 
98
- def extract_best_frames_bytes(media_path: str, sample_count: int = 5, timeout_extract: int = 15) -> list:
 
99
  if not FFMPEG_BIN or not os.path.exists(media_path):
100
- return []
101
  timestamps = [0.5, 1.0, 2.0, 3.0, 4.0][:sample_count]
102
- frames = []
103
  for i, t in enumerate(timestamps):
104
  fd, tmp = tempfile.mkstemp(suffix=f"_{i}.jpg")
105
  os.close(fd)
@@ -109,37 +152,19 @@ def extract_best_frames_bytes(media_path: str, sample_count: int = 5, timeout_ex
109
  if os.path.exists(tmp) and os.path.getsize(tmp) > 0:
110
  with open(tmp, "rb") as f:
111
  frames.append(f.read())
 
 
112
  finally:
113
  try: os.remove(tmp)
114
  except Exception: pass
115
  return frames
116
 
117
- def upload_file_to_mistral(client: Mistral, path: str, filename: str | None = None, purpose: str = "batch") -> str:
118
- fname = filename or os.path.basename(path)
119
- try:
120
- with open(path, "rb") as fh:
121
- res = client.files.upload(file={"file_name": fname, "content": fh}, purpose=purpose)
122
- fid = getattr(res, "id", None) or (res.get("id") if isinstance(res, dict) else None)
123
- if not fid:
124
- fid = res["data"][0]["id"]
125
- return fid
126
- except Exception:
127
- api_key = client.api_key if hasattr(client, "api_key") else os.getenv("MISTRAL_API_KEY", "")
128
- url = "https://api.mistral.ai/v1/files"
129
- headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
130
- with open(path, "rb") as fh:
131
- files = {"file": (fname, fh)}
132
- data = {"purpose": purpose}
133
- r = requests.post(url, headers=headers, files=files, data=data, timeout=120)
134
- r.raise_for_status()
135
- jr = r.json()
136
- return jr.get("id") or jr.get("data", [{}])[0].get("id")
137
-
138
- def build_messages_for_image(prompt: str, b64_jpg: str) -> list:
139
  content = f"{prompt}\n\nImage (base64): data:image/jpeg;base64,{b64_jpg}"
140
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": content}]
141
 
142
- def build_messages_for_text(prompt: str, extra: str) -> list:
143
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"{prompt}\n\n{extra}"}]
144
 
145
  def extract_text_from_response(res, parts: list):
@@ -164,19 +189,58 @@ def extract_text_from_response(res, parts: list):
164
  except Exception:
165
  parts.append(str(res))
166
 
167
- def chat_complete(client: Mistral, model: str, messages: list) -> str:
 
168
  parts = []
169
- res = client.chat.complete(model=model, messages=messages, stream=False)
170
- extract_text_from_response(res, parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return "".join(parts).strip()
172
 
173
- def analyze_image(client: Mistral, img_bytes: bytes, prompt: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  jpeg = convert_to_jpeg_bytes(img_bytes, base_h=640)
175
  b64 = b64_jpeg(jpeg)
176
  msgs = build_messages_for_image(prompt, b64)
177
  return chat_complete(client, PIXTRAL_MODEL, msgs)
178
 
179
- def analyze_frames_and_consolidate(client: Mistral, frames: list, prompt: str) -> str:
180
  per_frame = []
181
  for i, fb in enumerate(frames):
182
  txt = analyze_image(client, fb, f"{prompt}\n\nFrame index: {i + 1}")
@@ -191,6 +255,29 @@ def analyze_frames_and_consolidate(client: Mistral, frames: list, prompt: str) -
191
  summary = chat_complete(client, PIXTRAL_MODEL, msgs)
192
  return "\n\n".join(per_frame + [f"Consolidated summary:\n{summary}"])
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def process_media(src: str, custom_prompt: str, api_key: str) -> str:
195
  client = get_client(api_key)
196
  prompt = custom_prompt.strip() or "Please provide a detailed visual review."
@@ -198,16 +285,9 @@ def process_media(src: str, custom_prompt: str, api_key: str) -> str:
198
  is_image = ext in IMAGE_EXTS
199
  is_video = ext in VIDEO_EXTS
200
  if is_remote(src):
201
- try:
202
- h = requests.head(src, timeout=6, allow_redirects=True)
203
- h.raise_for_status()
204
- ctype = (h.headers.get("content-type") or "").lower()
205
- if ctype.startswith("video/"):
206
- is_video = True; is_image = False
207
- elif ctype.startswith("image/"):
208
- is_image = True; is_video = False
209
- except Exception:
210
- pass
211
  if is_image:
212
  try:
213
  raw = fetch_bytes(src)
@@ -224,75 +304,92 @@ def process_media(src: str, custom_prompt: str, api_key: str) -> str:
224
  raw = fetch_bytes(src, timeout=120)
225
  except Exception as e:
226
  return f"Error fetching video: {e}"
227
- tmp_path = save_bytes_to_temp(raw, suffix=ext or ".mp4")
 
228
  try:
 
229
  try:
230
  file_id = upload_file_to_mistral(client, tmp_path, filename=os.path.basename(src.split("?")[0]))
231
  extra = f"Uploaded video to Mistral Files with id: {file_id}\n\nInstruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
232
  msgs = build_messages_for_text(prompt, extra)
233
  return chat_complete(client, VIDEO_MODEL, msgs)
234
  except Exception:
 
235
  frames = extract_best_frames_bytes(tmp_path, sample_count=5)
236
  if not frames:
237
- return "Error: could not upload remote video and no frames extracted."
238
  return analyze_frames_and_consolidate(client, frames, prompt)
239
  finally:
240
  try: os.remove(tmp_path)
241
  except Exception: pass
242
  return "Unable to determine media type from the provided URL or file extension."
243
 
 
244
  css = ".preview_media img, .preview_media video { max-width: 100%; height: auto; }"
245
 
246
  def load_preview(url: str):
 
 
 
247
  if not url:
248
- return gr.update(value=None, visible=False), gr.update(value=None, visible=False)
 
249
  if not is_remote(url) and os.path.exists(url):
250
  ext = ext_from_src(url)
251
  if ext in VIDEO_EXTS:
252
- return gr.update(value=None, visible=False), gr.update(value=os.path.abspath(url), visible=True)
253
  if ext in IMAGE_EXTS:
254
  try:
255
  img = Image.open(url)
256
  if getattr(img, "is_animated", False):
257
  img.seek(0)
258
- return gr.update(value=img.convert("RGB"), visible=True), gr.update(value=None, visible=False)
259
  except Exception:
260
- return gr.update(value=None, visible=False), gr.update(value=None, visible=False)
 
 
 
 
 
 
 
261
  try:
262
- h = requests.head(url, timeout=6, allow_redirects=True)
263
- if h.ok:
264
- ctype = (h.headers.get("content-type") or "").lower()
265
- if ctype.startswith("video/") or any(url.lower().split("?")[0].endswith(ext) for ext in VIDEO_EXTS):
266
- return gr.update(value=None, visible=False), gr.update(value=url, visible=True)
267
- except Exception:
268
- pass
269
- try:
270
- r = requests.get(url, timeout=15)
271
- r.raise_for_status()
272
  img = Image.open(BytesIO(r.content))
273
  if getattr(img, "is_animated", False):
274
  img.seek(0)
275
- return gr.update(value=img.convert("RGB"), visible=True), gr.update(value=None, visible=False)
276
  except Exception:
277
- return gr.update(value=None, visible=False), gr.update(value=None, visible=False)
278
-
279
- with gr.Blocks(title="Flux Multimodal (fixed)", css=css) as demo:
280
- with gr.Row():
281
- with gr.Column(scale=1):
282
- url_input = gr.Textbox(label="Image / Video URL or local path", placeholder="https://... or /path/to/file", lines=1)
283
- custom_prompt = gr.Textbox(label="Prompt (optional)", lines=2, value="")
284
- with gr.Accordion("Mistral API Key (optional)", open=False):
285
- api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
286
- submit_btn = gr.Button("Submit")
287
- preview_image = gr.Image(label="Preview Image", type="pil", elem_classes="preview_media", visible=False)
288
- preview_video = gr.Video(label="Preview Video", elem_classes="preview_media", visible=False)
289
- with gr.Column(scale=2):
290
- final_md = gr.Markdown(value="")
291
-
292
- url_input.change(fn=load_preview, inputs=[url_input], outputs=[preview_image, preview_video])
293
- def submit_wrapper(url, prompt, key):
294
- return process_media(url, prompt, key)
295
- submit_btn.click(fn=submit_wrapper, inputs=[url_input, custom_prompt, api_key], outputs=[final_md])
 
 
 
 
 
 
 
 
 
 
296
 
297
  if __name__ == "__main__":
 
298
  demo.queue().launch()
 
1
  #!/usr/bin/env python3
2
+ """
3
+ flux_multimodal_fixed.py
4
+ Streamlined Gradio app to preview an image/video from URL or local path,
5
+ send media (or extracted frames) to Mistral API for analysis using the
6
+ default SYSTEM_INSTRUCTION prompt (unless user supplies one).
7
+ """
8
+
9
+ from __future__ import annotations
10
  import os
11
+ import shutil
12
  import subprocess
13
  import tempfile
 
14
  import base64
15
  import requests
16
  from io import BytesIO
17
+ from typing import List, Tuple
18
  from PIL import Image, ImageFile, UnidentifiedImageError
19
  import gradio as gr
 
20
 
21
+ # Import Mistral client in the same way original code did.
22
+ # If you have a different client interface, adjust get_client/upload_file_to_mistral accordingly.
23
+ try:
24
+ from mistralai import Mistral
25
+ except Exception:
26
+ Mistral = None # Fallback; upload will use raw HTTP if needed
27
+
28
+ # --- Configuration / constants ---
29
  DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
30
  PIXTRAL_MODEL = "pixtral-12b-2409"
31
  VIDEO_MODEL = "voxtral-mini-latest"
32
  STREAM_THRESHOLD = 20 * 1024 * 1024
33
  FFMPEG_BIN = shutil.which("ffmpeg")
34
+ IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
35
+ VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
36
 
37
  SYSTEM_INSTRUCTION = (
38
  "You are a clinical visual analyst. Only analyze media actually provided (image data or extracted frames). "
 
41
  "Do not invent sensory information not present in the media."
42
  )
43
 
44
+ # Pillow config
45
  ImageFile.LOAD_TRUNCATED_IMAGES = True
46
  Image.MAX_IMAGE_PIXELS = 10000 * 10000
47
 
48
+ # --- Utilities ---
49
+ def get_client(key: str | None = None):
 
 
50
  api_key = (key or "").strip() or DEFAULT_KEY
51
+ if Mistral is None:
52
+ # If mistralai package is not available, return a thin object with api_key attr for upload fallback.
53
+ class Dummy:
54
+ def __init__(self, k): self.api_key = k
55
+ return Dummy(api_key)
56
  return Mistral(api_key=api_key)
57
 
58
  def is_remote(src: str) -> bool:
59
  return bool(src) and src.startswith(("http://", "https://"))
60
 
61
  def ext_from_src(src: str) -> str:
62
+ if not src:
63
+ return ""
64
  _, ext = os.path.splitext((src or "").split("?")[0])
65
  return ext.lower()
66
 
67
+ def safe_head(url: str, timeout: int = 6) -> requests.Response | None:
68
+ try:
69
+ r = requests.head(url, timeout=timeout, allow_redirects=True)
70
+ if r.status_code >= 400:
71
+ return None
72
+ return r
73
+ except Exception:
74
+ return None
75
+
76
+ def safe_get(url: str, timeout: int = 15) -> requests.Response:
77
+ r = requests.get(url, timeout=timeout)
78
+ r.raise_for_status()
79
+ return r
80
+
81
+ def fetch_bytes(src: str, stream_threshold: int = STREAM_THRESHOLD, timeout: int = 60) -> bytes:
82
  if is_remote(src):
83
+ # try HEAD to learn content-length
84
+ head = safe_head(src)
85
+ if head is not None:
86
+ cl = head.headers.get("content-length")
87
+ try:
88
+ if cl and int(cl) > stream_threshold:
89
+ # stream download to temp file to avoid memory spike
90
+ with requests.get(src, timeout=timeout, stream=True) as r:
91
+ r.raise_for_status()
92
+ fd, p = tempfile.mkstemp()
93
+ os.close(fd)
94
+ try:
95
+ with open(p, "wb") as fh:
96
+ for chunk in r.iter_content(8192):
97
+ if chunk:
98
+ fh.write(chunk)
99
+ with open(p, "rb") as fh:
100
+ return fh.read()
101
+ finally:
102
+ try: os.remove(p)
103
+ except Exception: pass
104
+ except Exception:
105
+ # fallthrough to simple get
106
+ pass
107
+ # regular GET
108
+ r = safe_get(src, timeout=timeout)
109
+ return r.content
110
+ else:
111
+ with open(src, "rb") as f:
112
+ return f.read()
113
 
114
  def save_bytes_to_temp(b: bytes, suffix: str) -> str:
115
  fd, path = tempfile.mkstemp(suffix=suffix)
 
128
  if img.mode != "RGB":
129
  img = img.convert("RGB")
130
  h = base_h
131
+ # maintain aspect
132
  w = max(1, int(img.width * (h / img.height)))
133
  img = img.resize((w, h), Image.LANCZOS)
134
  buf = BytesIO()
 
138
  def b64_jpeg(img_bytes: bytes) -> str:
139
  return base64.b64encode(img_bytes).decode("utf-8")
140
 
141
+ def extract_best_frames_bytes(media_path: str, sample_count: int = 5, timeout_extract: int = 15) -> List[bytes]:
142
+ frames = []
143
  if not FFMPEG_BIN or not os.path.exists(media_path):
144
+ return frames
145
  timestamps = [0.5, 1.0, 2.0, 3.0, 4.0][:sample_count]
 
146
  for i, t in enumerate(timestamps):
147
  fd, tmp = tempfile.mkstemp(suffix=f"_{i}.jpg")
148
  os.close(fd)
 
152
  if os.path.exists(tmp) and os.path.getsize(tmp) > 0:
153
  with open(tmp, "rb") as f:
154
  frames.append(f.read())
155
+ except Exception:
156
+ pass
157
  finally:
158
  try: os.remove(tmp)
159
  except Exception: pass
160
  return frames
161
 
162
+ # --- Mistral interaction helpers ---
163
+ def build_messages_for_image(prompt: str, b64_jpg: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  content = f"{prompt}\n\nImage (base64): data:image/jpeg;base64,{b64_jpg}"
165
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": content}]
166
 
167
+ def build_messages_for_text(prompt: str, extra: str):
168
  return [{"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"{prompt}\n\n{extra}"}]
169
 
170
  def extract_text_from_response(res, parts: list):
 
189
  except Exception:
190
  parts.append(str(res))
191
 
192
+ def chat_complete(client, model: str, messages: list) -> str:
193
+ # Prefer client.chat.complete if available; otherwise attempt REST call
194
  parts = []
195
+ try:
196
+ if hasattr(client, "chat") and hasattr(client.chat, "complete"):
197
+ res = client.chat.complete(model=model, messages=messages, stream=False)
198
+ else:
199
+ # Try basic HTTP request (Mistral REST)
200
+ api_key = getattr(client, "api_key", "") or DEFAULT_KEY
201
+ url = f"https://api.mistral.ai/v1/chat/completions"
202
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
203
+ payload = {"model": model, "messages": messages}
204
+ r = requests.post(url, json=payload, headers=headers, timeout=120)
205
+ r.raise_for_status()
206
+ res = r.json()
207
+ extract_text_from_response(res, parts)
208
+ except Exception as e:
209
+ parts.append(f"Error during model call: {e}")
210
  return "".join(parts).strip()
211
 
212
+ def upload_file_to_mistral(client, path: str, filename: str | None = None, purpose: str = "batch") -> str:
213
+ fname = filename or os.path.basename(path)
214
+ # Prefer SDK upload if available
215
+ try:
216
+ if hasattr(client, "files") and hasattr(client.files, "upload"):
217
+ with open(path, "rb") as fh:
218
+ res = client.files.upload(file={"file_name": fname, "content": fh}, purpose=purpose)
219
+ fid = getattr(res, "id", None) or (res.get("id") if isinstance(res, dict) else None)
220
+ if not fid:
221
+ fid = res["data"][0]["id"]
222
+ return fid
223
+ except Exception:
224
+ pass
225
+ # Fallback to HTTP upload
226
+ api_key = getattr(client, "api_key", "") or DEFAULT_KEY
227
+ url = "https://api.mistral.ai/v1/files"
228
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
229
+ with open(path, "rb") as fh:
230
+ files = {"file": (fname, fh)}
231
+ data = {"purpose": purpose}
232
+ r = requests.post(url, headers=headers, files=files, data=data, timeout=120)
233
+ r.raise_for_status()
234
+ jr = r.json()
235
+ return jr.get("id") or jr.get("data", [{}])[0].get("id")
236
+
237
+ def analyze_image(client, img_bytes: bytes, prompt: str) -> str:
238
  jpeg = convert_to_jpeg_bytes(img_bytes, base_h=640)
239
  b64 = b64_jpeg(jpeg)
240
  msgs = build_messages_for_image(prompt, b64)
241
  return chat_complete(client, PIXTRAL_MODEL, msgs)
242
 
243
+ def analyze_frames_and_consolidate(client, frames: List[bytes], prompt: str) -> str:
244
  per_frame = []
245
  for i, fb in enumerate(frames):
246
  txt = analyze_image(client, fb, f"{prompt}\n\nFrame index: {i + 1}")
 
255
  summary = chat_complete(client, PIXTRAL_MODEL, msgs)
256
  return "\n\n".join(per_frame + [f"Consolidated summary:\n{summary}"])
257
 
258
+ # --- Core processing ---
259
+ def determine_media_type_from_remote(url: str) -> Tuple[bool, bool]:
260
+ """
261
+ Returns (is_image, is_video) based on HEAD content-type or URL extension
262
+ """
263
+ is_image = False
264
+ is_video = False
265
+ if not url:
266
+ return is_image, is_video
267
+ ext = ext_from_src(url)
268
+ if ext in IMAGE_EXTS:
269
+ is_image = True
270
+ if ext in VIDEO_EXTS:
271
+ is_video = True
272
+ head = safe_head(url)
273
+ if head is not None:
274
+ ctype = (head.headers.get("content-type") or "").lower()
275
+ if ctype.startswith("video/"):
276
+ is_video = True; is_image = False
277
+ elif ctype.startswith("image/"):
278
+ is_image = True; is_video = False
279
+ return is_image, is_video
280
+
281
  def process_media(src: str, custom_prompt: str, api_key: str) -> str:
282
  client = get_client(api_key)
283
  prompt = custom_prompt.strip() or "Please provide a detailed visual review."
 
285
  is_image = ext in IMAGE_EXTS
286
  is_video = ext in VIDEO_EXTS
287
  if is_remote(src):
288
+ ri, rv = determine_media_type_from_remote(src)
289
+ if ri or rv:
290
+ is_image, is_video = ri, rv
 
 
 
 
 
 
 
291
  if is_image:
292
  try:
293
  raw = fetch_bytes(src)
 
304
  raw = fetch_bytes(src, timeout=120)
305
  except Exception as e:
306
  return f"Error fetching video: {e}"
307
+ tmp_suffix = ext or ".mp4"
308
+ tmp_path = save_bytes_to_temp(raw, suffix=tmp_suffix)
309
  try:
310
+ # Try uploading file to Mistral first
311
  try:
312
  file_id = upload_file_to_mistral(client, tmp_path, filename=os.path.basename(src.split("?")[0]))
313
  extra = f"Uploaded video to Mistral Files with id: {file_id}\n\nInstruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
314
  msgs = build_messages_for_text(prompt, extra)
315
  return chat_complete(client, VIDEO_MODEL, msgs)
316
  except Exception:
317
+ # fallback to extracting frames
318
  frames = extract_best_frames_bytes(tmp_path, sample_count=5)
319
  if not frames:
320
+ return "Error: could not upload remote video and no frames extracted (ffmpeg missing or extraction failed)."
321
  return analyze_frames_and_consolidate(client, frames, prompt)
322
  finally:
323
  try: os.remove(tmp_path)
324
  except Exception: pass
325
  return "Unable to determine media type from the provided URL or file extension."
326
 
327
+ # --- Gradio app UI helpers ---
328
  css = ".preview_media img, .preview_media video { max-width: 100%; height: auto; }"
329
 
330
  def load_preview(url: str):
331
+ # Returns (preview_image, preview_video) where only one is visible at a time
332
+ empty_img = gr.update(value=None, visible=False)
333
+ empty_vid = gr.update(value=None, visible=False)
334
  if not url:
335
+ return empty_img, empty_vid
336
+ # Local file
337
  if not is_remote(url) and os.path.exists(url):
338
  ext = ext_from_src(url)
339
  if ext in VIDEO_EXTS:
340
+ return empty_img, gr.update(value=os.path.abspath(url), visible=True)
341
  if ext in IMAGE_EXTS:
342
  try:
343
  img = Image.open(url)
344
  if getattr(img, "is_animated", False):
345
  img.seek(0)
346
+ return gr.update(value=img.convert("RGB"), visible=True), empty_vid
347
  except Exception:
348
+ return empty_img, empty_vid
349
+ # Remote: first try HEAD
350
+ head = safe_head(url)
351
+ if head:
352
+ ctype = (head.headers.get("content-type") or "").lower()
353
+ if ctype.startswith("video/") or any(url.lower().split("?")[0].endswith(ext) for ext in VIDEO_EXTS):
354
+ return empty_img, gr.update(value=url, visible=True)
355
+ # Finally try GET and attempt to open as image
356
  try:
357
+ r = safe_get(url, timeout=15)
 
 
 
 
 
 
 
 
 
358
  img = Image.open(BytesIO(r.content))
359
  if getattr(img, "is_animated", False):
360
  img.seek(0)
361
+ return gr.update(value=img.convert("RGB"), visible=True), empty_vid
362
  except Exception:
363
+ return empty_img, empty_vid
364
+
365
+ # --- Gradio app layout ---
366
+ def create_app():
367
+ with gr.Blocks(title="Flux Multimodal (fixed)", css=css) as demo:
368
+ with gr.Row():
369
+ with gr.Column(scale=1):
370
+ url_input = gr.Textbox(label="Image / Video URL or local path", placeholder="https://... or /path/to/file", lines=1)
371
+ custom_prompt = gr.Textbox(label="Prompt (optional)", lines=2, value="")
372
+ with gr.Accordion("Mistral API Key (optional)", open=False):
373
+ api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
374
+ submit_btn = gr.Button("Submit")
375
+ preview_image = gr.Image(label="Preview Image", type="pil", elem_classes="preview_media", visible=False)
376
+ preview_video = gr.Video(label="Preview Video", elem_classes="preview_media", visible=False)
377
+ with gr.Column(scale=2):
378
+ final_md = gr.Markdown(value="")
379
+
380
+ # Update preview on change
381
+ url_input.change(fn=load_preview, inputs=[url_input], outputs=[preview_image, preview_video])
382
+
383
+ def submit_wrapper(url, prompt, key):
384
+ try:
385
+ return process_media(url or "", prompt or "", key or "")
386
+ except Exception as e:
387
+ return f"Unhandled error: {e}"
388
+
389
+ submit_btn.click(fn=submit_wrapper, inputs=[url_input, custom_prompt, api_key], outputs=[final_md])
390
+
391
+ return demo
392
 
393
  if __name__ == "__main__":
394
+ demo = create_app()
395
  demo.queue().launch()