Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from segmentation import SegmentationModule | |
| MODEL = None | |
| DEVICE = torch.device("cpu") | |
| def load_model(use_box=False): | |
| global MODEL, DEVICE | |
| MODEL = SegmentationModule(use_box=use_box) | |
| ckpt_path = hf_hub_download( | |
| repo_id="phoebe777777/111", | |
| filename="microscopy_matching_seg.pth", | |
| token=None, | |
| force_download=False | |
| ) | |
| MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) | |
| MODEL.eval() | |
| if torch.cuda.is_available(): | |
| DEVICE = torch.device("cuda") | |
| MODEL.move_to_device(DEVICE) | |
| print("✅ Model moved to CUDA") | |
| else: | |
| DEVICE = torch.device("cpu") | |
| MODEL.move_to_device(DEVICE) | |
| print("✅ Model on CPU") | |
| return MODEL, DEVICE | |
| def run(model, img_path, box=None, device="cpu"): | |
| print("DEVICE:", device) | |
| model.move_to_device(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| if box is not None: | |
| use_box = True | |
| else: | |
| use_box = False | |
| model.use_box = use_box | |
| output = model(img_path, box=box) | |
| mask = output | |
| return mask | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from huggingface_hub import hf_hub_download | |
| # from segmentation import SegmentationModule | |
| # MODEL = None | |
| # DEVICE = torch.device("cpu") | |
| # def load_model(use_box=False): | |
| # global MODEL, DEVICE | |
| # # === 优化1: 使用 /data 缓存模型,避免写入 .cache === | |
| # cache_dir = "/data/cellseg_model_cache" | |
| # os.makedirs(cache_dir, exist_ok=True) | |
| # ckpt_path = hf_hub_download( | |
| # repo_id="Shengxiao0709/cellsegmodel", | |
| # filename="microscopy_matching_seg.pth", | |
| # token=None, | |
| # local_dir=cache_dir, # ✅ 下载到 /data | |
| # local_dir_use_symlinks=False, # ✅ 避免软链接问题 | |
| # force_download=False # ✅ 已存在时不重复下载 | |
| # ) | |
| # # === 优化2: 加载模型 === | |
| # MODEL = SegmentationModule(use_box=use_box) | |
| # state_dict = torch.load(ckpt_path, map_location="cpu") | |
| # MODEL.load_state_dict(state_dict, strict=False) | |
| # MODEL.eval() | |
| # DEVICE = torch.device("cpu") | |
| # print(f"✅ Model loaded from {ckpt_path}") | |
| # return MODEL, DEVICE | |
| # @torch.no_grad() | |
| # def run(model, img_path, box=None, device="cpu"): | |
| # output = model(img_path, box=box) | |
| # mask = output["pred"] | |
| # mask = (mask > 0).astype(np.uint8) | |
| # return mask |