Shengxiao0709 commited on
Commit
688cd9b
·
verified ·
1 Parent(s): 6359de8

Update inference_seg.py

Browse files
Files changed (1) hide show
  1. inference_seg.py +46 -4
inference_seg.py CHANGED
@@ -1,24 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
  from huggingface_hub import hf_hub_download
4
- from segmentation import SegmentationModule
5
 
6
  MODEL = None
7
  DEVICE = torch.device("cpu")
8
 
9
  def load_model(use_box=False):
10
  global MODEL, DEVICE
11
- MODEL = CountingModule(use_box=use_box)
 
 
 
12
 
13
  ckpt_path = hf_hub_download(
14
  repo_id="Shengxiao0709/cellsegmodel",
15
  filename="microscopy_matching_seg.pth",
16
  token=None,
17
- force_download=False
 
 
18
  )
19
- MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
 
 
 
 
20
  MODEL.eval()
 
21
  DEVICE = torch.device("cpu")
 
22
  return MODEL, DEVICE
23
 
24
 
 
1
+ # import torch
2
+ # import numpy as np
3
+ # from huggingface_hub import hf_hub_download
4
+ # from segmentation import SegmentationModule
5
+
6
+ # MODEL = None
7
+ # DEVICE = torch.device("cpu")
8
+
9
+ # def load_model(use_box=False):
10
+ # global MODEL, DEVICE
11
+ # MODEL = CountingModule(use_box=use_box)
12
+
13
+ # ckpt_path = hf_hub_download(
14
+ # repo_id="Shengxiao0709/cellsegmodel",
15
+ # filename="microscopy_matching_seg.pth",
16
+ # token=None,
17
+ # force_download=False
18
+ # )
19
+ # MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
+ # MODEL.eval()
21
+ # DEVICE = torch.device("cpu")
22
+ # return MODEL, DEVICE
23
+
24
+
25
+ # @torch.no_grad()
26
+ # def run(model, img_path, box=None, device="cpu"):
27
+ # output = model(img_path, box=box)
28
+ # mask = output["pred"]
29
+ # mask = (mask > 0).astype(np.uint8)
30
+ # return mask
31
+ import os
32
  import torch
33
  import numpy as np
34
  from huggingface_hub import hf_hub_download
35
+ from segmentation import SegmentationModule
36
 
37
  MODEL = None
38
  DEVICE = torch.device("cpu")
39
 
40
  def load_model(use_box=False):
41
  global MODEL, DEVICE
42
+
43
+ # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
44
+ cache_dir = "/data/cellseg_model_cache"
45
+ os.makedirs(cache_dir, exist_ok=True)
46
 
47
  ckpt_path = hf_hub_download(
48
  repo_id="Shengxiao0709/cellsegmodel",
49
  filename="microscopy_matching_seg.pth",
50
  token=None,
51
+ local_dir=cache_dir, # ✅ 下载到 /data
52
+ local_dir_use_symlinks=False, # ✅ 避免软链接问题
53
+ force_download=False # ✅ 已存在时不重复下载
54
  )
55
+
56
+ # === 优化2: 加载模型 ===
57
+ MODEL = SegmentationModule(use_box=use_box)
58
+ state_dict = torch.load(ckpt_path, map_location="cpu")
59
+ MODEL.load_state_dict(state_dict, strict=False)
60
  MODEL.eval()
61
+
62
  DEVICE = torch.device("cpu")
63
+ print(f"✅ Model loaded from {ckpt_path}")
64
  return MODEL, DEVICE
65
 
66