import torch import numpy as np from huggingface_hub import hf_hub_download from segmentation import SegmentationModule MODEL = None DEVICE = torch.device("cpu") def load_model(use_box=False): global MODEL, DEVICE MODEL = SegmentationModule(use_box=use_box) ckpt_path = hf_hub_download( repo_id="phoebe777777/111", filename="microscopy_matching_seg.pth", token=None, force_download=False ) MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) MODEL.eval() if torch.cuda.is_available(): DEVICE = torch.device("cuda") MODEL.move_to_device(DEVICE) print("✅ Model moved to CUDA") else: DEVICE = torch.device("cpu") MODEL.move_to_device(DEVICE) print("✅ Model on CPU") return MODEL, DEVICE @torch.no_grad() def run(model, img_path, box=None, device="cpu"): print("DEVICE:", device) model.move_to_device(device) model.eval() with torch.no_grad(): if box is not None: use_box = True else: use_box = False model.use_box = use_box output = model(img_path, box=box) mask = output return mask # import os # import torch # import numpy as np # from huggingface_hub import hf_hub_download # from segmentation import SegmentationModule # MODEL = None # DEVICE = torch.device("cpu") # def load_model(use_box=False): # global MODEL, DEVICE # # === 优化1: 使用 /data 缓存模型,避免写入 .cache === # cache_dir = "/data/cellseg_model_cache" # os.makedirs(cache_dir, exist_ok=True) # ckpt_path = hf_hub_download( # repo_id="Shengxiao0709/cellsegmodel", # filename="microscopy_matching_seg.pth", # token=None, # local_dir=cache_dir, # ✅ 下载到 /data # local_dir_use_symlinks=False, # ✅ 避免软链接问题 # force_download=False # ✅ 已存在时不重复下载 # ) # # === 优化2: 加载模型 === # MODEL = SegmentationModule(use_box=use_box) # state_dict = torch.load(ckpt_path, map_location="cpu") # MODEL.load_state_dict(state_dict, strict=False) # MODEL.eval() # DEVICE = torch.device("cpu") # print(f"✅ Model loaded from {ckpt_path}") # return MODEL, DEVICE # @torch.no_grad() # def run(model, img_path, box=None, device="cpu"): # output = model(img_path, box=box) # mask = output["pred"] # mask = (mask > 0).astype(np.uint8) # return mask