""" WeDetect: Open-Vocabulary Object Detection Demo HuggingFace Spaces Application This app provides an interactive interface for WeDetect, a fast open-vocabulary object detection model that uses Chinese class names internally. Features: - Upload any image for object detection - Enter class names in English OR Chinese - Automatic English→Chinese translation with editable preview - Adjustable confidence threshold - Visual results with bounding boxes Compatible with: - Gradio 5.50.0+ / 6.x - huggingface_hub 1.x """ import os import sys import subprocess # ============================================================================ # INSTALL MMCV/MMDET/MMENGINE WITH CUDA EXTENSIONS # ============================================================================ # mmcv needs pre-built CUDA extensions. We must install from OpenMMLab's # wheel index with the correct CUDA and PyTorch version. # ============================================================================ def get_torch_cuda_version(): """Detect PyTorch and CUDA versions for wheel selection.""" import torch torch_version = torch.__version__.split('+')[0] # e.g., "2.1.0" torch_major_minor = '.'.join(torch_version.split('.')[:2]) # e.g., "2.1" if torch.cuda.is_available(): cuda_version = torch.version.cuda # e.g., "12.1" cuda_tag = 'cu' + cuda_version.replace('.', '')[:3] # e.g., "cu121" else: cuda_tag = 'cpu' return torch_major_minor, cuda_tag def install_mm_packages(): """Install mmcv, mmdet, mmengine with proper CUDA extensions.""" # First install mmengine (no CUDA extensions needed) try: import mmengine print(f"✅ mmengine already installed: {mmengine.__version__}") except ImportError: print("📦 Installing mmengine...") subprocess.run( [sys.executable, "-m", "pip", "install", "mmengine==0.10.7"], capture_output=True, text=True ) print("✅ mmengine installed") # Install mmcv with CUDA extensions from OpenMMLab wheel index try: import mmcv from mmcv.ops import roi_align # Test if extensions work print(f"✅ mmcv already installed with extensions: {mmcv.__version__}") except (ImportError, ModuleNotFoundError) as e: print(f"📦 Installing mmcv with CUDA extensions... (reason: {e})") # Get versions for wheel selection torch_version, cuda_tag = get_torch_cuda_version() print(f" Detected: PyTorch {torch_version}, CUDA tag: {cuda_tag}") # OpenMMLab wheel index URL wheel_index = f"https://download.openmmlab.com/mmcv/dist/{cuda_tag}/torch{torch_version}/index.html" print(f" Wheel index: {wheel_index}") # Uninstall any existing broken mmcv subprocess.run( [sys.executable, "-m", "pip", "uninstall", "mmcv", "-y"], capture_output=True, text=True ) # Install from OpenMMLab wheel index result = subprocess.run( [sys.executable, "-m", "pip", "install", "mmcv==2.1.0", "-f", wheel_index], capture_output=True, text=True ) if result.returncode != 0: print(f"⚠️ First attempt failed: {result.stderr}") # Try alternative CUDA versions for alt_cuda in ["cu121", "cu118", "cu117"]: if alt_cuda == cuda_tag: continue alt_wheel_index = f"https://download.openmmlab.com/mmcv/dist/{alt_cuda}/torch{torch_version}/index.html" print(f" Trying alternative: {alt_wheel_index}") result = subprocess.run( [sys.executable, "-m", "pip", "install", "mmcv==2.1.0", "-f", alt_wheel_index], capture_output=True, text=True ) if result.returncode == 0: break print("✅ mmcv installed") # Install mmdet try: import mmdet print(f"✅ mmdet already installed: {mmdet.__version__}") except ImportError: print("📦 Installing mmdet...") subprocess.run( [sys.executable, "-m", "pip", "install", "mmdet==3.3.0"], capture_output=True, text=True ) print("✅ mmdet installed") # Run installation before other imports print("🔧 Setting up MM packages with CUDA extensions...") install_mm_packages() # Verify installation print("🔍 Verifying mmcv extensions...") try: from mmcv.ops import roi_align print("✅ mmcv._ext loaded successfully!") except Exception as e: print(f"⚠️ Warning: mmcv extensions may not be fully loaded: {e}") # ============================================================================ # STANDARD IMPORTS (after MM packages are installed) # ============================================================================ import tempfile import colorsys from typing import List, Tuple, Optional import gradio as gr import spaces import numpy as np from PIL import Image, ImageDraw, ImageFont from huggingface_hub import hf_hub_download # ============================================================================ # CONFIGURATION # ============================================================================ DEFAULT_MODEL = "large" # Options: "tiny", "base", "large" REPO_ID = "fushh7/WeDetect" MODEL_INFO = { "tiny": {"file": "wedetect_tiny.pth", "config": "wedetect_tiny"}, "base": {"file": "wedetect_base.pth", "config": "wedetect_base"}, "large": {"file": "wedetect_large.pth", "config": "wedetect_large"}, } # ============================================================================ # ENGLISH TO CHINESE DICTIONARY (~200 common objects) # ============================================================================ ENGLISH_TO_CHINESE = { # People "person": "人", "man": "男人", "woman": "女人", "child": "儿童", "kid": "小孩", "baby": "婴儿", "boy": "男孩", "girl": "女孩", "people": "人", "human": "人", # Animals "dog": "狗", "cat": "猫", "bird": "鸟", "fish": "鱼", "horse": "马", "cow": "牛", "sheep": "羊", "pig": "猪", "chicken": "鸡", "duck": "鸭", "elephant": "大象", "bear": "熊", "zebra": "斑马", "giraffe": "长颈鹿", "lion": "狮子", "tiger": "老虎", "monkey": "猴子", "rabbit": "兔子", "mouse": "老鼠", "snake": "蛇", "turtle": "乌龟", "frog": "青蛙", "butterfly": "蝴蝶", "bee": "蜜蜂", "spider": "蜘蛛", "ant": "蚂蚁", # Vehicles "car": "车", "truck": "卡车", "bus": "公交车", "train": "火车", "airplane": "飞机", "plane": "飞机", "boat": "船", "ship": "船", "bicycle": "自行车", "bike": "自行车", "motorcycle": "摩托车", "helicopter": "直升机", "taxi": "出租车", "ambulance": "救护车", "fire truck": "消防车", "police car": "警车", "van": "面包车", # Furniture "chair": "椅子", "table": "桌子", "desk": "书桌", "bed": "床", "couch": "沙发", "sofa": "沙发", "bench": "长凳", "cabinet": "柜子", "shelf": "架子", "drawer": "抽屉", "wardrobe": "衣柜", "mirror": "镜子", # Electronics "tv": "电视", "television": "电视", "computer": "电脑", "laptop": "笔记本电脑", "phone": "手机", "cell phone": "手机", "mobile phone": "手机", "tablet": "平板电脑", "keyboard": "键盘", "mouse": "鼠标", "monitor": "显示器", "screen": "屏幕", "camera": "相机", "speaker": "音箱", "headphones": "耳机", "microphone": "麦克风", "remote": "遥控器", # Kitchen items "refrigerator": "冰箱", "fridge": "冰箱", "oven": "烤箱", "microwave": "微波炉", "toaster": "烤面包机", "blender": "搅拌机", "kettle": "水壶", "pot": "锅", "pan": "平底锅", "bowl": "碗", "plate": "盘子", "cup": "杯子", "mug": "马克杯", "glass": "玻璃杯", "bottle": "瓶子", "fork": "叉子", "knife": "刀", "spoon": "勺子", "chopsticks": "筷子", # Food "apple": "苹果", "banana": "香蕉", "orange": "橙子", "grape": "葡萄", "strawberry": "草莓", "watermelon": "西瓜", "pizza": "披萨", "hamburger": "汉堡", "sandwich": "三明治", "hot dog": "热狗", "cake": "蛋糕", "bread": "面包", "rice": "米饭", "noodles": "面条", "egg": "鸡蛋", "meat": "肉", "vegetable": "蔬菜", "fruit": "水果", # Clothing "shirt": "衬衫", "pants": "裤子", "dress": "连衣裙", "skirt": "裙子", "jacket": "夹克", "coat": "外套", "sweater": "毛衣", "hat": "帽子", "cap": "帽子", "shoe": "鞋", "shoes": "鞋", "boot": "靴子", "sock": "袜子", "glove": "手套", "scarf": "围巾", "tie": "领带", "belt": "腰带", "bag": "包", "backpack": "背包", "purse": "钱包", "wallet": "钱包", "watch": "手表", "glasses": "眼镜", "sunglasses": "太阳镜", # Sports "ball": "球", "football": "足球", "soccer ball": "足球", "basketball": "篮球", "baseball": "棒球", "tennis ball": "网球", "golf ball": "高尔夫球", "volleyball": "排球", "tennis racket": "网球拍", "skateboard": "滑板", "surfboard": "冲浪板", "ski": "滑雪板", "snowboard": "单板滑雪", "frisbee": "飞盘", # Office/School "book": "书", "notebook": "笔记本", "pen": "笔", "pencil": "铅笔", "paper": "纸", "scissors": "剪刀", "ruler": "尺子", "eraser": "橡皮", "stapler": "订书机", "calculator": "计算器", "clock": "时钟", "calendar": "日历", "folder": "文件夹", # Outdoor "tree": "树", "flower": "花", "grass": "草", "plant": "植物", "leaf": "叶子", "rock": "石头", "mountain": "山", "river": "河", "lake": "湖", "ocean": "海洋", "beach": "海滩", "sky": "天空", "cloud": "云", "sun": "太阳", "moon": "月亮", "star": "星星", # Buildings/Structures "house": "房子", "building": "建筑", "door": "门", "window": "窗户", "wall": "墙", "roof": "屋顶", "floor": "地板", "stairs": "楼梯", "fence": "栅栏", "bridge": "桥", "road": "道路", "street": "街道", "traffic light": "红绿灯", "stop sign": "停止标志", # Household "lamp": "灯", "light": "灯", "fan": "风扇", "air conditioner": "空调", "pillow": "枕头", "blanket": "毯子", "towel": "毛巾", "soap": "肥皂", "toothbrush": "牙刷", "toilet": "马桶", "sink": "水槽", "bathtub": "浴缸", "shower": "淋浴", "curtain": "窗帘", "carpet": "地毯", "rug": "地毯", # Misc "umbrella": "雨伞", "key": "钥匙", "lock": "锁", "box": "盒子", "basket": "篮子", "vase": "花瓶", "candle": "蜡烛", "picture": "图片", "painting": "画", "photo": "照片", "frame": "框架", "toy": "玩具", "teddy bear": "泰迪熊", "doll": "娃娃", "robot": "机器人", "kite": "风筝", "balloon": "气球", "flag": "旗帜", } # ============================================================================ # TRANSLATION FUNCTIONS # ============================================================================ def translate_to_chinese(text: str) -> str: """Translate a single English word/phrase to Chinese using dictionary.""" text_lower = text.lower().strip() # Direct lookup if text_lower in ENGLISH_TO_CHINESE: return ENGLISH_TO_CHINESE[text_lower] # Try without 's' (plurals) if text_lower.endswith('s') and text_lower[:-1] in ENGLISH_TO_CHINESE: return ENGLISH_TO_CHINESE[text_lower[:-1]] # Try without 'es' (plurals) if text_lower.endswith('es') and text_lower[:-2] in ENGLISH_TO_CHINESE: return ENGLISH_TO_CHINESE[text_lower[:-2]] # Return original if not found (might already be Chinese) return text def translate_class_list(classes_text: str, input_mode: str) -> str: """ Translate a comma-separated list of classes. Args: classes_text: Comma-separated class names input_mode: "English" or "Chinese (中文)" Returns: Comma-separated Chinese class names """ if not classes_text.strip(): return "" classes = [c.strip() for c in classes_text.split(',') if c.strip()] if input_mode == "English": translated = [translate_to_chinese(c) for c in classes] return ', '.join(translated) else: # Already Chinese, return as-is return ', '.join(classes) # ============================================================================ # MODEL LOADING # ============================================================================ # Global model cache _model_cache = {} _repo_path = None def setup_repo(): """Clone the WeDetect repository if not already present.""" global _repo_path if _repo_path is not None and os.path.exists(_repo_path): return _repo_path repo_dir = "/tmp/WeDetect" if not os.path.exists(repo_dir): print("📥 Cloning WeDetect repository...") try: subprocess.run( ["git", "clone", "--depth", "1", "https://github.com/WeChatCV/WeDetect.git", repo_dir], check=True, capture_output=True, text=True ) print("✅ Repository cloned!") except subprocess.CalledProcessError as e: print(f"❌ Failed to clone repository: {e.stderr}") raise # Add to Python path for imports if repo_dir not in sys.path: sys.path.insert(0, repo_dir) _repo_path = repo_dir return repo_dir def get_model(model_size: str = DEFAULT_MODEL): """Load and cache the WeDetect model.""" global _model_cache if model_size in _model_cache: return _model_cache[model_size] import torch from mmengine.config import Config from mmdet.apis import init_detector device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"🚀 Loading WeDetect-{model_size.capitalize()} on {device}...") # Setup repository for config files repo_dir = setup_repo() # Config path from cloned repo config_path = os.path.join(repo_dir, "config", f"wedetect_{model_size}.py") if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found: {config_path}") # Download checkpoint from HuggingFace checkpoint_file = MODEL_INFO[model_size]["file"] print(f"📥 Downloading checkpoint: {checkpoint_file}...") checkpoint_path = hf_hub_download( repo_id=REPO_ID, filename=checkpoint_file, cache_dir="./models" ) # Initialize model print("🔧 Initializing model...") model = init_detector(config_path, checkpoint_path, device=device) _model_cache[model_size] = model print(f"✅ WeDetect-{model_size.capitalize()} loaded successfully!") return model # ============================================================================ # VISUALIZATION # ============================================================================ def generate_colors(n: int) -> List[Tuple[int, int, int]]: """Generate n distinct colors for visualization.""" colors = [] for i in range(max(n, 1)): hue = i / max(n, 1) rgb = colorsys.hsv_to_rgb(hue, 0.8, 0.9) colors.append(tuple(int(x * 255) for x in rgb)) return colors def draw_detections( image: Image.Image, boxes: np.ndarray, scores: np.ndarray, labels: np.ndarray, class_names: List[str], threshold: float ) -> Tuple[Image.Image, int]: """Draw bounding boxes and labels on an image.""" img_draw = image.copy() draw = ImageDraw.Draw(img_draw) # Try to load a Chinese-compatible font font = None font_paths = [ "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", "simsun.ttc", ] for font_path in font_paths: try: font = ImageFont.truetype(font_path, 18) break except (IOError, OSError): continue if font is None: try: font = ImageFont.load_default(size=16) except TypeError: # Older Pillow versions don't support size argument font = ImageFont.load_default() colors = generate_colors(len(class_names)) detection_count = 0 for box, score, label_idx in zip(boxes, scores, labels): if score < threshold: continue detection_count += 1 x1, y1, x2, y2 = map(int, box) color = colors[int(label_idx) % len(colors)] # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # Prepare label text class_name = class_names[int(label_idx)] if int(label_idx) < len(class_names) else "?" label_text = f"{class_name}: {score:.2f}" # Get text bounding box bbox = draw.textbbox((x1, y1), label_text, font=font) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] # Draw label background draw.rectangle( [x1, y1 - text_h - 6, x1 + text_w + 6, y1], fill=color ) # Draw label text draw.text((x1 + 3, y1 - text_h - 3), label_text, fill='white', font=font) return img_draw, detection_count # ============================================================================ # MAIN DETECTION FUNCTION # ============================================================================ @spaces.GPU def detect_objects( image: Optional[Image.Image], chinese_classes: str, threshold: float, model_size: str ) -> Tuple[Optional[Image.Image], str]: """ Run object detection on an image. Args: image: Input PIL Image chinese_classes: Comma-separated Chinese class names threshold: Confidence threshold model_size: Model size to use Returns: Tuple of (annotated image, status message) """ if image is None: return None, "⚠️ Please upload an image" if not chinese_classes.strip(): return image, "⚠️ Please enter class names to detect" # Parse class names class_names = [c.strip() for c in chinese_classes.split(',') if c.strip()] if not class_names: return image, "⚠️ No valid class names provided" try: import torch from mmdet.apis import inference_detector # Load model model = get_model(model_size) # Update model with class names model.dataset_meta['classes'] = class_names # Save image temporarily with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f: temp_path = f.name image.save(temp_path) try: # Run inference results = inference_detector(model, temp_path, texts=[class_names]) # Extract predictions pred = results.pred_instances boxes = pred.bboxes.cpu().numpy() scores = pred.scores.cpu().numpy() labels = pred.labels.cpu().numpy() # Draw results result_image, count = draw_detections( image, boxes, scores, labels, class_names, threshold ) status = f"✅ Found {count} object(s) | Classes: {', '.join(class_names)}" return result_image, status finally: # Cleanup if os.path.exists(temp_path): os.unlink(temp_path) except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return image, f"❌ Error: {str(e)}" # ============================================================================ # GRADIO INTERFACE # ============================================================================ # Custom CSS for styling CUSTOM_CSS = """ .output-class { font-family: 'Noto Sans SC', sans-serif; } .info-text { color: #666; font-size: 0.9em; } """ def create_demo(): """Create the Gradio demo interface.""" # NOTE: In Gradio 5.50+/6.0, theme and css must be passed to launch(), not Blocks() with gr.Blocks() as demo: gr.Markdown(""" # 🔍 WeDetect: Open-Vocabulary Object Detection Upload an image and specify what objects to detect. Enter class names in **English** or **Chinese**. > **Note:** WeDetect uses Chinese internally. English inputs are automatically translated. """) with gr.Row(): # Left column: Inputs with gr.Column(scale=1): input_image = gr.Image( label="📷 Upload Image", type="pil", height=350 ) input_mode = gr.Radio( choices=["English", "Chinese (中文)"], value="English", label="🌐 Input Language", info="Choose the language for entering class names" ) classes_input = gr.Textbox( label="🏷️ Classes to Detect", placeholder="person, car, dog, cat", value="person, car, dog", info="Enter class names separated by commas", lines=2 ) chinese_preview = gr.Textbox( label="🀄 Chinese Classes (Editable)", placeholder="人, 车, 狗, 猫", value="人, 车, 狗", info="Edit if translation needs correction", lines=2 ) threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.3, step=0.05, label="📊 Confidence Threshold" ) model_dropdown = gr.Dropdown( choices=["large", "base", "tiny"], value=DEFAULT_MODEL, label="🧠 Model Size", info="Large=best quality, Tiny=fastest" ) detect_btn = gr.Button( "🔍 Detect Objects", variant="primary", size="lg" ) # Right column: Output with gr.Column(scale=1): output_image = gr.Image( label="🎯 Detection Results", type="pil", height=350 ) status_text = gr.Textbox( label="📋 Status", interactive=False, lines=2 ) # Class name reference with gr.Accordion("📚 Common Class Names Reference", open=False): gr.Markdown(""" | English | Chinese | | English | Chinese | | English | Chinese | |---------|---------|---|---------|---------|---|---------|---------| | person | 人 | | car | 车 | | dog | 狗 | | cat | 猫 | | bird | 鸟 | | horse | 马 | | bicycle | 自行车 | | motorcycle | 摩托车 | | bus | 公交车 | | truck | 卡车 | | chair | 椅子 | | table | 桌子 | | bed | 床 | | couch | 沙发 | | tv | 电视 | | laptop | 笔记本电脑 | | phone | 手机 | | book | 书 | | bottle | 瓶子 | | cup | 杯子 | | shoe | 鞋 | | bag | 包 | | umbrella | 雨伞 | | tree | 树 | **Example inputs:** - English: `person, car, dog, cat, bicycle` - Chinese: `人, 车, 狗, 猫, 自行车` """) # Event handlers def update_chinese_preview(classes_text: str, mode: str) -> str: return translate_class_list(classes_text, mode) # Auto-translate when input changes classes_input.change( fn=update_chinese_preview, inputs=[classes_input, input_mode], outputs=chinese_preview ) input_mode.change( fn=update_chinese_preview, inputs=[classes_input, input_mode], outputs=chinese_preview ) # Detection button click detect_btn.click( fn=detect_objects, inputs=[input_image, chinese_preview, threshold_slider, model_dropdown], outputs=[output_image, status_text] ) gr.Markdown(""" --- **Credits:** [WeDetect](https://github.com/WeChatCV/WeDetect) by WeChatCV | [Paper](https://arxiv.org/abs/2512.12309) | [Models](https://huggingface.co/fushh7/WeDetect) """) return demo # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": demo = create_demo() # Pass theme and css to launch() for Gradio 5.50+/6.0 compatibility demo.launch( )