TransNormal / app.py
Longxiang-ai's picture
Use Docker with Python 3.10 for compatibility
7a2a3df
#!/usr/bin/env python
"""
TransNormal - Hugging Face Spaces Zero GPU Version
Surface Normal Estimation for Transparent Objects
"""
import os
import spaces
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import snapshot_download
from transnormal import TransNormalPipeline, create_dino_encoder
# ============== Model Paths ==============
TRANSNORMAL_REPO = "Longxiang-ai/TransNormal"
DINO_REPO = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
# =========================================
# Global pipeline
pipe = None
weights_downloaded = False
def download_weights():
"""Download model weights from HuggingFace Hub."""
global weights_downloaded
if weights_downloaded:
return "./weights/transnormal", "./weights/dinov3_vith16plus"
print("[TransNormal] Downloading TransNormal weights...")
transnormal_path = snapshot_download(
TRANSNORMAL_REPO,
local_dir="./weights/transnormal"
)
print("[TransNormal] Downloading DINOv3 weights...")
dino_path = snapshot_download(
DINO_REPO,
local_dir="./weights/dinov3_vith16plus"
)
weights_downloaded = True
print("[TransNormal] Weights downloaded successfully!")
return transnormal_path, dino_path
def load_pipeline():
"""Load the TransNormal pipeline."""
global pipe
if pipe is not None:
return pipe
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"[TransNormal] Loading model on {device} with {dtype}...")
# Download weights
transnormal_path, dino_path = download_weights()
projector_path = os.path.join(transnormal_path, "cross_attention_projector.pt")
# Load DINO encoder
dino_encoder = create_dino_encoder(
model_name="dinov3_vith16plus",
cross_attention_dim=1024,
weights_path=dino_path,
projector_path=projector_path,
device=device,
dtype=dtype,
freeze_encoder=True,
)
# Load pipeline
pipe = TransNormalPipeline.from_pretrained(
transnormal_path,
dino_encoder=dino_encoder,
torch_dtype=dtype,
)
pipe = pipe.to(device)
print("[TransNormal] Model loaded successfully!")
return pipe
@spaces.GPU(duration=120)
def predict_normal(image: Image.Image, processing_res: int = 768) -> Image.Image:
"""
Predict surface normal from input image using Zero GPU.
Args:
image: Input RGB image
processing_res: Processing resolution
Returns:
Normal map as PIL Image
"""
if image is None:
return None
# Load pipeline (will use GPU allocated by @spaces.GPU)
pipeline = load_pipeline()
# Run inference
with torch.no_grad():
normal_map = pipeline(
image=image,
processing_res=processing_res,
output_type="pil",
)
return normal_map
# ============== Gradio Interface ==============
custom_css = """
.gradio-container {
font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif !important;
}
h1 {
font-weight: 600 !important;
}
"""
with gr.Blocks(
title="TransNormal",
theme=gr.themes.Soft(),
css=custom_css,
) as demo:
gr.Markdown(
"""
# 🔮 TransNormal
### Surface Normal Estimation for Transparent Objects
Upload an image to estimate surface normals. Particularly effective for **transparent objects** like glass and plastic.
**Normal Convention:** Red=X (Left) | Green=Y (Up) | Blue=Z (Out)
> ⏱️ First inference may take ~1-2 minutes to load model weights.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil",
height=400,
)
processing_res = gr.Slider(
minimum=256,
maximum=1024,
value=768,
step=64,
label="Processing Resolution (higher = better quality but slower)",
)
submit_btn = gr.Button("🚀 Estimate Normal", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(
label="Normal Map",
type="pil",
height=400,
)
# Event handlers
submit_btn.click(
fn=predict_normal,
inputs=[input_image, processing_res],
outputs=output_image,
)
# Footer
gr.Markdown(
"""
---
**Paper:** [TransNormal: Dense Visual Semantics for Diffusion-based Transparent Object Normal Estimation](https://longxiang-ai.github.io/TransNormal/)
**Authors:** Mingwei Li, Hehe Fan, Yi Yang (Zhejiang University)
**Code:** [GitHub](https://github.com/longxiang-ai/TransNormal)
"""
)
# Launch
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)