jerpelhan commited on
Commit
852095f
·
1 Parent(s): 3e76066

Remove create model at import time

Browse files
Files changed (1) hide show
  1. demo_gradio.py +46 -17
demo_gradio.py CHANGED
@@ -14,32 +14,61 @@ import numpy as np
14
  import colorsys
15
 
16
 
17
- # Load model (once, to avoid reloading)
18
- @spaces.GPU
19
- def load_model():
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- args = get_argparser().parse_args()
22
- args.zero_shot = True
23
- model = DataParallel(build_model(args).to(device))
24
- WEIGHTS_PATH = hf_hub_download(
25
- repo_id="jerpelhan/geco2-assets",
26
- filename="weights/CNTQG_multitrain_ca44.pth",
27
- repo_type="dataset",
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu", weights_only=True)["model"], strict=False)
31
- model.eval()
32
- return model, device
 
33
 
 
 
34
 
35
- model, device = load_model()
 
 
36
 
37
 
38
  # **Function to Process Image Once**
39
  @spaces.GPU
40
  def process_image_once(inputs, enable_mask):
41
  model.module.return_masks = enable_mask
42
-
 
43
  image = inputs["image"]
44
  drawn_boxes = inputs["points"]
45
  image_tensor = torch.tensor(image).to(device)
 
14
  import colorsys
15
 
16
 
17
+ _MODEL = None
18
+ _ARGS = None
19
+ _WEIGHTS_PATH = None
20
+
21
+ def _get_args():
22
+ global _ARGS
23
+ if _ARGS is None:
24
+ args = get_argparser().parse_args()
25
+ args.zero_shot = True
26
+ _ARGS = args
27
+ return _ARGS
28
+
29
+ def _get_weights_path():
30
+ global _WEIGHTS_PATH
31
+ if _WEIGHTS_PATH is None:
32
+ _WEIGHTS_PATH = hf_hub_download(
33
+ repo_id="jerpelhan/geco2-assets",
34
+ filename="weights/CNTQG_multitrain_ca44.pth",
35
+ repo_type="dataset",
36
+ )
37
+ return _WEIGHTS_PATH
38
+
39
+ def get_model_on_device(device: torch.device):
40
+ """
41
+ Lazily build and load model, then move to the requested device.
42
+ IMPORTANT: model is constructed/loaded without initializing CUDA in the main process.
43
+ This function will be called from inside the @spaces.GPU worker.
44
+ """
45
+ global _MODEL
46
+ if _MODEL is None:
47
+ args = _get_args()
48
+
49
+ # Build on CPU first to avoid CUDA init in the wrong process
50
+ model = build_model(args)
51
+ model = DataParallel(model) # wrap before loading; matches your original
52
 
53
+ weights_path = _get_weights_path()
54
+ ckpt = torch.load(weights_path, map_location="cpu", weights_only=True)
55
+ state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
56
+ model.load_state_dict(state, strict=False)
57
 
58
+ model.eval()
59
+ _MODEL = model
60
 
61
+ # Ensure correct device for this invocation
62
+ _MODEL = _MODEL.to(device)
63
+ return _MODEL
64
 
65
 
66
  # **Function to Process Image Once**
67
  @spaces.GPU
68
  def process_image_once(inputs, enable_mask):
69
  model.module.return_masks = enable_mask
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model = get_model_on_device(device)
72
  image = inputs["image"]
73
  drawn_boxes = inputs["points"]
74
  image_tensor = torch.tensor(image).to(device)