Shengxiao0709 commited on
Commit
d6ea8cb
·
verified ·
1 Parent(s): 27c851a

Create inference_cout.py

Browse files
Files changed (1) hide show
  1. inference_cout.py +225 -0
inference_cout.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_count.py
2
+ # 计数模型推理模块 - 独立版本
3
+
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import tempfile
9
+ import os
10
+ from huggingface_hub import hf_hub_download
11
+ from counting_model import CountingModule
12
+
13
+ MODEL = None
14
+ DEVICE = torch.device("cpu")
15
+
16
+ def load_model(use_box=False):
17
+ """
18
+ 加载计数模型
19
+
20
+ Args:
21
+ use_box: 是否使用边界框
22
+
23
+ Returns:
24
+ model: 加载的模型
25
+ device: 设备
26
+ """
27
+ global MODEL, DEVICE
28
+
29
+ try:
30
+ print("🔄 Loading counting model...")
31
+
32
+ # 初始化模型
33
+ MODEL = CountingModule(use_box=use_box)
34
+
35
+ # 从 Hugging Face Hub 下载权重
36
+ ckpt_path = hf_hub_download(
37
+ repo_id="Shengxiao0709/cellsegmodel",
38
+ filename="microscopy_matching_cnt.pth",
39
+ token=None,
40
+ force_download=False
41
+ )
42
+
43
+ print(f"✅ Checkpoint downloaded: {ckpt_path}")
44
+
45
+ # 加载权重
46
+ MODEL.load_state_dict(
47
+ torch.load(ckpt_path, map_location="cpu"),
48
+ strict=True
49
+ )
50
+ MODEL.eval()
51
+
52
+ DEVICE = torch.device("cpu")
53
+
54
+ print("✅ Counting model loaded successfully")
55
+ return MODEL, DEVICE
56
+
57
+ except Exception as e:
58
+ print(f"❌ Error loading counting model: {e}")
59
+ import traceback
60
+ traceback.print_exc()
61
+ return None, torch.device("cpu")
62
+
63
+
64
+ @torch.no_grad()
65
+ def run(model, img_path, box=None, device="cpu", visualize=True):
66
+ """
67
+ 运行计数推理
68
+
69
+ Args:
70
+ model: 计数模型
71
+ img_path: 图像路径
72
+ box: 边界框 [[x1, y1, x2, y2], ...] 或 None
73
+ device: 设备
74
+ visualize: 是否生成可视化
75
+
76
+ Returns:
77
+ result_dict: {
78
+ 'density_map': numpy array,
79
+ 'count': float,
80
+ 'visualized_path': str (如果 visualize=True)
81
+ }
82
+ """
83
+ if model is None:
84
+ return {
85
+ 'density_map': None,
86
+ 'count': 0,
87
+ 'visualized_path': None,
88
+ 'error': 'Model not loaded'
89
+ }
90
+
91
+ try:
92
+ print(f"🔄 Running counting inference on {img_path}")
93
+
94
+ # 运行推理 (调用你的模型的 forward 方法)
95
+ density_map, count = model(img_path, box)
96
+
97
+ print(f"✅ Counting result: {count:.1f} objects")
98
+
99
+ result = {
100
+ 'density_map': density_map,
101
+ 'count': count,
102
+ 'visualized_path': None
103
+ }
104
+
105
+ # 可视化
106
+ if visualize:
107
+ viz_path = visualize_result(img_path, density_map, count)
108
+ result['visualized_path'] = viz_path
109
+
110
+ return result
111
+
112
+ except Exception as e:
113
+ print(f"❌ Counting inference error: {e}")
114
+ import traceback
115
+ traceback.print_exc()
116
+ return {
117
+ 'density_map': None,
118
+ 'count': 0,
119
+ 'visualized_path': None,
120
+ 'error': str(e)
121
+ }
122
+
123
+
124
+ def visualize_result(image_path, density_map, count):
125
+ """
126
+ 可视化计数结果 (与你原来的可视化代码一致)
127
+
128
+ Args:
129
+ image_path: 原始图像路径
130
+ density_map: 密度图 (numpy array)
131
+ count: 计数值
132
+
133
+ Returns:
134
+ output_path: 可视化结果的临时文件路径
135
+ """
136
+ try:
137
+ import skimage.io as io
138
+
139
+ # 读取原始图像
140
+ img = io.imread(image_path)
141
+
142
+ # 处理不同格式的图像
143
+ if len(img.shape) == 3 and img.shape[2] > 3:
144
+ img = img[:, :, :3]
145
+ if len(img.shape) == 2:
146
+ img = np.stack([img]*3, axis=-1)
147
+
148
+ # 归一化显示
149
+ img_show = img.squeeze()
150
+ density_map_show = density_map.squeeze()
151
+
152
+ # 归一化图像
153
+ img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
154
+
155
+ # 创建可视化 (与你原来的代码一致)
156
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6))
157
+
158
+ # 左图: 原始图像
159
+ ax[0].imshow(img_show)
160
+ ax[0].axis('off')
161
+ ax[0].set_title(f"Input image")
162
+
163
+ # 右图: 密度图叠加
164
+ ax[1].imshow(img_show)
165
+ ax[1].imshow(density_map_show, cmap='jet', alpha=0.5)
166
+ ax[1].axis('off')
167
+ ax[1].set_title(f"Predicted density map, count: {count:.1f}")
168
+
169
+ plt.tight_layout()
170
+
171
+ # 保存到临时文件
172
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
173
+ plt.savefig(temp_file.name, dpi=300)
174
+ plt.close()
175
+
176
+ print(f"✅ Visualization saved to {temp_file.name}")
177
+ return temp_file.name
178
+
179
+ except Exception as e:
180
+ print(f"❌ Visualization error: {e}")
181
+ import traceback
182
+ traceback.print_exc()
183
+ return image_path
184
+
185
+
186
+ # ===== 测试代码 =====
187
+ if __name__ == "__main__":
188
+ print("="*60)
189
+ print("Testing Counting Model")
190
+ print("="*60)
191
+
192
+ # 测试模型加载
193
+ model, device = load_model(use_box=False)
194
+
195
+ if model is not None:
196
+ print("\n" + "="*60)
197
+ print("Model loaded successfully, testing inference...")
198
+ print("="*60)
199
+
200
+ # 测试推理
201
+ test_image = "example_imgs/1977_Well_F-5_Field_1.png"
202
+
203
+ if os.path.exists(test_image):
204
+ result = run(
205
+ model,
206
+ test_image,
207
+ box=None,
208
+ device=device,
209
+ visualize=True
210
+ )
211
+
212
+ if 'error' not in result:
213
+ print("\n" + "="*60)
214
+ print("Inference Results:")
215
+ print("="*60)
216
+ print(f"Count: {result['count']:.1f}")
217
+ print(f"Density map shape: {result['density_map'].shape}")
218
+ if result['visualized_path']:
219
+ print(f"Visualization saved to: {result['visualized_path']}")
220
+ else:
221
+ print(f"\n❌ Inference failed: {result['error']}")
222
+ else:
223
+ print(f"\n⚠️ Test image not found: {test_image}")
224
+ else:
225
+ print("\n❌ Model loading failed")