CodeJackR commited on
Commit
0e71822
·
1 Parent(s): 233c56f

Manage image resizing

Browse files
Files changed (1) hide show
  1. handler.py +7 -7
handler.py CHANGED
@@ -5,7 +5,7 @@ import base64
5
  import numpy as np
6
  from PIL import Image
7
  import torch
8
- from transformers import SamModel, SamProcessor
9
  from typing import Dict, List, Any
10
  import torch.nn.functional as F
11
 
@@ -21,13 +21,13 @@ class EndpointHandler():
21
  try:
22
  # Load the model and processor from the local path
23
  self.model = SamModel.from_pretrained(path).to(device)
24
- self.processor = SamProcessor.from_pretrained(path)
25
  except Exception as e:
26
  # Fallback to loading from a known SAM model if local loading fails
27
  print("Failed to load from local path: {}".format(e))
28
  print("Attempting to load from facebook/sam-vit-base")
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
- self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
  def __call__(self, data):
33
  """
@@ -52,11 +52,11 @@ class EndpointHandler():
52
  raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
53
 
54
  # 2. Prepare prompts and process the image
55
- height, width = img.size[1], img.size[0]
56
- input_points = [[[width // 2, height // 2]]]
57
- input_labels = [[1]]
58
 
59
- inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
60
 
61
  # 3. Generate masks
62
  with torch.no_grad():
 
5
  import numpy as np
6
  from PIL import Image
7
  import torch
8
+ from transformers import SamModel, SamImageProcessor
9
  from typing import Dict, List, Any
10
  import torch.nn.functional as F
11
 
 
21
  try:
22
  # Load the model and processor from the local path
23
  self.model = SamModel.from_pretrained(path).to(device)
24
+ self.processor = SamImageProcessor.from_pretrained(path, do_resize=False)
25
  except Exception as e:
26
  # Fallback to loading from a known SAM model if local loading fails
27
  print("Failed to load from local path: {}".format(e))
28
  print("Attempting to load from facebook/sam-vit-base")
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
+ self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False)
31
 
32
  def __call__(self, data):
33
  """
 
52
  raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
53
 
54
  # 2. Prepare prompts and process the image
55
+ # height, width = img.size[1], img.size[0]
56
+ # input_points = [[[width // 2, height // 2]]]
57
+ # input_labels = [[1]]
58
 
59
+ inputs = self.processor(img, return_tensors="pt").to(device)
60
 
61
  # 3. Generate masks
62
  with torch.no_grad():