File size: 3,921 Bytes
12510fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import tempfile
from pathlib import Path

import torch
import gradio as gr
from PIL import Image

from config import CONFIG
from model import load_from_checkpoint, MobileNetUNet
from inference_utils import full_inference, make_side_by_side, tiled_inference


MODEL = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def load_model():
    global MODEL
    ckpt_path = CONFIG["checkpoint_path"]
    if os.path.exists(ckpt_path):
        print(f"Loading checkpoint from {ckpt_path}...")
        MODEL = load_from_checkpoint(ckpt_path, device=DEVICE, width_mult=CONFIG["width_mult"])
        MODEL.to(DEVICE)
        MODEL.eval()
        print("Model loaded successfully.")
    else:
        print(f"Checkpoint not found at {ckpt_path}. Using ONNX is not yet supported in this app.")
        MODEL = None


def process_image(input_img, use_tiled, tile_overlap):
    if MODEL is None:
        load_model()
    if MODEL is None:
        raise gr.Error("No model loaded. Please ensure checkpoints/last.ckpt exists.")

    img = input_img.convert("RGB")

    with torch.no_grad():
        if use_tiled:
            inp_img, outputs = tiled_inference(
                MODEL, img, CONFIG["image_size"], DEVICE, overlap=tile_overlap,
            )
        else:
            inp_img, outputs = full_inference(MODEL, img, CONFIG["image_size"], DEVICE)

    result = make_side_by_side(inp_img, outputs)
    return result, outputs


def run_and_save(input_img, use_tiled, tile_overlap):
    result, outputs = process_image(input_img, use_tiled, tile_overlap)

    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    result.save(tmp.name)
    return tmp.name


def run_and_return_maps(input_img, use_tiled, tile_overlap):
    result, outputs = process_image(input_img, use_tiled, tile_overlap)
    return result, outputs.get("basecolor"), outputs.get("normal"), outputs.get("rmd"), outputs.get("rgb")


with gr.Blocks(title="ShadeNet", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ShadeNet
        Lightweight inverse rendering model that decomposes an RGB image into **basecolor**, **normal**, **roughness/metallic/depth** maps,
        then recombines them into a reconstructed RGB.
        """
    )

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="pil", label="Input Image")
            with gr.Row():
                use_tiled = gr.Checkbox(label="Tiled Inference", value=False)
                tile_overlap = gr.Slider(minimum=0, maximum=128, value=16, step=1, label="Tile Overlap (px)")
            run_btn = gr.Button("Decompose", variant="primary")

        with gr.Column():
            output_img = gr.Image(type="pil", label="Side-by-Side Result")

    with gr.Row():
        gr.Markdown("### Individual Output Maps")
    with gr.Row():
        basecolor_out = gr.Image(type="pil", label="Basecolor")
        normal_out = gr.Image(type="pil", label="Normal")
        rmd_out = gr.Image(type="pil", label="RMD (R=roughness, G=metallic, B=depth)")
        rgb_out = gr.Image(type="pil", label="Reconstructed RGB")

    run_btn.click(
        fn=run_and_return_maps,
        inputs=[input_img, use_tiled, tile_overlap],
        outputs=[output_img, basecolor_out, normal_out, rmd_out, rgb_out],
    )

    gr.Examples(
        examples=[],
        inputs=[input_img],
    )

    gr.Markdown(
        """
        ---
        ### Model Details
        - **Architecture**: ShadeNet 28M — MobileNetV2 backbone + Parallel Encoder + UNet decoder with attention
        - **Input**: RGB image (resized to 512×512 center crop)
        - **Outputs**: Basecolor (3ch), Normal (3ch), RMD (3ch = roughness, metallic, depth), Reconstructed RGB (3ch)
        - **ONNX models**: Available in the `onnx/` folder for CPU/edge deployment
        """
    )


if __name__ == "__main__":
    load_model()
    demo.launch(server_name="0.0.0.0")