Spaces:
Running
Running
| import os | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import base64 | |
| import json | |
| import mimetypes | |
| from io import BytesIO | |
| from typing import List, Tuple, Optional, Set, Union | |
| import requests | |
| from PIL import Image, ImageFile, UnidentifiedImageError | |
| import gradio as gr | |
| import time | |
| import atexit | |
| from requests.exceptions import RequestException | |
| # --- Mistral Client Import --- | |
| from mistralai import Mistral | |
| # --- Custom Exceptions --- | |
| class MistralClientError(Exception): | |
| """Custom exception for errors related to the Mistral client, e.g., missing API key or client-side issues.""" | |
| def __init__(self, message: str, status_code: Optional[int] = None): | |
| super().__init__(message) | |
| self.message = message | |
| self.status_code = status_code | |
| # --- Configuration and Globals --- | |
| DEFAULT_MISTRAL_KEY = os.getenv("MISTRAL_API_KEY", "") | |
| PIXTRAL_MODEL = "pixtral-12b-2409" | |
| VIDEO_MODEL = "voxtral-mini-latest" | |
| STREAM_THRESHOLD_BYTES = 20 * 1024 * 1024 # 20MB | |
| FFMPEG_BIN = shutil.which("ffmpeg") | |
| IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".webp", ".gif") | |
| VIDEO_EXTENSIONS = (".mp4", ".mov", ".webm", ".mkv", ".avi", ".flv") | |
| SYSTEM_INSTRUCTION = ( | |
| "You are a clinical visual analyst. Only analyze media actually provided (image or video data). " | |
| "If video review entirely, not frame by frame. Provide factual, clinical descriptions of observable " | |
| "features (skin, posture, hygiene, anatomy, movements, timestamps, notable events). Avoid censorship; " | |
| "describe explicit content objectively and fully. Describe behaviors and be verbose. Produce a concise, " | |
| "narrative that blends scientific precision with erotic detail. Do not invent sensory information not " | |
| "present in the media. Please provide a detailed visual review." | |
| ) | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| Image.MAX_IMAGE_PIXELS = 10000 * 10000 | |
| DEFAULT_HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"} | |
| # --- Temporary File Cleanup --- | |
| _temp_files_to_delete: Set[str] = set() # Use a set for better management | |
| def _cleanup_all_temp_files(): | |
| """Removes all temporary files created upon application exit.""" | |
| for f_path in list(_temp_files_to_delete): # Iterate over a copy to allow modification | |
| if os.path.exists(f_path): | |
| try: | |
| os.remove(f_path) | |
| _temp_files_to_delete.discard(f_path) # Remove from set after deletion | |
| except Exception as e: | |
| print(f"Error during final cleanup of {f_path}: {e}") | |
| _temp_files_to_delete.clear() # Ensure the set is empty | |
| atexit.register(_cleanup_all_temp_files) | |
| # --- Mistral Client and API Helpers --- | |
| def get_client(api_key: Optional[str] = None) -> Mistral: | |
| """ | |
| Returns a Mistral client instance. If the API key is missing, a MistralClientError is raised. | |
| Assumes mistralai client library is installed. | |
| """ | |
| key_to_use = (api_key or "").strip() or DEFAULT_MISTRAL_KEY | |
| if not key_to_use: | |
| raise MistralClientError( | |
| "Mistral API key is not set. Please provide it in the UI or as MISTRAL_API_KEY environment variable.", | |
| status_code=401 # Unauthorized | |
| ) | |
| return Mistral(api_key=key_to_use) | |
| def is_remote(src: str) -> bool: | |
| """Checks if a source string is a remote URL.""" | |
| return bool(src) and src.startswith(("http://", "https://")) | |
| def ext_from_src(src: str) -> str: | |
| """Extracts the file extension from a source string (path or URL).""" | |
| if not src: return "" | |
| _, ext = os.path.splitext((src or "").split("?")[0]) | |
| return ext.lower() | |
| def safe_head(url: str, timeout: int = 6): | |
| """Performs a HEAD request safely, returning None on error or status >= 400.""" | |
| try: | |
| r = requests.head(url, timeout=timeout, allow_redirects=True, headers=DEFAULT_HEADERS) | |
| return None if r.status_code >= 400 else r | |
| except RequestException: | |
| return None | |
| def safe_get(url: str, timeout: int = 15): | |
| """Performs a GET request safely, raising for status errors.""" | |
| r = requests.get(url, timeout=timeout, headers=DEFAULT_HEADERS) | |
| r.raise_for_status() | |
| return r | |
| def _temp_file(data: bytes, suffix: str) -> str: | |
| """Creates a temporary file with the given data and suffix, and registers it for cleanup.""" | |
| if not data: | |
| return "" | |
| fd, path = tempfile.mkstemp(suffix=suffix) | |
| os.close(fd) | |
| with open(path, "wb") as f: | |
| f.write(data) | |
| _temp_files_to_delete.add(path) # Add to set | |
| return path | |
| def fetch_bytes(src: str, stream_threshold: int = STREAM_THRESHOLD_BYTES, timeout: int = 60, progress=None) -> bytes: | |
| """Fetches content bytes from a local path or remote URL, with streaming for large files.""" | |
| if progress is not None: | |
| progress(0.05, desc="Checking remote/local source...") | |
| if is_remote(src): | |
| head = safe_head(src) | |
| if head is not None: | |
| cl = head.headers.get("content-length") | |
| try: | |
| if cl and int(cl) > stream_threshold: | |
| if progress is not None: | |
| progress(0.1, desc="Streaming large remote file...") | |
| fd, p = tempfile.mkstemp(suffix=ext_from_src(src) or ".tmp") | |
| os.close(fd) | |
| try: | |
| with open(p, "wb") as fh_write: | |
| with requests.get(src, timeout=timeout, stream=True, headers=DEFAULT_HEADERS) as r: | |
| r.raise_for_status() | |
| total_size = int(r.headers.get("content-length", 0)) | |
| downloaded_size = 0 | |
| for chunk in r.iter_content(8192): | |
| if chunk: | |
| fh_write.write(chunk) | |
| downloaded_size += len(chunk) | |
| if progress is not None and total_size > 0: | |
| progress(0.1 + (downloaded_size / total_size) * 0.15) | |
| with open(p, "rb") as fh_read: | |
| return fh_read.read() | |
| finally: | |
| try: _temp_files_to_delete.discard(p); os.remove(p) | |
| except Exception as e: print(f"Error during streaming temp file cleanup {p}: {e}") | |
| except Exception as e: | |
| print(f"Warning: Streaming download failed for {src}: {e}. Falling back to non-streaming.") | |
| r = safe_get(src, timeout=timeout) | |
| if progress is not None: | |
| progress(0.25, desc="Downloaded remote content") | |
| return r.content | |
| else: | |
| if not os.path.exists(src): | |
| raise FileNotFoundError(f"Local path does not exist: {src}") | |
| if progress is not None: | |
| progress(0.05, desc="Reading local file...") | |
| with open(src, "rb") as f: | |
| data = f.read() | |
| if progress is not None: | |
| progress(0.15, desc="Read local file") | |
| return data | |
| def convert_to_jpeg_bytes(img_bytes: bytes, base_h: int = 480) -> bytes: | |
| """Converts image bytes to JPEG, resizing to a target height while maintaining aspect ratio.""" | |
| try: | |
| img = Image.open(BytesIO(img_bytes)) | |
| except UnidentifiedImageError: | |
| print("Warning: convert_to_jpeg_bytes received unidentifiable image data.") | |
| return b"" | |
| except Exception as e: | |
| print(f"Warning: Error opening image for JPEG conversion: {e}") | |
| return b"" | |
| try: | |
| if getattr(img, "is_animated", False): | |
| img.seek(0) | |
| except Exception: | |
| pass | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| w = max(1, int(img.width * (base_h / img.height))) | |
| img = img.resize((w, base_h), Image.LANCZOS) | |
| buf = BytesIO() | |
| img.save(buf, format="JPEG", quality=90) # Increased quality from 85 to 90 | |
| return buf.getvalue() | |
| def b64_bytes(b: bytes, mime: str = "image/jpeg") -> str: | |
| """Encodes bytes to a Data URL string.""" | |
| return f"data:{mime};base64," + base64.b64encode(b).decode("utf-8") | |
| def _ffprobe_streams(path: str) -> Optional[dict]: | |
| """Uses ffprobe to get stream information for a media file.""" | |
| if not FFMPEG_BIN: | |
| return None | |
| ffprobe_path = None | |
| if FFMPEG_BIN: | |
| ffmpeg_dir = os.path.dirname(FFMPEG_BIN) | |
| potential_ffprobe_in_dir = os.path.join(ffmpeg_dir, "ffprobe") | |
| if os.path.exists(potential_ffprobe_in_dir) and os.access(potential_ffprobe_in_dir, os.X_OK): | |
| ffprobe_path = potential_ffprobe_in_dir | |
| if not ffprobe_path: | |
| ffprobe_path = shutil.which("ffprobe") | |
| if not ffprobe_path: | |
| return None | |
| cmd = [ | |
| ffprobe_path, "-v", "error", "-print_format", "json", "-show_streams", "-show_format", path | |
| ] | |
| try: | |
| out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL) | |
| return json.loads(out) | |
| except Exception as e: | |
| print(f"Error running ffprobe on {path}: {e}") | |
| return None | |
| def _get_video_info_and_timestamps(media_path: str, sample_count: int) -> Tuple[Optional[dict], List[float]]: | |
| """Extracts video info and generates timestamps for frame extraction.""" | |
| info = _ffprobe_streams(media_path) | |
| duration = 0.0 | |
| if info and "format" in info and "duration" in info["format"]: | |
| try: | |
| duration = float(info["format"]["duration"]) | |
| except ValueError: | |
| pass | |
| timestamps: List[float] = [] | |
| if duration > 0 and sample_count > 0: | |
| actual_sample_count = min(sample_count, max(1, int(duration))) | |
| if actual_sample_count > 0: | |
| step = duration / (actual_sample_count + 1) | |
| timestamps = [step * (i + 1) for i in range(actual_sample_count)] | |
| if not timestamps: | |
| # Fallback for very short videos or if duration couldn't be determined | |
| timestamps = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0][:sample_count] # Ensure enough fallback timestamps | |
| return info, timestamps | |
| def extract_frames_for_model_and_gallery(media_path: str, sample_count: int = 5, timeout_extract: int = 15, gallery_base_h: int = 1080, model_base_h: int = 1024, progress=None) -> Tuple[List[bytes], List[str]]: | |
| """ | |
| Extracts frames from a video for model input and a gallery display. | |
| Returns: (list of JPEG bytes for model, list of paths to JPEG files for gallery) | |
| """ | |
| frames_for_model: List[bytes] = [] | |
| frame_paths_for_gallery: List[str] = [] | |
| if not FFMPEG_BIN: | |
| print(f"Warning: FFMPEG not found. Cannot extract frames for {media_path}.") | |
| return frames_for_model, frame_paths_for_gallery | |
| if not os.path.exists(media_path): | |
| print(f"Warning: Media path does not exist: {media_path}. Cannot extract frames.") | |
| return frames_for_model, frame_paths_for_gallery | |
| if progress is not None: | |
| progress(0.05, desc="Preparing frame extraction...") | |
| _, timestamps = _get_video_info_and_timestamps(media_path, sample_count) | |
| if not timestamps: | |
| print(f"Warning: No valid timestamps generated for {media_path}. Cannot extract frames.") | |
| return frames_for_model, frame_paths_for_gallery | |
| for i, t in enumerate(timestamps): | |
| if progress is not None: | |
| progress(0.1 + (i / max(1, sample_count)) * 0.2, desc=f"Extracting frame {i+1}/{sample_count} at {t:.1f}s...") | |
| fd_raw, tmp_png_path = tempfile.mkstemp(suffix=f"_frame_{i}.png") | |
| os.close(fd_raw) | |
| cmd_extract = [ | |
| FFMPEG_BIN, "-nostdin", "-y", "-ss", str(t), "-i", media_path, | |
| "-frames:v", "1", "-pix_fmt", "rgb24", tmp_png_path, | |
| ] | |
| try: | |
| subprocess.run(cmd_extract, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=timeout_extract) | |
| if os.path.exists(tmp_png_path) and os.path.getsize(tmp_png_path) > 0: | |
| with open(tmp_png_path, "rb") as f: | |
| raw_frame_bytes = f.read() | |
| jpeg_model_bytes = convert_to_jpeg_bytes(raw_frame_bytes, base_h=model_base_h) | |
| if jpeg_model_bytes: | |
| frames_for_model.append(jpeg_model_bytes) | |
| else: | |
| print(f"Warning: Failed to convert extracted frame {i+1} to JPEG for model input.") | |
| jpeg_gallery_bytes = convert_to_jpeg_bytes(raw_frame_bytes, base_h=gallery_base_h) | |
| if jpeg_gallery_bytes: | |
| temp_jpeg_path = _temp_file(jpeg_gallery_bytes, suffix=f"_gallery_{i}.jpg") | |
| if temp_jpeg_path: | |
| frame_paths_for_gallery.append(temp_jpeg_path) | |
| else: | |
| print(f"Warning: Failed to convert extracted frame {i+1} to JPEG for gallery.") | |
| else: | |
| print(f"Warning: Extracted frame {i+1} was empty or non-existent at {tmp_png_path}.") | |
| except Exception as e: | |
| print(f"Error processing frame {i+1} for model/gallery: {e}") | |
| finally: | |
| if os.path.exists(tmp_png_path): | |
| try: os.remove(tmp_png_path) | |
| except Exception: pass | |
| if progress is not None: | |
| progress(0.45, desc=f"Extracted {len(frames_for_model)} frames for analysis and gallery") | |
| return frames_for_model, frame_paths_for_gallery | |
| def chat_complete(client: Mistral, model: str, messages, timeout: int = 120, progress=None) -> str: | |
| """Sends messages to the Mistral chat completion API with retry logic.""" | |
| max_retries = 5 | |
| initial_delay = 1.0 | |
| for attempt in range(max_retries): | |
| try: | |
| if progress is not None: | |
| progress(0.6 + 0.01 * attempt, desc=f"Sending request to model (attempt {attempt+1}/{max_retries})...") | |
| # Always use the real Mistral client's chat.complete method | |
| res = client.chat.complete(model=model, messages=messages, stream=False, timeout_ms=timeout * 1000) | |
| if progress is not None: | |
| progress(0.8, desc="Model responded, parsing...") | |
| # Access attributes directly from the client's response object | |
| choices = getattr(res, "choices", []) | |
| if not choices: | |
| return f"Empty response from model: {res}" | |
| first = choices[0] | |
| msg = getattr(first, "message", None) | |
| content = getattr(msg, "content", None) | |
| return content.strip() if isinstance(content, str) else str(content) | |
| except Exception as e: # Catch all exceptions, including mistralai.client.exceptions.MistralAPIException | |
| status_code = getattr(e, "status_code", None) | |
| message = getattr(e, "message", str(e)) # Default to str(e) if no .message attribute | |
| if status_code == 429 and attempt < max_retries - 1: | |
| delay = initial_delay * (2 ** attempt) | |
| print(f"Mistral API: Rate limit exceeded (429). Retrying in {delay:.2f}s...") | |
| time.sleep(delay) | |
| elif isinstance(e, RequestException) and attempt < max_retries - 1: # Catch general network issues | |
| delay = initial_delay * (2 ** attempt) | |
| print(f"Network/API request failed: {e}. Retrying in {delay:.2f}s...") | |
| time.sleep(delay) | |
| else: | |
| # If it's not a 429 or network error, or max retries reached, report it. | |
| error_type = "Mistral API" if status_code else type(e).__name__ | |
| return f"Error: {error_type} error occurred ({status_code if status_code else 'unknown'}): {message}" | |
| return "Error: Maximum retries reached for API call." | |
| def upload_file_to_mistral(client: Mistral, path: str, purpose: str = "batch", timeout: int = 120, progress=None) -> str: | |
| """Uploads a file to the Mistral API, returning its file ID.""" | |
| max_retries = 3 | |
| initial_delay = 1.0 | |
| for attempt in range(max_retries): | |
| try: | |
| if progress is not None: | |
| progress(0.5 + 0.01 * attempt, desc=f"Uploading file to model service (attempt {attempt+1}/{max_retries})...") | |
| # CHANGE: Pass the file path (str) directly, allowing the mistralai client | |
| # to handle opening the file and inferring filename/mimetype. | |
| res = client.files.upload(file=path, purpose=purpose) | |
| fid = getattr(res, "id", None) | |
| if not fid: | |
| raise RuntimeError(f"Mistral API upload response missing file ID: {res}") | |
| if progress is not None: | |
| progress(0.6, desc="Upload complete") | |
| return fid | |
| except Exception as e: # Catch all exceptions, including mistralai.client.exceptions.MistralAPIException | |
| status_code = getattr(e, "status_code", None) | |
| message = getattr(e, "message", str(e)) | |
| if status_code == 429 and attempt < max_retries - 1: | |
| delay = initial_delay * (2 ** attempt) | |
| print(f"Mistral API: Upload rate limit exceeded (429). Retrying in {delay:.2f}s...") | |
| time.sleep(delay) | |
| elif isinstance(e, RequestException) and attempt < max_retries - 1: | |
| delay = initial_delay * (2 ** attempt) | |
| print(f"Upload network/API request failed: {e}. Retrying in {delay:.2f}s...") | |
| time.sleep(delay) | |
| else: | |
| error_type = "Mistral API" if status_code else type(e).__name__ | |
| raise RuntimeError(f"{error_type} file upload failed with status {status_code}: {message}") from e | |
| raise RuntimeError("File upload failed: Maximum retries reached.") | |
| def determine_media_type(src: str, progress=None) -> Tuple[bool, bool]: | |
| """Provides an initial hint about media type based on extension or content-type header.""" | |
| is_image = False | |
| is_video = False | |
| ext = ext_from_src(src) | |
| if ext in IMAGE_EXTENSIONS: | |
| is_image = True | |
| elif ext in VIDEO_EXTENSIONS: | |
| is_video = True | |
| if is_remote(src): | |
| head = safe_head(src) | |
| if head: | |
| ctype = (head.headers.get("content-type") or "").lower() | |
| if ctype.startswith("image/"): | |
| is_image, is_video = True, False | |
| elif ctype.startswith("video/"): | |
| is_video, is_image = True, False | |
| if progress is not None: | |
| progress(0.02, desc="Determined media type (initial hint)") | |
| return is_image, is_video | |
| def analyze_image_structured(client: Mistral, img_bytes: bytes, prompt: str, progress=None) -> str: | |
| """Analyzes an image using the PixTRAL model.""" | |
| try: | |
| if progress is not None: | |
| progress(0.3, desc="Preparing image for analysis...") | |
| jpeg = convert_to_jpeg_bytes(img_bytes, base_h=1024) | |
| if not jpeg: | |
| return "Error: Could not convert image for analysis." | |
| data_url = b64_bytes(jpeg, mime="image/jpeg") | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_INSTRUCTION}, | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": data_url}, | |
| ]}, | |
| ] | |
| return chat_complete(client, PIXTRAL_MODEL, messages, progress=progress) | |
| except UnidentifiedImageError: | |
| return "Error: provided file is not a valid image." | |
| except Exception as e: | |
| return f"Error analyzing image: {e}" | |
| def analyze_video_cohesive(client: Mistral, video_path: str, prompt: str, progress=None) -> Tuple[str, List[str]]: | |
| """ | |
| Analyzes a video using the VoxTRAL model (if available) or by extracting frames | |
| and using PixTRAL as a fallback. | |
| Returns: (analysis result text, list of paths to gallery frames) | |
| """ | |
| gallery_frame_paths: List[str] = [] | |
| if not FFMPEG_BIN: | |
| return "Error: FFmpeg is not found in your system PATH. Video analysis and preview are unavailable.", [] | |
| try: | |
| if progress is not None: | |
| progress(0.3, desc="Uploading video for full analysis...") | |
| file_id = upload_file_to_mistral(client, video_path, purpose="batch", progress=progress) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_INSTRUCTION}, | |
| {"role": "user", "content": [ | |
| {"type": "video", "id": file_id}, # Correct format for video input | |
| {"type": "text", "text": f"Instruction: Analyze the entire video and produce a single cohesive narrative describing consistent observations.\n\n{prompt}"}, | |
| ]}, | |
| ] | |
| result = chat_complete(client, VIDEO_MODEL, messages, progress=progress) | |
| # Always extract frames for gallery, even if full analysis worked | |
| _, gallery_frame_paths = extract_frames_for_model_and_gallery( | |
| video_path, sample_count=6, gallery_base_h=1080, model_base_h=1024, progress=progress | |
| ) | |
| return result, gallery_frame_paths | |
| except Exception as e: | |
| print(f"Warning: Video upload/full analysis failed ({type(e).__name__}: {e}). Extracting frames as fallback...") | |
| if progress is not None: | |
| progress(0.35, desc=f"Video upload failed ({type(e).__name__}). Extracting frames as fallback...") | |
| frames_for_model_bytes, gallery_frame_paths = extract_frames_for_model_and_gallery( | |
| video_path, sample_count=6, gallery_base_h=1080, model_base_h=1024, progress=progress | |
| ) | |
| if not frames_for_model_bytes: | |
| return f"Error: could not upload video and no frames could be extracted for fallback. ({type(e).__name__}: {e})", [] | |
| image_entries = [] | |
| for i, fb in enumerate(frames_for_model_bytes, start=1): | |
| if progress is not None: | |
| progress(0.4 + (i / len(frames_for_model_bytes)) * 0.2, desc=f"Adding frame {i}/{len(frames_for_model_bytes)} to model input...") | |
| image_entries.append( | |
| { | |
| "type": "image_url", | |
| "image_url": b64_bytes(fb, mime="image/jpeg"), | |
| "meta": {"frame_index": i}, | |
| } | |
| ) | |
| content = [{"type": "text", "text": prompt + "\n\nPlease consolidate observations across these frames into a single cohesive narrative."}] + image_entries | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_INSTRUCTION}, | |
| {"role": "user", "content": content}, | |
| ] | |
| result = chat_complete(client, PIXTRAL_MODEL, messages, progress=progress) | |
| return result, gallery_frame_paths | |
| # --- FFmpeg Helpers for Preview --- | |
| def _convert_video_for_preview_if_needed(path: str) -> str: | |
| """ | |
| Converts a video to a web-friendly MP4 format if necessary for preview. | |
| Returns the path to the converted video or the original path if no conversion needed/failed. | |
| """ | |
| if not FFMPEG_BIN or not os.path.exists(path): | |
| return path | |
| # Check if it's already a web-friendly MP4 (H.264/H.265 with AAC audio) | |
| if path.lower().endswith((".mp4", ".m4v")): | |
| info = _ffprobe_streams(path) | |
| if info: | |
| video_streams = [s for s in info.get("streams", []) if s.get("codec_type") == "video"] | |
| audio_streams = [s for s in info.get("streams", []) if s.get("codec_type") == "audio"] | |
| is_h264_or_h265 = any(s.get("codec_name") in ("h264", "h265", "avc1") for s in video_streams) | |
| is_aac_audio = any(s.get("codec_name") == "aac" for s in audio_streams) | |
| if is_h264_or_h265 and (not audio_streams or is_aac_audio): # If no audio, still good. | |
| return path | |
| out_path = _temp_file(b"", suffix=".mp4") | |
| if not out_path: | |
| print(f"Error: Could not create temporary file for video conversion from {path}.") | |
| return path | |
| audio_codec_args = [] | |
| video_info = _ffprobe_streams(path) | |
| if video_info and any(s.get("codec_type") == "audio" for s in video_info.get("streams", [])): | |
| audio_codec_args = ["-c:a", "aac", "-b:a", "128k"] | |
| cmd = [ | |
| FFMPEG_BIN, "-y", "-i", path, | |
| "-c:v", "libx264", "-preset", "veryfast", "-crf", "28", | |
| *audio_codec_args, # Unpack the list | |
| "-movflags", "+faststart", out_path, | |
| "-map_metadata", "-1" | |
| ] | |
| try: | |
| subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=60) | |
| # Verify if conversion was successful and file exists/has content | |
| if os.path.exists(out_path) and os.path.getsize(out_path) > 0: | |
| return out_path | |
| else: | |
| print(f"Warning: FFMPEG conversion to {out_path} resulted in an empty file. Using original path.") | |
| _temp_files_to_delete.discard(out_path) | |
| try: os.remove(out_path) | |
| except Exception: pass | |
| return path | |
| except Exception as e: | |
| print(f"Error converting video for preview: {e}") | |
| _temp_files_to_delete.discard(out_path) | |
| try: os.remove(out_path) | |
| except Exception: pass | |
| return path | |
| # --- Preview Generation Logic --- | |
| def _get_playable_preview_path_from_raw(src_url: str, raw_bytes: bytes, is_image_hint: bool, is_video_hint: bool) -> str: | |
| """ | |
| Generates a playable preview file (JPEG for image, MP4 for video) from raw bytes. | |
| Returns the path to the generated preview file. | |
| """ | |
| if not raw_bytes: | |
| print(f"Error: No raw bytes provided for preview generation of {src_url}.") | |
| return "" | |
| is_actually_image = False | |
| try: | |
| img_check = Image.open(BytesIO(raw_bytes)) | |
| img_check.verify() # Verify if it's a valid image | |
| is_actually_image = True | |
| img_check.close() # Close to release file handle | |
| except (UnidentifiedImageError, Exception): | |
| pass | |
| if is_actually_image: | |
| jpeg_bytes = convert_to_jpeg_bytes(raw_bytes, base_h=1024) | |
| if jpeg_bytes: | |
| return _temp_file(jpeg_bytes, suffix=".jpg") | |
| return "" | |
| elif is_video_hint: # Fallback to hint if not clearly an image | |
| temp_raw_video_path = _temp_file(raw_bytes, suffix=ext_from_src(src_url) or ".mp4") | |
| if not temp_raw_video_path: | |
| print(f"Error: Failed to create temporary raw video file for {src_url}.") | |
| return "" | |
| playable_path = _convert_video_for_preview_if_needed(temp_raw_video_path) | |
| return playable_path | |
| elif is_image_hint: # Secondary image check based on hint, if PIL couldn't verify initially | |
| jpeg_bytes = convert_to_jpeg_bytes(raw_bytes, base_h=1024) | |
| if jpeg_bytes: | |
| return _temp_file(jpeg_bytes, suffix=".jpg") | |
| return "" | |
| print(f"Error: No playable preview path generated for {src_url} based on hints and byte inspection.") | |
| return "" | |
| # --- Gradio Interface Logic --- | |
| GRADIO_CSS = """ | |
| .preview_media img, .preview_media video { | |
| max-width: 100%; | |
| height: auto; | |
| border-radius: 6px; | |
| margin: 0 auto; /* Center image/video */ | |
| display: block; /* Ensure margin auto works */ | |
| } | |
| .status_footer { | |
| opacity: 0.7; | |
| font-size: 0.8em; | |
| text-align: right; | |
| margin-top: 20px; | |
| } | |
| """ | |
| def _get_button_label_for_status(status: str) -> str: | |
| """Returns the appropriate button label based on the processing status.""" | |
| return {"idle": "Submit", "busy": "Processing…", "done": "Done!", "error": "Retry"}.get(status, "Submit") | |
| def create_demo(): | |
| """Creates the Gradio interface for Flux Multimodal analysis.""" | |
| ffmpeg_status_message = "" | |
| if not FFMPEG_BIN: | |
| ffmpeg_status_message = "🔴 FFmpeg not found! Video analysis and preview will be limited/unavailable." | |
| else: | |
| ffmpeg_status_message = "🟢 FFmpeg found. Video features enabled." | |
| with gr.Blocks(title="Flux Multimodal", css=GRADIO_CSS) as demo: | |
| gr.Markdown("# Flux Multimodal AI Assistant") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| preview_image = gr.Image(label="Preview Image", type="filepath", elem_classes="preview_media", visible=False) | |
| preview_video = gr.Video(label="Preview Video", elem_classes="preview_media", visible=False, format="mp4") | |
| # CHANGE: Set columns to 6 to display all 6 extracted frames without scrolling | |
| screenshot_gallery = gr.Gallery(label="Extracted Screenshots", columns=6, rows=1, height="auto", object_fit="contain", visible=False) | |
| # Initially hidden, will become visible when a preview status is set | |
| preview_status_text = gr.Textbox(label="Preview Status", interactive=False, lines=1, value="", visible=False) | |
| with gr.Column(scale=2): | |
| url_input = gr.Textbox(label="Image / Video URL", placeholder="https://...", lines=1) | |
| with gr.Accordion("Prompt (optional)", open=False): | |
| custom_prompt = gr.Textbox(label="Prompt", lines=4, value="") | |
| with gr.Accordion("Mistral API Key (optional)", open=False): | |
| api_key_input = gr.Textbox(label="Mistral API Key", type="password", max_lines=1) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| # Progress and Output below the buttons | |
| progress_markdown = gr.Markdown("Idle") | |
| output_markdown = gr.Markdown("Enter a URL to analyze an image or video, then click Submit.") | |
| status_state = gr.State("idle") | |
| main_preview_path_state = gr.State("") | |
| screenshot_paths_state = gr.State([]) | |
| raw_media_path_state = gr.State("") | |
| # Moved status messages to the bottom | |
| gr.Markdown(f"🟢 Mistral AI client found.<br>{ffmpeg_status_message}", elem_classes="status_footer") | |
| def clear_all_ui_and_files_handler(): | |
| """ | |
| Cleans up all tracked temporary files and resets all relevant UI components and states. | |
| """ | |
| for f_path in list(_temp_files_to_delete): | |
| if os.path.exists(f_path): | |
| try: | |
| os.remove(f_path) | |
| _temp_files_to_delete.discard(f_path) | |
| except Exception as e: | |
| print(f"Error during proactive cleanup of {f_path}: {e}") | |
| _temp_files_to_delete.clear() | |
| return "", \ | |
| gr.update(value=None, visible=False), \ | |
| gr.update(value=None, visible=False), \ | |
| gr.update(value=[], visible=False), \ | |
| "idle", \ | |
| "Idle", \ | |
| "Enter a URL to analyze an image or video, then click Submit.", \ | |
| "", \ | |
| [], \ | |
| gr.update(value="", visible=False), \ | |
| "" | |
| clear_btn.click( | |
| fn=clear_all_ui_and_files_handler, | |
| inputs=[], | |
| outputs=[ | |
| url_input, | |
| preview_image, | |
| preview_video, | |
| screenshot_gallery, | |
| status_state, | |
| progress_markdown, | |
| output_markdown, | |
| main_preview_path_state, | |
| screenshot_paths_state, | |
| preview_status_text, # Ensure this is updated to hidden | |
| raw_media_path_state | |
| ], | |
| queue=False | |
| ) | |
| def load_main_preview_and_setup_for_analysis( | |
| url: str, | |
| current_main_preview_path: str, | |
| current_raw_media_path: str, | |
| current_screenshot_paths: List[str], | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Loads media from URL, generates a preview, and sets up temporary files for analysis. | |
| Also handles cleanup of previously loaded media. | |
| """ | |
| if current_main_preview_path and os.path.exists(current_main_preview_path): | |
| _temp_files_to_delete.discard(current_main_preview_path) | |
| try: os.remove(current_main_preview_path) | |
| except Exception as e: print(f"Error cleaning up old temp file {current_main_preview_path}: {e}") | |
| if current_raw_media_path and os.path.exists(current_raw_media_path): | |
| _temp_files_to_delete.discard(current_raw_media_path) | |
| try: os.remove(current_raw_media_path) | |
| except Exception as e: print(f"Error cleaning up old temp file {current_raw_media_path}: {e}") | |
| for path in current_screenshot_paths: | |
| if path and os.path.exists(path): | |
| _temp_files_to_delete.discard(path) | |
| try: os.remove(path) | |
| except Exception as e: print(f"Error cleaning up old temp file {path}: {e}") | |
| img_update_clear = gr.update(value=None, visible=False) | |
| video_update_clear = gr.update(value=None, visible=False) | |
| gallery_update_clear = gr.update(value=[], visible=False) | |
| preview_status_clear = gr.update(value="", visible=False) # Keep hidden on clear | |
| main_path_clear = "" | |
| screenshot_paths_clear = [] | |
| raw_media_path_clear = "" | |
| progress_markdown_update_clear = gr.update(value="Idle") | |
| if not url: | |
| return img_update_clear, video_update_clear, gallery_update_clear, \ | |
| preview_status_clear, main_path_clear, raw_media_path_clear, \ | |
| screenshot_paths_clear, progress_markdown_update_clear | |
| temp_raw_path_for_analysis = "" | |
| try: | |
| progress(0.01, desc="Downloading media for preview and analysis...") | |
| raw_bytes_for_analysis = fetch_bytes(url, timeout=60, progress=progress) | |
| if not raw_bytes_for_analysis: | |
| return img_update_clear, video_update_clear, gallery_update_clear, \ | |
| gr.update(value="Preview load failed: No media bytes fetched.", visible=True), \ | |
| main_path_clear, raw_media_path_clear, screenshot_paths_clear, \ | |
| gr.update(value="Preview load failed (Error)") | |
| temp_raw_path_for_analysis = _temp_file(raw_bytes_for_analysis, suffix=ext_from_src(url) or ".tmp") | |
| if not temp_raw_path_for_analysis: | |
| return img_update_clear, video_update_clear, gallery_update_clear, \ | |
| gr.update(value="Preview load failed: Could not save raw media to temp file.", visible=True), \ | |
| main_path_clear, raw_media_path_clear, screenshot_paths_clear, \ | |
| gr.update(value="Preview load failed (Error)") | |
| progress(0.25, desc="Generating playable preview...") | |
| is_img_initial, is_vid_initial = determine_media_type(url) | |
| local_playable_path = _get_playable_preview_path_from_raw(url, raw_bytes_for_analysis, is_img_initial, is_vid_initial) | |
| if not local_playable_path: | |
| _temp_files_to_delete.discard(temp_raw_path_for_analysis) | |
| try: os.remove(temp_raw_path_for_analysis) | |
| except Exception as e: print(f"Error during cleanup of raw temp file {temp_raw_path_for_analysis}: {e}") | |
| return img_update_clear, video_update_clear, gallery_update_clear, \ | |
| gr.update(value="Preview load failed: could not make content playable.", visible=True), \ | |
| main_path_clear, raw_media_path_clear, screenshot_paths_clear, \ | |
| gr.update(value="Preview load failed (Error)") | |
| ext = ext_from_src(local_playable_path) | |
| is_img_preview = ext in IMAGE_EXTENSIONS | |
| is_vid_preview = ext in VIDEO_EXTENSIONS | |
| if is_img_preview: | |
| return gr.update(value=local_playable_path, visible=True), gr.update(value=None, visible=False), \ | |
| gallery_update_clear, gr.update(value="Image preview loaded.", visible=True), \ | |
| local_playable_path, temp_raw_path_for_analysis, screenshot_paths_clear, \ | |
| gr.update(value="Preview ready") | |
| elif is_vid_preview: | |
| return gr.update(value=None, visible=False), gr.update(value=local_playable_path, visible=True), \ | |
| gallery_update_clear, gr.update(value="Video preview loaded.", visible=True), \ | |
| local_playable_path, temp_raw_path_for_analysis, screenshot_paths_clear, \ | |
| gr.update(value="Preview ready") | |
| else: | |
| _temp_files_to_delete.discard(local_playable_path) | |
| try: os.remove(local_playable_path) | |
| except Exception as e: print(f"Error during cleanup of unplayable temp file {local_playable_path}: {e}") | |
| _temp_files_to_delete.discard(temp_raw_path_for_analysis) | |
| try: os.remove(temp_raw_path_for_analysis) | |
| except Exception as e: print(f"Error during cleanup of raw temp file {temp_raw_path_for_analysis}: {e}") | |
| return img_update_clear, video_update_clear, gallery_update_clear, \ | |
| gr.update(value="Preview load failed: unknown playable format.", visible=True), \ | |
| main_path_clear, raw_media_path_clear, screenshot_paths_clear, \ | |
| gr.update(value="Preview load failed (Error)") | |
| except Exception as e: | |
| if os.path.exists(temp_raw_path_for_analysis): | |
| _temp_files_to_delete.discard(temp_raw_path_for_analysis) | |
| try: os.remove(temp_raw_path_for_analysis) | |
| except Exception as ex: print(f"Error during cleanup of raw temp file {temp_raw_path_for_analysis} on error: {ex}") | |
| return img_update_clear, video_update_clear, gallery_update_clear, \ | |
| gr.update(value=f"Preview load failed: {type(e).__name__}: {e}", visible=True), \ | |
| main_path_clear, raw_media_path_clear, screenshot_paths_clear, \ | |
| gr.update(value="Preview load failed (Error)") | |
| url_input.change( | |
| fn=load_main_preview_and_setup_for_analysis, | |
| inputs=[url_input, main_preview_path_state, raw_media_path_state, screenshot_paths_state], | |
| outputs=[preview_image, preview_video, screenshot_gallery, preview_status_text, main_preview_path_state, raw_media_path_state, screenshot_paths_state, progress_markdown] # Added progress_markdown to outputs | |
| ) | |
| def worker(url: str, prompt: str, key: str, raw_media_path: str, progress=gr.Progress()): | |
| """ | |
| The main worker function that performs media analysis using Mistral models. | |
| """ | |
| generated_screenshot_paths: List[str] = [] | |
| result_text = "" | |
| try: | |
| if not raw_media_path or not os.path.exists(raw_media_path): | |
| return "error", "**Error:** No raw media file available for analysis. Please load a URL first.", [], gr.update() | |
| if not FFMPEG_BIN: | |
| ext = ext_from_src(raw_media_path) | |
| if ext in VIDEO_EXTENSIONS: | |
| return "error", "**Error:** FFmpeg is not found in your system PATH. Video analysis is unavailable. Please install FFmpeg.", [], gr.update() | |
| with open(raw_media_path, "rb") as f: | |
| raw_bytes_for_analysis = f.read() | |
| if not raw_bytes_for_analysis: | |
| return "error", "**Error:** Raw media file is empty for analysis.", [], gr.update() | |
| progress(0.01, desc="Starting media analysis...") | |
| is_actually_image_for_analysis = False | |
| is_actually_video_for_analysis = False | |
| try: | |
| Image.open(BytesIO(raw_bytes_for_analysis)).verify() | |
| is_actually_image_for_analysis = True | |
| except UnidentifiedImageError: | |
| if ext_from_src(raw_media_path) in VIDEO_EXTENSIONS: | |
| is_actually_video_for_analysis = True | |
| except Exception as e: | |
| print(f"Warning: PIL error during image verification for raw analysis media ({raw_media_path}): {e}. Checking for video extension.") | |
| if ext_from_src(raw_media_path) in VIDEO_EXTENSIONS: | |
| is_actually_video_for_analysis = True | |
| client = get_client(key) | |
| if is_actually_video_for_analysis: | |
| progress(0.25, desc="Running full-video analysis") | |
| result_text, generated_screenshot_paths = analyze_video_cohesive(client, raw_media_path, prompt, progress=progress) | |
| elif is_actually_image_for_analysis: | |
| progress(0.20, desc="Running image analysis") | |
| result_text = analyze_image_structured(client, raw_bytes_for_analysis, prompt, progress=progress) | |
| else: | |
| return "error", "Error: Could not definitively determine media type for analysis after byte inspection and extension check. Please check the URL/file content.", [], gr.update() | |
| status = "done" if not (isinstance(result_text, str) and result_text.lower().startswith("error")) else "error" | |
| return status, result_text, generated_screenshot_paths, gr.update() # main_preview_path_state should remain unchanged | |
| except MistralClientError as e: # Catch custom API key error | |
| return "error", f"**Mistral API Key Error:** {e.message}", [], gr.update() | |
| except Exception as exc: # Catch any other unexpected errors | |
| return "error", f"**Unexpected worker error:** {type(exc).__name__}: {exc}", [], gr.update() | |
| submit_btn.click( | |
| fn=worker, | |
| inputs=[url_input, custom_prompt, api_key_input, raw_media_path_state], | |
| outputs=[status_state, output_markdown, screenshot_paths_state, main_preview_path_state], | |
| show_progress="full", | |
| show_progress_on=progress_markdown, | |
| ) | |
| status_state.change(fn=_get_button_label_for_status, inputs=[status_state], outputs=[submit_btn], queue=False) | |
| def _status_to_progress_text(s): | |
| """Converts internal status to user-friendly progress text.""" | |
| return {"idle": "Idle", "busy": "Processing…", "done": "Completed", "error": "Error — see output"}.get(s, s) | |
| status_state.change(fn=_status_to_progress_text, inputs=[status_state], outputs=[progress_markdown], queue=False) | |
| def _update_preview_components(current_main_preview_path: str, current_screenshot_paths: List[str]): | |
| """Updates the visibility and content of preview components (image, video, gallery).""" | |
| img_update = gr.update(value=None, visible=False) | |
| video_update = gr.update(value=None, visible=False) | |
| if current_main_preview_path: | |
| ext = ext_from_src(current_main_preview_path) | |
| if ext in IMAGE_EXTENSIONS: | |
| img_update = gr.update(value=current_main_preview_path, visible=True) | |
| elif ext in VIDEO_EXTENSIONS: | |
| video_update = gr.update(value=current_main_preview_path, visible=True) | |
| else: | |
| print(f"Warning: Unknown media type for main preview path: {current_main_preview_path}") | |
| gallery_update = gr.update(value=current_screenshot_paths, visible=bool(current_screenshot_paths)) | |
| return img_update, video_update, gallery_update | |
| main_preview_path_state.change( | |
| fn=_update_preview_components, | |
| inputs=[main_preview_path_state, screenshot_paths_state], | |
| outputs=[preview_image, preview_video, screenshot_gallery], | |
| queue=False | |
| ) | |
| screenshot_paths_state.change( | |
| fn=_update_preview_components, | |
| inputs=[main_preview_path_state, screenshot_paths_state], | |
| outputs=[preview_image, preview_video, screenshot_gallery], | |
| queue=False | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| create_demo().launch(share=False, server_name="0.0.0.0", server_port=7860, max_threads=8) |