FinalVision / inference_count.py
phoebehxf
update model
06244eb
# inference_count.py
# 计数模型推理模块 - 独立版本
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tempfile
import os
from huggingface_hub import hf_hub_download
from counting import CountingModule
MODEL = None
DEVICE = torch.device("cpu")
def load_model(use_box=False):
"""
加载计数模型
Args:
use_box: 是否使用边界框
Returns:
model: 加载的模型
device: 设备
"""
global MODEL, DEVICE
try:
print("🔄 Loading counting model...")
# 初始化模型
MODEL = CountingModule(use_box=use_box)
# 从 Hugging Face Hub 下载权重
ckpt_path = hf_hub_download(
repo_id="phoebe777777/111",
filename="microscopy_matching_cnt.pth",
token=None,
force_download=False
)
print(f"✅ Checkpoint downloaded: {ckpt_path}")
# 加载权重
MODEL.load_state_dict(
torch.load(ckpt_path, map_location="cpu"),
strict=True
)
MODEL.eval()
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
MODEL.move_to_device(DEVICE)
print("✅ Model moved to CUDA")
else:
DEVICE = torch.device("cpu")
MODEL.move_to_device(DEVICE)
print("✅ Model on CPU")
print("✅ Counting model loaded successfully")
return MODEL, DEVICE
except Exception as e:
print(f"❌ Error loading counting model: {e}")
import traceback
traceback.print_exc()
return None, torch.device("cpu")
@torch.no_grad()
def run(model, img_path, box=None, device="cpu", visualize=True):
"""
运行计数推理
Args:
model: 计数模型
img_path: 图像路径
box: 边界框 [[x1, y1, x2, y2], ...] 或 None
device: 设备
visualize: 是否生成可视化
Returns:
result_dict: {
'density_map': numpy array,
'count': float,
'visualized_path': str (如果 visualize=True)
}
"""
print("DEVICE:", device)
model.move_to_device(device)
model.eval()
if box is not None:
use_box = True
else:
use_box = False
model.use_box = use_box
if model is None:
return {
'density_map': None,
'count': 0,
'visualized_path': None,
'error': 'Model not loaded'
}
try:
print(f"🔄 Running counting inference on {img_path}")
# 运行推理 (调用你的模型的 forward 方法)
with torch.no_grad():
density_map, count = model(img_path, box)
print(f"✅ Counting result: {count:.1f} objects")
result = {
'density_map': density_map,
'count': count,
'visualized_path': None
}
# 可视化
# if visualize:
# viz_path = visualize_result(img_path, density_map, count)
# result['visualized_path'] = viz_path
return result
except Exception as e:
print(f"❌ Counting inference error: {e}")
import traceback
traceback.print_exc()
return {
'density_map': None,
'count': 0,
'visualized_path': None,
'error': str(e)
}
def visualize_result(image_path, density_map, count):
"""
可视化计数结果 (与你原来的可视化代码一致)
Args:
image_path: 原始图像路径
density_map: 密度图 (numpy array)
count: 计数值
Returns:
output_path: 可视化结果的临时文件路径
"""
try:
import skimage.io as io
# 读取原始图像
img = io.imread(image_path)
# 处理不同格式的图像
if len(img.shape) == 3 and img.shape[2] > 3:
img = img[:, :, :3]
if len(img.shape) == 2:
img = np.stack([img]*3, axis=-1)
# 归一化显示
img_show = img.squeeze()
density_map_show = density_map.squeeze()
# 归一化图像
img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
# 创建可视化 (与你原来的代码一致)
fig, ax = plt.subplots(figsize=(8, 6))
# 右图: 密度图叠加
ax.imshow(img_show)
ax.imshow(density_map_show, cmap='jet', alpha=0.5)
ax.axis('off')
# ax.set_title(f"Predicted density map, count: {count:.1f}")
plt.tight_layout()
# 保存到临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, dpi=300)
plt.close()
print(f"✅ Visualization saved to {temp_file.name}")
return temp_file.name
except Exception as e:
print(f"❌ Visualization error: {e}")
import traceback
traceback.print_exc()
return image_path
# ===== 测试代码 =====
if __name__ == "__main__":
print("="*60)
print("Testing Counting Model")
print("="*60)
# 测试模型加载
model, device = load_model(use_box=False)
if model is not None:
print("\n" + "="*60)
print("Model loaded successfully, testing inference...")
print("="*60)
# 测试推理
test_image = "example_imgs/1977_Well_F-5_Field_1.png"
if os.path.exists(test_image):
result = run(
model,
test_image,
box=None,
device=device,
visualize=True
)
if 'error' not in result:
print("\n" + "="*60)
print("Inference Results:")
print("="*60)
print(f"Count: {result['count']:.1f}")
print(f"Density map shape: {result['density_map'].shape}")
if result['visualized_path']:
print(f"Visualization saved to: {result['visualized_path']}")
else:
print(f"\n❌ Inference failed: {result['error']}")
else:
print(f"\n⚠️ Test image not found: {test_image}")
else:
print("\n❌ Model loading failed")