Shengxiao0709 commited on
Commit
a5f49f2
·
verified ·
1 Parent(s): 9ae8560

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -25,21 +25,21 @@ import spaces
25
  # def make_example(path):
26
  # return [path, []]
27
 
 
28
  MODEL = None
29
  DEVICE = torch.device("cpu")
30
  CUDA_READY = False
31
 
32
- # --------- 1) 先在 CPU 上加载权重(禁止这里触发 CUDA)---------
33
- def load_model_cpu(checkpoint_path):
34
- # 请确保 inference.load_model 不里面 .to("cuda")
35
- model, _ = load_model(checkpoint_path) # 如果你里面会 .to(device),请改掉
36
- model = model.to("cpu") # 强制在 CPU
37
- model.eval()
38
- return model
39
 
40
- MODEL = load_model_cpu("medsam_vit_b.pth")
41
 
42
- # --------- 2) 仅在 GPU 函数中迁移到 CUDA(不要在主进程里调用任何 CUDA)---------
43
  @spaces.GPU
44
  def prepare_cuda():
45
  global MODEL, DEVICE, CUDA_READY
@@ -47,8 +47,7 @@ def prepare_cuda():
47
  MODEL.to("cuda")
48
  DEVICE = torch.device("cuda")
49
  CUDA_READY = True
50
- # 可选:做一次极小张量 warmup,仍在 GPU 函数里,安全
51
- _ = torch.zeros(1, device=DEVICE)
52
 
53
  def parse_first_bbox(bboxes):
54
  """
@@ -99,8 +98,8 @@ def segment(annot_value):
99
  box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
100
  box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
101
 
102
- embedding = get_embedding(model, img_np, device)
103
- mask = run(model, embedding, box_1024, H, W) # (H, W) 0/1
104
 
105
  # 黑白 mask(白=前景)
106
  mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)
 
25
  # def make_example(path):
26
  # return [path, []]
27
 
28
+ # --------- 全局状态 ---------
29
  MODEL = None
30
  DEVICE = torch.device("cpu")
31
  CUDA_READY = False
32
 
33
+ def load_model_cpu(checkpoint_path: str):
34
+ global MODEL, DEVICE
35
+ # 要求 inference.load_model 不内部 .to("cuda")
36
+ MODEL, _ = load_model(checkpoint_path) # 或者直接返回 model
37
+ MODEL = MODEL.to("cpu")
38
+ MODEL.eval()
39
+ DEVICE = torch.device("cpu")
40
 
41
+ load_model_cpu("medsam_vit_b.pth")
42
 
 
43
  @spaces.GPU
44
  def prepare_cuda():
45
  global MODEL, DEVICE, CUDA_READY
 
47
  MODEL.to("cuda")
48
  DEVICE = torch.device("cuda")
49
  CUDA_READY = True
50
+ _ = torch.zeros(1, device=DEVICE) # 可选warmup
 
51
 
52
  def parse_first_bbox(bboxes):
53
  """
 
98
  box_np = np.array([[xmin, ymin, xmax, ymax]], dtype=float)
99
  box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
100
 
101
+ embedding = get_embedding(MODEL, img_np, DEVICE)
102
+ mask = run(MODEL, embedding, box_1024, H, W) # (H, W) 0/1
103
 
104
  # 黑白 mask(白=前景)
105
  mask_rgb = np.stack([mask * 255] * 3, axis=-1).astype(np.uint8)