Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from model import load_model, predict | |
| # ------------------------------------------------- | |
| # 1. Device Setup (CPU is fine for deployment) | |
| # ------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------------------------------- | |
| # 2. Load Model Once (important for speed) | |
| # ------------------------------------------------- | |
| model = load_model("best_unet_model.pth", device) | |
| # ------------------------------------------------- | |
| # 3. Inference Wrapper for Gradio | |
| # ------------------------------------------------- | |
| def segment_fundus(image): | |
| # image from gradio comes as a PIL image | |
| # convert PIL -> save temporary file | |
| image = image.convert("RGB") | |
| image.save("temp_input.png") | |
| original, mask = predict(model, "temp_input.png", device) | |
| return mask | |
| # ------------------------------------------------- | |
| # 4. Build Gradio Interface | |
| # ------------------------------------------------- | |
| interface = gr.Interface( | |
| fn=segment_fundus, | |
| inputs=gr.Image(type="pil", label="Upload Fundus Image"), | |
| outputs=gr.Image(type="numpy", label="Segmented Blood Vessel Mask"), | |
| title="Retinal Vessel Segmentation - U-Net", | |
| description="Upload a retinal fundus image and get the segmented vessel mask." | |
| ) | |
| # ------------------------------------------------- | |
| # 5. Launch App | |
| # ------------------------------------------------- | |
| if __name__ == "__main__": | |
| interface.launch() | |