import os import tempfile from pathlib import Path import torch import gradio as gr from PIL import Image from config import CONFIG from model import load_from_checkpoint, MobileNetUNet from inference_utils import full_inference, make_side_by_side, tiled_inference MODEL = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_model(): global MODEL ckpt_path = CONFIG["checkpoint_path"] if os.path.exists(ckpt_path): print(f"Loading checkpoint from {ckpt_path}...") MODEL = load_from_checkpoint(ckpt_path, device=DEVICE, width_mult=CONFIG["width_mult"]) MODEL.to(DEVICE) MODEL.eval() print("Model loaded successfully.") else: print(f"Checkpoint not found at {ckpt_path}. Using ONNX is not yet supported in this app.") MODEL = None def process_image(input_img, use_tiled, tile_overlap): if MODEL is None: load_model() if MODEL is None: raise gr.Error("No model loaded. Please ensure checkpoints/last.ckpt exists.") img = input_img.convert("RGB") with torch.no_grad(): if use_tiled: inp_img, outputs = tiled_inference( MODEL, img, CONFIG["image_size"], DEVICE, overlap=tile_overlap, ) else: inp_img, outputs = full_inference(MODEL, img, CONFIG["image_size"], DEVICE) result = make_side_by_side(inp_img, outputs) return result, outputs def run_and_save(input_img, use_tiled, tile_overlap): result, outputs = process_image(input_img, use_tiled, tile_overlap) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") result.save(tmp.name) return tmp.name def run_and_return_maps(input_img, use_tiled, tile_overlap): result, outputs = process_image(input_img, use_tiled, tile_overlap) return result, outputs.get("basecolor"), outputs.get("normal"), outputs.get("rmd"), outputs.get("rgb") with gr.Blocks(title="ShadeNet", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ShadeNet Lightweight inverse rendering model that decomposes an RGB image into **basecolor**, **normal**, **roughness/metallic/depth** maps, then recombines them into a reconstructed RGB. """ ) with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Input Image") with gr.Row(): use_tiled = gr.Checkbox(label="Tiled Inference", value=False) tile_overlap = gr.Slider(minimum=0, maximum=128, value=16, step=1, label="Tile Overlap (px)") run_btn = gr.Button("Decompose", variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="Side-by-Side Result") with gr.Row(): gr.Markdown("### Individual Output Maps") with gr.Row(): basecolor_out = gr.Image(type="pil", label="Basecolor") normal_out = gr.Image(type="pil", label="Normal") rmd_out = gr.Image(type="pil", label="RMD (R=roughness, G=metallic, B=depth)") rgb_out = gr.Image(type="pil", label="Reconstructed RGB") run_btn.click( fn=run_and_return_maps, inputs=[input_img, use_tiled, tile_overlap], outputs=[output_img, basecolor_out, normal_out, rmd_out, rgb_out], ) gr.Examples( examples=[], inputs=[input_img], ) gr.Markdown( """ --- ### Model Details - **Architecture**: ShadeNet 28M — MobileNetV2 backbone + Parallel Encoder + UNet decoder with attention - **Input**: RGB image (resized to 512×512 center crop) - **Outputs**: Basecolor (3ch), Normal (3ch), RMD (3ch = roughness, metallic, depth), Reconstructed RGB (3ch) - **ONNX models**: Available in the `onnx/` folder for CPU/edge deployment """ ) if __name__ == "__main__": load_model() demo.launch(server_name="0.0.0.0")