| import os |
| import tempfile |
| from pathlib import Path |
|
|
| import torch |
| import gradio as gr |
| from PIL import Image |
|
|
| from config import CONFIG |
| from model import load_from_checkpoint, MobileNetUNet |
| from inference_utils import full_inference, make_side_by_side, tiled_inference |
|
|
|
|
| MODEL = None |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def load_model(): |
| global MODEL |
| ckpt_path = CONFIG["checkpoint_path"] |
| if os.path.exists(ckpt_path): |
| print(f"Loading checkpoint from {ckpt_path}...") |
| MODEL = load_from_checkpoint(ckpt_path, device=DEVICE, width_mult=CONFIG["width_mult"]) |
| MODEL.to(DEVICE) |
| MODEL.eval() |
| print("Model loaded successfully.") |
| else: |
| print(f"Checkpoint not found at {ckpt_path}. Using ONNX is not yet supported in this app.") |
| MODEL = None |
|
|
|
|
| def process_image(input_img, use_tiled, tile_overlap): |
| if MODEL is None: |
| load_model() |
| if MODEL is None: |
| raise gr.Error("No model loaded. Please ensure checkpoints/last.ckpt exists.") |
|
|
| img = input_img.convert("RGB") |
|
|
| with torch.no_grad(): |
| if use_tiled: |
| inp_img, outputs = tiled_inference( |
| MODEL, img, CONFIG["image_size"], DEVICE, overlap=tile_overlap, |
| ) |
| else: |
| inp_img, outputs = full_inference(MODEL, img, CONFIG["image_size"], DEVICE) |
|
|
| result = make_side_by_side(inp_img, outputs) |
| return result, outputs |
|
|
|
|
| def run_and_save(input_img, use_tiled, tile_overlap): |
| result, outputs = process_image(input_img, use_tiled, tile_overlap) |
|
|
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
| result.save(tmp.name) |
| return tmp.name |
|
|
|
|
| def run_and_return_maps(input_img, use_tiled, tile_overlap): |
| result, outputs = process_image(input_img, use_tiled, tile_overlap) |
| return result, outputs.get("basecolor"), outputs.get("normal"), outputs.get("rmd"), outputs.get("rgb") |
|
|
|
|
| with gr.Blocks(title="ShadeNet", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # ShadeNet |
| Lightweight inverse rendering model that decomposes an RGB image into **basecolor**, **normal**, **roughness/metallic/depth** maps, |
| then recombines them into a reconstructed RGB. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| input_img = gr.Image(type="pil", label="Input Image") |
| with gr.Row(): |
| use_tiled = gr.Checkbox(label="Tiled Inference", value=False) |
| tile_overlap = gr.Slider(minimum=0, maximum=128, value=16, step=1, label="Tile Overlap (px)") |
| run_btn = gr.Button("Decompose", variant="primary") |
|
|
| with gr.Column(): |
| output_img = gr.Image(type="pil", label="Side-by-Side Result") |
|
|
| with gr.Row(): |
| gr.Markdown("### Individual Output Maps") |
| with gr.Row(): |
| basecolor_out = gr.Image(type="pil", label="Basecolor") |
| normal_out = gr.Image(type="pil", label="Normal") |
| rmd_out = gr.Image(type="pil", label="RMD (R=roughness, G=metallic, B=depth)") |
| rgb_out = gr.Image(type="pil", label="Reconstructed RGB") |
|
|
| run_btn.click( |
| fn=run_and_return_maps, |
| inputs=[input_img, use_tiled, tile_overlap], |
| outputs=[output_img, basecolor_out, normal_out, rmd_out, rgb_out], |
| ) |
|
|
| gr.Examples( |
| examples=[], |
| inputs=[input_img], |
| ) |
|
|
| gr.Markdown( |
| """ |
| --- |
| ### Model Details |
| - **Architecture**: ShadeNet 28M — MobileNetV2 backbone + Parallel Encoder + UNet decoder with attention |
| - **Input**: RGB image (resized to 512×512 center crop) |
| - **Outputs**: Basecolor (3ch), Normal (3ch), RMD (3ch = roughness, metallic, depth), Reconstructed RGB (3ch) |
| - **ONNX models**: Available in the `onnx/` folder for CPU/edge deployment |
| """ |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| load_model() |
| demo.launch(server_name="0.0.0.0") |
|
|