CodeJackR commited on
Commit
ea3669d
·
1 Parent(s): b0c81ad

add point for masking

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -52,11 +52,11 @@ class EndpointHandler():
52
  raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
53
 
54
  # 2. Prepare prompts and process the image
55
- # height, width = img.size[1], img.size[0]
56
- # input_points = [[[width // 2, height // 2]]]
57
- # input_labels = [[1]]
58
 
59
- inputs = self.processor(img, return_tensors="pt").to(device)
60
 
61
  # 3. Generate masks
62
  with torch.no_grad():
 
52
  raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
53
 
54
  # 2. Prepare prompts and process the image
55
+ height, width = img.size[1], img.size[0]
56
+ input_points = [[[width // 2, height // 2]]]
57
+ input_labels = [[1]]
58
 
59
+ inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
60
 
61
  # 3. Generate masks
62
  with torch.no_grad():