CodeJackR
commited on
Commit
·
0e71822
1
Parent(s):
233c56f
Manage image resizing
Browse files- handler.py +7 -7
handler.py
CHANGED
|
@@ -5,7 +5,7 @@ import base64
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
import torch
|
| 8 |
-
from transformers import SamModel,
|
| 9 |
from typing import Dict, List, Any
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
|
@@ -21,13 +21,13 @@ class EndpointHandler():
|
|
| 21 |
try:
|
| 22 |
# Load the model and processor from the local path
|
| 23 |
self.model = SamModel.from_pretrained(path).to(device)
|
| 24 |
-
self.processor =
|
| 25 |
except Exception as e:
|
| 26 |
# Fallback to loading from a known SAM model if local loading fails
|
| 27 |
print("Failed to load from local path: {}".format(e))
|
| 28 |
print("Attempting to load from facebook/sam-vit-base")
|
| 29 |
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
-
self.processor =
|
| 31 |
|
| 32 |
def __call__(self, data):
|
| 33 |
"""
|
|
@@ -52,11 +52,11 @@ class EndpointHandler():
|
|
| 52 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 53 |
|
| 54 |
# 2. Prepare prompts and process the image
|
| 55 |
-
height, width = img.size[1], img.size[0]
|
| 56 |
-
input_points = [[[width // 2, height // 2]]]
|
| 57 |
-
input_labels = [[1]]
|
| 58 |
|
| 59 |
-
inputs = self.processor(img,
|
| 60 |
|
| 61 |
# 3. Generate masks
|
| 62 |
with torch.no_grad():
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
import torch
|
| 8 |
+
from transformers import SamModel, SamImageProcessor
|
| 9 |
from typing import Dict, List, Any
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
|
|
|
| 21 |
try:
|
| 22 |
# Load the model and processor from the local path
|
| 23 |
self.model = SamModel.from_pretrained(path).to(device)
|
| 24 |
+
self.processor = SamImageProcessor.from_pretrained(path, do_resize=False)
|
| 25 |
except Exception as e:
|
| 26 |
# Fallback to loading from a known SAM model if local loading fails
|
| 27 |
print("Failed to load from local path: {}".format(e))
|
| 28 |
print("Attempting to load from facebook/sam-vit-base")
|
| 29 |
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
+
self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False)
|
| 31 |
|
| 32 |
def __call__(self, data):
|
| 33 |
"""
|
|
|
|
| 52 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 53 |
|
| 54 |
# 2. Prepare prompts and process the image
|
| 55 |
+
# height, width = img.size[1], img.size[0]
|
| 56 |
+
# input_points = [[[width // 2, height // 2]]]
|
| 57 |
+
# input_labels = [[1]]
|
| 58 |
|
| 59 |
+
inputs = self.processor(img, return_tensors="pt").to(device)
|
| 60 |
|
| 61 |
# 3. Generate masks
|
| 62 |
with torch.no_grad():
|