ShadeNet / app.py
singam96's picture
Initial
12510fb
Raw
History Blame Contribute Delete
3.92 kB
import os
import tempfile
from pathlib import Path
import torch
import gradio as gr
from PIL import Image
from config import CONFIG
from model import load_from_checkpoint, MobileNetUNet
from inference_utils import full_inference, make_side_by_side, tiled_inference
MODEL = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
global MODEL
ckpt_path = CONFIG["checkpoint_path"]
if os.path.exists(ckpt_path):
print(f"Loading checkpoint from {ckpt_path}...")
MODEL = load_from_checkpoint(ckpt_path, device=DEVICE, width_mult=CONFIG["width_mult"])
MODEL.to(DEVICE)
MODEL.eval()
print("Model loaded successfully.")
else:
print(f"Checkpoint not found at {ckpt_path}. Using ONNX is not yet supported in this app.")
MODEL = None
def process_image(input_img, use_tiled, tile_overlap):
if MODEL is None:
load_model()
if MODEL is None:
raise gr.Error("No model loaded. Please ensure checkpoints/last.ckpt exists.")
img = input_img.convert("RGB")
with torch.no_grad():
if use_tiled:
inp_img, outputs = tiled_inference(
MODEL, img, CONFIG["image_size"], DEVICE, overlap=tile_overlap,
)
else:
inp_img, outputs = full_inference(MODEL, img, CONFIG["image_size"], DEVICE)
result = make_side_by_side(inp_img, outputs)
return result, outputs
def run_and_save(input_img, use_tiled, tile_overlap):
result, outputs = process_image(input_img, use_tiled, tile_overlap)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
result.save(tmp.name)
return tmp.name
def run_and_return_maps(input_img, use_tiled, tile_overlap):
result, outputs = process_image(input_img, use_tiled, tile_overlap)
return result, outputs.get("basecolor"), outputs.get("normal"), outputs.get("rmd"), outputs.get("rgb")
with gr.Blocks(title="ShadeNet", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ShadeNet
Lightweight inverse rendering model that decomposes an RGB image into **basecolor**, **normal**, **roughness/metallic/depth** maps,
then recombines them into a reconstructed RGB.
"""
)
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Image")
with gr.Row():
use_tiled = gr.Checkbox(label="Tiled Inference", value=False)
tile_overlap = gr.Slider(minimum=0, maximum=128, value=16, step=1, label="Tile Overlap (px)")
run_btn = gr.Button("Decompose", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="Side-by-Side Result")
with gr.Row():
gr.Markdown("### Individual Output Maps")
with gr.Row():
basecolor_out = gr.Image(type="pil", label="Basecolor")
normal_out = gr.Image(type="pil", label="Normal")
rmd_out = gr.Image(type="pil", label="RMD (R=roughness, G=metallic, B=depth)")
rgb_out = gr.Image(type="pil", label="Reconstructed RGB")
run_btn.click(
fn=run_and_return_maps,
inputs=[input_img, use_tiled, tile_overlap],
outputs=[output_img, basecolor_out, normal_out, rmd_out, rgb_out],
)
gr.Examples(
examples=[],
inputs=[input_img],
)
gr.Markdown(
"""
---
### Model Details
- **Architecture**: ShadeNet 28M — MobileNetV2 backbone + Parallel Encoder + UNet decoder with attention
- **Input**: RGB image (resized to 512×512 center crop)
- **Outputs**: Basecolor (3ch), Normal (3ch), RMD (3ch = roughness, metallic, depth), Reconstructed RGB (3ch)
- **ONNX models**: Available in the `onnx/` folder for CPU/edge deployment
"""
)
if __name__ == "__main__":
load_model()
demo.launch(server_name="0.0.0.0")