chris-propeller commited on
Commit
c5fb163
·
1 Parent(s): 4f603ce
Files changed (1) hide show
  1. app.py +10 -21
app.py CHANGED
@@ -11,32 +11,21 @@ from transformers import Sam3Model, Sam3Processor
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
14
- # Global variables for lazy initialization
15
- _model = None
16
- _processor = None
17
- _device = None
18
-
19
- def get_model_and_processor():
20
- """Lazy initialization of model and processor"""
21
- global _model, _processor, _device
22
- if _model is None:
23
- _device = "cuda" if torch.cuda.is_available() else "cpu"
24
- _model = Sam3Model.from_pretrained(
25
- "facebook/sam3",
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
- ).to(_device)
28
- _processor = Sam3Processor.from_pretrained("facebook/sam3")
29
- print(f"Model loaded on device: {_device}")
30
- return _model, _processor, _device
31
-
32
  @spaces.GPU
33
  def sam3_inference(image, text_prompt, confidence_threshold=0.5):
34
  """
35
- Standalone GPU function with lazy model initialization for Spaces Stateless GPU
 
36
  """
37
  try:
38
- # Initialize model inside GPU function (required for Stateless GPU)
39
- model, processor, device = get_model_and_processor()
 
 
 
 
 
 
40
 
41
  # Handle base64 input (for API)
42
  if isinstance(image, str):
 
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @spaces.GPU
15
  def sam3_inference(image, text_prompt, confidence_threshold=0.5):
16
  """
17
+ Standalone GPU function with model initialization for Spaces Stateless GPU
18
+ All CUDA operations must happen inside this decorated function
19
  """
20
  try:
21
+ # Initialize model and processor inside GPU function (required for Stateless GPU)
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model = Sam3Model.from_pretrained(
24
+ "facebook/sam3",
25
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
26
+ ).to(device)
27
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
28
+ print(f"Model loaded on device: {device}")
29
 
30
  # Handle base64 input (for API)
31
  if isinstance(image, str):