Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| """ | |
| TransNormal - Hugging Face Spaces Zero GPU Version | |
| Surface Normal Estimation for Transparent Objects | |
| """ | |
| import os | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from transnormal import TransNormalPipeline, create_dino_encoder | |
| # ============== Model Paths ============== | |
| TRANSNORMAL_REPO = "Longxiang-ai/TransNormal" | |
| DINO_REPO = "facebook/dinov3-vith16plus-pretrain-lvd1689m" | |
| # ========================================= | |
| # Global pipeline | |
| pipe = None | |
| weights_downloaded = False | |
| def download_weights(): | |
| """Download model weights from HuggingFace Hub.""" | |
| global weights_downloaded | |
| if weights_downloaded: | |
| return "./weights/transnormal", "./weights/dinov3_vith16plus" | |
| print("[TransNormal] Downloading TransNormal weights...") | |
| transnormal_path = snapshot_download( | |
| TRANSNORMAL_REPO, | |
| local_dir="./weights/transnormal" | |
| ) | |
| print("[TransNormal] Downloading DINOv3 weights...") | |
| dino_path = snapshot_download( | |
| DINO_REPO, | |
| local_dir="./weights/dinov3_vith16plus" | |
| ) | |
| weights_downloaded = True | |
| print("[TransNormal] Weights downloaded successfully!") | |
| return transnormal_path, dino_path | |
| def load_pipeline(): | |
| """Load the TransNormal pipeline.""" | |
| global pipe | |
| if pipe is not None: | |
| return pipe | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| print(f"[TransNormal] Loading model on {device} with {dtype}...") | |
| # Download weights | |
| transnormal_path, dino_path = download_weights() | |
| projector_path = os.path.join(transnormal_path, "cross_attention_projector.pt") | |
| # Load DINO encoder | |
| dino_encoder = create_dino_encoder( | |
| model_name="dinov3_vith16plus", | |
| cross_attention_dim=1024, | |
| weights_path=dino_path, | |
| projector_path=projector_path, | |
| device=device, | |
| dtype=dtype, | |
| freeze_encoder=True, | |
| ) | |
| # Load pipeline | |
| pipe = TransNormalPipeline.from_pretrained( | |
| transnormal_path, | |
| dino_encoder=dino_encoder, | |
| torch_dtype=dtype, | |
| ) | |
| pipe = pipe.to(device) | |
| print("[TransNormal] Model loaded successfully!") | |
| return pipe | |
| def predict_normal(image: Image.Image, processing_res: int = 768) -> Image.Image: | |
| """ | |
| Predict surface normal from input image using Zero GPU. | |
| Args: | |
| image: Input RGB image | |
| processing_res: Processing resolution | |
| Returns: | |
| Normal map as PIL Image | |
| """ | |
| if image is None: | |
| return None | |
| # Load pipeline (will use GPU allocated by @spaces.GPU) | |
| pipeline = load_pipeline() | |
| # Run inference | |
| with torch.no_grad(): | |
| normal_map = pipeline( | |
| image=image, | |
| processing_res=processing_res, | |
| output_type="pil", | |
| ) | |
| return normal_map | |
| # ============== Gradio Interface ============== | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif !important; | |
| } | |
| h1 { | |
| font-weight: 600 !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="TransNormal", | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔮 TransNormal | |
| ### Surface Normal Estimation for Transparent Objects | |
| Upload an image to estimate surface normals. Particularly effective for **transparent objects** like glass and plastic. | |
| **Normal Convention:** Red=X (Left) | Green=Y (Up) | Blue=Z (Out) | |
| > ⏱️ First inference may take ~1-2 minutes to load model weights. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=400, | |
| ) | |
| processing_res = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=768, | |
| step=64, | |
| label="Processing Resolution (higher = better quality but slower)", | |
| ) | |
| submit_btn = gr.Button("🚀 Estimate Normal", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Normal Map", | |
| type="pil", | |
| height=400, | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=predict_normal, | |
| inputs=[input_image, processing_res], | |
| outputs=output_image, | |
| ) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Paper:** [TransNormal: Dense Visual Semantics for Diffusion-based Transparent Object Normal Estimation](https://longxiang-ai.github.io/TransNormal/) | |
| **Authors:** Mingwei Li, Hehe Fan, Yi Yang (Zhejiang University) | |
| **Code:** [GitHub](https://github.com/longxiang-ai/TransNormal) | |
| """ | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |