#!/usr/bin/env python """ TransNormal - Hugging Face Spaces Zero GPU Version Surface Normal Estimation for Transparent Objects """ import os import spaces import torch import gradio as gr from PIL import Image from huggingface_hub import snapshot_download from transnormal import TransNormalPipeline, create_dino_encoder # ============== Model Paths ============== TRANSNORMAL_REPO = "Longxiang-ai/TransNormal" DINO_REPO = "facebook/dinov3-vith16plus-pretrain-lvd1689m" # ========================================= # Global pipeline pipe = None weights_downloaded = False def download_weights(): """Download model weights from HuggingFace Hub.""" global weights_downloaded if weights_downloaded: return "./weights/transnormal", "./weights/dinov3_vith16plus" print("[TransNormal] Downloading TransNormal weights...") transnormal_path = snapshot_download( TRANSNORMAL_REPO, local_dir="./weights/transnormal" ) print("[TransNormal] Downloading DINOv3 weights...") dino_path = snapshot_download( DINO_REPO, local_dir="./weights/dinov3_vith16plus" ) weights_downloaded = True print("[TransNormal] Weights downloaded successfully!") return transnormal_path, dino_path def load_pipeline(): """Load the TransNormal pipeline.""" global pipe if pipe is not None: return pipe device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 print(f"[TransNormal] Loading model on {device} with {dtype}...") # Download weights transnormal_path, dino_path = download_weights() projector_path = os.path.join(transnormal_path, "cross_attention_projector.pt") # Load DINO encoder dino_encoder = create_dino_encoder( model_name="dinov3_vith16plus", cross_attention_dim=1024, weights_path=dino_path, projector_path=projector_path, device=device, dtype=dtype, freeze_encoder=True, ) # Load pipeline pipe = TransNormalPipeline.from_pretrained( transnormal_path, dino_encoder=dino_encoder, torch_dtype=dtype, ) pipe = pipe.to(device) print("[TransNormal] Model loaded successfully!") return pipe @spaces.GPU(duration=120) def predict_normal(image: Image.Image, processing_res: int = 768) -> Image.Image: """ Predict surface normal from input image using Zero GPU. Args: image: Input RGB image processing_res: Processing resolution Returns: Normal map as PIL Image """ if image is None: return None # Load pipeline (will use GPU allocated by @spaces.GPU) pipeline = load_pipeline() # Run inference with torch.no_grad(): normal_map = pipeline( image=image, processing_res=processing_res, output_type="pil", ) return normal_map # ============== Gradio Interface ============== custom_css = """ .gradio-container { font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif !important; } h1 { font-weight: 600 !important; } """ with gr.Blocks( title="TransNormal", theme=gr.themes.Soft(), css=custom_css, ) as demo: gr.Markdown( """ # 🔮 TransNormal ### Surface Normal Estimation for Transparent Objects Upload an image to estimate surface normals. Particularly effective for **transparent objects** like glass and plastic. **Normal Convention:** Red=X (Left) | Green=Y (Up) | Blue=Z (Out) > ⏱️ First inference may take ~1-2 minutes to load model weights. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="pil", height=400, ) processing_res = gr.Slider( minimum=256, maximum=1024, value=768, step=64, label="Processing Resolution (higher = better quality but slower)", ) submit_btn = gr.Button("🚀 Estimate Normal", variant="primary", size="lg") with gr.Column(): output_image = gr.Image( label="Normal Map", type="pil", height=400, ) # Event handlers submit_btn.click( fn=predict_normal, inputs=[input_image, processing_res], outputs=output_image, ) # Footer gr.Markdown( """ --- **Paper:** [TransNormal: Dense Visual Semantics for Diffusion-based Transparent Object Normal Estimation](https://longxiang-ai.github.io/TransNormal/) **Authors:** Mingwei Li, Hehe Fan, Yi Yang (Zhejiang University) **Code:** [GitHub](https://github.com/longxiang-ai/TransNormal) """ ) # Launch if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)