FinalVision / inference_seg.py
Shengxiao0709's picture
Update inference_seg.py
c364f5f verified
raw
history blame
794 Bytes
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from segmentation import SegmentationModule
MODEL = None
DEVICE = torch.device("cpu")
def load_model(use_box=False):
global MODEL, DEVICE
MODEL = CountingModule(use_box=use_box)
ckpt_path = hf_hub_download(
repo_id="Shengxiao0709/cellsegmodel",
filename="microscopy_matching_seg.pth",
token=None,
force_download=False
)
MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
MODEL.eval()
DEVICE = torch.device("cpu")
return MODEL, DEVICE
@torch.no_grad()
def run(model, img_path, box=None, device="cpu"):
output = model(img_path, box=box)
mask = output["pred"]
mask = (mask > 0).astype(np.uint8)
return mask