whitney0507 commited on
Commit
36ea6ba
·
verified ·
1 Parent(s): dc8f612

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +83 -17
handler.py CHANGED
@@ -1,6 +1,3 @@
1
- # handler.py
2
- # Save only the weights
3
-
4
  import torch
5
  import torch.nn as nn
6
  from torchvision import transforms
@@ -9,24 +6,95 @@ from huggingface_hub import hf_hub_download
9
  import io
10
  import base64
11
 
12
- model = UNet()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- torch.save(model.state_dict(), "UNet_Model.pth")
 
15
 
16
- # Define your UNet class here (shortened version for example)
17
  class UNet(nn.Module):
18
- def __init__(self): # Add your actual init params
19
- super(UNet, self).__init__()
20
- # Define layers...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def forward(self, x):
23
- # Implement forward pass
24
- return x
 
 
 
 
 
 
 
 
 
25
 
 
26
  class EndpointHandler:
27
  def __init__(self, path=""):
28
  model_path = hf_hub_download(repo_id="whitney0507/unet-model", filename="UNet_Model.pth")
29
- self.model = UNet() # Instantiate model
30
  self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
31
  self.model.eval()
32
  self.transform = transforms.Compose([
@@ -41,11 +109,9 @@ class EndpointHandler:
41
 
42
  with torch.no_grad():
43
  output = self.model(input_tensor)
44
- pred = output.argmax(dim=1).squeeze().byte().cpu().numpy()
45
 
46
- # Convert to base64
47
- output_img = Image.fromarray(pred * 255)
48
  buffer = io.BytesIO()
49
- output_img.save(buffer, format="PNG")
50
  return {"prediction": base64.b64encode(buffer.getvalue()).decode("utf-8")}
51
-
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms
 
6
  import io
7
  import base64
8
 
9
+ # --- Basic UNet Components ---
10
+ class DoubleConv(nn.Module):
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ self.double_conv = nn.Sequential(
14
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.ReLU(inplace=True),
16
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
17
+ nn.ReLU(inplace=True)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.double_conv(x)
22
+
23
+ class Down(nn.Module):
24
+ def __init__(self, in_channels, out_channels):
25
+ super().__init__()
26
+ self.maxpool_conv = nn.Sequential(
27
+ nn.MaxPool2d(2),
28
+ DoubleConv(in_channels, out_channels)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.maxpool_conv(x)
33
+
34
+ class Up(nn.Module):
35
+ def __init__(self, in_channels, out_channels, bilinear=True):
36
+ super().__init__()
37
+ if bilinear:
38
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
39
+ else:
40
+ self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
41
+ self.conv = DoubleConv(in_channels, out_channels)
42
+
43
+ def forward(self, x1, x2):
44
+ x1 = self.up(x1)
45
+ diffY = x2.size()[2] - x1.size()[2]
46
+ diffX = x2.size()[3] - x1.size()[3]
47
+ x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
48
+ diffY // 2, diffY - diffY // 2])
49
+ x = torch.cat([x2, x1], dim=1)
50
+ return self.conv(x)
51
+
52
+ class OutConv(nn.Module):
53
+ def __init__(self, in_channels, out_channels):
54
+ super(OutConv, self).__init__()
55
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
56
 
57
+ def forward(self, x):
58
+ return self.conv(x)
59
 
60
+ # --- UNet Architecture ---
61
  class UNet(nn.Module):
62
+ def __init__(self, n_channels=3, n_classes=1, bilinear=True):
63
+ super().__init__()
64
+ self.n_channels = n_channels
65
+ self.n_classes = n_classes
66
+ self.bilinear = bilinear
67
+
68
+ self.inc = DoubleConv(n_channels, 64)
69
+ self.down1 = Down(64, 128)
70
+ self.down2 = Down(128, 256)
71
+ self.down3 = Down(256, 512)
72
+ factor = 2 if bilinear else 1
73
+ self.down4 = Down(512, 1024 // factor)
74
+ self.up1 = Up(1024, 512 // factor, bilinear)
75
+ self.up2 = Up(512, 256 // factor, bilinear)
76
+ self.up3 = Up(256, 128 // factor, bilinear)
77
+ self.up4 = Up(128, 64, bilinear)
78
+ self.outc = OutConv(64, n_classes)
79
 
80
  def forward(self, x):
81
+ x1 = self.inc(x)
82
+ x2 = self.down1(x1)
83
+ x3 = self.down2(x2)
84
+ x4 = self.down3(x3)
85
+ x5 = self.down4(x4)
86
+ x = self.up1(x5, x4)
87
+ x = self.up2(x, x3)
88
+ x = self.up3(x, x2)
89
+ x = self.up4(x, x1)
90
+ logits = self.outc(x)
91
+ return torch.sigmoid(logits)
92
 
93
+ # --- Endpoint Handler ---
94
  class EndpointHandler:
95
  def __init__(self, path=""):
96
  model_path = hf_hub_download(repo_id="whitney0507/unet-model", filename="UNet_Model.pth")
97
+ self.model = UNet()
98
  self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
99
  self.model.eval()
100
  self.transform = transforms.Compose([
 
109
 
110
  with torch.no_grad():
111
  output = self.model(input_tensor)
112
+ mask = (output > 0.5).int().squeeze().byte().cpu().numpy()
113
 
114
+ result_img = Image.fromarray(mask * 255)
 
115
  buffer = io.BytesIO()
116
+ result_img.save(buffer, format="PNG")
117
  return {"prediction": base64.b64encode(buffer.getvalue()).decode("utf-8")}