John Ho commited on
Commit
18aa4a5
·
1 Parent(s): 2e155e5

set device to cuda

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -31,8 +31,17 @@ DTYPE = (
31
  if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
32
  else torch.float16
33
  )
34
- DEVICE = "auto"
35
  logger.info(f"Device: {DEVICE}, dtype: {DTYPE}")
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def apply_mask_overlay(base_image, mask_data, object_ids=None, opacity=0.5):
@@ -105,17 +114,6 @@ def apply_mask_overlay(base_image, mask_data, object_ids=None, opacity=0.5):
105
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
106
 
107
 
108
- logger.info("Loading Models and Processors...")
109
- try:
110
- VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(DEVICE, dtype=DTYPE)
111
- VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
112
- logger.success("Models and Processors Loaded!")
113
- except Exception as e:
114
- logger.error(f"❌ CRITICAL ERROR LOADING VIDEO MODELS: {e}")
115
- VID_MODEL = None
116
- VID_PROCESSOR = None
117
-
118
-
119
  # Our Inference Function
120
  @spaces.GPU(duration=120)
121
  def video_inference(input_video, prompt: str):
 
31
  if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
32
  else torch.float16
33
  )
34
+ DEVICE = "cuda"
35
  logger.info(f"Device: {DEVICE}, dtype: {DTYPE}")
36
+ logger.info("Loading Models and Processors...")
37
+ try:
38
+ VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(DEVICE, dtype=DTYPE)
39
+ VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
40
+ logger.success("Models and Processors Loaded!")
41
+ except Exception as e:
42
+ logger.error(f"❌ CRITICAL ERROR LOADING VIDEO MODELS: {e}")
43
+ VID_MODEL = None
44
+ VID_PROCESSOR = None
45
 
46
 
47
  def apply_mask_overlay(base_image, mask_data, object_ids=None, opacity=0.5):
 
114
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  # Our Inference Function
118
  @spaces.GPU(duration=120)
119
  def video_inference(input_video, prompt: str):