Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import segmentation_models_pytorch as smp | |
| # Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model architecture | |
| model = smp.Unet( | |
| encoder_name="resnet34", | |
| encoder_weights="imagenet", | |
| in_channels=1, # Grayscale input | |
| classes=1 # Binary output | |
| ).to(device) | |
| # Load trained weights | |
| model.load_state_dict(torch.load("./trained-models/unet_fibril_seg_model.pth", map_location=device)) | |
| model.eval() | |
| # Image preprocessing | |
| transform = T.Compose([ | |
| T.Resize((512, 512)), | |
| T.ToTensor(), | |
| T.Normalize(mean=(0.5,), std=(0.5,)) | |
| ]) | |
| # Inference function | |
| def segment_fibrils(image): | |
| image = image.convert("L") # Grayscale | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| output = torch.sigmoid(output).squeeze().cpu().numpy() | |
| # Postprocess mask | |
| output_mask = (output > 0.5).astype(np.uint8) * 255 | |
| return Image.fromarray(output_mask) | |
| # Launch Gradio app | |
| demo = gr.Interface( | |
| fn=segment_fibrils, | |
| inputs=gr.Image(type="pil", label="Upload Fibril Image"), | |
| outputs=gr.Image(type="pil", label="Predicted Segmentation Mask"), | |
| title="Fibril Segmentation Encoder (ResNet34) and Decoder (UNet)", | |
| description="Upload a grayscale fibril image to get the segmentation mask." | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |