import gradio as gr import torch from torchvision import transforms from PIL import Image import numpy as np # 1. Import your custom class from the local file from residual_unet import ResidualUNet # === Model Loading (The only part that changes) === weights_path = hf_hub_download( repo_id="keysun89/resunet_1", # ensure this matches your repo filename="best_residual_unet_model.pth" # make sure this file exists in repo ) # 2. Load the model using the built-in method # This automatically downloads the weights AND reads config.json model = ResidualUNet.from_pretrained(repo_id) model.eval() # === Preprocessing (Same as before) === IMG_HEIGHT, IMG_WIDTH = 128, 128 transform = transforms.Compose([ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), transforms.ToTensor()]) # === Prediction Function (Same as before) === def predict(image): orig_w, orig_h = image.size img = transform(image).unsqueeze(0) with torch.no_grad(): pred = model(img) mask = pred.squeeze(0).squeeze(0).cpu().numpy() mask = (mask * 255).astype(np.uint8) mask_img = Image.fromarray(mask).resize((orig_w, orig_h), Image.NEAREST) return mask_img # === Gradio Interface (Same as before) === demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Residual-UNet Segmentation", description="Upload an image to get the predicted segmentation mask.") if __name__ == "__main__": demo.launch()