eweffeffew / app.py
hrmndev's picture
Update app.py
b97e553 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import gradio as gr
import cv2
import numpy as np
import os
import tempfile
import re
import time
import base64
import gc
import io
import json
import uuid
from pathlib import Path
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, AutoModel, AutoTokenizer
from huggingface_hub import CommitScheduler
import spaces
_FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf")
def _get_first_env(*names):
for name in names:
value = os.environ.get(name)
if value and value.strip():
return value.strip()
return None
def _configure_hf_auth():
model_token = _get_first_env(
"MODEL_HF_TOKEN",
"LOG_HF_TOKEN",
"HF_TOKEN",
"HUGGINGFACE_HUB_TOKEN",
"HUGGINGFACEHUB_API_TOKEN",
)
log_token = _get_first_env(
"LOG_HF_TOKEN",
"MODEL_HF_TOKEN",
"HF_TOKEN",
"HUGGINGFACE_HUB_TOKEN",
"HUGGINGFACEHUB_API_TOKEN",
)
shared_token = model_token or log_token
if shared_token:
for name in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN"):
os.environ[name] = shared_token
return model_token, log_token
MODEL_HF_TOKEN, LOG_HF_TOKEN = _configure_hf_auth()
def _load_font(size=20):
if os.path.exists(_FONT_PATH):
try:
return ImageFont.truetype(_FONT_PATH, size)
except Exception:
pass
try:
return ImageFont.truetype("DejaVuSans-Bold.ttf", size)
except Exception:
return ImageFont.load_default()
# ============================================================
# Color / Parsing / Rendering Operations
# ============================================================
def get_color_for_label(label):
colors = [
(8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235),
(217, 119, 6), (147, 51, 234),
]
idx = sum(ord(c) for c in label)
return colors[idx % len(colors)]
def parse_mixed_results(text, category_str=""):
results = []
expected_cats = [c.strip().lower() for c in category_str.split("</c>") if c.strip()]
ref_box_pattern = r"(<ref>.*?</ref>)|(<box>.*?</box>)"
current_label = None
found_structured = False
for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL):
token = m.group(0)
if token.lower().startswith("<ref>"):
label_raw = re.sub(r"</?ref>", "", token, flags=re.IGNORECASE).strip()
if label_raw:
current_label = label_raw
else:
content = re.sub(r"</?box>", "", token, flags=re.IGNORECASE)
nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content)
coords = [float(n) for n in nums]
if not coords:
continue
label = current_label
if label is None:
label = expected_cats[0] if expected_cats else "object"
if len(coords) == 4:
results.append({"type": "box", "coords": coords, "label": label})
elif len(coords) == 2:
results.append({"type": "point", "coords": coords, "label": label})
found_structured = True
if found_structured:
return results
box_pattern = r"<box>(.*?)</box>"
parts = re.split(box_pattern, text)
for i in range(1, len(parts), 2):
preceding_text = parts[i - 1].lower()
content = parts[i]
label = expected_cats[0] if expected_cats else "object"
for cat in expected_cats:
if cat in preceding_text:
label = cat
break
nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content)
coords = [float(n) for n in nums]
if len(coords) == 4:
results.append({"type": "box", "coords": coords, "label": label})
elif len(coords) == 2:
results.append({"type": "point", "coords": coords, "label": label})
return results
def resize_image_short_side(image, short_side_size):
w, h = image.size
if w <= h:
new_w = short_side_size
scale_factor = new_w / w
new_h = int(h * scale_factor)
else:
new_h = short_side_size
scale_factor = new_h / h
new_w = int(w * scale_factor)
resized_image = image.resize((new_w, new_h), Image.BILINEAR)
return resized_image, scale_factor
def draw_on_frame(frame_bgr, results, draw_label=True):
pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
img_draw = pil_img.convert("RGBA")
overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(overlay)
font = _load_font(20)
w_img, h_img = pil_img.size
parsed = []
for res in results:
label = res.get("label", "object")
color = get_color_for_label(label)
if res.get("type") == "point":
c = res["coords"]
cx = max(0, min(w_img, c[0] * w_img / 1000))
cy = max(0, min(h_img, c[1] * h_img / 1000))
parsed.append(("point", label, color, cx, cy))
continue
if "is_pixel" in res:
x1, y1, bw, bh = res["coords"]
x2, y2 = x1 + bw, y1 + bh
else:
c = res["coords"]
if len(c) < 4:
continue
x1 = c[0] * w_img / 1000
y1 = c[1] * h_img / 1000
x2 = c[2] * w_img / 1000
y2 = c[3] * h_img / 1000
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2)
x1, x2 = min(x1, x2), max(x1, x2)
y1, y2 = min(y1, y2), max(y1, y2)
parsed.append(("box", label, color, x1, y1, x2, y2))
for item in parsed:
if item[0] == "box":
_, _, color, x1, y1, x2, y2 = item
fill_color = color + (65,)
draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4)
elif item[0] == "point":
_, _, color, cx, cy = item
r = 10
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2)
if draw_label:
for item in parsed:
if item[0] == "box":
_, label, color, x1, y1, x2, y2 = item
if not label:
continue
t_box = draw.textbbox((0, 0), label, font=font)
th = t_box[3] - t_box[1]
tw = t_box[2] - t_box[0]
pad_x, pad_y = 7, 4
tag_h = th + pad_y * 2
tag_w = tw + pad_x * 2
tag_y = y1 - tag_h - 2
if tag_y < 0:
tag_y = y2 + 2
draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color)
draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font)
elif item[0] == "point":
_, label, color, cx, cy = item
if not label:
continue
t_box = draw.textbbox((0, 0), label, font=font)
th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0]
tx, ty = cx + 14, cy - th // 2
draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color)
draw.text((tx + 2, ty), label, fill="white", font=font)
combined = Image.alpha_composite(img_draw, overlay).convert("RGB")
return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR)
# ============================================================
# Model Runner Component
# ============================================================
class EagleWorker:
def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"):
self.model_id = model_path
self.device = device
self.dtype = torch.bfloat16
self.generation_mode = generation_mode
self.hf_token = MODEL_HF_TOKEN
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
token=self.hf_token,
)
self.processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True,
token=self.hf_token,
)
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=self.dtype,
_attn_implementation="sdpa",
trust_remote_code=True,
token=self.hf_token,
).to(device).eval()
print("Model Engine Loaded Safely.")
def build_messages(self, image, categories, question_override=None):
if question_override is not None:
user_text = question_override
else:
category_set_str = "</c>".join(categories)
user_text = f"Locate all the instances that matches the following description: {category_set_str}."
return [{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": user_text},
]}]
def generate(self, image, categories, generation_mode=None,
max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50,
question_override=None):
messages = self.build_messages(image, categories, question_override=question_override)
text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos = self.processor.process_vision_info(messages)
inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device)
pixel_values = inputs["pixel_values"].to(self.dtype)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
image_grid_hws = inputs.get("image_grid_hws", None)
with torch.inference_mode():
result = self.model.generate(
pixel_values=pixel_values, input_ids=input_ids,
attention_mask=attention_mask, image_grid_hws=image_grid_hws,
tokenizer=self.tokenizer, max_new_tokens=max_new_tokens,
use_cache=True,
generation_mode=generation_mode if generation_mode is not None else self.generation_mode,
temperature=temp, do_sample=True, top_p=top_p,
repetition_penalty=1.1, verbose=True,
)
token_sequence, out_info, output_text = [], "", ""
if isinstance(result, tuple) and len(result) >= 3:
output_text, token_sequence, out_info = result
if generation_mode == "slow":
token_sequence[-1] = ("ar", token_sequence[-1][1])
else:
output_text = result
return output_text, token_sequence, out_info
# ============================================================
# Post-Processing UI Helpers
# ============================================================
def _postprocess_detections(detections, w, h):
valid = []
for det in detections:
if det["type"] == "box":
c = det["coords"]
rx1 = max(0, min(w - 1, int(c[0] * w / 1000)))
ry1 = max(0, min(h - 1, int(c[1] * h / 1000)))
rx2 = max(0, min(w - 1, int(c[2] * w / 1000)))
ry2 = max(0, min(h - 1, int(c[3] * h / 1000)))
box_w, box_h = rx2 - rx1, ry2 - ry1
if box_w <= 0 or box_h <= 0:
continue
valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h],
"is_pixel": True, "label": det["label"]})
elif det["type"] == "point":
valid.append(det)
return valid
def _parse_out_info_dict(out_info: str) -> dict:
stats = {}
if not out_info:
return stats
cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip())
for part in cleaned.split(";"):
part = part.strip()
if "=" in part:
k, v = part.split("=", 1)
stats[k.strip()] = v.strip()
return stats
def generate_dynamic_html(token_sequence, out_info, raw_text):
uid = f"a{int(time.time() * 1000)}"
css = f"""
<style>
.dc-root {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
border: 1px solid #cce875; border-radius: 10px; background: #ffffff; overflow: hidden;
}}
.dc-header {{
display: flex; align-items: center; justify-content: space-between; padding: 12px 18px;
background: linear-gradient(135deg, #76b900 0%, #649d00 100%); border-bottom: 1px solid #527f00;
}}
.dc-header-title {{ font-weight: 700; font-size: 0.95em; color: #ffffff !important; }}
.dc-legend {{ display: flex; gap: 16px; align-items: center; }}
.dc-legend-item {{ display: flex; align-items: center; gap: 5px; font-size: 0.78em; color: rgba(255,255,255,0.92); }}
.dc-legend-dot {{ width: 10px; height: 10px; border-radius: 3px; display: inline-block; }}
.dc-row {{ display: flex; gap: 10px; padding: 14px 18px; border-bottom: 1px solid #eef7d1; }}
.dc-row:last-child {{ border-bottom: none; }}
.dc-val {{ flex: 1; line-height: 2.3; word-wrap: break-word; color: #4b5563; font-size: 0.92em; }}
@keyframes tk-{uid} {{
0% {{ opacity: 0; transform: translateY(8px); }}
100% {{ opacity: 1; transform: translateY(0); }}
}}
.tk-mtp-{uid}, .tk-ar-{uid} {{
opacity: 0; animation: tk-{uid} 0.35s ease-out forwards; border-radius: 5px; padding: 2px 7px; margin: 2px 1px; display: inline-block;
font-size: 0.80em; font-weight: 600; font-family: monospace; white-space: nowrap;
}}
.tk-mtp-{uid} {{ background: #e8f5e9; border: 2px solid #76b900; color: #000000; }}
.tk-ar-{uid} {{ background: #fff3e0; border: 2px solid #e65100; color: #000000; }}
.tk-stat-{uid} {{
opacity: 0; animation: tk-{uid} 0.4s ease-out forwards; background: #f0f9e2; border: 1px solid #a4d422; border-radius: 6px;
padding: 5px 14px; display: inline-block; font-size: 0.82em; color: #3f6200; font-weight: 600;
}}
.dc-raw {{ padding: 0 18px 14px; }}
.dc-raw summary {{ cursor: pointer; color: #9ca3af; font-size: 0.82em; }}
.dc-raw-pre {{
background: #f7fbe8; border: 1px solid #ddf0a3; border-radius: 6px; padding: 12px; margin-top: 8px;
font-family: monospace; font-size: 0.78em; color: #374151; white-space: pre-wrap; max-height: 200px; overflow-y: auto;
}}
</style>
"""
h = css + '<div class="dc-root">'
h += ('<div class="dc-header">'
'<span class="dc-header-title">LocateAnything Decoding Trace</span>'
'<div class="dc-legend">'
'<div class="dc-legend-item"><span class="dc-legend-dot" style="background:#76b900;"></span>MTP</div>'
'<div class="dc-legend-item"><span class="dc-legend-dot" style="background:#e65100;"></span>AR</div>'
'</div></div>')
h += '<div class="dc-row"><div class="dc-val">'
tok_idx = 0
if token_sequence:
for item in token_sequence:
if not isinstance(item, (list, tuple)) or len(item) < 2:
continue
decode_type = str(item[0]).lower()
text = str(item[1])
safe = text.replace("<", "&lt;").replace(">", "&gt;")
delay = f"{tok_idx * 0.04:.2f}s"
cls = f"tk-ar-{uid}" if decode_type == "ar" else f"tk-mtp-{uid}"
h += f'<span class="{cls}" style="animation-delay:{delay}">{safe}</span> '
tok_idx += 1
h += '</div></div>'
if out_info:
stats = _parse_out_info_dict(out_info)
bits = []
for key, name in [("forward_step", "steps"), ("num_tokens", "tokens"), ("num_boxes", "boxes"), ("ar_step", "AR steps"), ("tps", "tok/s")]:
if key in stats:
bits.append(f"{stats[key]} {name}")
summary = " &middot; ".join(bits) if bits else out_info.strip()
stat_delay = f"{tok_idx * 0.04 + 0.2:.2f}s"
h += (f'<div class="dc-row" style="justify-content:flex-end;padding-top:4px;padding-bottom:10px;border-bottom:none;">'
f'<span class="tk-stat-{uid}" style="animation-delay:{stat_delay}">⚡ {summary}</span></div>')
if raw_text:
safe_raw = raw_text.replace("<", "&lt;").replace(">", "&gt;")
h += (f'<div class="dc-raw"><details><summary>📄 Show Raw Response</summary>'
f'<div class="dc-raw-pre">{safe_raw}</div></details></div>')
h += '</div>'
return h
def generate_raw_prompt(task_type, category):
if not category:
category = "objects"
cats = "</c>".join(c.strip() for c in category.split(",") if c.strip())
if task_type == "Detection":
return f"Locate all the instances that matches the following description: {cats}."
elif task_type == "Grounding":
return f"Locate all the instances that match the following description: {cats}."
elif task_type == "OCR":
return "Detect all the text in box format."
elif task_type == "GUI":
return f"Locate the region that matches the following description: {cats}."
elif task_type == "Pointing":
return f"Point to: {cats}."
return f"Locate all the instances that matches the following description: {cats}."
# ============================================================
# Dynamic Model Safety Initialization
# ============================================================
MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B")
print(f"Loading Base Weight Layer Model Matrix via: {MODEL_PATH}")
GLOBAL_WORKER = EagleWorker(MODEL_PATH)
LOG_DATASET_REPO = os.environ.get("LOG_DATASET_REPO")
_LOG_DIR = Path(tempfile.mkdtemp(prefix="hf_log_"))
_SESSION_ID = uuid.uuid4().hex[:8]
_log_scheduler = None
if LOG_DATASET_REPO and LOG_HF_TOKEN:
try:
_log_scheduler = CommitScheduler(
repo_id=LOG_DATASET_REPO,
repo_type="dataset",
folder_path=str(_LOG_DIR),
path_in_repo="data",
every=5,
token=LOG_HF_TOKEN,
squash_history=False,
)
print(f"[LOG] System Scheduler initialized successfully context workspace mapping tracking.")
except Exception as e:
print(f"[LOG] Remote logging skipped or unauthorized setup boundary: {e}")
def _pil_to_b64(pil_img):
buf = io.BytesIO()
pil_img.save(buf, "PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def _log_to_dataset(input_type, category, model_mode, raw_prompt, output_text="", input_image=None, output_image=None):
if _log_scheduler is None:
return
try:
entry_id = f"{int(time.time())}_{uuid.uuid4().hex[:6]}"
record = {
"id": entry_id,
"session_id": _SESSION_ID,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"input_type": input_type,
"category": category,
"model_mode": model_mode,
"raw_prompt": raw_prompt,
"output_text": output_text,
"input_image_b64": _pil_to_b64(input_image) if input_image else None,
"output_image_b64": _pil_to_b64(output_image) if output_image else None,
}
day_dir = _LOG_DIR / time.strftime("%Y-%m-%d", time.gmtime())
day_dir.mkdir(parents=True, exist_ok=True)
with _log_scheduler.lock:
with open(day_dir / f"{_SESSION_ID}__{entry_id}.jsonl", "w", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
except Exception as e:
print(f"[LOG] Write failure: {e}")
def _prepare_image_for_model(pil_img, short_size):
process_img = pil_img.copy()
if short_size and int(short_size) > 0:
process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024))
else:
if min(process_img.size) > 1024:
process_img, _ = resize_image_short_side(process_img, 1024)
return process_img
# ============================================================
# Spaces GPU Wrapper Decorators
# ============================================================
@spaces.GPU(duration=45)
def _run_image_inference(image_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_size, question_override):
if image_in is None:
return gr.update(value=None, visible=True), gr.update(value=None, visible=False), "<p>⚠️ Upload image.</p>"
process_img = _prepare_image_for_model(image_in, short_size)
output_text, token_sequence, out_info = GLOBAL_WORKER.generate(
process_img, categories_list, model_mode, temp=temp, top_p=top_p, top_k=top_k, question_override=question_override
)
detections = parse_mixed_results(output_text, category_str)
frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR)
out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True)
output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB))
_log_to_dataset("image", ", ".join(categories_list), model_mode, question_override or category_str, output_text, image_in, output_image)
return gr.update(value=output_image, visible=True), gr.update(value=None, visible=False), generate_dynamic_html(token_sequence, out_info, output_text)
@spaces.GPU(duration=180)
def _run_video_inference(video_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_size, question_override, max_video_frames):
import subprocess as _sp
if video_in is None:
return gr.update(value=None, visible=False), gr.update(value=None, visible=True), "<p>⚠️ Upload video.</p>"
cap = cv2.VideoCapture(video_in)
fps = cap.get(cv2.CAP_PROP_FPS)
vid_w, vid_h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
all_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
all_frames.append(frame)
cap.release()
total = len(all_frames)
max_frames = int(max_video_frames) if max_video_frames else 4
sample_indices = list(range(total)) if total <= max_frames else [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)]
sampled_frames = [all_frames[i] for i in sample_indices]
out_fps = max(1.0, len(sampled_frames) / (total / fps)) if fps > 0 else 5.0
del all_frames
gc.collect()
inference_results = []
for frame in sampled_frames:
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
process_img = _prepare_image_for_model(pil_img, short_size)
output_text, _, _ = GLOBAL_WORKER.generate(process_img, categories_list, model_mode, temp=temp, top_p=top_p, top_k=top_k, question_override=question_override)
inference_results.append(output_text)
tmp_raw = tempfile.mktemp(suffix=".raw.mp4")
out_video_path = tempfile.mktemp(suffix=".mp4")
out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h))
for frame, output_text in zip(sampled_frames, inference_results):
detections = parse_mixed_results(output_text, category_str)
valid_results = _postprocess_detections(detections, vid_w, vid_h)
out.write(draw_on_frame(frame, valid_results, draw_label=True))
out.release()
_sp.run(["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264", "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p", out_video_path], capture_output=True)
if os.path.exists(tmp_raw): os.remove(tmp_raw)
combined_raw_text = "\n\n".join([f"--- Frame {i+1} ---\n{t}" for i, t in enumerate(inference_results)])
return gr.update(value=None, visible=False), gr.update(value=out_video_path, visible=True), generate_dynamic_html([], "Processed Loop Successful", combined_raw_text)
def run_inference(input_type, image_in, video_in, task_type, category_str, model_mode, temp, top_p, top_k, short_side, question_override, max_video_frames):
categories_list = [c.strip() for c in category_str.split(",") if c.strip()] or ["object"]
final_override = question_override.strip() if (question_override and question_override.strip()) else None
if input_type == "Image":
return _run_image_inference(image_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_side, final_override)
return _run_video_inference(video_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_side, final_override, max_video_frames)
# ============================================================
# GRADIO INTERFACE LAYOUT BUILD
# ============================================================
def build_ui():
with gr.Blocks(title="LocateAnything Grounding Suite") as demo:
gr.Markdown("# 🔍 LocateAnything Grounding Studio\nInfer target regions, visual boxes, and point indicators.")
with gr.Row():
with gr.Column(scale=1):
input_type = gr.Radio(["Image", "Video"], value="Image", label="Input Format")
image_input = gr.Image(type="pil", label="Source Image", visible=True)
video_input = gr.Video(label="Source Video", visible=False)
task_dropdown = gr.Dropdown(["Detection", "Grounding", "OCR", "GUI", "Pointing"], value="Detection", label="Goal Context Task")
category_input = gr.Textbox(label="Categories / Label Targets (comma separated)", value="car, pedestrian")
raw_prompt_box = gr.Textbox(label="Generated Execution Prompt (Read Only)", value="Locate all the instances that matches the following description: car</c>pedestrian.", interactive=False)
with gr.Accordion("Advanced Parameters", open=False):
model_dropdown = gr.Dropdown(["hybrid", "fast", "slow"], value="hybrid", label="Decoding Engine Mode")
temp_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P")
top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top K")
short_size_input = gr.Slider(0, 1024, value=1024, step=64, label="Max Downscaling Res Constraint (0 for Native)")
max_video_frames_slider = gr.Slider(1, 16, value=4, step=1, label="Video Sample Extraction Cap")
run_btn = gr.Button("Run Inference", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Annotated Image Result", visible=True)
output_video = gr.Video(label="Annotated Video Result", visible=False)
raw_output_box = gr.HTML(label="Visual Trace Dashboard")
input_type.change(
fn=lambda c: (gr.update(visible=(c == "Image")), gr.update(visible=(c == "Video"))),
inputs=input_type, outputs=[image_input, video_input],
)
for comp in [task_dropdown, category_input]:
comp.change(fn=generate_raw_prompt, inputs=[task_dropdown, category_input], outputs=raw_prompt_box)
run_btn.click(
fn=lambda: gr.update(interactive=False, value="Processing Tensors..."),
outputs=[run_btn],
).then(
fn=run_inference,
inputs=[
input_type, image_input, video_input,
task_dropdown, category_input, model_dropdown,
temp_slider, top_p_slider, top_k_slider,
short_size_input, raw_prompt_box, max_video_frames_slider,
],
outputs=[output_image, output_video, raw_output_box],
).then(
fn=lambda: gr.update(interactive=True, value="Run Inference"),
outputs=[run_btn],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.queue().launch()