resunet_1 / app.py
keysun89's picture
Update app.py
8c1c2bc verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
# 1. Import your custom class from the local file
from residual_unet import ResidualUNet
# === Model Loading (The only part that changes) ===
weights_path = hf_hub_download(
repo_id="keysun89/resunet_1", # ensure this matches your repo
filename="best_residual_unet_model.pth" # make sure this file exists in repo
)
# 2. Load the model using the built-in method
# This automatically downloads the weights AND reads config.json
model = ResidualUNet.from_pretrained(repo_id)
model.eval()
# === Preprocessing (Same as before) ===
IMG_HEIGHT, IMG_WIDTH = 128, 128
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor()])
# === Prediction Function (Same as before) ===
def predict(image):
orig_w, orig_h = image.size
img = transform(image).unsqueeze(0)
with torch.no_grad():
pred = model(img)
mask = pred.squeeze(0).squeeze(0).cpu().numpy()
mask = (mask * 255).astype(np.uint8)
mask_img = Image.fromarray(mask).resize((orig_w, orig_h), Image.NEAREST)
return mask_img
# === Gradio Interface (Same as before) ===
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Residual-UNet Segmentation",
description="Upload an image to get the predicted segmentation mask.")
if __name__ == "__main__":
demo.launch()