Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import spaces # MUST BE THE ABSOLUTE FIRST IMPORT FOR ZEROGPU EMULATION | |
| import gradio as gr | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import cv2 | |
| import numpy as np | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import tempfile | |
| import re | |
| import time | |
| import base64 | |
| import gc | |
| import io | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoProcessor, AutoModel, AutoTokenizer | |
| _FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf") | |
| # Retrieve optional HF Token from typical env variables | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("MODEL_HF_TOKEN") | |
| def _load_font(size=20): | |
| """加载中文字体(LXGW WenKai),需提前放置到 assets/ 目录""" | |
| if os.path.exists(_FONT_PATH): | |
| try: | |
| return ImageFont.truetype(_FONT_PATH, size) | |
| except Exception: | |
| pass | |
| try: | |
| return ImageFont.truetype("DejaVuSans-Bold.ttf", size) | |
| except Exception: | |
| return ImageFont.load_default() | |
| # ============================================================ | |
| # 颜色 / 解析 / 绘制 | |
| # ============================================================ | |
| def get_color_for_label(label): | |
| colors = [ | |
| (8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235), | |
| (217, 119, 6), (147, 51, 234), | |
| ] | |
| idx = sum(ord(c) for c in label) | |
| return colors[idx % len(colors)] | |
| def parse_mixed_results(text, category_str=""): | |
| results = [] | |
| expected_cats = [c.strip().lower() for c in category_str.split("</c>") if c.strip()] | |
| ref_box_pattern = r"(<ref>.*?</ref>)|(<box>.*?</box>)" | |
| current_label = None | |
| found_structured = False | |
| for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL): | |
| token = m.group(0) | |
| if token.lower().startswith("<ref>"): | |
| label_raw = re.sub(r"</?ref>", "", token, flags=re.IGNORECASE).strip() | |
| if label_raw: | |
| current_label = label_raw | |
| else: | |
| content = re.sub(r"</?box>", "", token, flags=re.IGNORECASE) | |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) | |
| coords = [float(n) for n in nums] | |
| if not coords: | |
| continue | |
| label = current_label | |
| if label is None: | |
| label = expected_cats[0] if expected_cats else "object" | |
| if len(coords) == 4: | |
| results.append({"type": "box", "coords": coords, "label": label}) | |
| elif len(coords) == 2: | |
| results.append({"type": "point", "coords": coords, "label": label}) | |
| found_structured = True | |
| if found_structured: | |
| return results | |
| box_pattern = r"<box>(.*?)</box>" | |
| parts = re.split(box_pattern, text) | |
| for i in range(1, len(parts), 2): | |
| preceding_text = parts[i - 1].lower() | |
| content = parts[i] | |
| label = expected_cats[0] if expected_cats else "object" | |
| for cat in expected_cats: | |
| if cat in preceding_text: | |
| label = cat | |
| break | |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) | |
| coords = [float(n) for n in nums] | |
| if len(coords) == 4: | |
| results.append({"type": "box", "coords": coords, "label": label}) | |
| elif len(coords) == 2: | |
| results.append({"type": "point", "coords": coords, "label": label}) | |
| return results | |
| def resize_image_short_side(image, short_side_size): | |
| w, h = image.size | |
| if w <= h: | |
| new_w = short_side_size | |
| scale_factor = new_w / w | |
| new_h = int(h * scale_factor) | |
| else: | |
| new_h = short_side_size | |
| scale_factor = new_h / h | |
| new_w = int(w * scale_factor) | |
| resized_image = image.resize((new_w, new_h), Image.BILINEAR) | |
| return resized_image, scale_factor | |
| def draw_on_frame(frame_bgr, results, draw_label=True): | |
| pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) | |
| img_draw = pil_img.convert("RGBA") | |
| overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| font = _load_font(20) | |
| w_img, h_img = pil_img.size | |
| parsed = [] | |
| for res in results: | |
| label = res.get("label", "object") | |
| color = get_color_for_label(label) | |
| if res.get("type") == "point": | |
| c = res["coords"] | |
| cx = max(0, min(w_img, c[0] * w_img / 1000)) | |
| cy = max(0, min(h_img, c[1] * h_img / 1000)) | |
| parsed.append(("point", label, color, cx, cy)) | |
| continue | |
| if "is_pixel" in res: | |
| x1, y1, bw, bh = res["coords"] | |
| x2, y2 = x1 + bw, y1 + bh | |
| else: | |
| c = res["coords"] | |
| if len(c) < 4: | |
| continue | |
| x1 = c[0] * w_img / 1000 | |
| y1 = c[1] * h_img / 1000 | |
| x2 = c[2] * w_img / 1000 | |
| y2 = c[3] * h_img / 1000 | |
| x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2) | |
| x1, x2 = min(x1, x2), max(x1, x2) | |
| y1, y2 = min(y1, y2), max(y1, y2) | |
| parsed.append(("box", label, color, x1, y1, x2, y2)) | |
| for item in parsed: | |
| if item[0] == "box": | |
| _, _, color, x1, y1, x2, y2 = item | |
| fill_color = color + (65,) | |
| draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4) | |
| elif item[0] == "point": | |
| _, _, color, cx, cy = item | |
| r = 10 | |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2) | |
| if draw_label: | |
| for item in parsed: | |
| if item[0] == "box": | |
| _, label, color, x1, y1, x2, y2 = item | |
| if not label: | |
| continue | |
| t_box = draw.textbbox((0, 0), label, font=font) | |
| th = t_box[3] - t_box[1] | |
| tw = t_box[2] - t_box[0] | |
| pad_x, pad_y = 7, 4 | |
| tag_h = th + pad_y * 2 | |
| tag_w = tw + pad_x * 2 | |
| tag_y = y1 - tag_h - 2 | |
| if tag_y < 0: | |
| tag_y = y2 + 2 | |
| draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color) | |
| draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font) | |
| elif item[0] == "point": | |
| _, label, color, cx, cy = item | |
| if not label: | |
| continue | |
| t_box = draw.textbbox((0, 0), label, font=font) | |
| th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0] | |
| tx, ty = cx + 14, cy - th // 2 | |
| draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color) | |
| draw.text((tx + 2, ty), label, fill="white", font=font) | |
| combined = Image.alpha_composite(img_draw, overlay).convert("RGB") | |
| return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR) | |
| # ============================================================ | |
| # 模型 | |
| # ============================================================ | |
| class EagleWorker: | |
| def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"): | |
| self.model_id = model_path | |
| self.device = device | |
| self.dtype = torch.bfloat16 | |
| self.generation_mode = generation_mode | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| model_path, | |
| torch_dtype=self.dtype, | |
| _attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ).to(device).eval() | |
| print("Model Loaded Successfully!") | |
| def build_messages(self, image, categories, question_override=None): | |
| if question_override is not None: | |
| user_text = question_override | |
| else: | |
| category_set_str = "</c>".join(categories) | |
| user_text = f"Locate all the instances that matches the following description: {category_set_str}." | |
| return [{"role": "user", "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": user_text}, | |
| ]}] | |
| def generate(self, image, categories, generation_mode=None, | |
| max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50, | |
| question_override=None): | |
| messages = self.build_messages(image, categories, question_override=question_override) | |
| text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| images, videos = self.processor.process_vision_info(messages) | |
| inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device) | |
| pixel_values = inputs["pixel_values"].to(self.dtype) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| image_grid_hws = inputs.get("image_grid_hws", None) | |
| result = self.model.generate( | |
| pixel_values=pixel_values, input_ids=input_ids, | |
| attention_mask=attention_mask, image_grid_hws=image_grid_hws, | |
| tokenizer=self.tokenizer, max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| generation_mode=generation_mode if generation_mode is not None else self.generation_mode, | |
| temperature=temp, do_sample=True, top_p=top_p, | |
| repetition_penalty=1.1, verbose=True, | |
| ) | |
| token_sequence, out_info, output_text = [], "", "" | |
| if isinstance(result, tuple) and len(result) >= 3: | |
| output_text, token_sequence, out_info = result | |
| if generation_mode == "slow": | |
| token_sequence[-1] = ("ar", token_sequence[-1][1]) | |
| else: | |
| output_text = result | |
| return output_text, token_sequence, out_info | |
| # ============================================================ | |
| # 后处理 | |
| # ============================================================ | |
| def _postprocess_detections(detections, w, h): | |
| valid = [] | |
| for det in detections: | |
| if det["type"] == "box": | |
| c = det["coords"] | |
| rx1 = max(0, min(w - 1, int(c[0] * w / 1000))) | |
| ry1 = max(0, min(h - 1, int(c[1] * h / 1000))) | |
| rx2 = max(0, min(w - 1, int(c[2] * w / 1000))) | |
| ry2 = max(0, min(h - 1, int(c[3] * h / 1000))) | |
| box_w, box_h = rx2 - rx1, ry2 - ry1 | |
| if box_w <= 0 or box_h <= 0: | |
| continue | |
| valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h], | |
| "is_pixel": True, "label": det["label"]}) | |
| elif det["type"] == "point": | |
| valid.append(det) | |
| return valid | |
| def _parse_out_info_dict(out_info: str) -> dict: | |
| stats = {} | |
| if not out_info: | |
| return stats | |
| cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip()) | |
| for part in cleaned.split(";"): | |
| part = part.strip() | |
| if "=" in part: | |
| k, v = part.split("=", 1) | |
| stats[k.strip()] = v.strip() | |
| return stats | |
| def generate_raw_prompt(task_type, category): | |
| if not category: | |
| category = "objects" | |
| cats = "</c>".join(c.strip() for c in category.split(",") if c.strip()) | |
| if task_type == "Detection": | |
| return f"Locate all the instances that matches the following description: {cats}." | |
| elif task_type == "Grounding": | |
| return f"Locate all the instances that match the following description: {cats}." | |
| elif task_type == "OCR": | |
| return "Detect all the text in box format." | |
| elif task_type == "GUI": | |
| return f"Locate the region that matches the following description: {cats}." | |
| elif task_type == "Pointing": | |
| return f"Point to: {cats}." | |
| else: | |
| return f"Locate all the instances that matches the following description: {cats}." | |
| # ============================================================ | |
| # 模型初始化 | |
| # ============================================================ | |
| GLOBAL_WORKER = None | |
| def get_worker(): | |
| global GLOBAL_WORKER | |
| if GLOBAL_WORKER is None: | |
| try: | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B") | |
| print(f"Loading model inside @spaces.GPU context: {MODEL_PATH}") | |
| GLOBAL_WORKER = EagleWorker(MODEL_PATH) | |
| except Exception as e: | |
| print(f"Failed to load model: {e}. Will run in Mock Mode.") | |
| GLOBAL_WORKER = None | |
| return GLOBAL_WORKER | |
| def _prepare_image_for_model(pil_img, short_size): | |
| process_img = pil_img.copy() | |
| if short_size is not None and short_size > 0: | |
| process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024)) | |
| else: | |
| if min(process_img.size) > 1024: | |
| process_img, _ = resize_image_short_side(process_img, 1024) | |
| return process_img | |
| # ============================================================ | |
| # GPU 时间预算与推理保护(按模式区分) | |
| # ============================================================ | |
| GPU_HARD_LIMIT_IMAGE = 30 | |
| GPU_HARD_LIMIT_VIDEO = 240 | |
| PHASE2_RESERVE = 55 | |
| SAFETY_MARGIN = 25 | |
| EST_SECONDS_PER_FRAME = 20 | |
| def run_image_gpu_api( | |
| image_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int, | |
| short_size: int | None, question_override: str | None | |
| ): | |
| image_in = Image.open(image_path).convert("RGB") | |
| categories_list = [c.strip() for c in category.split(",") if c.strip()] | |
| category_str = "</c>".join(categories_list) | |
| process_img = _prepare_image_for_model(image_in, short_size) | |
| worker = get_worker() | |
| if worker: | |
| output_text, token_sequence, out_info = worker.generate( | |
| process_img, categories_list, model_mode, | |
| temp=temp, top_p=top_p, top_k=top_k, | |
| question_override=question_override, | |
| ) | |
| else: | |
| # Mock mode fallback | |
| output_text = "Mock detection: <ref>sweet</ref><box><240><480><620><940></box> and <ref>book</ref><box><50><120><400><380></box>" | |
| token_sequence = [] | |
| out_info = "forward_step=1;num_tokens=18;num_boxes=2;tps=45;bps=15" | |
| detections = parse_mixed_results(output_text, category_str) | |
| frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR) | |
| out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True) | |
| output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB)) | |
| # Save to temp file | |
| temp_dir = tempfile.mkdtemp() | |
| out_img_path = os.path.join(temp_dir, "output.png") | |
| output_image.save(out_img_path) | |
| stats = _parse_out_info_dict(out_info) | |
| # Simplified summary lists | |
| detections_summary = [] | |
| for det in detections: | |
| detections_summary.append({ | |
| "label": det.get("label", "object"), | |
| "type": det.get("type", "box"), | |
| "coords": [round(c, 2) for c in det.get("coords", [])] | |
| }) | |
| return out_img_path, stats, output_text, detections_summary | |
| def run_video_gpu_api( | |
| video_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int, | |
| short_size: int | None, question_override: str | None, max_video_frames: int | |
| ): | |
| import subprocess as _sp | |
| total_start = time.time() | |
| max_frames = int(max_video_frames) if max_video_frames else 4 | |
| categories_list = [c.strip() for c in category.split(",") if c.strip()] | |
| category_str = "</c>".join(categories_list) | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| all_frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| all_frames.append(frame) | |
| cap.release() | |
| total = len(all_frames) | |
| if total == 0: | |
| raise ValueError("Failed to read any frames from the video.") | |
| # Sample frames | |
| if total <= max_frames: | |
| sample_indices = list(range(total)) | |
| else: | |
| sample_indices = [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)] | |
| sampled_frames = [all_frames[i] for i in sample_indices] | |
| n_sampled = len(sampled_frames) | |
| # Budget check | |
| time_already_used = time.time() - total_start | |
| available_for_inference = GPU_HARD_LIMIT_VIDEO - time_already_used - PHASE2_RESERVE - SAFETY_MARGIN | |
| estimated_inference_time = n_sampled * EST_SECONDS_PER_FRAME | |
| if estimated_inference_time > available_for_inference: | |
| max_feasible = max(1, int(available_for_inference // EST_SECONDS_PER_FRAME)) | |
| if total <= max_feasible: | |
| sample_indices = list(range(total)) | |
| else: | |
| sample_indices = [int(round(i * (total - 1) / (max_feasible - 1))) for i in range(max_feasible)] | |
| sampled_frames = [all_frames[i] for i in sample_indices] | |
| n_sampled = len(sampled_frames) | |
| out_fps = max(1.0, n_sampled / (total / fps)) if fps > 0 else 5.0 | |
| del all_frames | |
| gc.collect() | |
| inference_results = [] | |
| processed_count = 0 | |
| early_stopped = False | |
| early_stop_reason = "" | |
| for i, frame in enumerate(sampled_frames): | |
| elapsed_since_start = time.time() - total_start | |
| remaining_total = GPU_HARD_LIMIT_VIDEO - elapsed_since_start | |
| if remaining_total < PHASE2_RESERVE + SAFETY_MARGIN: | |
| early_stopped = True | |
| early_stop_reason = f"GPU time budget running out. Only {remaining_total:.0f}s left." | |
| break | |
| pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| process_img = _prepare_image_for_model(pil_img, short_size) | |
| worker = get_worker() | |
| if worker: | |
| output_text, _, _ = worker.generate( | |
| process_img, categories_list, model_mode, | |
| temp=temp, top_p=top_p, top_k=top_k, | |
| question_override=question_override, | |
| ) | |
| else: | |
| output_text = f"Mock video detection: <ref>person</ref><box><100><150><800><900></box>" | |
| inference_results.append(output_text) | |
| processed_count += 1 | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if processed_count == 0: | |
| raise RuntimeError("GPU budget exceeded before processing any frames.") | |
| sampled_frames_for_draw = sampled_frames[:processed_count] | |
| inference_results_for_draw = inference_results[:processed_count] | |
| tmp_raw = tempfile.mktemp(suffix=".raw.mp4") | |
| out_video_path = tempfile.mktemp(suffix=".mp4") | |
| out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h)) | |
| detections_summary = [] | |
| for i, (frame, output_text) in enumerate(zip(sampled_frames_for_draw, inference_results_for_draw)): | |
| detections = parse_mixed_results(output_text, category_str) | |
| valid_results = _postprocess_detections(detections, vid_w, vid_h) | |
| frame_to_draw = draw_on_frame(frame, valid_results, draw_label=True) | |
| out.write(frame_to_draw) | |
| for det in valid_results: | |
| detections_summary.append({ | |
| "frame": i + 1, | |
| "label": det.get("label", "object"), | |
| "type": det.get("type", "box"), | |
| "coords": det.get("coords", []) | |
| }) | |
| out.release() | |
| # ffmpeg re-encode | |
| elapsed_now = time.time() - total_start | |
| remaining_now = GPU_HARD_LIMIT_VIDEO - elapsed_now | |
| if remaining_now > 15: | |
| try: | |
| ffmpeg_timeout = max(10, int(remaining_now - 5)) | |
| _sp.run( | |
| ["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264", | |
| "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p", | |
| "-movflags", "+faststart", out_video_path], | |
| check=True, capture_output=True, timeout=ffmpeg_timeout, | |
| ) | |
| os.remove(tmp_raw) | |
| except Exception: | |
| if os.path.exists(tmp_raw): | |
| os.replace(tmp_raw, out_video_path) | |
| else: | |
| os.replace(tmp_raw, out_video_path) | |
| total_time = time.time() - total_start | |
| stats = { | |
| "total_frames": total, | |
| "sampled_frames": n_sampled, | |
| "processed_frames": processed_count, | |
| "total_time_seconds": round(total_time, 2), | |
| "early_stopped": early_stopped, | |
| "early_stop_reason": early_stop_reason | |
| } | |
| return out_video_path, stats, "\n---\n".join(inference_results_for_draw), detections_summary | |
| # ============================================================ | |
| # GRADIO SERVER APP | |
| # ============================================================ | |
| app = Server() | |
| # Serve static assets folder | |
| assets_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
| if os.path.exists(assets_dir): | |
| app.mount("/assets", StaticFiles(directory=assets_dir), name="assets") | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| if os.path.exists(html_path): | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return HTMLResponse(f.read()) | |
| return HTMLResponse("<h1 style='color: #ef4444; font-family: Inter, sans-serif; text-align: center; margin-top: 100px;'>index.html is missing</h1>") | |
| def run_inference_api( | |
| input_type: str, | |
| image_file: Any = None, | |
| video_file: Any = None, | |
| task_type: str = "Detection", | |
| category: str = "objects", | |
| model_mode: str = "hybrid", | |
| temp: float = 0.7, | |
| top_p: float = 0.9, | |
| top_k: int = 20, | |
| short_size: int | None = None, | |
| question_override: str | None = None, | |
| max_video_frames: int = 4 | |
| ) -> tuple[FileData | None, FileData | None, dict]: | |
| """Exposed Gradio Queueing Endpoint for custom frontend interactions. | |
| ZeroGPU allocation is triggered directly at this endpoint boundary. | |
| Supports both FileData dict (from web uploads) and local strings (for examples). | |
| """ | |
| try: | |
| if not category: | |
| category = "objects" | |
| final_prompt = question_override | |
| if not final_prompt: | |
| final_prompt = generate_raw_prompt(task_type, category) | |
| if input_type == "Image": | |
| if not image_file: | |
| return None, None, {"success": False, "error": "Please upload an image."} | |
| # Resolve image path (from either FileData upload or local example string) | |
| if isinstance(image_file, str): | |
| img_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), image_file) | |
| elif isinstance(image_file, dict): | |
| img_path = image_file.get("path") | |
| else: | |
| img_path = getattr(image_file, "path", None) | |
| if not img_path or not os.path.exists(img_path): | |
| return None, None, {"success": False, "error": f"Invalid image file path: {img_path}"} | |
| out_img_path, stats, raw_text, detections = run_image_gpu_api( | |
| img_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt | |
| ) | |
| meta = { | |
| "success": True, | |
| "input_type": "Image", | |
| "stats": stats, | |
| "raw_text": raw_text, | |
| "detections": detections, | |
| "final_prompt": final_prompt | |
| } | |
| return FileData(path=out_img_path), None, meta | |
| else: | |
| if not video_file: | |
| return None, None, {"success": False, "error": "Please upload a video."} | |
| # Resolve video path (from either FileData upload or local example string) | |
| if isinstance(video_file, str): | |
| vid_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), video_file) | |
| elif isinstance(video_file, dict): | |
| vid_path = video_file.get("path") | |
| else: | |
| vid_path = getattr(video_file, "path", None) | |
| if not vid_path or not os.path.exists(vid_path): | |
| return None, None, {"success": False, "error": f"Invalid video file path: {vid_path}"} | |
| out_vid_path, stats, raw_text, detections = run_video_gpu_api( | |
| vid_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt, max_video_frames | |
| ) | |
| meta = { | |
| "success": True, | |
| "input_type": "Video", | |
| "stats": stats, | |
| "raw_text": raw_text, | |
| "detections": detections, | |
| "final_prompt": final_prompt | |
| } | |
| return None, FileData(path=out_vid_path), meta | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, {"success": False, "error": str(e)} | |
| if __name__ == "__main__": | |
| app.launch(show_error=True) |