File size: 3,921 Bytes
12510fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import os
import tempfile
from pathlib import Path
import torch
import gradio as gr
from PIL import Image
from config import CONFIG
from model import load_from_checkpoint, MobileNetUNet
from inference_utils import full_inference, make_side_by_side, tiled_inference
MODEL = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
global MODEL
ckpt_path = CONFIG["checkpoint_path"]
if os.path.exists(ckpt_path):
print(f"Loading checkpoint from {ckpt_path}...")
MODEL = load_from_checkpoint(ckpt_path, device=DEVICE, width_mult=CONFIG["width_mult"])
MODEL.to(DEVICE)
MODEL.eval()
print("Model loaded successfully.")
else:
print(f"Checkpoint not found at {ckpt_path}. Using ONNX is not yet supported in this app.")
MODEL = None
def process_image(input_img, use_tiled, tile_overlap):
if MODEL is None:
load_model()
if MODEL is None:
raise gr.Error("No model loaded. Please ensure checkpoints/last.ckpt exists.")
img = input_img.convert("RGB")
with torch.no_grad():
if use_tiled:
inp_img, outputs = tiled_inference(
MODEL, img, CONFIG["image_size"], DEVICE, overlap=tile_overlap,
)
else:
inp_img, outputs = full_inference(MODEL, img, CONFIG["image_size"], DEVICE)
result = make_side_by_side(inp_img, outputs)
return result, outputs
def run_and_save(input_img, use_tiled, tile_overlap):
result, outputs = process_image(input_img, use_tiled, tile_overlap)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
result.save(tmp.name)
return tmp.name
def run_and_return_maps(input_img, use_tiled, tile_overlap):
result, outputs = process_image(input_img, use_tiled, tile_overlap)
return result, outputs.get("basecolor"), outputs.get("normal"), outputs.get("rmd"), outputs.get("rgb")
with gr.Blocks(title="ShadeNet", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ShadeNet
Lightweight inverse rendering model that decomposes an RGB image into **basecolor**, **normal**, **roughness/metallic/depth** maps,
then recombines them into a reconstructed RGB.
"""
)
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Image")
with gr.Row():
use_tiled = gr.Checkbox(label="Tiled Inference", value=False)
tile_overlap = gr.Slider(minimum=0, maximum=128, value=16, step=1, label="Tile Overlap (px)")
run_btn = gr.Button("Decompose", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="Side-by-Side Result")
with gr.Row():
gr.Markdown("### Individual Output Maps")
with gr.Row():
basecolor_out = gr.Image(type="pil", label="Basecolor")
normal_out = gr.Image(type="pil", label="Normal")
rmd_out = gr.Image(type="pil", label="RMD (R=roughness, G=metallic, B=depth)")
rgb_out = gr.Image(type="pil", label="Reconstructed RGB")
run_btn.click(
fn=run_and_return_maps,
inputs=[input_img, use_tiled, tile_overlap],
outputs=[output_img, basecolor_out, normal_out, rmd_out, rgb_out],
)
gr.Examples(
examples=[],
inputs=[input_img],
)
gr.Markdown(
"""
---
### Model Details
- **Architecture**: ShadeNet 28M — MobileNetV2 backbone + Parallel Encoder + UNet decoder with attention
- **Input**: RGB image (resized to 512×512 center crop)
- **Outputs**: Basecolor (3ch), Normal (3ch), RMD (3ch = roughness, metallic, depth), Reconstructed RGB (3ch)
- **ONNX models**: Available in the `onnx/` folder for CPU/edge deployment
"""
)
if __name__ == "__main__":
load_model()
demo.launch(server_name="0.0.0.0")
|