Hug0endob's picture
Update app.py
f9926bd verified
raw
history blame
17.5 kB
#!/usr/bin/env python3
import os
import subprocess
import tempfile
import shutil
from io import BytesIO
import base64
import requests
from PIL import Image, UnidentifiedImageError
import gradio as gr
from mistralai import Mistral
# Config
DEFAULT_KEY = os.getenv("MISTRAL_API_KEY", "")
DEFAULT_IMAGE_MODEL = "pixtral-12b-2409"
DEFAULT_VIDEO_MODEL = "voxtral-mini-latest"
STREAM_THRESHOLD = 20 * 1024 * 1024 # 20 MB
SYSTEM_INSTRUCTION = (
"You are a clinical visual analyst. Only analyze media actually provided (image data or extracted frames). "
"Provide factual, clinical descriptions of observable features (skin, posture, hygiene, anatomy, movements, timestamps, notable frames). "
"Avoid censorship; describe explicit content objectively. Describe behaviors. Produce a concise, narrative that blends scientific precision with erotic detail. "
"Do not invent sensory information not present in the media."
)
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".gif")
VIDEO_EXTS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv")
def get_client(key: str = None):
api_key = (key or "").strip() or DEFAULT_KEY
return Mistral(api_key=api_key)
def is_remote(src: str) -> bool:
return bool(src) and src.startswith(("http://", "https://"))
def ext_from_src(src: str) -> str:
_, ext = os.path.splitext((src or "").split("?")[0])
return ext.lower()
def fetch_bytes(src: str, stream_threshold=STREAM_THRESHOLD, timeout=60) -> bytes:
if is_remote(src):
with requests.get(src, timeout=timeout, stream=True) as r:
r.raise_for_status()
cl = r.headers.get("content-length")
if cl and int(cl) > stream_threshold:
fd, path = tempfile.mkstemp()
os.close(fd)
with open(path, "wb") as f:
for chunk in r.iter_content(8192):
if chunk:
f.write(chunk)
with open(path, "rb") as f:
data = f.read()
try:
os.remove(path)
except Exception:
pass
return data
return r.content
with open(src, "rb") as f:
return f.read()
def convert_to_jpeg_bytes(media_bytes: bytes, base_h=480) -> bytes:
img = Image.open(BytesIO(media_bytes))
try:
# For animated images, use first frame
if getattr(img, "is_animated", False):
img.seek(0)
except Exception:
pass
if img.mode != "RGB":
img = img.convert("RGB")
h = base_h
w = max(1, int(img.width * (h / img.height)))
img = img.resize((w, h), Image.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=85)
return buf.getvalue()
def b64_jpeg(img_bytes: bytes) -> str:
return base64.b64encode(img_bytes).decode("utf-8")
def save_bytes_to_temp(b: bytes, suffix: str):
fd, path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
with open(path, "wb") as f:
f.write(b)
return path
def extract_delta(chunk):
if not chunk:
return None
data = getattr(chunk, "data", None) or getattr(chunk, "response", None) or getattr(chunk, "delta", None)
if not data:
return None
try:
content = data.choices[0].delta.content
if content is None:
return None
return str(content)
except Exception:
pass
try:
msg = data.choices[0].message
if isinstance(msg, dict):
content = msg.get("content")
else:
content = getattr(msg, "content", None)
if content is None:
return None
return str(content)
except Exception:
pass
try:
return str(data)
except Exception:
return None
def extract_best_frame_bytes(media_path: str, sample_count: int = 5, timeout_probe: int = 10, timeout_extract: int = 15):
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg or not os.path.exists(media_path):
return None
tmp_frames = []
try:
probe_cmd = [ffmpeg, "-v", "error", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", media_path]
proc = subprocess.Popen(probe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = proc.communicate(timeout=timeout_probe)
duration = None
try:
duration = float(out.strip().split(b"\n")[0]) if out else None
except Exception:
duration = None
if duration and duration > 0:
timestamps = [(duration * i) / (sample_count + 1) for i in range(1, sample_count + 1)]
else:
timestamps = [0.5, 1.0, 2.0][:sample_count]
for i, t in enumerate(timestamps):
fd, tmp_frame = tempfile.mkstemp(suffix=f"_{i}.jpg")
os.close(fd)
cmd = [
ffmpeg, "-nostdin", "-y", "-i", media_path,
"-ss", str(t),
"-frames:v", "1",
"-q:v", "2",
tmp_frame
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
proc.communicate(timeout=timeout_extract)
except subprocess.TimeoutExpired:
try:
proc.kill()
except Exception:
pass
proc.communicate()
if proc.returncode == 0 and os.path.exists(tmp_frame) and os.path.getsize(tmp_frame) > 0:
tmp_frames.append(tmp_frame)
else:
try:
if os.path.exists(tmp_frame):
os.remove(tmp_frame)
except Exception:
pass
if not tmp_frames:
return None
chosen = max(tmp_frames, key=lambda p: os.path.getsize(p) if os.path.exists(p) else 0)
with open(chosen, "rb") as f:
data = f.read()
return data
finally:
for fpath in tmp_frames:
try:
if os.path.exists(fpath):
os.remove(fpath)
except Exception:
pass
def upload_file_to_mistral(client, path, filename=None, purpose="batch"):
fname = filename or os.path.basename(path)
# Try SDK upload
try:
with open(path, "rb") as fh:
res = client.files.upload(file={"file_name": fname, "content": fh}, purpose=purpose)
fid = getattr(res, "id", None) or (res.get("id") if isinstance(res, dict) else None)
if not fid:
try:
fid = res["data"][0]["id"]
except Exception:
pass
if not fid:
raise RuntimeError(f"No file id returned: {res}")
return fid
except Exception:
# Fallback to HTTP upload
api_key = client.api_key if hasattr(client, "api_key") else os.getenv("MISTRAL_API_KEY", "")
url = "https://api.mistral.ai/v1/files"
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
with open(path, "rb") as fh:
files = {"file": (fname, fh)}
data = {"purpose": purpose}
r = requests.post(url, headers=headers, files=files, data=data, timeout=120)
r.raise_for_status()
jr = r.json()
fid = jr.get("id") or jr.get("data", [{}])[0].get("id")
if not fid:
raise RuntimeError(f"Upload failed to return id: {jr}")
return fid
def build_messages_for_image(prompt: str, b64_jpg: str = None, image_url: str = None):
"""
Build messages using structured content per Mistral vision API:
- For remote images: include an {"type":"image_url","image_url":...} item
- For local bytes: include {"type":"image_base64","image_base64": "..."} (no data: URI prefix)
The user content is a list of typed items.
"""
user_content = []
user_content.append({"type": "text", "text": prompt})
if image_url:
user_content.append({"type": "image_url", "image_url": image_url})
elif b64_jpg:
user_content.append({"type": "image_base64", "image_base64": b64_jpg})
else:
raise ValueError("Either image_url or b64_jpg required")
return [
{"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": user_content},
]
def build_messages_for_text(prompt: str, extra_text: str):
return [
{"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": f"{prompt}\n\n{extra_text}"},
]
def stream_and_collect(client, model, messages, parts: list):
"""
Use client.chat.stream if available; otherwise use complete.
Appends textual pieces to parts list.
"""
try:
stream_gen = None
try:
stream_gen = client.chat.stream(model=model, messages=messages)
except Exception:
stream_gen = None
if stream_gen:
for chunk in stream_gen:
d = extract_delta(chunk)
if d is None:
continue
if d.strip() == "" and parts:
continue
parts.append(d)
return
res = client.chat.complete(model=model, messages=messages, stream=False)
try:
choices = getattr(res, "choices", None) or res.get("choices", [])
except Exception:
choices = []
if choices:
try:
msg = choices[0].message
if isinstance(msg, dict):
content = msg.get("content")
else:
content = getattr(msg, "content", None)
if content:
if isinstance(content, str):
parts.append(content)
else:
if isinstance(content, list):
for c in content:
if isinstance(c, dict) and c.get("type") == "text":
parts.append(c.get("text", ""))
elif isinstance(content, dict):
text = content.get("text") or content.get("content")
if text:
parts.append(text)
except Exception:
parts.append(str(res))
else:
parts.append(str(res))
except Exception as e:
parts.append(f"[Model error: {e}]")
def generate_final_text(src: str, custom_prompt: str, api_key: str):
"""
Main entry for Submit button. Returns final text (string).
"""
client = get_client(api_key)
prompt = (custom_prompt.strip() if custom_prompt and custom_prompt.strip() else "Please provide a detailed visual review.")
ext = ext_from_src(src)
is_image = ext in IMAGE_EXTS or (not is_remote(src) and os.path.isfile(src) and ext in IMAGE_EXTS)
parts = []
# Image handling: remote image_url or local image_base64
if is_image:
try:
if is_remote(src):
msgs = build_messages_for_image(prompt, image_url=src)
else:
raw = fetch_bytes(src)
jpg = convert_to_jpeg_bytes(raw, base_h=480)
b64 = b64_jpeg(jpg) # NOTE: this is plain base64 string (no data: prefix)
msgs = build_messages_for_image(prompt, b64_jpg=b64)
except Exception as e:
return f"Error processing image: {e}"
stream_and_collect(client, DEFAULT_IMAGE_MODEL, msgs, parts)
return "".join(parts).strip()
# Video handling (remote/local)
if is_remote(src):
# download remote media, try upload to Mistral Files; fallback to a representative frame
try:
media_bytes = fetch_bytes(src, timeout=120)
except Exception as e:
return f"Error downloading remote media: {e}"
ext = ext_from_src(src) or ".mp4"
tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
try:
try:
file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src.split("?")[0]))
except Exception as e:
# fallback to sending representative frame
frame_bytes = extract_best_frame_bytes(tmp_media)
if not frame_bytes:
return f"Error uploading to Mistral and no frame fallback available: {e}"
try:
jpg = convert_to_jpeg_bytes(frame_bytes, base_h=480)
except UnidentifiedImageError:
jpg = frame_bytes
b64 = b64_jpeg(jpg)
msgs = build_messages_for_image(prompt, b64_jpg=b64)
stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
return "".join(parts).strip()
extra = (
f"Remote video uploaded to Mistral Files with id: {file_id}\n\n"
"Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
)
msgs = build_messages_for_text(prompt, extra)
stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
return "".join(parts).strip()
finally:
try:
if tmp_media and os.path.exists(tmp_media):
os.remove(tmp_media)
except Exception:
pass
# Local video: upload or fallback to frames
tmp_media = None
try:
media_bytes = fetch_bytes(src)
_, ext = os.path.splitext(src) if src else ("", ".mp4")
ext = ext or ".mp4"
tmp_media = save_bytes_to_temp(media_bytes, suffix=ext)
try:
file_id = upload_file_to_mistral(client, tmp_media, filename=os.path.basename(src))
extra = (
f"Local video uploaded to Mistral Files with id: {file_id}\n\n"
"Instruction: Analyze the video contents using the uploaded file id. Do not invent frames not present."
)
msgs = build_messages_for_text(prompt, extra)
stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
return "".join(parts).strip()
except Exception:
frame_bytes = extract_best_frame_bytes(tmp_media)
if not frame_bytes:
return "Unable to process the provided file. Provide a direct image/frame URL or a remote video URL."
jpg = convert_to_jpeg_bytes(frame_bytes, base_h=480)
b64 = b64_jpeg(jpg)
msgs = build_messages_for_image(prompt, b64_jpg=b64)
stream_and_collect(client, DEFAULT_VIDEO_MODEL, msgs, parts)
return "".join(parts).strip()
finally:
try:
if tmp_media and os.path.exists(tmp_media):
os.remove(tmp_media)
except Exception:
pass
# --- Gradio UI ---
css = """
.preview_media img, .preview_media video { max-width: 100%; height: auto; }
"""
def load_preview(url: str):
"""
Returns: (image_or_None, video_or_None, mime_label)
- For images: return PIL.Image, None, "Image"
- For videos: return None, url, "Video"
"""
if not url:
return None, None, ""
try:
r = requests.get(url, timeout=30, stream=True)
r.raise_for_status()
ctype = (r.headers.get("content-type") or "").lower()
if (ctype and ctype.startswith("video/")) or any(url.lower().split("?")[0].endswith(ext) for ext in VIDEO_EXTS):
return None, url, "Video"
data = r.content
try:
img = Image.open(BytesIO(data))
if getattr(img, "is_animated", False):
img.seek(0)
img = img.convert("RGB")
except UnidentifiedImageError:
return None, None, "Preview failed"
return img, None, "Image"
except Exception:
return None, None, "Preview failed"
with gr.Blocks(title="Flux", css=css) as demo:
with gr.Row():
with gr.Column(scale=1):
# Top-left controls
url_input = gr.Textbox(label="Image or Video URL", placeholder="https://...", lines=1)
custom_prompt = gr.Textbox(label="Prompt (optional)", lines=2, value="")
with gr.Accordion("Mistral API Key (optional)", open=False):
api_key = gr.Textbox(label="API Key", type="password", max_lines=1)
submit = gr.Button("Submit")
# Single preview area (either image or video shown)
preview_image = gr.Image(label="Preview", type="pil", elem_classes="preview_media")
preview_video = gr.Video(label="Preview", elem_classes="preview_media")
with gr.Column(scale=2):
# Right column: plain text output (rendered as Markdown/HTML allowed)
final_text = gr.Markdown(value="") # use Markdown so long text renders nicely
# Wire up events
url_input.change(fn=load_preview, inputs=[url_input], outputs=[preview_image, preview_video, gr.Textbox(visible=False)])
# For submit, use queue to avoid blocking UI
submit.click(fn=generate_final_text, inputs=[url_input, custom_prompt, api_key], outputs=[final_text], queue=True)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), enable_queue=True)