WeDetect-demo / app.py
mrdbourke's picture
Update app.py
cbfe9bd verified
"""
WeDetect: Open-Vocabulary Object Detection Demo
HuggingFace Spaces Application
This app provides an interactive interface for WeDetect, a fast open-vocabulary
object detection model that uses Chinese class names internally.
Features:
- Upload any image for object detection
- Enter class names in English OR Chinese
- Automatic English→Chinese translation with editable preview
- Adjustable confidence threshold
- Visual results with bounding boxes
Compatible with:
- Gradio 5.50.0+ / 6.x
- huggingface_hub 1.x
"""
import os
import sys
import subprocess
# ============================================================================
# INSTALL MMCV/MMDET/MMENGINE WITH CUDA EXTENSIONS
# ============================================================================
# mmcv needs pre-built CUDA extensions. We must install from OpenMMLab's
# wheel index with the correct CUDA and PyTorch version.
# ============================================================================
def get_torch_cuda_version():
"""Detect PyTorch and CUDA versions for wheel selection."""
import torch
torch_version = torch.__version__.split('+')[0] # e.g., "2.1.0"
torch_major_minor = '.'.join(torch_version.split('.')[:2]) # e.g., "2.1"
if torch.cuda.is_available():
cuda_version = torch.version.cuda # e.g., "12.1"
cuda_tag = 'cu' + cuda_version.replace('.', '')[:3] # e.g., "cu121"
else:
cuda_tag = 'cpu'
return torch_major_minor, cuda_tag
def install_mm_packages():
"""Install mmcv, mmdet, mmengine with proper CUDA extensions."""
# First install mmengine (no CUDA extensions needed)
try:
import mmengine
print(f"✅ mmengine already installed: {mmengine.__version__}")
except ImportError:
print("📦 Installing mmengine...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "mmengine==0.10.7"],
capture_output=True, text=True
)
print("✅ mmengine installed")
# Install mmcv with CUDA extensions from OpenMMLab wheel index
try:
import mmcv
from mmcv.ops import roi_align # Test if extensions work
print(f"✅ mmcv already installed with extensions: {mmcv.__version__}")
except (ImportError, ModuleNotFoundError) as e:
print(f"📦 Installing mmcv with CUDA extensions... (reason: {e})")
# Get versions for wheel selection
torch_version, cuda_tag = get_torch_cuda_version()
print(f" Detected: PyTorch {torch_version}, CUDA tag: {cuda_tag}")
# OpenMMLab wheel index URL
wheel_index = f"https://download.openmmlab.com/mmcv/dist/{cuda_tag}/torch{torch_version}/index.html"
print(f" Wheel index: {wheel_index}")
# Uninstall any existing broken mmcv
subprocess.run(
[sys.executable, "-m", "pip", "uninstall", "mmcv", "-y"],
capture_output=True, text=True
)
# Install from OpenMMLab wheel index
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "mmcv==2.1.0", "-f", wheel_index],
capture_output=True, text=True
)
if result.returncode != 0:
print(f"⚠️ First attempt failed: {result.stderr}")
# Try alternative CUDA versions
for alt_cuda in ["cu121", "cu118", "cu117"]:
if alt_cuda == cuda_tag:
continue
alt_wheel_index = f"https://download.openmmlab.com/mmcv/dist/{alt_cuda}/torch{torch_version}/index.html"
print(f" Trying alternative: {alt_wheel_index}")
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "mmcv==2.1.0", "-f", alt_wheel_index],
capture_output=True, text=True
)
if result.returncode == 0:
break
print("✅ mmcv installed")
# Install mmdet
try:
import mmdet
print(f"✅ mmdet already installed: {mmdet.__version__}")
except ImportError:
print("📦 Installing mmdet...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "mmdet==3.3.0"],
capture_output=True, text=True
)
print("✅ mmdet installed")
# Run installation before other imports
print("🔧 Setting up MM packages with CUDA extensions...")
install_mm_packages()
# Verify installation
print("🔍 Verifying mmcv extensions...")
try:
from mmcv.ops import roi_align
print("✅ mmcv._ext loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: mmcv extensions may not be fully loaded: {e}")
# ============================================================================
# STANDARD IMPORTS (after MM packages are installed)
# ============================================================================
import tempfile
import colorsys
from typing import List, Tuple, Optional
import gradio as gr
import spaces
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download
# ============================================================================
# CONFIGURATION
# ============================================================================
DEFAULT_MODEL = "large" # Options: "tiny", "base", "large"
REPO_ID = "fushh7/WeDetect"
MODEL_INFO = {
"tiny": {"file": "wedetect_tiny.pth", "config": "wedetect_tiny"},
"base": {"file": "wedetect_base.pth", "config": "wedetect_base"},
"large": {"file": "wedetect_large.pth", "config": "wedetect_large"},
}
# ============================================================================
# ENGLISH TO CHINESE DICTIONARY (~200 common objects)
# ============================================================================
ENGLISH_TO_CHINESE = {
# People
"person": "人", "man": "男人", "woman": "女人", "child": "儿童", "kid": "小孩",
"baby": "婴儿", "boy": "男孩", "girl": "女孩", "people": "人", "human": "人",
# Animals
"dog": "狗", "cat": "猫", "bird": "鸟", "fish": "鱼", "horse": "马",
"cow": "牛", "sheep": "羊", "pig": "猪", "chicken": "鸡", "duck": "鸭",
"elephant": "大象", "bear": "熊", "zebra": "斑马", "giraffe": "长颈鹿",
"lion": "狮子", "tiger": "老虎", "monkey": "猴子", "rabbit": "兔子",
"mouse": "老鼠", "snake": "蛇", "turtle": "乌龟", "frog": "青蛙",
"butterfly": "蝴蝶", "bee": "蜜蜂", "spider": "蜘蛛", "ant": "蚂蚁",
# Vehicles
"car": "车", "truck": "卡车", "bus": "公交车", "train": "火车",
"airplane": "飞机", "plane": "飞机", "boat": "船", "ship": "船",
"bicycle": "自行车", "bike": "自行车", "motorcycle": "摩托车",
"helicopter": "直升机", "taxi": "出租车", "ambulance": "救护车",
"fire truck": "消防车", "police car": "警车", "van": "面包车",
# Furniture
"chair": "椅子", "table": "桌子", "desk": "书桌", "bed": "床",
"couch": "沙发", "sofa": "沙发", "bench": "长凳", "cabinet": "柜子",
"shelf": "架子", "drawer": "抽屉", "wardrobe": "衣柜", "mirror": "镜子",
# Electronics
"tv": "电视", "television": "电视", "computer": "电脑", "laptop": "笔记本电脑",
"phone": "手机", "cell phone": "手机", "mobile phone": "手机",
"tablet": "平板电脑", "keyboard": "键盘", "mouse": "鼠标",
"monitor": "显示器", "screen": "屏幕", "camera": "相机", "speaker": "音箱",
"headphones": "耳机", "microphone": "麦克风", "remote": "遥控器",
# Kitchen items
"refrigerator": "冰箱", "fridge": "冰箱", "oven": "烤箱",
"microwave": "微波炉", "toaster": "烤面包机", "blender": "搅拌机",
"kettle": "水壶", "pot": "锅", "pan": "平底锅", "bowl": "碗",
"plate": "盘子", "cup": "杯子", "mug": "马克杯", "glass": "玻璃杯",
"bottle": "瓶子", "fork": "叉子", "knife": "刀", "spoon": "勺子",
"chopsticks": "筷子",
# Food
"apple": "苹果", "banana": "香蕉", "orange": "橙子", "grape": "葡萄",
"strawberry": "草莓", "watermelon": "西瓜", "pizza": "披萨",
"hamburger": "汉堡", "sandwich": "三明治", "hot dog": "热狗",
"cake": "蛋糕", "bread": "面包", "rice": "米饭", "noodles": "面条",
"egg": "鸡蛋", "meat": "肉", "vegetable": "蔬菜", "fruit": "水果",
# Clothing
"shirt": "衬衫", "pants": "裤子", "dress": "连衣裙", "skirt": "裙子",
"jacket": "夹克", "coat": "外套", "sweater": "毛衣", "hat": "帽子",
"cap": "帽子", "shoe": "鞋", "shoes": "鞋", "boot": "靴子",
"sock": "袜子", "glove": "手套", "scarf": "围巾", "tie": "领带",
"belt": "腰带", "bag": "包", "backpack": "背包", "purse": "钱包",
"wallet": "钱包", "watch": "手表", "glasses": "眼镜", "sunglasses": "太阳镜",
# Sports
"ball": "球", "football": "足球", "soccer ball": "足球",
"basketball": "篮球", "baseball": "棒球", "tennis ball": "网球",
"golf ball": "高尔夫球", "volleyball": "排球",
"tennis racket": "网球拍", "skateboard": "滑板", "surfboard": "冲浪板",
"ski": "滑雪板", "snowboard": "单板滑雪", "frisbee": "飞盘",
# Office/School
"book": "书", "notebook": "笔记本", "pen": "笔", "pencil": "铅笔",
"paper": "纸", "scissors": "剪刀", "ruler": "尺子", "eraser": "橡皮",
"stapler": "订书机", "calculator": "计算器", "clock": "时钟",
"calendar": "日历", "folder": "文件夹",
# Outdoor
"tree": "树", "flower": "花", "grass": "草", "plant": "植物",
"leaf": "叶子", "rock": "石头", "mountain": "山", "river": "河",
"lake": "湖", "ocean": "海洋", "beach": "海滩", "sky": "天空",
"cloud": "云", "sun": "太阳", "moon": "月亮", "star": "星星",
# Buildings/Structures
"house": "房子", "building": "建筑", "door": "门", "window": "窗户",
"wall": "墙", "roof": "屋顶", "floor": "地板", "stairs": "楼梯",
"fence": "栅栏", "bridge": "桥", "road": "道路", "street": "街道",
"traffic light": "红绿灯", "stop sign": "停止标志",
# Household
"lamp": "灯", "light": "灯", "fan": "风扇", "air conditioner": "空调",
"pillow": "枕头", "blanket": "毯子", "towel": "毛巾", "soap": "肥皂",
"toothbrush": "牙刷", "toilet": "马桶", "sink": "水槽", "bathtub": "浴缸",
"shower": "淋浴", "curtain": "窗帘", "carpet": "地毯", "rug": "地毯",
# Misc
"umbrella": "雨伞", "key": "钥匙", "lock": "锁", "box": "盒子",
"basket": "篮子", "vase": "花瓶", "candle": "蜡烛", "picture": "图片",
"painting": "画", "photo": "照片", "frame": "框架", "toy": "玩具",
"teddy bear": "泰迪熊", "doll": "娃娃", "robot": "机器人",
"kite": "风筝", "balloon": "气球", "flag": "旗帜",
}
# ============================================================================
# TRANSLATION FUNCTIONS
# ============================================================================
def translate_to_chinese(text: str) -> str:
"""Translate a single English word/phrase to Chinese using dictionary."""
text_lower = text.lower().strip()
# Direct lookup
if text_lower in ENGLISH_TO_CHINESE:
return ENGLISH_TO_CHINESE[text_lower]
# Try without 's' (plurals)
if text_lower.endswith('s') and text_lower[:-1] in ENGLISH_TO_CHINESE:
return ENGLISH_TO_CHINESE[text_lower[:-1]]
# Try without 'es' (plurals)
if text_lower.endswith('es') and text_lower[:-2] in ENGLISH_TO_CHINESE:
return ENGLISH_TO_CHINESE[text_lower[:-2]]
# Return original if not found (might already be Chinese)
return text
def translate_class_list(classes_text: str, input_mode: str) -> str:
"""
Translate a comma-separated list of classes.
Args:
classes_text: Comma-separated class names
input_mode: "English" or "Chinese (中文)"
Returns:
Comma-separated Chinese class names
"""
if not classes_text.strip():
return ""
classes = [c.strip() for c in classes_text.split(',') if c.strip()]
if input_mode == "English":
translated = [translate_to_chinese(c) for c in classes]
return ', '.join(translated)
else:
# Already Chinese, return as-is
return ', '.join(classes)
# ============================================================================
# MODEL LOADING
# ============================================================================
# Global model cache
_model_cache = {}
_repo_path = None
def setup_repo():
"""Clone the WeDetect repository if not already present."""
global _repo_path
if _repo_path is not None and os.path.exists(_repo_path):
return _repo_path
repo_dir = "/tmp/WeDetect"
if not os.path.exists(repo_dir):
print("📥 Cloning WeDetect repository...")
try:
subprocess.run(
["git", "clone", "--depth", "1", "https://github.com/WeChatCV/WeDetect.git", repo_dir],
check=True,
capture_output=True,
text=True
)
print("✅ Repository cloned!")
except subprocess.CalledProcessError as e:
print(f"❌ Failed to clone repository: {e.stderr}")
raise
# Add to Python path for imports
if repo_dir not in sys.path:
sys.path.insert(0, repo_dir)
_repo_path = repo_dir
return repo_dir
def get_model(model_size: str = DEFAULT_MODEL):
"""Load and cache the WeDetect model."""
global _model_cache
if model_size in _model_cache:
return _model_cache[model_size]
import torch
from mmengine.config import Config
from mmdet.apis import init_detector
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🚀 Loading WeDetect-{model_size.capitalize()} on {device}...")
# Setup repository for config files
repo_dir = setup_repo()
# Config path from cloned repo
config_path = os.path.join(repo_dir, "config", f"wedetect_{model_size}.py")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
# Download checkpoint from HuggingFace
checkpoint_file = MODEL_INFO[model_size]["file"]
print(f"📥 Downloading checkpoint: {checkpoint_file}...")
checkpoint_path = hf_hub_download(
repo_id=REPO_ID,
filename=checkpoint_file,
cache_dir="./models"
)
# Initialize model
print("🔧 Initializing model...")
model = init_detector(config_path, checkpoint_path, device=device)
_model_cache[model_size] = model
print(f"✅ WeDetect-{model_size.capitalize()} loaded successfully!")
return model
# ============================================================================
# VISUALIZATION
# ============================================================================
def generate_colors(n: int) -> List[Tuple[int, int, int]]:
"""Generate n distinct colors for visualization."""
colors = []
for i in range(max(n, 1)):
hue = i / max(n, 1)
rgb = colorsys.hsv_to_rgb(hue, 0.8, 0.9)
colors.append(tuple(int(x * 255) for x in rgb))
return colors
def draw_detections(
image: Image.Image,
boxes: np.ndarray,
scores: np.ndarray,
labels: np.ndarray,
class_names: List[str],
threshold: float
) -> Tuple[Image.Image, int]:
"""Draw bounding boxes and labels on an image."""
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
# Try to load a Chinese-compatible font
font = None
font_paths = [
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"simsun.ttc",
]
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 18)
break
except (IOError, OSError):
continue
if font is None:
try:
font = ImageFont.load_default(size=16)
except TypeError:
# Older Pillow versions don't support size argument
font = ImageFont.load_default()
colors = generate_colors(len(class_names))
detection_count = 0
for box, score, label_idx in zip(boxes, scores, labels):
if score < threshold:
continue
detection_count += 1
x1, y1, x2, y2 = map(int, box)
color = colors[int(label_idx) % len(colors)]
# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# Prepare label text
class_name = class_names[int(label_idx)] if int(label_idx) < len(class_names) else "?"
label_text = f"{class_name}: {score:.2f}"
# Get text bounding box
bbox = draw.textbbox((x1, y1), label_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
# Draw label background
draw.rectangle(
[x1, y1 - text_h - 6, x1 + text_w + 6, y1],
fill=color
)
# Draw label text
draw.text((x1 + 3, y1 - text_h - 3), label_text, fill='white', font=font)
return img_draw, detection_count
# ============================================================================
# MAIN DETECTION FUNCTION
# ============================================================================
@spaces.GPU
def detect_objects(
image: Optional[Image.Image],
chinese_classes: str,
threshold: float,
model_size: str
) -> Tuple[Optional[Image.Image], str]:
"""
Run object detection on an image.
Args:
image: Input PIL Image
chinese_classes: Comma-separated Chinese class names
threshold: Confidence threshold
model_size: Model size to use
Returns:
Tuple of (annotated image, status message)
"""
if image is None:
return None, "⚠️ Please upload an image"
if not chinese_classes.strip():
return image, "⚠️ Please enter class names to detect"
# Parse class names
class_names = [c.strip() for c in chinese_classes.split(',') if c.strip()]
if not class_names:
return image, "⚠️ No valid class names provided"
try:
import torch
from mmdet.apis import inference_detector
# Load model
model = get_model(model_size)
# Update model with class names
model.dataset_meta['classes'] = class_names
# Save image temporarily
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f:
temp_path = f.name
image.save(temp_path)
try:
# Run inference
results = inference_detector(model, temp_path, texts=[class_names])
# Extract predictions
pred = results.pred_instances
boxes = pred.bboxes.cpu().numpy()
scores = pred.scores.cpu().numpy()
labels = pred.labels.cpu().numpy()
# Draw results
result_image, count = draw_detections(
image, boxes, scores, labels, class_names, threshold
)
status = f"✅ Found {count} object(s) | Classes: {', '.join(class_names)}"
return result_image, status
finally:
# Cleanup
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return image, f"❌ Error: {str(e)}"
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
# Custom CSS for styling
CUSTOM_CSS = """
.output-class { font-family: 'Noto Sans SC', sans-serif; }
.info-text { color: #666; font-size: 0.9em; }
"""
def create_demo():
"""Create the Gradio demo interface."""
# NOTE: In Gradio 5.50+/6.0, theme and css must be passed to launch(), not Blocks()
with gr.Blocks() as demo:
gr.Markdown("""
# 🔍 WeDetect: Open-Vocabulary Object Detection
Upload an image and specify what objects to detect. Enter class names in **English** or **Chinese**.
> **Note:** WeDetect uses Chinese internally. English inputs are automatically translated.
""")
with gr.Row():
# Left column: Inputs
with gr.Column(scale=1):
input_image = gr.Image(
label="📷 Upload Image",
type="pil",
height=350
)
input_mode = gr.Radio(
choices=["English", "Chinese (中文)"],
value="English",
label="🌐 Input Language",
info="Choose the language for entering class names"
)
classes_input = gr.Textbox(
label="🏷️ Classes to Detect",
placeholder="person, car, dog, cat",
value="person, car, dog",
info="Enter class names separated by commas",
lines=2
)
chinese_preview = gr.Textbox(
label="🀄 Chinese Classes (Editable)",
placeholder="人, 车, 狗, 猫",
value="人, 车, 狗",
info="Edit if translation needs correction",
lines=2
)
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.05,
label="📊 Confidence Threshold"
)
model_dropdown = gr.Dropdown(
choices=["large", "base", "tiny"],
value=DEFAULT_MODEL,
label="🧠 Model Size",
info="Large=best quality, Tiny=fastest"
)
detect_btn = gr.Button(
"🔍 Detect Objects",
variant="primary",
size="lg"
)
# Right column: Output
with gr.Column(scale=1):
output_image = gr.Image(
label="🎯 Detection Results",
type="pil",
height=350
)
status_text = gr.Textbox(
label="📋 Status",
interactive=False,
lines=2
)
# Class name reference
with gr.Accordion("📚 Common Class Names Reference", open=False):
gr.Markdown("""
| English | Chinese | | English | Chinese | | English | Chinese |
|---------|---------|---|---------|---------|---|---------|---------|
| person | 人 | | car | 车 | | dog | 狗 |
| cat | 猫 | | bird | 鸟 | | horse | 马 |
| bicycle | 自行车 | | motorcycle | 摩托车 | | bus | 公交车 |
| truck | 卡车 | | chair | 椅子 | | table | 桌子 |
| bed | 床 | | couch | 沙发 | | tv | 电视 |
| laptop | 笔记本电脑 | | phone | 手机 | | book | 书 |
| bottle | 瓶子 | | cup | 杯子 | | shoe | 鞋 |
| bag | 包 | | umbrella | 雨伞 | | tree | 树 |
**Example inputs:**
- English: `person, car, dog, cat, bicycle`
- Chinese: `人, 车, 狗, 猫, 自行车`
""")
# Event handlers
def update_chinese_preview(classes_text: str, mode: str) -> str:
return translate_class_list(classes_text, mode)
# Auto-translate when input changes
classes_input.change(
fn=update_chinese_preview,
inputs=[classes_input, input_mode],
outputs=chinese_preview
)
input_mode.change(
fn=update_chinese_preview,
inputs=[classes_input, input_mode],
outputs=chinese_preview
)
# Detection button click
detect_btn.click(
fn=detect_objects,
inputs=[input_image, chinese_preview, threshold_slider, model_dropdown],
outputs=[output_image, status_text]
)
gr.Markdown("""
---
**Credits:** [WeDetect](https://github.com/WeChatCV/WeDetect) by WeChatCV |
[Paper](https://arxiv.org/abs/2512.12309) |
[Models](https://huggingface.co/fushh7/WeDetect)
""")
return demo
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
demo = create_demo()
# Pass theme and css to launch() for Gradio 5.50+/6.0 compatibility
demo.launch(
)