Spaces:
Sleeping
Sleeping
File size: 2,542 Bytes
27c851a f409b59 27c851a 04954a4 27c851a 06244eb 27c851a f409b59 27c851a f409b59 27c851a 047ae7d 27c851a 047ae7d f4db8ac f409b59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from segmentation import SegmentationModule
MODEL = None
DEVICE = torch.device("cpu")
def load_model(use_box=False):
global MODEL, DEVICE
MODEL = SegmentationModule(use_box=use_box)
ckpt_path = hf_hub_download(
repo_id="phoebe777777/111",
filename="microscopy_matching_seg.pth",
token=None,
force_download=False
)
MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
MODEL.eval()
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
MODEL.move_to_device(DEVICE)
print("✅ Model moved to CUDA")
else:
DEVICE = torch.device("cpu")
MODEL.move_to_device(DEVICE)
print("✅ Model on CPU")
return MODEL, DEVICE
@torch.no_grad()
def run(model, img_path, box=None, device="cpu"):
print("DEVICE:", device)
model.move_to_device(device)
model.eval()
with torch.no_grad():
if box is not None:
use_box = True
else:
use_box = False
model.use_box = use_box
output = model(img_path, box=box)
mask = output
return mask
# import os
# import torch
# import numpy as np
# from huggingface_hub import hf_hub_download
# from segmentation import SegmentationModule
# MODEL = None
# DEVICE = torch.device("cpu")
# def load_model(use_box=False):
# global MODEL, DEVICE
# # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
# cache_dir = "/data/cellseg_model_cache"
# os.makedirs(cache_dir, exist_ok=True)
# ckpt_path = hf_hub_download(
# repo_id="Shengxiao0709/cellsegmodel",
# filename="microscopy_matching_seg.pth",
# token=None,
# local_dir=cache_dir, # ✅ 下载到 /data
# local_dir_use_symlinks=False, # ✅ 避免软链接问题
# force_download=False # ✅ 已存在时不重复下载
# )
# # === 优化2: 加载模型 ===
# MODEL = SegmentationModule(use_box=use_box)
# state_dict = torch.load(ckpt_path, map_location="cpu")
# MODEL.load_state_dict(state_dict, strict=False)
# MODEL.eval()
# DEVICE = torch.device("cpu")
# print(f"✅ Model loaded from {ckpt_path}")
# return MODEL, DEVICE
# @torch.no_grad()
# def run(model, img_path, box=None, device="cpu"):
# output = model(img_path, box=box)
# mask = output["pred"]
# mask = (mask > 0).astype(np.uint8)
# return mask |