#!/usr/bin/env python
# -*- coding: utf-8 -*-
import spaces # MUST BE THE ABSOLUTE FIRST IMPORT FOR ZEROGPU EMULATION
import gradio as gr
from gradio import Server
from gradio.data_classes import FileData
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import cv2
import numpy as np
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import tempfile
import re
import time
import base64
import gc
import io
import json
import uuid
from pathlib import Path
from typing import Any
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, AutoModel, AutoTokenizer
_FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf")
# Retrieve optional HF Token from typical env variables
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("MODEL_HF_TOKEN")
def _load_font(size=20):
"""加载中文字体(LXGW WenKai),需提前放置到 assets/ 目录"""
if os.path.exists(_FONT_PATH):
try:
return ImageFont.truetype(_FONT_PATH, size)
except Exception:
pass
try:
return ImageFont.truetype("DejaVuSans-Bold.ttf", size)
except Exception:
return ImageFont.load_default()
# ============================================================
# 颜色 / 解析 / 绘制
# ============================================================
def get_color_for_label(label):
colors = [
(8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235),
(217, 119, 6), (147, 51, 234),
]
idx = sum(ord(c) for c in label)
return colors[idx % len(colors)]
def parse_mixed_results(text, category_str=""):
results = []
expected_cats = [c.strip().lower() for c in category_str.split("") if c.strip()]
ref_box_pattern = r"([.*?])|(.*?)"
current_label = None
found_structured = False
for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL):
token = m.group(0)
if token.lower().startswith("["):
label_raw = re.sub(r"?ref>", "", token, flags=re.IGNORECASE).strip()
if label_raw:
current_label = label_raw
else:
content = re.sub(r"?box>", "", token, flags=re.IGNORECASE)
nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content)
coords = [float(n) for n in nums]
if not coords:
continue
label = current_label
if label is None:
label = expected_cats[0] if expected_cats else "object"
if len(coords) == 4:
results.append({"type": "box", "coords": coords, "label": label})
elif len(coords) == 2:
results.append({"type": "point", "coords": coords, "label": label})
found_structured = True
if found_structured:
return results
box_pattern = r"(.*?)"
parts = re.split(box_pattern, text)
for i in range(1, len(parts), 2):
preceding_text = parts[i - 1].lower()
content = parts[i]
label = expected_cats[0] if expected_cats else "object"
for cat in expected_cats:
if cat in preceding_text:
label = cat
break
nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content)
coords = [float(n) for n in nums]
if len(coords) == 4:
results.append({"type": "box", "coords": coords, "label": label})
elif len(coords) == 2:
results.append({"type": "point", "coords": coords, "label": label})
return results
def resize_image_short_side(image, short_side_size):
w, h = image.size
if w <= h:
new_w = short_side_size
scale_factor = new_w / w
new_h = int(h * scale_factor)
else:
new_h = short_side_size
scale_factor = new_h / h
new_w = int(w * scale_factor)
resized_image = image.resize((new_w, new_h), Image.BILINEAR)
return resized_image, scale_factor
def draw_on_frame(frame_bgr, results, draw_label=True):
pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
img_draw = pil_img.convert("RGBA")
overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(overlay)
font = _load_font(20)
w_img, h_img = pil_img.size
parsed = []
for res in results:
label = res.get("label", "object")
color = get_color_for_label(label)
if res.get("type") == "point":
c = res["coords"]
cx = max(0, min(w_img, c[0] * w_img / 1000))
cy = max(0, min(h_img, c[1] * h_img / 1000))
parsed.append(("point", label, color, cx, cy))
continue
if "is_pixel" in res:
x1, y1, bw, bh = res["coords"]
x2, y2 = x1 + bw, y1 + bh
else:
c = res["coords"]
if len(c) < 4:
continue
x1 = c[0] * w_img / 1000
y1 = c[1] * h_img / 1000
x2 = c[2] * w_img / 1000
y2 = c[3] * h_img / 1000
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2)
x1, x2 = min(x1, x2), max(x1, x2)
y1, y2 = min(y1, y2), max(y1, y2)
parsed.append(("box", label, color, x1, y1, x2, y2))
for item in parsed:
if item[0] == "box":
_, _, color, x1, y1, x2, y2 = item
fill_color = color + (65,)
draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4)
elif item[0] == "point":
_, _, color, cx, cy = item
r = 10
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2)
if draw_label:
for item in parsed:
if item[0] == "box":
_, label, color, x1, y1, x2, y2 = item
if not label:
continue
t_box = draw.textbbox((0, 0), label, font=font)
th = t_box[3] - t_box[1]
tw = t_box[2] - t_box[0]
pad_x, pad_y = 7, 4
tag_h = th + pad_y * 2
tag_w = tw + pad_x * 2
tag_y = y1 - tag_h - 2
if tag_y < 0:
tag_y = y2 + 2
draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color)
draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font)
elif item[0] == "point":
_, label, color, cx, cy = item
if not label:
continue
t_box = draw.textbbox((0, 0), label, font=font)
th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0]
tx, ty = cx + 14, cy - th // 2
draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color)
draw.text((tx + 2, ty), label, fill="white", font=font)
combined = Image.alpha_composite(img_draw, overlay).convert("RGB")
return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR)
# ============================================================
# 模型
# ============================================================
class EagleWorker:
def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"):
self.model_id = model_path
self.device = device
self.dtype = torch.bfloat16
self.generation_mode = generation_mode
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
token=HF_TOKEN if HF_TOKEN else None,
)
self.processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True,
token=HF_TOKEN if HF_TOKEN else None,
)
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=self.dtype,
_attn_implementation="sdpa",
trust_remote_code=True,
token=HF_TOKEN if HF_TOKEN else None,
).to(device).eval()
print("Model Loaded Successfully!")
def build_messages(self, image, categories, question_override=None):
if question_override is not None:
user_text = question_override
else:
category_set_str = "".join(categories)
user_text = f"Locate all the instances that matches the following description: {category_set_str}."
return [{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": user_text},
]}]
@torch.no_grad()
def generate(self, image, categories, generation_mode=None,
max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50,
question_override=None):
messages = self.build_messages(image, categories, question_override=question_override)
text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos = self.processor.process_vision_info(messages)
inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device)
pixel_values = inputs["pixel_values"].to(self.dtype)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
image_grid_hws = inputs.get("image_grid_hws", None)
result = self.model.generate(
pixel_values=pixel_values, input_ids=input_ids,
attention_mask=attention_mask, image_grid_hws=image_grid_hws,
tokenizer=self.tokenizer, max_new_tokens=max_new_tokens,
use_cache=True,
generation_mode=generation_mode if generation_mode is not None else self.generation_mode,
temperature=temp, do_sample=True, top_p=top_p,
repetition_penalty=1.1, verbose=True,
)
token_sequence, out_info, output_text = [], "", ""
if isinstance(result, tuple) and len(result) >= 3:
output_text, token_sequence, out_info = result
if generation_mode == "slow":
token_sequence[-1] = ("ar", token_sequence[-1][1])
else:
output_text = result
return output_text, token_sequence, out_info
# ============================================================
# 后处理
# ============================================================
def _postprocess_detections(detections, w, h):
valid = []
for det in detections:
if det["type"] == "box":
c = det["coords"]
rx1 = max(0, min(w - 1, int(c[0] * w / 1000)))
ry1 = max(0, min(h - 1, int(c[1] * h / 1000)))
rx2 = max(0, min(w - 1, int(c[2] * w / 1000)))
ry2 = max(0, min(h - 1, int(c[3] * h / 1000)))
box_w, box_h = rx2 - rx1, ry2 - ry1
if box_w <= 0 or box_h <= 0:
continue
valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h],
"is_pixel": True, "label": det["label"]})
elif det["type"] == "point":
valid.append(det)
return valid
def _parse_out_info_dict(out_info: str) -> dict:
stats = {}
if not out_info:
return stats
cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip())
for part in cleaned.split(";"):
part = part.strip()
if "=" in part:
k, v = part.split("=", 1)
stats[k.strip()] = v.strip()
return stats
def generate_raw_prompt(task_type, category):
if not category:
category = "objects"
cats = "".join(c.strip() for c in category.split(",") if c.strip())
if task_type == "Detection":
return f"Locate all the instances that matches the following description: {cats}."
elif task_type == "Grounding":
return f"Locate all the instances that match the following description: {cats}."
elif task_type == "OCR":
return "Detect all the text in box format."
elif task_type == "GUI":
return f"Locate the region that matches the following description: {cats}."
elif task_type == "Pointing":
return f"Point to: {cats}."
else:
return f"Locate all the instances that matches the following description: {cats}."
# ============================================================
# 模型初始化
# ============================================================
GLOBAL_WORKER = None
def get_worker():
global GLOBAL_WORKER
if GLOBAL_WORKER is None:
try:
MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B")
print(f"Loading model inside @spaces.GPU context: {MODEL_PATH}")
GLOBAL_WORKER = EagleWorker(MODEL_PATH)
except Exception as e:
print(f"Failed to load model: {e}. Will run in Mock Mode.")
GLOBAL_WORKER = None
return GLOBAL_WORKER
def _prepare_image_for_model(pil_img, short_size):
process_img = pil_img.copy()
if short_size is not None and short_size > 0:
process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024))
else:
if min(process_img.size) > 1024:
process_img, _ = resize_image_short_side(process_img, 1024)
return process_img
# ============================================================
# GPU 时间预算与推理保护(按模式区分)
# ============================================================
GPU_HARD_LIMIT_IMAGE = 30
GPU_HARD_LIMIT_VIDEO = 240
PHASE2_RESERVE = 55
SAFETY_MARGIN = 25
EST_SECONDS_PER_FRAME = 20
@spaces.GPU(duration=120, size="xlarge")
def run_image_gpu_api(
image_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int,
short_size: int | None, question_override: str | None
):
image_in = Image.open(image_path).convert("RGB")
categories_list = [c.strip() for c in category.split(",") if c.strip()]
category_str = "".join(categories_list)
process_img = _prepare_image_for_model(image_in, short_size)
worker = get_worker()
if worker:
output_text, token_sequence, out_info = worker.generate(
process_img, categories_list, model_mode,
temp=temp, top_p=top_p, top_k=top_k,
question_override=question_override,
)
else:
# Mock mode fallback
output_text = "Mock detection: ][sweet]<240><480><620><940> and [book]<50><120><400><380>"
token_sequence = []
out_info = "forward_step=1;num_tokens=18;num_boxes=2;tps=45;bps=15"
detections = parse_mixed_results(output_text, category_str)
frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR)
out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True)
output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB))
# Save to temp file
temp_dir = tempfile.mkdtemp()
out_img_path = os.path.join(temp_dir, "output.png")
output_image.save(out_img_path)
stats = _parse_out_info_dict(out_info)
# Simplified summary lists
detections_summary = []
for det in detections:
detections_summary.append({
"label": det.get("label", "object"),
"type": det.get("type", "box"),
"coords": [round(c, 2) for c in det.get("coords", [])]
})
return out_img_path, stats, output_text, detections_summary
@spaces.GPU(duration=240, size="xlarge")
def run_video_gpu_api(
video_path: str, category: str, model_mode: str, temp: float, top_p: float, top_k: int,
short_size: int | None, question_override: str | None, max_video_frames: int
):
import subprocess as _sp
total_start = time.time()
max_frames = int(max_video_frames) if max_video_frames else 4
categories_list = [c.strip() for c in category.split(",") if c.strip()]
category_str = "".join(categories_list)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
all_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
all_frames.append(frame)
cap.release()
total = len(all_frames)
if total == 0:
raise ValueError("Failed to read any frames from the video.")
# Sample frames
if total <= max_frames:
sample_indices = list(range(total))
else:
sample_indices = [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)]
sampled_frames = [all_frames[i] for i in sample_indices]
n_sampled = len(sampled_frames)
# Budget check
time_already_used = time.time() - total_start
available_for_inference = GPU_HARD_LIMIT_VIDEO - time_already_used - PHASE2_RESERVE - SAFETY_MARGIN
estimated_inference_time = n_sampled * EST_SECONDS_PER_FRAME
if estimated_inference_time > available_for_inference:
max_feasible = max(1, int(available_for_inference // EST_SECONDS_PER_FRAME))
if total <= max_feasible:
sample_indices = list(range(total))
else:
sample_indices = [int(round(i * (total - 1) / (max_feasible - 1))) for i in range(max_feasible)]
sampled_frames = [all_frames[i] for i in sample_indices]
n_sampled = len(sampled_frames)
out_fps = max(1.0, n_sampled / (total / fps)) if fps > 0 else 5.0
del all_frames
gc.collect()
inference_results = []
processed_count = 0
early_stopped = False
early_stop_reason = ""
for i, frame in enumerate(sampled_frames):
elapsed_since_start = time.time() - total_start
remaining_total = GPU_HARD_LIMIT_VIDEO - elapsed_since_start
if remaining_total < PHASE2_RESERVE + SAFETY_MARGIN:
early_stopped = True
early_stop_reason = f"GPU time budget running out. Only {remaining_total:.0f}s left."
break
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
process_img = _prepare_image_for_model(pil_img, short_size)
worker = get_worker()
if worker:
output_text, _, _ = worker.generate(
process_img, categories_list, model_mode,
temp=temp, top_p=top_p, top_k=top_k,
question_override=question_override,
)
else:
output_text = f"Mock video detection: [person]<100><150><800><900>"
inference_results.append(output_text)
processed_count += 1
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if processed_count == 0:
raise RuntimeError("GPU budget exceeded before processing any frames.")
sampled_frames_for_draw = sampled_frames[:processed_count]
inference_results_for_draw = inference_results[:processed_count]
tmp_raw = tempfile.mktemp(suffix=".raw.mp4")
out_video_path = tempfile.mktemp(suffix=".mp4")
out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h))
detections_summary = []
for i, (frame, output_text) in enumerate(zip(sampled_frames_for_draw, inference_results_for_draw)):
detections = parse_mixed_results(output_text, category_str)
valid_results = _postprocess_detections(detections, vid_w, vid_h)
frame_to_draw = draw_on_frame(frame, valid_results, draw_label=True)
out.write(frame_to_draw)
for det in valid_results:
detections_summary.append({
"frame": i + 1,
"label": det.get("label", "object"),
"type": det.get("type", "box"),
"coords": det.get("coords", [])
})
out.release()
# ffmpeg re-encode
elapsed_now = time.time() - total_start
remaining_now = GPU_HARD_LIMIT_VIDEO - elapsed_now
if remaining_now > 15:
try:
ffmpeg_timeout = max(10, int(remaining_now - 5))
_sp.run(
["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264",
"-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p",
"-movflags", "+faststart", out_video_path],
check=True, capture_output=True, timeout=ffmpeg_timeout,
)
os.remove(tmp_raw)
except Exception:
if os.path.exists(tmp_raw):
os.replace(tmp_raw, out_video_path)
else:
os.replace(tmp_raw, out_video_path)
total_time = time.time() - total_start
stats = {
"total_frames": total,
"sampled_frames": n_sampled,
"processed_frames": processed_count,
"total_time_seconds": round(total_time, 2),
"early_stopped": early_stopped,
"early_stop_reason": early_stop_reason
}
return out_video_path, stats, "\n---\n".join(inference_results_for_draw), detections_summary
# ============================================================
# GRADIO SERVER APP
# ============================================================
app = Server()
# Serve static assets folder
assets_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
if os.path.exists(assets_dir):
app.mount("/assets", StaticFiles(directory=assets_dir), name="assets")
@app.get("/")
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
if os.path.exists(html_path):
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
return HTMLResponse("index.html is missing
")
@app.api(name="run_inference")
def run_inference_api(
input_type: str,
image_file: Any = None,
video_file: Any = None,
task_type: str = "Detection",
category: str = "objects",
model_mode: str = "hybrid",
temp: float = 0.7,
top_p: float = 0.9,
top_k: int = 20,
short_size: int | None = None,
question_override: str | None = None,
max_video_frames: int = 4
) -> tuple[FileData | None, FileData | None, dict]:
"""Exposed Gradio Queueing Endpoint for custom frontend interactions.
ZeroGPU allocation is triggered directly at this endpoint boundary.
Supports both FileData dict (from web uploads) and local strings (for examples).
"""
try:
if not category:
category = "objects"
final_prompt = question_override
if not final_prompt:
final_prompt = generate_raw_prompt(task_type, category)
if input_type == "Image":
if not image_file:
return None, None, {"success": False, "error": "Please upload an image."}
# Resolve image path (from either FileData upload or local example string)
if isinstance(image_file, str):
img_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), image_file)
elif isinstance(image_file, dict):
img_path = image_file.get("path")
else:
img_path = getattr(image_file, "path", None)
if not img_path or not os.path.exists(img_path):
return None, None, {"success": False, "error": f"Invalid image file path: {img_path}"}
out_img_path, stats, raw_text, detections = run_image_gpu_api(
img_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt
)
meta = {
"success": True,
"input_type": "Image",
"stats": stats,
"raw_text": raw_text,
"detections": detections,
"final_prompt": final_prompt
}
return FileData(path=out_img_path), None, meta
else:
if not video_file:
return None, None, {"success": False, "error": "Please upload a video."}
# Resolve video path (from either FileData upload or local example string)
if isinstance(video_file, str):
vid_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), video_file)
elif isinstance(video_file, dict):
vid_path = video_file.get("path")
else:
vid_path = getattr(video_file, "path", None)
if not vid_path or not os.path.exists(vid_path):
return None, None, {"success": False, "error": f"Invalid video file path: {vid_path}"}
out_vid_path, stats, raw_text, detections = run_video_gpu_api(
vid_path, category, model_mode, temp, top_p, top_k, short_size, final_prompt, max_video_frames
)
meta = {
"success": True,
"input_type": "Video",
"stats": stats,
"raw_text": raw_text,
"detections": detections,
"final_prompt": final_prompt
}
return None, FileData(path=out_vid_path), meta
except Exception as e:
import traceback
traceback.print_exc()
return None, None, {"success": False, "error": str(e)}
if __name__ == "__main__":
app.launch(show_error=True)