AndaiMD commited on
Commit
fd1011e
·
verified ·
1 Parent(s): 970421b

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +12 -0
  2. config.json +5 -0
  3. inference.py +25 -0
  4. unet.py +97 -0
  5. unet_epoch20.pth +3 -0
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Brain U-Net for MRI Segmentation
2
+
3
+ This repository contains a custom U-Net model trained to segment brain MRI images into white matter, gray matter, and cerebrospinal fluid (CSF) compartments.
4
+
5
+ - Format: PyTorch `.pth` state dict
6
+ - Input: Grayscale 256x256 MRI images
7
+ - Output: Segmentation map with 3 classes (WM, GM, CSF)
8
+
9
+ ## Files
10
+ - `unet_epoch20.pth` — model weights
11
+ - `unet.py` — U-Net architecture
12
+ - `inference.py` — handles image uploads and inference for the HF Inference API
config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "model_type": "unet",
3
+ "framework": "pytorch",
4
+ "task": "image-segmentation"
5
+ }
inference.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torchvision.transforms as T
6
+ from unet import UNet
7
+
8
+ # Define model
9
+ model = UNet(in_channels=1, out_channels=3)
10
+ state_dict = torch.load("unet_epoch20.pth", map_location="cpu")
11
+ model.load_state_dict(state_dict)
12
+ model.eval()
13
+
14
+ transform = T.Compose([
15
+ T.Grayscale(), # In case image is RGB
16
+ T.Resize((256, 256)),
17
+ T.ToTensor(),
18
+ ])
19
+
20
+ def predict(image: Image.Image):
21
+ img_tensor = transform(image).unsqueeze(0)
22
+ with torch.no_grad():
23
+ output = model(img_tensor)
24
+ pred = torch.argmax(F.softmax(output, dim=1), dim=1)
25
+ return pred.squeeze().numpy().tolist() # list is JSON-serializable
unet.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DoubleConv(nn.Module):
6
+ """
7
+ This is the core building block of the U-Net architecture.
8
+ Use consecutive convolutional layers
9
+ Each followed by batch normalization and ReLU activation
10
+ """
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ """
14
+ nn.Conv2d:
15
+ Applies a 2D convolution filter (kernel size 3×3)
16
+ padding=1 ensures the output spatial size stays the same
17
+ First conv changes input channels → output channels
18
+ Second conv keeps it at out_channels
19
+
20
+ nn.BatchNorm2d
21
+ Normalizes activations across the batch and channels
22
+ Helps stabilize and speed up training
23
+ Reduces internal covariate shift
24
+ """
25
+ self.double_conv = nn.Sequential(
26
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
27
+ nn.BatchNorm2d(out_channels),
28
+ nn.ReLU(inplace=True),
29
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.double_conv(x)
36
+
37
+ class UNet(nn.Module):
38
+ def __init__(self, in_channels=1, out_channels=3):
39
+ super().__init__()
40
+
41
+ # Encoder
42
+ self.down1 = DoubleConv(in_channels, 64)
43
+ self.pool1 = nn.MaxPool2d(2)
44
+
45
+ self.down2 = DoubleConv(64, 128)
46
+ self.pool2 = nn.MaxPool2d(2)
47
+
48
+ self.down3 = DoubleConv(128, 256)
49
+ self.pool3 = nn.MaxPool2d(2)
50
+
51
+ self.down4 = DoubleConv(256, 512)
52
+ self.pool4 = nn.MaxPool2d(2)
53
+
54
+ # Bottleneck
55
+ self.bottleneck = DoubleConv(512, 1024)
56
+
57
+ # Decoder
58
+ self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
59
+ self.dec4 = DoubleConv(1024, 512)
60
+
61
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
62
+ self.dec3 = DoubleConv(512, 256)
63
+
64
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
65
+ self.dec2 = DoubleConv(256, 128)
66
+
67
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
68
+ self.dec1 = DoubleConv(128, 64)
69
+
70
+ # Final output layer
71
+ self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
72
+
73
+ def forward(self, x):
74
+ # Encoder
75
+ d1 = self.down1(x)
76
+ d2 = self.down2(self.pool1(d1))
77
+ d3 = self.down3(self.pool2(d2))
78
+ d4 = self.down4(self.pool3(d3))
79
+
80
+ # Bottleneck
81
+ bn = self.bottleneck(self.pool4(d4))
82
+
83
+ # Decoder
84
+ up4 = self.up4(bn)
85
+ dec4 = self.dec4(torch.cat([up4, d4], dim=1))
86
+
87
+ up3 = self.up3(dec4)
88
+ dec3 = self.dec3(torch.cat([up3, d3], dim=1))
89
+
90
+ up2 = self.up2(dec3)
91
+ dec2 = self.dec2(torch.cat([up2, d2], dim=1))
92
+
93
+ up1 = self.up1(dec2)
94
+ dec1 = self.dec1(torch.cat([up1, d1], dim=1))
95
+
96
+ # Output
97
+ return self.out_conv(dec1)
unet_epoch20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:179e3e327b9ed30ad9a869d97088bda713d9de51ffd2fe4f7f1bccbbb439607e
3
+ size 124262517