whitney0507 commited on
Commit
ce38d8d
·
verified ·
1 Parent(s): 1306dd3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -35
handler.py CHANGED
@@ -1,35 +1,47 @@
1
- # handler.py
2
- import torch
3
- from torchvision import transforms
4
- from PIL import Image
5
- from huggingface_hub import hf_hub_download
6
- import io
7
- import base64
8
- import json
9
-
10
- class EndpointHandler:
11
- def __init__(self, path=""):
12
- model_path = hf_hub_download(repo_id="whitney0507/unet-model", filename="UNet_Model.pth")
13
- self.model = torch.load(model_path, map_location="cpu")
14
- self.model.eval()
15
- self.transform = transforms.Compose([
16
- transforms.Resize((256, 256)),
17
- transforms.ToTensor()
18
- ])
19
-
20
- def __call__(self, data):
21
- image_bytes = base64.b64decode(data["inputs"])
22
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
23
- input_tensor = self.transform(image).unsqueeze(0)
24
-
25
- with torch.no_grad():
26
- output = self.model(input_tensor)
27
- pred = output.argmax(dim=1).squeeze().byte().cpu().numpy()
28
-
29
- # Convert output back to base64 image
30
- pred_img = Image.fromarray(pred * 255)
31
- buffer = io.BytesIO()
32
- pred_img.save(buffer, format="PNG")
33
- base64_output = base64.b64encode(buffer.getvalue()).decode("utf-8")
34
-
35
- return { "prediction": base64_output }
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ # Save only the weights
3
+ torch.save(model.state_dict(), "UNet_Model.pth")
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import io
10
+ import base64
11
+
12
+ # Define your UNet class here (shortened version for example)
13
+ class UNet(nn.Module):
14
+ def __init__(self): # Add your actual init params
15
+ super(UNet, self).__init__()
16
+ # Define layers...
17
+
18
+ def forward(self, x):
19
+ # Implement forward pass
20
+ return x
21
+
22
+ class EndpointHandler:
23
+ def __init__(self, path=""):
24
+ model_path = hf_hub_download(repo_id="whitney0507/unet-model", filename="UNet_Model.pth")
25
+ self.model = UNet() # Instantiate model
26
+ self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
27
+ self.model.eval()
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor()
31
+ ])
32
+
33
+ def __call__(self, data):
34
+ image_bytes = base64.b64decode(data["inputs"])
35
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
+ input_tensor = self.transform(image).unsqueeze(0)
37
+
38
+ with torch.no_grad():
39
+ output = self.model(input_tensor)
40
+ pred = output.argmax(dim=1).squeeze().byte().cpu().numpy()
41
+
42
+ # Convert to base64
43
+ output_img = Image.fromarray(pred * 255)
44
+ buffer = io.BytesIO()
45
+ output_img.save(buffer, format="PNG")
46
+ return {"prediction": base64.b64encode(buffer.getvalue()).decode("utf-8")}
47
+