File size: 5,178 Bytes
159500c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f64ccd
 
 
 
 
 
 
159500c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a2a3df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
"""
TransNormal - Hugging Face Spaces Zero GPU Version

Surface Normal Estimation for Transparent Objects
"""

import os
import spaces
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import snapshot_download

from transnormal import TransNormalPipeline, create_dino_encoder

# ============== Model Paths ==============
TRANSNORMAL_REPO = "Longxiang-ai/TransNormal"
DINO_REPO = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
# =========================================

# Global pipeline
pipe = None
weights_downloaded = False


def download_weights():
    """Download model weights from HuggingFace Hub."""
    global weights_downloaded
    
    if weights_downloaded:
        return "./weights/transnormal", "./weights/dinov3_vith16plus"
    
    print("[TransNormal] Downloading TransNormal weights...")
    transnormal_path = snapshot_download(
        TRANSNORMAL_REPO,
        local_dir="./weights/transnormal"
    )
    
    print("[TransNormal] Downloading DINOv3 weights...")
    dino_path = snapshot_download(
        DINO_REPO,
        local_dir="./weights/dinov3_vith16plus"
    )
    
    weights_downloaded = True
    print("[TransNormal] Weights downloaded successfully!")
    return transnormal_path, dino_path


def load_pipeline():
    """Load the TransNormal pipeline."""
    global pipe
    
    if pipe is not None:
        return pipe
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if device == "cuda" else torch.float32
    
    print(f"[TransNormal] Loading model on {device} with {dtype}...")
    
    # Download weights
    transnormal_path, dino_path = download_weights()
    projector_path = os.path.join(transnormal_path, "cross_attention_projector.pt")
    
    # Load DINO encoder
    dino_encoder = create_dino_encoder(
        model_name="dinov3_vith16plus",
        cross_attention_dim=1024,
        weights_path=dino_path,
        projector_path=projector_path,
        device=device,
        dtype=dtype,
        freeze_encoder=True,
    )
    
    # Load pipeline
    pipe = TransNormalPipeline.from_pretrained(
        transnormal_path,
        dino_encoder=dino_encoder,
        torch_dtype=dtype,
    )
    pipe = pipe.to(device)
    
    print("[TransNormal] Model loaded successfully!")
    return pipe


@spaces.GPU(duration=120)
def predict_normal(image: Image.Image, processing_res: int = 768) -> Image.Image:
    """
    Predict surface normal from input image using Zero GPU.
    
    Args:
        image: Input RGB image
        processing_res: Processing resolution
    
    Returns:
        Normal map as PIL Image
    """
    if image is None:
        return None
    
    # Load pipeline (will use GPU allocated by @spaces.GPU)
    pipeline = load_pipeline()
    
    # Run inference
    with torch.no_grad():
        normal_map = pipeline(
            image=image,
            processing_res=processing_res,
            output_type="pil",
        )
    
    return normal_map


# ============== Gradio Interface ==============

custom_css = """
.gradio-container {
    font-family: 'Segoe UI', 'Helvetica Neue', Arial, sans-serif !important;
}
h1 {
    font-weight: 600 !important;
}
"""

with gr.Blocks(
    title="TransNormal",
    theme=gr.themes.Soft(),
    css=custom_css,
) as demo:
    
    gr.Markdown(
        """
        # 🔮 TransNormal
        ### Surface Normal Estimation for Transparent Objects
        
        Upload an image to estimate surface normals. Particularly effective for **transparent objects** like glass and plastic.
        
        **Normal Convention:** Red=X (Left) | Green=Y (Up) | Blue=Z (Out)
        
        > ⏱️ First inference may take ~1-2 minutes to load model weights.
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Input Image",
                type="pil",
                height=400,
            )
            
            processing_res = gr.Slider(
                minimum=256,
                maximum=1024,
                value=768,
                step=64,
                label="Processing Resolution (higher = better quality but slower)",
            )
            
            submit_btn = gr.Button("🚀 Estimate Normal", variant="primary", size="lg")
        
        with gr.Column():
            output_image = gr.Image(
                label="Normal Map",
                type="pil",
                height=400,
            )
    
    # Event handlers
    submit_btn.click(
        fn=predict_normal,
        inputs=[input_image, processing_res],
        outputs=output_image,
    )
    
    # Footer
    gr.Markdown(
        """
        ---
        
        **Paper:** [TransNormal: Dense Visual Semantics for Diffusion-based Transparent Object Normal Estimation](https://longxiang-ai.github.io/TransNormal/)
        
        **Authors:** Mingwei Li, Hehe Fan, Yi Yang (Zhejiang University)
        
        **Code:** [GitHub](https://github.com/longxiang-ai/TransNormal)
        """
    )

# Launch
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)