chris-propeller commited on
Commit
bda3ba8
·
1 Parent(s): ff6286a

runtime error

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -8,18 +8,21 @@ import io
8
  import cv2
9
  from typing import Dict, Any, List, Optional
10
  from transformers import Sam3Model, Sam3Processor
 
 
11
 
12
  class SAM3Handler:
13
  """SAM3 handler for both UI and API access"""
14
 
15
  def __init__(self):
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
- print(f"Loading SAM3 model on device: {self.device}")
 
18
 
19
  # Load SAM3 model and processor
20
  self.model = Sam3Model.from_pretrained(
21
  "facebook/sam3",
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
23
  ).to(self.device)
24
 
25
  self.processor = Sam3Processor.from_pretrained("facebook/sam3")
@@ -51,7 +54,15 @@ class SAM3Handler:
51
  images=image,
52
  text=text_prompt,
53
  return_tensors="pt"
54
- ).to(self.device)
 
 
 
 
 
 
 
 
55
 
56
  with torch.no_grad():
57
  outputs = self.model(**inputs, multimask_output=True)
 
8
  import cv2
9
  from typing import Dict, Any, List, Optional
10
  from transformers import Sam3Model, Sam3Processor
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
 
14
  class SAM3Handler:
15
  """SAM3 handler for both UI and API access"""
16
 
17
  def __init__(self):
18
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
20
+ print(f"Loading SAM3 model on device: {self.device} with dtype: {self.dtype}")
21
 
22
  # Load SAM3 model and processor
23
  self.model = Sam3Model.from_pretrained(
24
  "facebook/sam3",
25
+ torch_dtype=self.dtype
26
  ).to(self.device)
27
 
28
  self.processor = Sam3Processor.from_pretrained("facebook/sam3")
 
54
  images=image,
55
  text=text_prompt,
56
  return_tensors="pt"
57
+ )
58
+
59
+ # Move inputs to device
60
+ inputs = inputs.to(self.device)
61
+
62
+ # Convert dtype to match model (following working space pattern)
63
+ for key in inputs:
64
+ if inputs[key].dtype == torch.float32:
65
+ inputs[key] = inputs[key].to(self.model.dtype)
66
 
67
  with torch.no_grad():
68
  outputs = self.model(**inputs, multimask_output=True)