File size: 2,542 Bytes
27c851a
 
 
f409b59
27c851a
 
 
 
 
 
04954a4
27c851a
 
06244eb
27c851a
 
f409b59
27c851a
f409b59
27c851a
047ae7d
 
 
 
 
 
 
 
27c851a
 
 
 
 
047ae7d
 
 
 
 
 
 
 
 
 
f4db8ac
f409b59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from segmentation import SegmentationModule  

MODEL = None
DEVICE = torch.device("cpu")

def load_model(use_box=False):
    global MODEL, DEVICE
    MODEL = SegmentationModule(use_box=use_box)

    ckpt_path = hf_hub_download(
        repo_id="phoebe777777/111",
        filename="microscopy_matching_seg.pth",
        token=None,
        force_download=False
    )
    MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
    MODEL.eval()
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda")
        MODEL.move_to_device(DEVICE)
        print("✅ Model moved to CUDA")
    else:
        DEVICE = torch.device("cpu")
        MODEL.move_to_device(DEVICE)
        print("✅ Model on CPU")
    return MODEL, DEVICE


@torch.no_grad()
def run(model, img_path, box=None, device="cpu"):
    print("DEVICE:", device)
    model.move_to_device(device)
    model.eval()
    with torch.no_grad():
        if box is not None:
            use_box = True
        else:
            use_box = False
        model.use_box = use_box
        output = model(img_path, box=box)
    mask = output
    return mask
# import os
# import torch
# import numpy as np
# from huggingface_hub import hf_hub_download
# from segmentation import SegmentationModule

# MODEL = None
# DEVICE = torch.device("cpu")

# def load_model(use_box=False):
#     global MODEL, DEVICE

#     # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
#     cache_dir = "/data/cellseg_model_cache"
#     os.makedirs(cache_dir, exist_ok=True)

#     ckpt_path = hf_hub_download(
#         repo_id="Shengxiao0709/cellsegmodel",
#         filename="microscopy_matching_seg.pth",
#         token=None,
#         local_dir=cache_dir,              # ✅ 下载到 /data
#         local_dir_use_symlinks=False,     # ✅ 避免软链接问题
#         force_download=False              # ✅ 已存在时不重复下载
#     )

#     # === 优化2: 加载模型 ===
#     MODEL = SegmentationModule(use_box=use_box)
#     state_dict = torch.load(ckpt_path, map_location="cpu")
#     MODEL.load_state_dict(state_dict, strict=False)
#     MODEL.eval()

#     DEVICE = torch.device("cpu")
#     print(f"✅ Model loaded from {ckpt_path}")
#     return MODEL, DEVICE


# @torch.no_grad()
# def run(model, img_path, box=None, device="cpu"):
#     output = model(img_path, box=box)
#     mask = output["pred"]
#     mask = (mask > 0).astype(np.uint8)
#     return mask