File size: 6,595 Bytes
d6ea8cb
 
 
 
 
 
 
 
 
 
a3ef2be
d6ea8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06244eb
d6ea8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
047ae7d
 
 
 
 
 
 
 
d6ea8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
047ae7d
 
 
 
 
 
 
 
 
d6ea8cb
 
 
 
 
 
 
 
 
 
 
 
047ae7d
 
d6ea8cb
 
 
 
 
 
 
 
 
 
047ae7d
 
 
d6ea8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
047ae7d
d6ea8cb
 
047ae7d
 
 
 
d6ea8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# inference_count.py
# 计数模型推理模块 - 独立版本

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tempfile
import os
from huggingface_hub import hf_hub_download
from counting import CountingModule

MODEL = None
DEVICE = torch.device("cpu")

def load_model(use_box=False):
    """
    加载计数模型
    
    Args:
        use_box: 是否使用边界框
    
    Returns:
        model: 加载的模型
        device: 设备
    """
    global MODEL, DEVICE
    
    try:
        print("🔄 Loading counting model...")
        
        # 初始化模型
        MODEL = CountingModule(use_box=use_box)
        
        # 从 Hugging Face Hub 下载权重
        ckpt_path = hf_hub_download(
            repo_id="phoebe777777/111",
            filename="microscopy_matching_cnt.pth",
            token=None,
            force_download=False
        )
        
        print(f"✅ Checkpoint downloaded: {ckpt_path}")
        
        # 加载权重
        MODEL.load_state_dict(
            torch.load(ckpt_path, map_location="cpu"), 
            strict=True
        )
        MODEL.eval()
        
        if torch.cuda.is_available():
            DEVICE = torch.device("cuda")
            MODEL.move_to_device(DEVICE)
            print("✅ Model moved to CUDA")
        else:
            DEVICE = torch.device("cpu")
            MODEL.move_to_device(DEVICE)
            print("✅ Model on CPU")
        
        print("✅ Counting model loaded successfully")
        return MODEL, DEVICE
        
    except Exception as e:
        print(f"❌ Error loading counting model: {e}")
        import traceback
        traceback.print_exc()
        return None, torch.device("cpu")


@torch.no_grad()
def run(model, img_path, box=None, device="cpu", visualize=True):
    """
    运行计数推理
    
    Args:
        model: 计数模型
        img_path: 图像路径
        box: 边界框 [[x1, y1, x2, y2], ...] 或 None
        device: 设备
        visualize: 是否生成可视化
    
    Returns:
        result_dict: {
            'density_map': numpy array,
            'count': float,
            'visualized_path': str (如果 visualize=True)
        }
    """
    print("DEVICE:", device)
    model.move_to_device(device)
    model.eval()
    if box is not None:
        use_box = True
    else:
        use_box = False
    model.use_box = use_box

    if model is None:
        return {
            'density_map': None,
            'count': 0,
            'visualized_path': None,
            'error': 'Model not loaded'
        }
    
    try:
        print(f"🔄 Running counting inference on {img_path}")
        
        # 运行推理 (调用你的模型的 forward 方法)
        with torch.no_grad():
            density_map, count = model(img_path, box)
        
        print(f"✅ Counting result: {count:.1f} objects")
        
        result = {
            'density_map': density_map,
            'count': count,
            'visualized_path': None
        }
        
        # 可视化
        # if visualize:
        #     viz_path = visualize_result(img_path, density_map, count)
        #     result['visualized_path'] = viz_path
        
        return result
        
    except Exception as e:
        print(f"❌ Counting inference error: {e}")
        import traceback
        traceback.print_exc()
        return {
            'density_map': None,
            'count': 0,
            'visualized_path': None,
            'error': str(e)
        }


def visualize_result(image_path, density_map, count):
    """
    可视化计数结果 (与你原来的可视化代码一致)
    
    Args:
        image_path: 原始图像路径
        density_map: 密度图 (numpy array)
        count: 计数值
    
    Returns:
        output_path: 可视化结果的临时文件路径
    """
    try:
        import skimage.io as io
        
        # 读取原始图像
        img = io.imread(image_path)
        
        # 处理不同格式的图像
        if len(img.shape) == 3 and img.shape[2] > 3:
            img = img[:, :, :3]
        if len(img.shape) == 2:
            img = np.stack([img]*3, axis=-1)
        
        # 归一化显示
        img_show = img.squeeze()
        density_map_show = density_map.squeeze()
        
        # 归一化图像
        img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
        
        # 创建可视化 (与你原来的代码一致)
        fig, ax = plt.subplots(figsize=(8, 6))
        
        # 右图: 密度图叠加
        ax.imshow(img_show)
        ax.imshow(density_map_show, cmap='jet', alpha=0.5)
        ax.axis('off')
        # ax.set_title(f"Predicted density map, count: {count:.1f}")
        
        plt.tight_layout()
        
        # 保存到临时文件
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        plt.savefig(temp_file.name, dpi=300)
        plt.close()
        
        print(f"✅ Visualization saved to {temp_file.name}")
        return temp_file.name
        
    except Exception as e:
        print(f"❌ Visualization error: {e}")
        import traceback
        traceback.print_exc()
        return image_path


# ===== 测试代码 =====
if __name__ == "__main__":
    print("="*60)
    print("Testing Counting Model")
    print("="*60)
    
    # 测试模型加载
    model, device = load_model(use_box=False)
    
    if model is not None:
        print("\n" + "="*60)
        print("Model loaded successfully, testing inference...")
        print("="*60)
        
        # 测试推理
        test_image = "example_imgs/1977_Well_F-5_Field_1.png"
        
        if os.path.exists(test_image):
            result = run(
                model,
                test_image,
                box=None,
                device=device,
                visualize=True
            )
            
            if 'error' not in result:
                print("\n" + "="*60)
                print("Inference Results:")
                print("="*60)
                print(f"Count: {result['count']:.1f}")
                print(f"Density map shape: {result['density_map'].shape}")
                if result['visualized_path']:
                    print(f"Visualization saved to: {result['visualized_path']}")
            else:
                print(f"\n❌ Inference failed: {result['error']}")
        else:
            print(f"\n⚠️ Test image not found: {test_image}")
    else:
        print("\n❌ Model loading failed")