whitney0507 commited on
Commit
80d6002
·
verified ·
1 Parent(s): 1a9a99a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -7
handler.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
  import io
7
  import base64
 
8
 
9
  # --- Basic UNet Components ---
10
  class DoubleConv(nn.Module):
@@ -51,13 +52,13 @@ class Up(nn.Module):
51
 
52
  class OutConv(nn.Module):
53
  def __init__(self, in_channels, out_channels):
54
- super(OutConv, self).__init__()
55
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
56
 
57
  def forward(self, x):
58
  return self.conv(x)
59
 
60
- # --- UNet Architecture ---
61
  class UNet(nn.Module):
62
  def __init__(self, n_channels=3, n_classes=1, bilinear=True):
63
  super().__init__()
@@ -90,12 +91,13 @@ class UNet(nn.Module):
90
  logits = self.outc(x)
91
  return torch.sigmoid(logits)
92
 
93
- # --- Endpoint Handler ---
94
  class EndpointHandler:
95
  def __init__(self, path=""):
96
  model_path = hf_hub_download(repo_id="whitney0507/unet-model", filename="UNet_Model.pth")
97
  self.model = UNet()
98
- self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
 
99
  self.model.eval()
100
  self.transform = transforms.Compose([
101
  transforms.Resize((256, 256)),
@@ -109,9 +111,11 @@ class EndpointHandler:
109
 
110
  with torch.no_grad():
111
  output = self.model(input_tensor)
112
- mask = (output > 0.5).int().squeeze().byte().cpu().numpy()
113
 
114
- result_img = Image.fromarray(mask * 255)
 
115
  buffer = io.BytesIO()
116
  result_img.save(buffer, format="PNG")
117
- return {"prediction": base64.b64encode(buffer.getvalue()).decode("utf-8")}
 
 
5
  from huggingface_hub import hf_hub_download
6
  import io
7
  import base64
8
+ import numpy as np
9
 
10
  # --- Basic UNet Components ---
11
  class DoubleConv(nn.Module):
 
52
 
53
  class OutConv(nn.Module):
54
  def __init__(self, in_channels, out_channels):
55
+ super().__init__()
56
  self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
57
 
58
  def forward(self, x):
59
  return self.conv(x)
60
 
61
+ # --- Full UNet ---
62
  class UNet(nn.Module):
63
  def __init__(self, n_channels=3, n_classes=1, bilinear=True):
64
  super().__init__()
 
91
  logits = self.outc(x)
92
  return torch.sigmoid(logits)
93
 
94
+ # --- EndpointHandler for Hugging Face Inference Endpoint ---
95
  class EndpointHandler:
96
  def __init__(self, path=""):
97
  model_path = hf_hub_download(repo_id="whitney0507/unet-model", filename="UNet_Model.pth")
98
  self.model = UNet()
99
+ state_dict = torch.load(model_path, map_location=torch.device("cpu"))
100
+ self.model.load_state_dict(state_dict)
101
  self.model.eval()
102
  self.transform = transforms.Compose([
103
  transforms.Resize((256, 256)),
 
111
 
112
  with torch.no_grad():
113
  output = self.model(input_tensor)
114
+ mask = (output > 0.5).int().squeeze().cpu().numpy()
115
 
116
+ # Ensure mask is in uint8 format for image encoding
117
+ result_img = Image.fromarray((mask * 255).astype(np.uint8))
118
  buffer = io.BytesIO()
119
  result_img.save(buffer, format="PNG")
120
+ encoded_output = base64.b64encode(buffer.getvalue()).decode("utf-8")
121
+ return {"prediction": encoded_output}