Shengxiao0709 commited on
Commit
f409b59
·
verified ·
1 Parent(s): 9cd4782

Update inference_seg.py

Browse files
Files changed (1) hide show
  1. inference_seg.py +47 -47
inference_seg.py CHANGED
@@ -1,66 +1,24 @@
1
- # import torch
2
- # import numpy as np
3
- # from huggingface_hub import hf_hub_download
4
- # from segmentation import SegmentationModule
5
-
6
- # MODEL = None
7
- # DEVICE = torch.device("cpu")
8
-
9
- # def load_model(use_box=False):
10
- # global MODEL, DEVICE
11
- # MODEL = CountingModule(use_box=use_box)
12
-
13
- # ckpt_path = hf_hub_download(
14
- # repo_id="Shengxiao0709/cellsegmodel",
15
- # filename="microscopy_matching_seg.pth",
16
- # token=None,
17
- # force_download=False
18
- # )
19
- # MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
- # MODEL.eval()
21
- # DEVICE = torch.device("cpu")
22
- # return MODEL, DEVICE
23
-
24
-
25
- # @torch.no_grad()
26
- # def run(model, img_path, box=None, device="cpu"):
27
- # output = model(img_path, box=box)
28
- # mask = output["pred"]
29
- # mask = (mask > 0).astype(np.uint8)
30
- # return mask
31
- import os
32
  import torch
33
  import numpy as np
34
  from huggingface_hub import hf_hub_download
35
- from segmentation import SegmentationModule
36
 
37
  MODEL = None
38
  DEVICE = torch.device("cpu")
39
 
40
  def load_model(use_box=False):
41
  global MODEL, DEVICE
42
-
43
- # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
44
- cache_dir = "/data/cellseg_model_cache"
45
- os.makedirs(cache_dir, exist_ok=True)
46
 
47
  ckpt_path = hf_hub_download(
48
  repo_id="Shengxiao0709/cellsegmodel",
49
  filename="microscopy_matching_seg.pth",
50
  token=None,
51
- local_dir=cache_dir, # ✅ 下载到 /data
52
- local_dir_use_symlinks=False, # ✅ 避免软链接问题
53
- force_download=False # ✅ 已存在时不重复下载
54
  )
55
-
56
- # === 优化2: 加载模型 ===
57
- MODEL = SegmentationModule(use_box=use_box)
58
- state_dict = torch.load(ckpt_path, map_location="cpu")
59
- MODEL.load_state_dict(state_dict, strict=False)
60
  MODEL.eval()
61
-
62
  DEVICE = torch.device("cpu")
63
- print(f"✅ Model loaded from {ckpt_path}")
64
  return MODEL, DEVICE
65
 
66
 
@@ -69,4 +27,46 @@ def run(model, img_path, box=None, device="cpu"):
69
  output = model(img_path, box=box)
70
  mask = output["pred"]
71
  mask = (mask > 0).astype(np.uint8)
72
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
  from huggingface_hub import hf_hub_download
4
+ from segmentation import SegmentationModule
5
 
6
  MODEL = None
7
  DEVICE = torch.device("cpu")
8
 
9
  def load_model(use_box=False):
10
  global MODEL, DEVICE
11
+ MODEL = CountingModule(use_box=use_box)
 
 
 
12
 
13
  ckpt_path = hf_hub_download(
14
  repo_id="Shengxiao0709/cellsegmodel",
15
  filename="microscopy_matching_seg.pth",
16
  token=None,
17
+ force_download=False
 
 
18
  )
19
+ MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
 
 
 
 
20
  MODEL.eval()
 
21
  DEVICE = torch.device("cpu")
 
22
  return MODEL, DEVICE
23
 
24
 
 
27
  output = model(img_path, box=box)
28
  mask = output["pred"]
29
  mask = (mask > 0).astype(np.uint8)
30
+ return mask
31
+ # import os
32
+ # import torch
33
+ # import numpy as np
34
+ # from huggingface_hub import hf_hub_download
35
+ # from segmentation import SegmentationModule
36
+
37
+ # MODEL = None
38
+ # DEVICE = torch.device("cpu")
39
+
40
+ # def load_model(use_box=False):
41
+ # global MODEL, DEVICE
42
+
43
+ # # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
44
+ # cache_dir = "/data/cellseg_model_cache"
45
+ # os.makedirs(cache_dir, exist_ok=True)
46
+
47
+ # ckpt_path = hf_hub_download(
48
+ # repo_id="Shengxiao0709/cellsegmodel",
49
+ # filename="microscopy_matching_seg.pth",
50
+ # token=None,
51
+ # local_dir=cache_dir, # ✅ 下载到 /data
52
+ # local_dir_use_symlinks=False, # ✅ 避免软链接问题
53
+ # force_download=False # ✅ 已存在时不重复下载
54
+ # )
55
+
56
+ # # === 优化2: 加载模型 ===
57
+ # MODEL = SegmentationModule(use_box=use_box)
58
+ # state_dict = torch.load(ckpt_path, map_location="cpu")
59
+ # MODEL.load_state_dict(state_dict, strict=False)
60
+ # MODEL.eval()
61
+
62
+ # DEVICE = torch.device("cpu")
63
+ # print(f"✅ Model loaded from {ckpt_path}")
64
+ # return MODEL, DEVICE
65
+
66
+
67
+ # @torch.no_grad()
68
+ # def run(model, img_path, box=None, device="cpu"):
69
+ # output = model(img_path, box=box)
70
+ # mask = output["pred"]
71
+ # mask = (mask > 0).astype(np.uint8)
72
+ # return mask