himanshuch8055 commited on
Commit
0cdca7b
·
1 Parent(s): 80cf578

Implement Unet++ model for fibril segmentation and add Gradio interface; include model weights and requirements

Browse files
app.py CHANGED
@@ -1,7 +1,72 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ import segmentation_models_pytorch as smp
8
  import gradio as gr
9
 
10
+ # ─── Configuration ─────────────────────────────────────────
11
+ CONFIG = {
12
+ "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
13
+ "img_size": 512
14
+ }
15
 
16
+ # ─── Device Setup ──────────────────────────────────────────
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(f"✅ Using device: {device}")
19
+
20
+ # ─── Load Model ────────────────────────────────────────────
21
+ model = smp.UnetPlusPlus(
22
+ encoder_name='resnet34',
23
+ encoder_depth=5,
24
+ encoder_weights='imagenet',
25
+ decoder_use_norm='batchnorm',
26
+ decoder_channels=(256, 128, 64, 32, 16),
27
+ decoder_attention_type=None,
28
+ decoder_interpolation='nearest',
29
+ in_channels=1,
30
+ classes=1,
31
+ activation=None
32
+ ).to(device)
33
+
34
+ model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
35
+ model.eval()
36
+
37
+ # ─── Transform Function ────────────────────────────────────
38
+ def get_transform(size):
39
+ return A.Compose([
40
+ A.Resize(size, size),
41
+ A.Normalize(mean=(0.5,), std=(0.5,)),
42
+ ToTensorV2()
43
+ ])
44
+
45
+ transform = get_transform(CONFIG["img_size"])
46
+
47
+ # ─── Prediction Function ───────────────────────────────────
48
+ def predict(image):
49
+ image = image.convert("L") # Convert to grayscale
50
+ img_np = np.array(image)
51
+ img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
52
+
53
+ with torch.no_grad():
54
+ pred = torch.sigmoid(model(img_tensor))
55
+ mask = (pred > 0.5).float().cpu().squeeze().numpy()
56
+
57
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
58
+ return mask_img
59
+
60
+ # ─── Gradio Interface ──────────────────────────────────────
61
+ demo = gr.Interface(
62
+ fn=predict,
63
+ inputs=gr.Image(type="pil", label="Upload Microscopy Image"),
64
+ outputs=gr.Image(type="pil", label="Predicted Segmentation Mask"),
65
+ title="Fibril Segmentation with Unet++",
66
+ description="Upload a grayscale microscopy image to get its predicted segmentation mask.",
67
+ allow_flagging="never",
68
+ live=False
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ demo.launch()
model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ca490168815735ccf5d296c09f3f8af2cd4de04f04785f9355d62265ac44f8c
3
+ size 104521023
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ pillow
4
+ albumentations
5
+ segmentation-models-pytorch
6
+ gradio