Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import spaces # MUST BE THE ABSOLUTE FIRST IMPORT FOR ZEROGPU EMULATION | |
| import gradio as gr | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import cv2 | |
| import numpy as np | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import tempfile | |
| import re | |
| import time | |
| import base64 | |
| import gc | |
| import io | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoProcessor, AutoModel, AutoTokenizer | |
| from huggingface_hub import CommitScheduler | |
| _FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf") | |
| def _get_first_env(*names): | |
| for name in names: | |
| value = os.environ.get(name) | |
| if value and value.strip(): | |
| return value.strip() | |
| return None | |
| def _configure_hf_auth(): | |
| model_token = _get_first_env( | |
| "MODEL_HF_TOKEN", | |
| "LOG_HF_TOKEN", | |
| "HF_TOKEN", | |
| "HUGGINGFACE_HUB_TOKEN", | |
| "HUGGINGFACEHUB_API_TOKEN", | |
| ) | |
| log_token = _get_first_env( | |
| "LOG_HF_TOKEN", | |
| "MODEL_HF_TOKEN", | |
| "HF_TOKEN", | |
| "HUGGINGFACE_HUB_TOKEN", | |
| "HUGGINGFACEHUB_API_TOKEN", | |
| ) | |
| shared_token = model_token or log_token | |
| if shared_token: | |
| for name in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN"): | |
| os.environ[name] = shared_token | |
| return model_token, log_token | |
| MODEL_HF_TOKEN, LOG_HF_TOKEN = _configure_hf_auth() | |
| HF_TOKEN = MODEL_HF_TOKEN | |
| def _load_font(size=20): | |
| """加载中文字体(LXGW WenKai),需提前放置到 assets/ 目录""" | |
| if os.path.exists(_FONT_PATH): | |
| try: | |
| return ImageFont.truetype(_FONT_PATH, size) | |
| except Exception: | |
| pass | |
| try: | |
| return ImageFont.truetype("DejaVuSans-Bold.ttf", size) | |
| except Exception: | |
| return ImageFont.load_default() | |
| # ============================================================ | |
| # 颜色 / 解析 / 绘制 | |
| # ============================================================ | |
| def get_color_for_label(label): | |
| colors = [ | |
| (8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235), | |
| (217, 119, 6), (147, 51, 234), | |
| ] | |
| idx = sum(ord(c) for c in label) | |
| return colors[idx % len(colors)] | |
| def parse_mixed_results(text, category_str=""): | |
| results = [] | |
| expected_cats = [c.strip().lower() for c in category_str.split("</c>") if c.strip()] | |
| ref_box_pattern = r"(<ref>.*?</ref>)|(<box>.*?</box>)" | |
| current_label = None | |
| found_structured = False | |
| for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL): | |
| token = m.group(0) | |
| if token.lower().startswith("<ref>"): | |
| label_raw = re.sub(r"</?ref>", "", token, flags=re.IGNORECASE).strip() | |
| if label_raw: | |
| current_label = label_raw | |
| else: | |
| content = re.sub(r"</?box>", "", token, flags=re.IGNORECASE) | |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) | |
| coords = [float(n) for n in nums] | |
| if not coords: | |
| continue | |
| label = current_label | |
| if label is None: | |
| label = expected_cats[0] if expected_cats else "object" | |
| if len(coords) == 4: | |
| results.append({"type": "box", "coords": coords, "label": label}) | |
| elif len(coords) == 2: | |
| results.append({"type": "point", "coords": coords, "label": label}) | |
| found_structured = True | |
| if found_structured: | |
| return results | |
| box_pattern = r"<box>(.*?)</box>" | |
| parts = re.split(box_pattern, text) | |
| for i in range(1, len(parts), 2): | |
| preceding_text = parts[i - 1].lower() | |
| content = parts[i] | |
| label = expected_cats[0] if expected_cats else "object" | |
| for cat in expected_cats: | |
| if cat in preceding_text: | |
| label = cat | |
| break | |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) | |
| coords = [float(n) for n in nums] | |
| if len(coords) == 4: | |
| results.append({"type": "box", "coords": coords, "label": label}) | |
| elif len(coords) == 2: | |
| results.append({"type": "point", "coords": coords, "label": label}) | |
| return results | |
| def resize_image_short_side(image, short_side_size): | |
| w, h = image.size | |
| if w <= h: | |
| new_w = short_side_size | |
| scale_factor = new_w / w | |
| new_h = int(h * scale_factor) | |
| else: | |
| new_h = short_side_size | |
| scale_factor = new_h / h | |
| new_w = int(w * scale_factor) | |
| resized_image = image.resize((new_w, new_h), Image.BILINEAR) | |
| return resized_image, scale_factor | |
| def draw_on_frame(frame_bgr, results, draw_label=True): | |
| pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) | |
| img_draw = pil_img.convert("RGBA") | |
| overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| font = _load_font(20) | |
| w_img, h_img = pil_img.size | |
| parsed = [] | |
| for res in results: | |
| label = res.get("label", "object") | |
| color = get_color_for_label(label) | |
| if res.get("type") == "point": | |
| c = res["coords"] | |
| cx = max(0, min(w_img, c[0] * w_img / 1000)) | |
| cy = max(0, min(h_img, c[1] * h_img / 1000)) | |
| parsed.append(("point", label, color, cx, cy)) | |
| continue | |
| if "is_pixel" in res: | |
| x1, y1, bw, bh = res["coords"] | |
| x2, y2 = x1 + bw, y1 + bh | |
| else: | |
| c = res["coords"] | |
| if len(c) < 4: | |
| continue | |
| x1 = c[0] * w_img / 1000 | |
| y1 = c[1] * h_img / 1000 | |
| x2 = c[2] * w_img / 1000 | |
| y2 = c[3] * h_img / 1000 | |
| x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2) | |
| x1, x2 = min(x1, x2), max(x1, x2) | |
| y1, y2 = min(y1, y2), max(y1, y2) | |
| parsed.append(("box", label, color, x1, y1, x2, y2)) | |
| for item in parsed: | |
| if item[0] == "box": | |
| _, _, color, x1, y1, x2, y2 = item | |
| fill_color = color + (65,) | |
| draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4) | |
| elif item[0] == "point": | |
| _, _, color, cx, cy = item | |
| r = 10 | |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2) | |
| if draw_label: | |
| for item in parsed: | |
| if item[0] == "box": | |
| _, label, color, x1, y1, x2, y2 = item | |
| if not label: | |
| continue | |
| t_box = draw.textbbox((0, 0), label, font=font) | |
| th = t_box[3] - t_box[1] | |
| tw = t_box[2] - t_box[0] | |
| pad_x, pad_y = 7, 4 | |
| tag_h = th + pad_y * 2 | |
| tag_w = tw + pad_x * 2 | |
| tag_y = y1 - tag_h - 2 | |
| if tag_y < 0: | |
| tag_y = y2 + 2 | |
| draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color) | |
| draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font) | |
| elif item[0] == "point": | |
| _, label, color, cx, cy = item | |
| if not label: | |
| continue | |
| t_box = draw.textbbox((0, 0), label, font=font) | |
| th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0] | |
| tx, ty = cx + 14, cy - th // 2 | |
| draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color) | |
| draw.text((tx + 2, ty), label, fill="white", font=font) | |
| combined = Image.alpha_composite(img_draw, overlay).convert("RGB") | |
| return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR) | |
| # ============================================================ | |
| # 模型 | |
| # ============================================================ | |
| class EagleWorker: | |
| def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"): | |
| self.model_id = model_path | |
| self.device = device | |
| self.dtype = torch.bfloat16 | |
| self.generation_mode = generation_mode | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| model_path, | |
| torch_dtype=self.dtype, | |
| _attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ).to(device).eval() | |
| print("Model Loaded Successfully!") | |
| def build_messages(self, image, categories, question_override=None): | |
| if question_override is not None: | |
| user_text = question_override | |
| else: | |
| category_set_str = "</c>".join(categories) | |
| user_text = f"Locate all the instances that matches the following description: {category_set_str}." | |
| return [{"role": "user", "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": user_text}, | |
| ]}] | |
| def generate(self, image, categories, generation_mode=None, | |
| max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50, | |
| question_override=None): | |
| messages = self.build_messages(image, categories, question_override=question_override) | |
| text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| images, videos = self.processor.process_vision_info(messages) | |
| inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device) | |
| pixel_values = inputs["pixel_values"].to(self.dtype) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| image_grid_hws = inputs.get("image_grid_hws", None) | |
| result = self.model.generate( | |
| pixel_values=pixel_values, input_ids=input_ids, | |
| attention_mask=attention_mask, image_grid_hws=image_grid_hws, | |
| tokenizer=self.tokenizer, max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| generation_mode=generation_mode if generation_mode is not None else self.generation_mode, | |
| temperature=temp, do_sample=True, top_p=top_p, | |
| repetition_penalty=1.1, verbose=True, | |
| ) | |
| token_sequence, out_info, output_text = [], "", "" | |
| if isinstance(result, tuple) and len(result) >= 3: | |
| output_text, token_sequence, out_info = result | |
| if generation_mode == "slow": | |
| token_sequence[-1] = ("ar", token_sequence[-1][1]) | |
| else: | |
| output_text = result | |
| return output_text, token_sequence, out_info | |
| # ============================================================ | |
| # 后处理 | |
| # ============================================================ | |
| def _postprocess_detections(detections, w, h): | |
| valid = [] | |
| for det in detections: | |
| if det["type"] == "box": | |
| c = det["coords"] | |
| rx1 = max(0, min(w - 1, int(c[0] * w / 1000))) | |
| ry1 = max(0, min(h - 1, int(c[1] * h / 1000))) | |
| rx2 = max(0, min(w - 1, int(c[2] * w / 1000))) | |
| ry2 = max(0, min(h - 1, int(c[3] * h / 1000))) | |
| box_w, box_h = rx2 - rx1, ry2 - ry1 | |
| if box_w <= 0 or box_h <= 0: | |
| continue | |
| valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h], | |
| "is_pixel": True, "label": det["label"]}) | |
| elif det["type"] == "point": | |
| valid.append(det) | |
| return valid | |
| def _parse_out_info_dict(out_info: str) -> dict: | |
| stats = {} | |
| if not out_info: | |
| return stats | |
| cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip()) | |
| for part in cleaned.split(";"): | |
| part = part.strip() | |
| if "=" in part: | |
| k, v = part.split("=", 1) | |
| stats[k.strip()] = v.strip() | |
| return stats | |
| def generate_dynamic_html(token_sequence, out_info, raw_text): | |
| uid = f"a{int(time.time() * 1000)}" | |
| css = f""" | |
| <style> | |
| .dc-root-{uid} {{ | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; | |
| border: 1px solid rgba(118, 185, 0, 0.25); border-radius: 12px; | |
| background: rgba(0, 0, 0, 0.55); overflow: visible; | |
| }} | |
| .dc-header-{uid} {{ | |
| display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 8px; | |
| padding: 10px 14px; | |
| background: linear-gradient(135deg, rgba(118, 185, 0, 0.25) 0%, rgba(63, 98, 0, 0.35) 100%); | |
| border-bottom: 1px solid rgba(118, 185, 0, 0.2); | |
| }} | |
| .dc-header-title-{uid} {{ font-weight: 700; font-size: 0.82em; color: #d9f99d; letter-spacing: 0.04em; text-transform: uppercase; }} | |
| .dc-legend-{uid} {{ display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }} | |
| .dc-legend-item-{uid} {{ display: flex; align-items: center; gap: 5px; font-size: 0.72em; color: rgba(226, 232, 240, 0.85); }} | |
| .dc-legend-dot-{uid} {{ width: 8px; height: 8px; border-radius: 2px; display: inline-block; }} | |
| .dc-row-{uid} {{ display: flex; gap: 10px; padding: 12px 14px; border-bottom: 1px solid rgba(255,255,255,0.05); }} | |
| .dc-row-{uid}:last-child {{ border-bottom: none; }} | |
| .dc-val-{uid} {{ | |
| flex: 1; line-height: 1.9; word-wrap: break-word; color: #cbd5e1; font-size: 0.85em; | |
| display: flex; flex-wrap: wrap; gap: 4px; align-items: center; align-content: flex-start; | |
| }} | |
| @keyframes tk-{uid} {{ | |
| 0% {{ opacity: 0; transform: translateY(8px) scale(0.92); }} | |
| 60% {{ opacity: 1; transform: translateY(-2px) scale(1.02); }} | |
| 100% {{ opacity: 1; transform: translateY(0) scale(1); }} | |
| }} | |
| .tk-mtp-{uid}, .tk-ar-{uid} {{ | |
| opacity: 0; animation: tk-{uid} 0.35s ease-out forwards; | |
| border-radius: 5px; padding: 2px 7px; margin: 0; | |
| display: inline-block; max-width: 100%; | |
| font-size: 0.78em; font-weight: 600; | |
| font-family: 'Fira Code', Consolas, monospace; | |
| white-space: normal; word-break: break-all; | |
| }} | |
| .tk-mtp-{uid} {{ background: rgba(118, 185, 0, 0.15); border: 1px solid rgba(118, 185, 0, 0.55); color: #bbf7d0; }} | |
| .tk-ar-{uid} {{ background: rgba(230, 81, 0, 0.15); border: 1px solid rgba(230, 81, 0, 0.55); color: #fed7aa; }} | |
| .tk-stat-{uid} {{ | |
| opacity: 0; animation: tk-{uid} 0.4s ease-out forwards; | |
| background: rgba(118, 185, 0, 0.12); border: 1px solid rgba(118, 185, 0, 0.35); border-radius: 6px; | |
| padding: 4px 12px; display: inline-block; font-size: 0.78em; color: #d9f99d; font-weight: 600; | |
| }} | |
| .dc-raw-{uid} {{ padding: 0 14px 12px; }} | |
| .dc-raw-{uid} summary {{ cursor: pointer; color: #94a3b8; font-size: 0.78em; user-select: none; }} | |
| .dc-raw-{uid} summary:hover {{ color: #76b900; }} | |
| .dc-raw-pre-{uid} {{ | |
| background: rgba(0,0,0,0.45); border: 1px solid rgba(255,255,255,0.08); border-radius: 6px; | |
| padding: 10px; margin-top: 8px; | |
| font-family: 'Fira Code', Consolas, monospace; | |
| font-size: 0.74em; color: #cbd5e1; white-space: pre-wrap; word-break: break-all; | |
| }} | |
| </style> | |
| """ | |
| h = css + f'<div class="dc-root-{uid}">' | |
| h += (f'<div class="dc-header-{uid}">' | |
| f'<span class="dc-header-title-{uid}">Decoding Trace</span>' | |
| f'<div class="dc-legend-{uid}">' | |
| f'<div class="dc-legend-item-{uid}"><span class="dc-legend-dot-{uid}" style="background:#76b900;"></span>MTP Parallel</div>' | |
| f'<div class="dc-legend-item-{uid}"><span class="dc-legend-dot-{uid}" style="background:#e65100;"></span>AR Fallback</div>' | |
| f'</div></div>') | |
| tok_idx = 0 | |
| if out_info: | |
| stats = _parse_out_info_dict(out_info) | |
| bits = [] | |
| if "forward_step" in stats: | |
| bits.append(f"{stats['forward_step']} steps") | |
| if "num_tokens" in stats: | |
| bits.append(f"{stats['num_tokens']} tokens") | |
| if "num_boxes" in stats: | |
| bits.append(f"{stats['num_boxes']} boxes") | |
| if "switch_to_ar" in stats: | |
| n = stats["switch_to_ar"] | |
| bits.append(f"{n} AR fallback{'s' if n != '1' else ''}") | |
| if "tps" in stats: | |
| bits.append(f"{stats['tps']} tok/s") | |
| if "bps" in stats: | |
| bits.append(f"{stats['bps']} box/s") | |
| summary = " · ".join(bits) if bits else out_info.strip() | |
| h += (f'<div class="dc-row-{uid}" style="justify-content:flex-start;padding-top:8px;padding-bottom:4px;border-bottom:none;">' | |
| f'<span class="tk-stat-{uid}" style="animation-delay:0s">{summary}</span></div>') | |
| h += f'<div class="dc-row-{uid}"><div class="dc-val-{uid}">' | |
| if token_sequence: | |
| for item in token_sequence: | |
| if not isinstance(item, (list, tuple)) or len(item) < 2: | |
| continue | |
| decode_type = str(item[0]).lower() | |
| text = str(item[1]) | |
| safe = text.replace("<", "<").replace(">", ">") | |
| delay = f"{tok_idx * 0.06:.2f}s" | |
| cls = f"tk-ar-{uid}" if decode_type == "ar" else f"tk-mtp-{uid}" | |
| h += f'<span class="{cls}" style="animation-delay:{delay}">{safe}</span> ' | |
| tok_idx += 1 | |
| h += '</div></div>' | |
| if raw_text: | |
| safe_raw = raw_text.replace("<", "<").replace(">", ">") | |
| h += (f'<div class="dc-raw-{uid}"><details open><summary>Raw Response</summary>' | |
| f'<div class="dc-raw-pre-{uid}">{safe_raw}</div></details></div>') | |
| h += '</div>' | |
| return h | |
| def generate_raw_prompt(task_type, category): | |
| if not category: | |
| category = "objects" | |
| cats = "</c>".join(c.strip() for c in category.split(",") if c.strip()) | |
| if task_type == "Detection": | |
| return f"Locate all the instances that matches the following description: {cats}." | |
| elif task_type == "Grounding": | |
| return f"Locate all the instances that match the following description: {cats}." | |
| elif task_type == "OCR": | |
| return "Detect all the text in box format." | |
| elif task_type == "GUI": | |
| return f"Locate the region that matches the following description: {cats}." | |
| elif task_type == "Pointing": | |
| return f"Point to: {cats}." | |
| else: | |
| return f"Locate all the instances that matches the following description: {cats}." | |
| # ============================================================ | |
| # 模型初始化 | |
| # ============================================================ | |
| GLOBAL_WORKER = None | |
| def get_worker(): | |
| global GLOBAL_WORKER | |
| if GLOBAL_WORKER is None: | |
| try: | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B") | |
| print(f"Loading model inside @spaces.GPU context: {MODEL_PATH}") | |
| GLOBAL_WORKER = EagleWorker(MODEL_PATH) | |
| except Exception as e: | |
| print(f"Failed to load model: {e}. Will run in Mock Mode.") | |
| GLOBAL_WORKER = None | |
| return GLOBAL_WORKER | |
| def _prepare_image_for_model(pil_img, short_size): | |
| process_img = pil_img.copy() | |
| if short_size is not None and short_size > 0: | |
| process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024)) | |
| else: | |
| if min(process_img.size) > 1024: | |
| process_img, _ = resize_image_short_side(process_img, 1024) | |
| return process_img | |
| # ============================================================ | |
| # 用户数据收集(HuggingFace Public Dataset) | |
| # 策略:one-record-per-file,配合按日目录 + 容器级 SESSION_ID | |
| # 每条记录:data/<YYYY-MM-DD>/<session_id>__<entry_id>.jsonl | |
| # CommitScheduler 只会新增文件,不会覆盖其它 session 的数据 | |
| # ============================================================ | |
| LOG_DATASET_REPO = os.environ.get("LOG_DATASET_REPO", "woshichaoren123/log") | |
| _LOG_DIR = Path(tempfile.mkdtemp(prefix="hf_log_")) | |
| _SESSION_ID = uuid.uuid4().hex[:8] | |
| _log_scheduler = None | |
| if LOG_DATASET_REPO and LOG_HF_TOKEN: | |
| try: | |
| _log_scheduler = CommitScheduler( | |
| repo_id=LOG_DATASET_REPO, | |
| repo_type="dataset", | |
| folder_path=str(_LOG_DIR), | |
| path_in_repo="data", | |
| every=3, | |
| token=LOG_HF_TOKEN, | |
| squash_history=False, | |
| ) | |
| print(f"[LOG] Dataset logging enabled -> {LOG_DATASET_REPO} " | |
| f"(session={_SESSION_ID}, dir={_LOG_DIR})") | |
| except Exception as e: | |
| _log_scheduler = None | |
| print(f"[LOG] Dataset logging disabled: {e}") | |
| else: | |
| print("[LOG] Dataset logging disabled (LOG_HF_TOKEN not set)") | |
| def _pil_to_b64(pil_img): | |
| buf = io.BytesIO() | |
| pil_img.save(buf, "PNG") | |
| return base64.b64encode(buf.getvalue()).decode("ascii") | |
| def _atomic_write_text(path: Path, text: str): | |
| tmp_path = path.with_name(path.name + ".tmp") | |
| with open(tmp_path, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| os.replace(tmp_path, path) | |
| def _log_to_dataset( | |
| input_type, category, model_mode, raw_prompt, | |
| output_text="", input_image=None, output_image=None, | |
| extra=None, | |
| ): | |
| if _log_scheduler is None: | |
| return | |
| try: | |
| entry_id = f"{int(time.time())}_{uuid.uuid4().hex[:6]}" | |
| ts = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) | |
| date_str = time.strftime("%Y-%m-%d", time.gmtime()) | |
| input_b64 = None | |
| if input_image is not None and isinstance(input_image, Image.Image): | |
| input_b64 = _pil_to_b64(input_image) | |
| output_b64 = None | |
| if output_image is not None and isinstance(output_image, Image.Image): | |
| output_b64 = _pil_to_b64(output_image) | |
| record = { | |
| "id": entry_id, | |
| "session_id": _SESSION_ID, | |
| "timestamp": ts, | |
| "input_type": input_type, | |
| "category": category, | |
| "model_mode": model_mode, | |
| "raw_prompt": raw_prompt, | |
| "output_text": output_text, | |
| "input_image_b64": input_b64, | |
| "output_image_b64": output_b64, | |
| } | |
| if extra: | |
| record.update(extra) | |
| day_dir = _LOG_DIR / date_str | |
| day_dir.mkdir(parents=True, exist_ok=True) | |
| log_file = day_dir / f"{_SESSION_ID}__{entry_id}.jsonl" | |
| payload = json.dumps(record, ensure_ascii=False) + "\n" | |
| with _log_scheduler.lock: | |
| _atomic_write_text(log_file, payload) | |
| except Exception as e: | |
| print(f"[LOG] Failed to log to dataset: {e}") | |
| def _maybe_log_inference( | |
| input_type: str, | |
| category: str, | |
| model_mode: str, | |
| raw_prompt: str, | |
| output_text: str, | |
| input_path: str | None = None, | |
| output_path: str | None = None, | |
| extra: dict | None = None, | |
| ): | |
| try: | |
| input_image = None | |
| output_image = None | |
| if input_path and os.path.exists(input_path): | |
| if input_type == "image": | |
| input_image = Image.open(input_path).convert("RGB") | |
| elif input_type == "video": | |
| cap = cv2.VideoCapture(input_path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| input_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| if output_path and os.path.exists(output_path) and input_type == "image": | |
| output_image = Image.open(output_path).convert("RGB") | |
| categories_list = [c.strip() for c in category.split(",") if c.strip()] | |
| _log_to_dataset( | |
| input_type=input_type, | |
| category=", ".join(categories_list) if categories_list else category, | |
| model_mode=model_mode, | |
| raw_prompt=raw_prompt, | |
| output_text=output_text, | |
| input_image=input_image, | |
| output_image=output_image, | |
| extra=extra, | |
| ) | |
| except Exception as e: | |
| print(f"[LOG] Failed to prepare log record: {e}") | |
| # ============================================================ | |
| # GPU 时间预算与推理保护(按模式区分) | |
| # ============================================================ | |
| GPU_HARD_LIMIT_IMAGE = 30 | |
| GPU_HARD_LIMIT_VIDEO = 240 | |
| PHASE2_RESERVE = 55 | |
| SAFETY_MARGIN = 25 | |
| EST_SECONDS_PER_FRAME = 20 | |
| def run_image_gpu_api( | |
| image_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int, | |
| short_size: int | None, question_override: str | None | |
| ): | |
| image_in = Image.open(image_path).convert("RGB") | |
| categories_list = [c.strip() for c in category.split(",") if c.strip()] | |
| category_str = "</c>".join(categories_list) | |
| process_img = _prepare_image_for_model(image_in, short_size) | |
| worker = get_worker() | |
| if worker: | |
| output_text, token_sequence, out_info = worker.generate( | |
| process_img, categories_list, model_mode, | |
| temp=temp, top_p=top_p, top_k=top_k, | |
| question_override=question_override, | |
| ) | |
| else: | |
| # Mock mode fallback | |
| output_text = "Mock detection: <ref>sushi</ref><box><240><480><620><940></box> and <ref>book</ref><box><50><120><400><380></box>" | |
| token_sequence = [] | |
| out_info = "forward_step=1;num_tokens=18;num_boxes=2;tps=45;bps=15" | |
| detections = parse_mixed_results(output_text, category_str) | |
| frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR) | |
| out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True) | |
| output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB)) | |
| # Save to temp file | |
| temp_dir = tempfile.mkdtemp() | |
| out_img_path = os.path.join(temp_dir, "output.png") | |
| output_image.save(out_img_path) | |
| stats = _parse_out_info_dict(out_info) | |
| # Simplified summary lists | |
| detections_summary = [] | |
| for det in detections: | |
| detections_summary.append({ | |
| "label": det.get("label", "object"), | |
| "type": det.get("type", "box"), | |
| "coords": [round(c, 2) for c in det.get("coords", [])] | |
| }) | |
| html = generate_dynamic_html(token_sequence, out_info, output_text) | |
| return out_img_path, stats, output_text, detections_summary, html | |
| def run_video_gpu_api( | |
| video_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int, | |
| short_size: int | None, question_override: str | None, max_video_frames: int | |
| ): | |
| import subprocess as _sp | |
| total_start = time.time() | |
| max_frames = int(max_video_frames) if max_video_frames else 4 | |
| categories_list = [c.strip() for c in category.split(",") if c.strip()] | |
| category_str = "</c>".join(categories_list) | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| all_frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| all_frames.append(frame) | |
| cap.release() | |
| total = len(all_frames) | |
| if total == 0: | |
| raise ValueError("Failed to read any frames from the video.") | |
| # Sample frames | |
| if total <= max_frames: | |
| sample_indices = list(range(total)) | |
| else: | |
| sample_indices = [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)] | |
| sampled_frames = [all_frames[i] for i in sample_indices] | |
| n_sampled = len(sampled_frames) | |
| # Budget check | |
| time_already_used = time.time() - total_start | |
| available_for_inference = GPU_HARD_LIMIT_VIDEO - time_already_used - PHASE2_RESERVE - SAFETY_MARGIN | |
| estimated_inference_time = n_sampled * EST_SECONDS_PER_FRAME | |
| if estimated_inference_time > available_for_inference: | |
| max_feasible = max(1, int(available_for_inference // EST_SECONDS_PER_FRAME)) | |
| if total <= max_feasible: | |
| sample_indices = list(range(total)) | |
| else: | |
| sample_indices = [int(round(i * (total - 1) / (max_feasible - 1))) for i in range(max_feasible)] | |
| sampled_frames = [all_frames[i] for i in sample_indices] | |
| n_sampled = len(sampled_frames) | |
| out_fps = max(1.0, n_sampled / (total / fps)) if fps > 0 else 5.0 | |
| del all_frames | |
| gc.collect() | |
| inference_results = [] | |
| processed_count = 0 | |
| early_stopped = False | |
| early_stop_reason = "" | |
| for i, frame in enumerate(sampled_frames): | |
| elapsed_since_start = time.time() - total_start | |
| remaining_total = GPU_HARD_LIMIT_VIDEO - elapsed_since_start | |
| if remaining_total < PHASE2_RESERVE + SAFETY_MARGIN: | |
| early_stopped = True | |
| early_stop_reason = f"GPU time budget running out. Only {remaining_total:.0f}s left." | |
| break | |
| pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| process_img = _prepare_image_for_model(pil_img, short_size) | |
| worker = get_worker() | |
| if worker: | |
| output_text, _, _ = worker.generate( | |
| process_img, categories_list, model_mode, | |
| temp=temp, top_p=top_p, top_k=top_k, | |
| question_override=question_override, | |
| ) | |
| else: | |
| output_text = f"Mock video detection: <ref>person</ref><box><100><150><800><900></box>" | |
| inference_results.append(output_text) | |
| processed_count += 1 | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if processed_count == 0: | |
| raise RuntimeError("GPU budget exceeded before processing any frames.") | |
| sampled_frames_for_draw = sampled_frames[:processed_count] | |
| inference_results_for_draw = inference_results[:processed_count] | |
| tmp_raw = tempfile.mktemp(suffix=".raw.mp4") | |
| out_video_path = tempfile.mktemp(suffix=".mp4") | |
| out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h)) | |
| detections_summary = [] | |
| for i, (frame, output_text) in enumerate(zip(sampled_frames_for_draw, inference_results_for_draw)): | |
| detections = parse_mixed_results(output_text, category_str) | |
| valid_results = _postprocess_detections(detections, vid_w, vid_h) | |
| frame_to_draw = draw_on_frame(frame, valid_results, draw_label=True) | |
| out.write(frame_to_draw) | |
| for det in valid_results: | |
| detections_summary.append({ | |
| "frame": i + 1, | |
| "label": det.get("label", "object"), | |
| "type": det.get("type", "box"), | |
| "coords": det.get("coords", []) | |
| }) | |
| out.release() | |
| # ffmpeg re-encode | |
| elapsed_now = time.time() - total_start | |
| remaining_now = GPU_HARD_LIMIT_VIDEO - elapsed_now | |
| if remaining_now > 15: | |
| try: | |
| ffmpeg_timeout = max(10, int(remaining_now - 5)) | |
| _sp.run( | |
| ["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264", | |
| "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p", | |
| "-movflags", "+faststart", out_video_path], | |
| check=True, capture_output=True, timeout=ffmpeg_timeout, | |
| ) | |
| os.remove(tmp_raw) | |
| except Exception: | |
| if os.path.exists(tmp_raw): | |
| os.replace(tmp_raw, out_video_path) | |
| else: | |
| os.replace(tmp_raw, out_video_path) | |
| total_time = time.time() - total_start | |
| stats = { | |
| "total_frames": total, | |
| "sampled_frames": n_sampled, | |
| "processed_frames": processed_count, | |
| "total_time_seconds": round(total_time, 2), | |
| "early_stopped": early_stopped, | |
| "early_stop_reason": early_stop_reason | |
| } | |
| raw_combined = "\n---\n".join(inference_results_for_draw) | |
| timing_summary = ( | |
| f"Processed {processed_count}/{n_sampled} sampled frames " | |
| f"({total} total) in {total_time:.1f}s" | |
| ) | |
| if early_stopped: | |
| timing_summary += f" — {early_stop_reason}" | |
| html = generate_dynamic_html([], "", timing_summary + "\n\n" + raw_combined) | |
| return out_video_path, stats, raw_combined, detections_summary, html | |
| # ============================================================ | |
| # GRADIO SERVER APP | |
| # ============================================================ | |
| app = Server() | |
| # Serve static assets folder | |
| assets_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
| if os.path.exists(assets_dir): | |
| app.mount("/assets", StaticFiles(directory=assets_dir), name="assets") | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| if os.path.exists(html_path): | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return HTMLResponse(f.read()) | |
| return HTMLResponse("<h1 style='color: #ef4444; font-family: Inter, sans-serif; text-align: center; margin-top: 100px;'>index.html is missing</h1>") | |
| def run_inference_api( | |
| input_type: str, | |
| image_file: Any = None, | |
| video_file: Any = None, | |
| task_type: str = "Detection", | |
| category: str = "objects", | |
| model_mode: str = "hybrid", | |
| temp: float = 0.7, | |
| top_p: float = 0.9, | |
| top_k: int = 20, | |
| short_size: int | None = None, | |
| question_override: str | None = None, | |
| max_video_frames: int = 4 | |
| ) -> tuple[FileData | None, FileData | None, dict]: | |
| """Exposed Gradio Queueing Endpoint for custom frontend interactions. | |
| ZeroGPU allocation is triggered directly at this endpoint boundary. | |
| Supports both FileData dict (from web uploads) and local strings (for examples). | |
| """ | |
| try: | |
| if not category: | |
| category = "objects" | |
| final_prompt = question_override | |
| if not final_prompt: | |
| final_prompt = generate_raw_prompt(task_type, category) | |
| if input_type == "Image": | |
| if not image_file: | |
| return None, None, {"success": False, "error": "Please upload an image."} | |
| # Resolve image path (from either FileData upload or local example string) | |
| if isinstance(image_file, str): | |
| img_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), image_file) | |
| elif isinstance(image_file, dict): | |
| img_path = image_file.get("path") | |
| else: | |
| img_path = getattr(image_file, "path", None) | |
| if not img_path or not os.path.exists(img_path): | |
| return None, None, {"success": False, "error": f"Invalid image file path: {img_path}"} | |
| out_img_path, stats, raw_text, detections, html = run_image_gpu_api( | |
| img_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt | |
| ) | |
| meta = { | |
| "success": True, | |
| "input_type": "Image", | |
| "stats": stats, | |
| "raw_text": raw_text, | |
| "detections": detections, | |
| "final_prompt": final_prompt, | |
| "html": html, | |
| } | |
| _maybe_log_inference( | |
| input_type="image", | |
| category=category, | |
| model_mode=model_mode, | |
| raw_prompt=final_prompt, | |
| output_text=raw_text, | |
| input_path=img_path, | |
| output_path=out_img_path, | |
| extra={"task_type": task_type, "detections": detections, "stats": stats}, | |
| ) | |
| return FileData(path=out_img_path), None, meta | |
| else: | |
| if not video_file: | |
| return None, None, {"success": False, "error": "Please upload a video."} | |
| # Resolve video path (from either FileData upload or local example string) | |
| if isinstance(video_file, str): | |
| vid_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), video_file) | |
| elif isinstance(video_file, dict): | |
| vid_path = video_file.get("path") | |
| else: | |
| vid_path = getattr(video_file, "path", None) | |
| if not vid_path or not os.path.exists(vid_path): | |
| return None, None, {"success": False, "error": f"Invalid video file path: {vid_path}"} | |
| out_vid_path, stats, raw_text, detections, html = run_video_gpu_api( | |
| vid_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt, max_video_frames | |
| ) | |
| meta = { | |
| "success": True, | |
| "input_type": "Video", | |
| "stats": stats, | |
| "raw_text": raw_text, | |
| "detections": detections, | |
| "final_prompt": final_prompt, | |
| "html": html, | |
| } | |
| _maybe_log_inference( | |
| input_type="video", | |
| category=category, | |
| model_mode=model_mode, | |
| raw_prompt=final_prompt, | |
| output_text=raw_text, | |
| input_path=vid_path, | |
| extra={ | |
| "task_type": task_type, | |
| "detections": detections, | |
| "stats": stats, | |
| "video_total_frames": stats.get("total_frames"), | |
| "video_sampled_frames": stats.get("sampled_frames"), | |
| "video_processed_frames": stats.get("processed_frames"), | |
| }, | |
| ) | |
| return None, FileData(path=out_vid_path), meta | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, {"success": False, "error": str(e)} | |
| if __name__ == "__main__": | |
| app.launch(show_error=True) |