Shengxiao0709 commited on
Commit
27c851a
·
verified ·
1 Parent(s): a4c9958

Update inference_seg.py

Browse files
Files changed (1) hide show
  1. inference_seg.py +30 -0
inference_seg.py CHANGED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from huggingface_hub import hf_hub_download
4
+ from inference_script import CountingModule
5
+
6
+ MODEL = None
7
+ DEVICE = torch.device("cpu")
8
+
9
+ def load_model(use_box=False):
10
+ global MODEL, DEVICE
11
+ MODEL = CountingModule(use_box=use_box)
12
+
13
+ ckpt_path = hf_hub_download(
14
+ repo_id="Shengxiao0709/cellsegmodel",
15
+ filename="microscopy_matching_seg.pth",
16
+ token=None,
17
+ force_download=False
18
+ )
19
+ MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
+ MODEL.eval()
21
+ DEVICE = torch.device("cpu")
22
+ return MODEL, DEVICE
23
+
24
+
25
+ @torch.no_grad()
26
+ def run(model, img_path, box=None, device="cpu"):
27
+ output = model(img_path, box=box)
28
+ mask = output["pred"]
29
+ mask = (mask > 0).astype(np.uint8)
30
+ return mask