| |
| |
| import spaces |
|
|
| import gradio as gr |
| from gradio import Server |
| from gradio.data_classes import FileData |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
|
|
| import cv2 |
| import numpy as np |
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| import tempfile |
| import re |
| import time |
| import base64 |
| import gc |
| import io |
| import json |
| import uuid |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from PIL import Image, ImageDraw, ImageFont |
| from transformers import AutoProcessor, AutoModel, AutoTokenizer |
| from huggingface_hub import CommitScheduler |
|
|
| _FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf") |
|
|
|
|
| def _get_first_env(*names): |
| for name in names: |
| value = os.environ.get(name) |
| if value and value.strip(): |
| return value.strip() |
| return None |
|
|
|
|
| def _configure_hf_auth(): |
| model_token = _get_first_env( |
| "MODEL_HF_TOKEN", |
| "LOG_HF_TOKEN", |
| "HF_TOKEN", |
| "HUGGINGFACE_HUB_TOKEN", |
| "HUGGINGFACEHUB_API_TOKEN", |
| ) |
| log_token = _get_first_env( |
| "LOG_HF_TOKEN", |
| "MODEL_HF_TOKEN", |
| "HF_TOKEN", |
| "HUGGINGFACE_HUB_TOKEN", |
| "HUGGINGFACEHUB_API_TOKEN", |
| ) |
| shared_token = model_token or log_token |
| if shared_token: |
| for name in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN"): |
| os.environ[name] = shared_token |
| return model_token, log_token |
|
|
|
|
| MODEL_HF_TOKEN, LOG_HF_TOKEN = _configure_hf_auth() |
| HF_TOKEN = MODEL_HF_TOKEN |
|
|
| def _load_font(size=20): |
| """加载中文字体(LXGW WenKai),需提前放置到 assets/ 目录""" |
| if os.path.exists(_FONT_PATH): |
| try: |
| return ImageFont.truetype(_FONT_PATH, size) |
| except Exception: |
| pass |
| try: |
| return ImageFont.truetype("DejaVuSans-Bold.ttf", size) |
| except Exception: |
| return ImageFont.load_default() |
|
|
|
|
| |
| |
| |
| def get_color_for_label(label): |
| colors = [ |
| (8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235), |
| (217, 119, 6), (147, 51, 234), |
| ] |
| idx = sum(ord(c) for c in label) |
| return colors[idx % len(colors)] |
|
|
|
|
| def parse_mixed_results(text, category_str=""): |
| results = [] |
| expected_cats = [c.strip().lower() for c in category_str.split("</c>") if c.strip()] |
|
|
| ref_box_pattern = r"(<ref>.*?</ref>)|(<box>.*?</box>)" |
| current_label = None |
| found_structured = False |
|
|
| for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL): |
| token = m.group(0) |
| if token.lower().startswith("<ref>"): |
| label_raw = re.sub(r"</?ref>", "", token, flags=re.IGNORECASE).strip() |
| if label_raw: |
| current_label = label_raw |
| else: |
| content = re.sub(r"</?box>", "", token, flags=re.IGNORECASE) |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) |
| coords = [float(n) for n in nums] |
| if not coords: |
| continue |
| label = current_label |
| if label is None: |
| label = expected_cats[0] if expected_cats else "object" |
| if len(coords) == 4: |
| results.append({"type": "box", "coords": coords, "label": label}) |
| elif len(coords) == 2: |
| results.append({"type": "point", "coords": coords, "label": label}) |
| found_structured = True |
|
|
| if found_structured: |
| return results |
|
|
| box_pattern = r"<box>(.*?)</box>" |
| parts = re.split(box_pattern, text) |
| for i in range(1, len(parts), 2): |
| preceding_text = parts[i - 1].lower() |
| content = parts[i] |
| label = expected_cats[0] if expected_cats else "object" |
| for cat in expected_cats: |
| if cat in preceding_text: |
| label = cat |
| break |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) |
| coords = [float(n) for n in nums] |
| if len(coords) == 4: |
| results.append({"type": "box", "coords": coords, "label": label}) |
| elif len(coords) == 2: |
| results.append({"type": "point", "coords": coords, "label": label}) |
|
|
| return results |
|
|
|
|
| def resize_image_short_side(image, short_side_size): |
| w, h = image.size |
| if w <= h: |
| new_w = short_side_size |
| scale_factor = new_w / w |
| new_h = int(h * scale_factor) |
| else: |
| new_h = short_side_size |
| scale_factor = new_h / h |
| new_w = int(w * scale_factor) |
| resized_image = image.resize((new_w, new_h), Image.BILINEAR) |
| return resized_image, scale_factor |
|
|
|
|
| def draw_on_frame(frame_bgr, results, draw_label=True): |
| pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) |
| img_draw = pil_img.convert("RGBA") |
| overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0)) |
| draw = ImageDraw.Draw(overlay) |
| font = _load_font(20) |
| w_img, h_img = pil_img.size |
|
|
| parsed = [] |
| for res in results: |
| label = res.get("label", "object") |
| color = get_color_for_label(label) |
| if res.get("type") == "point": |
| c = res["coords"] |
| cx = max(0, min(w_img, c[0] * w_img / 1000)) |
| cy = max(0, min(h_img, c[1] * h_img / 1000)) |
| parsed.append(("point", label, color, cx, cy)) |
| continue |
| if "is_pixel" in res: |
| x1, y1, bw, bh = res["coords"] |
| x2, y2 = x1 + bw, y1 + bh |
| else: |
| c = res["coords"] |
| if len(c) < 4: |
| continue |
| x1 = c[0] * w_img / 1000 |
| y1 = c[1] * h_img / 1000 |
| x2 = c[2] * w_img / 1000 |
| y2 = c[3] * h_img / 1000 |
| x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2) |
| x1, x2 = min(x1, x2), max(x1, x2) |
| y1, y2 = min(y1, y2), max(y1, y2) |
| parsed.append(("box", label, color, x1, y1, x2, y2)) |
|
|
| for item in parsed: |
| if item[0] == "box": |
| _, _, color, x1, y1, x2, y2 = item |
| fill_color = color + (65,) |
| draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4) |
| elif item[0] == "point": |
| _, _, color, cx, cy = item |
| r = 10 |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2) |
|
|
| if draw_label: |
| for item in parsed: |
| if item[0] == "box": |
| _, label, color, x1, y1, x2, y2 = item |
| if not label: |
| continue |
| t_box = draw.textbbox((0, 0), label, font=font) |
| th = t_box[3] - t_box[1] |
| tw = t_box[2] - t_box[0] |
| pad_x, pad_y = 7, 4 |
| tag_h = th + pad_y * 2 |
| tag_w = tw + pad_x * 2 |
| tag_y = y1 - tag_h - 2 |
| if tag_y < 0: |
| tag_y = y2 + 2 |
| draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color) |
| draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font) |
| elif item[0] == "point": |
| _, label, color, cx, cy = item |
| if not label: |
| continue |
| t_box = draw.textbbox((0, 0), label, font=font) |
| th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0] |
| tx, ty = cx + 14, cy - th // 2 |
| draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color) |
| draw.text((tx + 2, ty), label, fill="white", font=font) |
|
|
| combined = Image.alpha_composite(img_draw, overlay).convert("RGB") |
| return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR) |
|
|
|
|
| |
| |
| |
| class EagleWorker: |
| def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"): |
| self.model_id = model_path |
| self.device = device |
| self.dtype = torch.bfloat16 |
| self.generation_mode = generation_mode |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| token=HF_TOKEN if HF_TOKEN else None, |
| ) |
| self.processor = AutoProcessor.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| token=HF_TOKEN if HF_TOKEN else None, |
| ) |
| self.model = AutoModel.from_pretrained( |
| model_path, |
| torch_dtype=self.dtype, |
| _attn_implementation="sdpa", |
| trust_remote_code=True, |
| token=HF_TOKEN if HF_TOKEN else None, |
| ).to(device).eval() |
| print("Model Loaded Successfully!") |
|
|
| def build_messages(self, image, categories, question_override=None): |
| if question_override is not None: |
| user_text = question_override |
| else: |
| category_set_str = "</c>".join(categories) |
| user_text = f"Locate all the instances that matches the following description: {category_set_str}." |
| return [{"role": "user", "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": user_text}, |
| ]}] |
|
|
| @torch.no_grad() |
| def generate(self, image, categories, generation_mode=None, |
| max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50, |
| question_override=None): |
| messages = self.build_messages(image, categories, question_override=question_override) |
| text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| images, videos = self.processor.process_vision_info(messages) |
| inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device) |
|
|
| pixel_values = inputs["pixel_values"].to(self.dtype) |
| input_ids = inputs["input_ids"] |
| attention_mask = inputs["attention_mask"] |
| image_grid_hws = inputs.get("image_grid_hws", None) |
|
|
| result = self.model.generate( |
| pixel_values=pixel_values, input_ids=input_ids, |
| attention_mask=attention_mask, image_grid_hws=image_grid_hws, |
| tokenizer=self.tokenizer, max_new_tokens=max_new_tokens, |
| use_cache=True, |
| generation_mode=generation_mode if generation_mode is not None else self.generation_mode, |
| temperature=temp, do_sample=True, top_p=top_p, |
| repetition_penalty=1.1, verbose=True, |
| ) |
|
|
| token_sequence, out_info, output_text = [], "", "" |
| if isinstance(result, tuple) and len(result) >= 3: |
| output_text, token_sequence, out_info = result |
| if generation_mode == "slow": |
| token_sequence[-1] = ("ar", token_sequence[-1][1]) |
| else: |
| output_text = result |
| return output_text, token_sequence, out_info |
|
|
|
|
| |
| |
| |
| def _postprocess_detections(detections, w, h): |
| valid = [] |
| for det in detections: |
| if det["type"] == "box": |
| c = det["coords"] |
| rx1 = max(0, min(w - 1, int(c[0] * w / 1000))) |
| ry1 = max(0, min(h - 1, int(c[1] * h / 1000))) |
| rx2 = max(0, min(w - 1, int(c[2] * w / 1000))) |
| ry2 = max(0, min(h - 1, int(c[3] * h / 1000))) |
| box_w, box_h = rx2 - rx1, ry2 - ry1 |
| if box_w <= 0 or box_h <= 0: |
| continue |
| valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h], |
| "is_pixel": True, "label": det["label"]}) |
| elif det["type"] == "point": |
| valid.append(det) |
| return valid |
|
|
|
|
| def _parse_out_info_dict(out_info: str) -> dict: |
| stats = {} |
| if not out_info: |
| return stats |
| cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip()) |
| for part in cleaned.split(";"): |
| part = part.strip() |
| if "=" in part: |
| k, v = part.split("=", 1) |
| stats[k.strip()] = v.strip() |
| return stats |
|
|
|
|
| def generate_dynamic_html(token_sequence, out_info, raw_text): |
| uid = f"a{int(time.time() * 1000)}" |
| css = f""" |
| <style> |
| .dc-root-{uid} {{ |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; |
| border: 1px solid rgba(118, 185, 0, 0.25); border-radius: 12px; |
| background: rgba(0, 0, 0, 0.55); overflow: visible; |
| }} |
| .dc-header-{uid} {{ |
| display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 8px; |
| padding: 10px 14px; |
| background: linear-gradient(135deg, rgba(118, 185, 0, 0.25) 0%, rgba(63, 98, 0, 0.35) 100%); |
| border-bottom: 1px solid rgba(118, 185, 0, 0.2); |
| }} |
| .dc-header-title-{uid} {{ font-weight: 700; font-size: 0.82em; color: #d9f99d; letter-spacing: 0.04em; text-transform: uppercase; }} |
| .dc-legend-{uid} {{ display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }} |
| .dc-legend-item-{uid} {{ display: flex; align-items: center; gap: 5px; font-size: 0.72em; color: rgba(226, 232, 240, 0.85); }} |
| .dc-legend-dot-{uid} {{ width: 8px; height: 8px; border-radius: 2px; display: inline-block; }} |
| .dc-row-{uid} {{ display: flex; gap: 10px; padding: 12px 14px; border-bottom: 1px solid rgba(255,255,255,0.05); }} |
| .dc-row-{uid}:last-child {{ border-bottom: none; }} |
| .dc-val-{uid} {{ |
| flex: 1; line-height: 1.9; word-wrap: break-word; color: #cbd5e1; font-size: 0.85em; |
| display: flex; flex-wrap: wrap; gap: 4px; align-items: center; align-content: flex-start; |
| }} |
| @keyframes tk-{uid} {{ |
| 0% {{ opacity: 0; transform: translateY(8px) scale(0.92); }} |
| 60% {{ opacity: 1; transform: translateY(-2px) scale(1.02); }} |
| 100% {{ opacity: 1; transform: translateY(0) scale(1); }} |
| }} |
| .tk-mtp-{uid}, .tk-ar-{uid} {{ |
| opacity: 0; animation: tk-{uid} 0.35s ease-out forwards; |
| border-radius: 5px; padding: 2px 7px; margin: 0; |
| display: inline-block; max-width: 100%; |
| font-size: 0.78em; font-weight: 600; |
| font-family: 'Fira Code', Consolas, monospace; |
| white-space: normal; word-break: break-all; |
| }} |
| .tk-mtp-{uid} {{ background: rgba(118, 185, 0, 0.15); border: 1px solid rgba(118, 185, 0, 0.55); color: #bbf7d0; }} |
| .tk-ar-{uid} {{ background: rgba(230, 81, 0, 0.15); border: 1px solid rgba(230, 81, 0, 0.55); color: #fed7aa; }} |
| .tk-stat-{uid} {{ |
| opacity: 0; animation: tk-{uid} 0.4s ease-out forwards; |
| background: rgba(118, 185, 0, 0.12); border: 1px solid rgba(118, 185, 0, 0.35); border-radius: 6px; |
| padding: 4px 12px; display: inline-block; font-size: 0.78em; color: #d9f99d; font-weight: 600; |
| }} |
| .dc-raw-{uid} {{ padding: 0 14px 12px; }} |
| .dc-raw-{uid} summary {{ cursor: pointer; color: #94a3b8; font-size: 0.78em; user-select: none; }} |
| .dc-raw-{uid} summary:hover {{ color: #76b900; }} |
| .dc-raw-pre-{uid} {{ |
| background: rgba(0,0,0,0.45); border: 1px solid rgba(255,255,255,0.08); border-radius: 6px; |
| padding: 10px; margin-top: 8px; |
| font-family: 'Fira Code', Consolas, monospace; |
| font-size: 0.74em; color: #cbd5e1; white-space: pre-wrap; word-break: break-all; |
| }} |
| </style> |
| """ |
| h = css + f'<div class="dc-root-{uid}">' |
| h += (f'<div class="dc-header-{uid}">' |
| f'<span class="dc-header-title-{uid}">Decoding Trace</span>' |
| f'<div class="dc-legend-{uid}">' |
| f'<div class="dc-legend-item-{uid}"><span class="dc-legend-dot-{uid}" style="background:#76b900;"></span>MTP Parallel</div>' |
| f'<div class="dc-legend-item-{uid}"><span class="dc-legend-dot-{uid}" style="background:#e65100;"></span>AR Fallback</div>' |
| f'</div></div>') |
| tok_idx = 0 |
| if out_info: |
| stats = _parse_out_info_dict(out_info) |
| bits = [] |
| if "forward_step" in stats: |
| bits.append(f"{stats['forward_step']} steps") |
| if "num_tokens" in stats: |
| bits.append(f"{stats['num_tokens']} tokens") |
| if "num_boxes" in stats: |
| bits.append(f"{stats['num_boxes']} boxes") |
| if "switch_to_ar" in stats: |
| n = stats["switch_to_ar"] |
| bits.append(f"{n} AR fallback{'s' if n != '1' else ''}") |
| if "tps" in stats: |
| bits.append(f"{stats['tps']} tok/s") |
| if "bps" in stats: |
| bits.append(f"{stats['bps']} box/s") |
| summary = " · ".join(bits) if bits else out_info.strip() |
| h += (f'<div class="dc-row-{uid}" style="justify-content:flex-start;padding-top:8px;padding-bottom:4px;border-bottom:none;">' |
| f'<span class="tk-stat-{uid}" style="animation-delay:0s">{summary}</span></div>') |
| h += f'<div class="dc-row-{uid}"><div class="dc-val-{uid}">' |
| if token_sequence: |
| for item in token_sequence: |
| if not isinstance(item, (list, tuple)) or len(item) < 2: |
| continue |
| decode_type = str(item[0]).lower() |
| text = str(item[1]) |
| safe = text.replace("<", "<").replace(">", ">") |
| delay = f"{tok_idx * 0.06:.2f}s" |
| cls = f"tk-ar-{uid}" if decode_type == "ar" else f"tk-mtp-{uid}" |
| h += f'<span class="{cls}" style="animation-delay:{delay}">{safe}</span> ' |
| tok_idx += 1 |
| h += '</div></div>' |
| if raw_text: |
| safe_raw = raw_text.replace("<", "<").replace(">", ">") |
| h += (f'<div class="dc-raw-{uid}"><details open><summary>Raw Response</summary>' |
| f'<div class="dc-raw-pre-{uid}">{safe_raw}</div></details></div>') |
| h += '</div>' |
| return h |
|
|
|
|
| def generate_raw_prompt(task_type, category): |
| if not category: |
| category = "objects" |
| cats = "</c>".join(c.strip() for c in category.split(",") if c.strip()) |
| if task_type == "Detection": |
| return f"Locate all the instances that matches the following description: {cats}." |
| elif task_type == "Grounding": |
| return f"Locate all the instances that match the following description: {cats}." |
| elif task_type == "OCR": |
| return "Detect all the text in box format." |
| elif task_type == "GUI": |
| return f"Locate the region that matches the following description: {cats}." |
| elif task_type == "Pointing": |
| return f"Point to: {cats}." |
| else: |
| return f"Locate all the instances that matches the following description: {cats}." |
|
|
|
|
| |
| |
| |
| GLOBAL_WORKER = None |
|
|
| def get_worker(): |
| global GLOBAL_WORKER |
| if GLOBAL_WORKER is None: |
| try: |
| MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B") |
| print(f"Loading model inside @spaces.GPU context: {MODEL_PATH}") |
| GLOBAL_WORKER = EagleWorker(MODEL_PATH) |
| except Exception as e: |
| print(f"Failed to load model: {e}. Will run in Mock Mode.") |
| GLOBAL_WORKER = None |
| return GLOBAL_WORKER |
|
|
|
|
| def _prepare_image_for_model(pil_img, short_size): |
| process_img = pil_img.copy() |
| if short_size is not None and short_size > 0: |
| process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024)) |
| else: |
| if min(process_img.size) > 1024: |
| process_img, _ = resize_image_short_side(process_img, 1024) |
| return process_img |
|
|
|
|
| |
| |
| |
| |
| |
| |
| LOG_DATASET_REPO = os.environ.get("LOG_DATASET_REPO", "woshichaoren123/log") |
| _LOG_DIR = Path(tempfile.mkdtemp(prefix="hf_log_")) |
| _SESSION_ID = uuid.uuid4().hex[:8] |
| _log_scheduler = None |
|
|
| if LOG_DATASET_REPO and LOG_HF_TOKEN: |
| try: |
| _log_scheduler = CommitScheduler( |
| repo_id=LOG_DATASET_REPO, |
| repo_type="dataset", |
| folder_path=str(_LOG_DIR), |
| path_in_repo="data", |
| every=3, |
| token=LOG_HF_TOKEN, |
| squash_history=False, |
| ) |
| print(f"[LOG] Dataset logging enabled -> {LOG_DATASET_REPO} " |
| f"(session={_SESSION_ID}, dir={_LOG_DIR})") |
| except Exception as e: |
| _log_scheduler = None |
| print(f"[LOG] Dataset logging disabled: {e}") |
| else: |
| print("[LOG] Dataset logging disabled (LOG_HF_TOKEN not set)") |
|
|
|
|
| def _pil_to_b64(pil_img): |
| buf = io.BytesIO() |
| pil_img.save(buf, "PNG") |
| return base64.b64encode(buf.getvalue()).decode("ascii") |
|
|
|
|
| def _atomic_write_text(path: Path, text: str): |
| tmp_path = path.with_name(path.name + ".tmp") |
| with open(tmp_path, "w", encoding="utf-8") as f: |
| f.write(text) |
| os.replace(tmp_path, path) |
|
|
|
|
| def _log_to_dataset( |
| input_type, category, model_mode, raw_prompt, |
| output_text="", input_image=None, output_image=None, |
| extra=None, |
| ): |
| if _log_scheduler is None: |
| return |
| try: |
| entry_id = f"{int(time.time())}_{uuid.uuid4().hex[:6]}" |
| ts = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) |
| date_str = time.strftime("%Y-%m-%d", time.gmtime()) |
|
|
| input_b64 = None |
| if input_image is not None and isinstance(input_image, Image.Image): |
| input_b64 = _pil_to_b64(input_image) |
|
|
| output_b64 = None |
| if output_image is not None and isinstance(output_image, Image.Image): |
| output_b64 = _pil_to_b64(output_image) |
|
|
| record = { |
| "id": entry_id, |
| "session_id": _SESSION_ID, |
| "timestamp": ts, |
| "input_type": input_type, |
| "category": category, |
| "model_mode": model_mode, |
| "raw_prompt": raw_prompt, |
| "output_text": output_text, |
| "input_image_b64": input_b64, |
| "output_image_b64": output_b64, |
| } |
| if extra: |
| record.update(extra) |
|
|
| day_dir = _LOG_DIR / date_str |
| day_dir.mkdir(parents=True, exist_ok=True) |
| log_file = day_dir / f"{_SESSION_ID}__{entry_id}.jsonl" |
|
|
| payload = json.dumps(record, ensure_ascii=False) + "\n" |
| with _log_scheduler.lock: |
| _atomic_write_text(log_file, payload) |
| except Exception as e: |
| print(f"[LOG] Failed to log to dataset: {e}") |
|
|
|
|
| def _maybe_log_inference( |
| input_type: str, |
| category: str, |
| model_mode: str, |
| raw_prompt: str, |
| output_text: str, |
| input_path: str | None = None, |
| output_path: str | None = None, |
| extra: dict | None = None, |
| ): |
| try: |
| input_image = None |
| output_image = None |
|
|
| if input_path and os.path.exists(input_path): |
| if input_type == "image": |
| input_image = Image.open(input_path).convert("RGB") |
| elif input_type == "video": |
| cap = cv2.VideoCapture(input_path) |
| ret, frame = cap.read() |
| cap.release() |
| if ret: |
| input_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
| if output_path and os.path.exists(output_path) and input_type == "image": |
| output_image = Image.open(output_path).convert("RGB") |
|
|
| categories_list = [c.strip() for c in category.split(",") if c.strip()] |
| _log_to_dataset( |
| input_type=input_type, |
| category=", ".join(categories_list) if categories_list else category, |
| model_mode=model_mode, |
| raw_prompt=raw_prompt, |
| output_text=output_text, |
| input_image=input_image, |
| output_image=output_image, |
| extra=extra, |
| ) |
| except Exception as e: |
| print(f"[LOG] Failed to prepare log record: {e}") |
|
|
|
|
| |
| |
| |
| GPU_HARD_LIMIT_IMAGE = 30 |
| GPU_HARD_LIMIT_VIDEO = 240 |
| PHASE2_RESERVE = 55 |
| SAFETY_MARGIN = 25 |
| EST_SECONDS_PER_FRAME = 20 |
|
|
|
|
| @spaces.GPU(duration=120, size="xlarge") |
| def run_image_gpu_api( |
| image_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int, |
| short_size: int | None, question_override: str | None |
| ): |
| image_in = Image.open(image_path).convert("RGB") |
| categories_list = [c.strip() for c in category.split(",") if c.strip()] |
| category_str = "</c>".join(categories_list) |
|
|
| process_img = _prepare_image_for_model(image_in, short_size) |
|
|
| worker = get_worker() |
| if worker: |
| output_text, token_sequence, out_info = worker.generate( |
| process_img, categories_list, model_mode, |
| temp=temp, top_p=top_p, top_k=top_k, |
| question_override=question_override, |
| ) |
| else: |
| |
| output_text = "Mock detection: <ref>sushi</ref><box><240><480><620><940></box> and <ref>book</ref><box><50><120><400><380></box>" |
| token_sequence = [] |
| out_info = "forward_step=1;num_tokens=18;num_boxes=2;tps=45;bps=15" |
|
|
| detections = parse_mixed_results(output_text, category_str) |
| frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR) |
| out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True) |
| output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB)) |
|
|
| |
| temp_dir = tempfile.mkdtemp() |
| out_img_path = os.path.join(temp_dir, "output.png") |
| output_image.save(out_img_path) |
|
|
| stats = _parse_out_info_dict(out_info) |
|
|
| |
| detections_summary = [] |
| for det in detections: |
| detections_summary.append({ |
| "label": det.get("label", "object"), |
| "type": det.get("type", "box"), |
| "coords": [round(c, 2) for c in det.get("coords", [])] |
| }) |
|
|
| html = generate_dynamic_html(token_sequence, out_info, output_text) |
| return out_img_path, stats, output_text, detections_summary, html |
|
|
|
|
| @spaces.GPU(duration=240, size="xlarge") |
| def run_video_gpu_api( |
| video_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int, |
| short_size: int | None, question_override: str | None, max_video_frames: int |
| ): |
| import subprocess as _sp |
|
|
| total_start = time.time() |
| max_frames = int(max_video_frames) if max_video_frames else 4 |
|
|
| categories_list = [c.strip() for c in category.split(",") if c.strip()] |
| category_str = "</c>".join(categories_list) |
|
|
| cap = cv2.VideoCapture(video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| all_frames = [] |
| while cap.isOpened(): |
| ret, frame = cap.read() |
| if not ret: |
| break |
| all_frames.append(frame) |
| cap.release() |
| total = len(all_frames) |
|
|
| if total == 0: |
| raise ValueError("Failed to read any frames from the video.") |
|
|
| |
| if total <= max_frames: |
| sample_indices = list(range(total)) |
| else: |
| sample_indices = [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)] |
|
|
| sampled_frames = [all_frames[i] for i in sample_indices] |
| n_sampled = len(sampled_frames) |
|
|
| |
| time_already_used = time.time() - total_start |
| available_for_inference = GPU_HARD_LIMIT_VIDEO - time_already_used - PHASE2_RESERVE - SAFETY_MARGIN |
| estimated_inference_time = n_sampled * EST_SECONDS_PER_FRAME |
|
|
| if estimated_inference_time > available_for_inference: |
| max_feasible = max(1, int(available_for_inference // EST_SECONDS_PER_FRAME)) |
| if total <= max_feasible: |
| sample_indices = list(range(total)) |
| else: |
| sample_indices = [int(round(i * (total - 1) / (max_feasible - 1))) for i in range(max_feasible)] |
| sampled_frames = [all_frames[i] for i in sample_indices] |
| n_sampled = len(sampled_frames) |
|
|
| out_fps = max(1.0, n_sampled / (total / fps)) if fps > 0 else 5.0 |
| del all_frames |
| gc.collect() |
|
|
| inference_results = [] |
| processed_count = 0 |
| early_stopped = False |
| early_stop_reason = "" |
|
|
| for i, frame in enumerate(sampled_frames): |
| elapsed_since_start = time.time() - total_start |
| remaining_total = GPU_HARD_LIMIT_VIDEO - elapsed_since_start |
|
|
| if remaining_total < PHASE2_RESERVE + SAFETY_MARGIN: |
| early_stopped = True |
| early_stop_reason = f"GPU time budget running out. Only {remaining_total:.0f}s left." |
| break |
|
|
| pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| process_img = _prepare_image_for_model(pil_img, short_size) |
|
|
| worker = get_worker() |
| if worker: |
| output_text, _, _ = worker.generate( |
| process_img, categories_list, model_mode, |
| temp=temp, top_p=top_p, top_k=top_k, |
| question_override=question_override, |
| ) |
| else: |
| output_text = f"Mock video detection: <ref>person</ref><box><100><150><800><900></box>" |
|
|
| inference_results.append(output_text) |
| processed_count += 1 |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| if processed_count == 0: |
| raise RuntimeError("GPU budget exceeded before processing any frames.") |
|
|
| sampled_frames_for_draw = sampled_frames[:processed_count] |
| inference_results_for_draw = inference_results[:processed_count] |
|
|
| tmp_raw = tempfile.mktemp(suffix=".raw.mp4") |
| out_video_path = tempfile.mktemp(suffix=".mp4") |
| out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h)) |
|
|
| detections_summary = [] |
| for i, (frame, output_text) in enumerate(zip(sampled_frames_for_draw, inference_results_for_draw)): |
| detections = parse_mixed_results(output_text, category_str) |
| valid_results = _postprocess_detections(detections, vid_w, vid_h) |
| frame_to_draw = draw_on_frame(frame, valid_results, draw_label=True) |
| out.write(frame_to_draw) |
|
|
| for det in valid_results: |
| detections_summary.append({ |
| "frame": i + 1, |
| "label": det.get("label", "object"), |
| "type": det.get("type", "box"), |
| "coords": det.get("coords", []) |
| }) |
|
|
| out.release() |
|
|
| |
| elapsed_now = time.time() - total_start |
| remaining_now = GPU_HARD_LIMIT_VIDEO - elapsed_now |
|
|
| if remaining_now > 15: |
| try: |
| ffmpeg_timeout = max(10, int(remaining_now - 5)) |
| _sp.run( |
| ["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264", |
| "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p", |
| "-movflags", "+faststart", out_video_path], |
| check=True, capture_output=True, timeout=ffmpeg_timeout, |
| ) |
| os.remove(tmp_raw) |
| except Exception: |
| if os.path.exists(tmp_raw): |
| os.replace(tmp_raw, out_video_path) |
| else: |
| os.replace(tmp_raw, out_video_path) |
|
|
| total_time = time.time() - total_start |
| stats = { |
| "total_frames": total, |
| "sampled_frames": n_sampled, |
| "processed_frames": processed_count, |
| "total_time_seconds": round(total_time, 2), |
| "early_stopped": early_stopped, |
| "early_stop_reason": early_stop_reason |
| } |
|
|
| raw_combined = "\n---\n".join(inference_results_for_draw) |
| timing_summary = ( |
| f"Processed {processed_count}/{n_sampled} sampled frames " |
| f"({total} total) in {total_time:.1f}s" |
| ) |
| if early_stopped: |
| timing_summary += f" — {early_stop_reason}" |
| html = generate_dynamic_html([], "", timing_summary + "\n\n" + raw_combined) |
| return out_video_path, stats, raw_combined, detections_summary, html |
|
|
|
|
| |
| |
| |
| app = Server() |
|
|
| |
| assets_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") |
| if os.path.exists(assets_dir): |
| app.mount("/assets", StaticFiles(directory=assets_dir), name="assets") |
|
|
| @app.get("/") |
| async def homepage(): |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") |
| if os.path.exists(html_path): |
| with open(html_path, "r", encoding="utf-8") as f: |
| return HTMLResponse(f.read()) |
| return HTMLResponse("<h1 style='color: #ef4444; font-family: Inter, sans-serif; text-align: center; margin-top: 100px;'>index.html is missing</h1>") |
|
|
|
|
| @app.api(name="run_inference") |
| def run_inference_api( |
| input_type: str, |
| image_file: Any = None, |
| video_file: Any = None, |
| task_type: str = "Detection", |
| category: str = "objects", |
| model_mode: str = "hybrid", |
| temp: float = 0.7, |
| top_p: float = 0.9, |
| top_k: int = 20, |
| short_size: int | None = None, |
| question_override: str | None = None, |
| max_video_frames: int = 4 |
| ) -> tuple[FileData | None, FileData | None, dict]: |
| """Exposed Gradio Queueing Endpoint for custom frontend interactions. |
| |
| ZeroGPU allocation is triggered directly at this endpoint boundary. |
| Supports both FileData dict (from web uploads) and local strings (for examples). |
| """ |
| try: |
| if not category: |
| category = "objects" |
| |
| final_prompt = question_override |
| if not final_prompt: |
| final_prompt = generate_raw_prompt(task_type, category) |
|
|
| if input_type == "Image": |
| if not image_file: |
| return None, None, {"success": False, "error": "Please upload an image."} |
| |
| |
| if isinstance(image_file, str): |
| img_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), image_file) |
| elif isinstance(image_file, dict): |
| img_path = image_file.get("path") |
| else: |
| img_path = getattr(image_file, "path", None) |
|
|
| if not img_path or not os.path.exists(img_path): |
| return None, None, {"success": False, "error": f"Invalid image file path: {img_path}"} |
|
|
| out_img_path, stats, raw_text, detections, html = run_image_gpu_api( |
| img_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt |
| ) |
| |
| meta = { |
| "success": True, |
| "input_type": "Image", |
| "stats": stats, |
| "raw_text": raw_text, |
| "detections": detections, |
| "final_prompt": final_prompt, |
| "html": html, |
| } |
| _maybe_log_inference( |
| input_type="image", |
| category=category, |
| model_mode=model_mode, |
| raw_prompt=final_prompt, |
| output_text=raw_text, |
| input_path=img_path, |
| output_path=out_img_path, |
| extra={"task_type": task_type, "detections": detections, "stats": stats}, |
| ) |
| return FileData(path=out_img_path), None, meta |
|
|
| else: |
| if not video_file: |
| return None, None, {"success": False, "error": "Please upload a video."} |
|
|
| |
| if isinstance(video_file, str): |
| vid_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), video_file) |
| elif isinstance(video_file, dict): |
| vid_path = video_file.get("path") |
| else: |
| vid_path = getattr(video_file, "path", None) |
|
|
| if not vid_path or not os.path.exists(vid_path): |
| return None, None, {"success": False, "error": f"Invalid video file path: {vid_path}"} |
|
|
| out_vid_path, stats, raw_text, detections, html = run_video_gpu_api( |
| vid_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt, max_video_frames |
| ) |
|
|
| meta = { |
| "success": True, |
| "input_type": "Video", |
| "stats": stats, |
| "raw_text": raw_text, |
| "detections": detections, |
| "final_prompt": final_prompt, |
| "html": html, |
| } |
| _maybe_log_inference( |
| input_type="video", |
| category=category, |
| model_mode=model_mode, |
| raw_prompt=final_prompt, |
| output_text=raw_text, |
| input_path=vid_path, |
| extra={ |
| "task_type": task_type, |
| "detections": detections, |
| "stats": stats, |
| "video_total_frames": stats.get("total_frames"), |
| "video_sampled_frames": stats.get("sampled_frames"), |
| "video_processed_frames": stats.get("processed_frames"), |
| }, |
| ) |
| return None, FileData(path=out_vid_path), meta |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return None, None, {"success": False, "error": str(e)} |
|
|
|
|
| if __name__ == "__main__": |
| app.launch(show_error=True) |