Spaces:
Sleeping
Sleeping
| # inference_count.py | |
| # 计数模型推理模块 - 独立版本 | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import tempfile | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from counting import CountingModule | |
| MODEL = None | |
| DEVICE = torch.device("cpu") | |
| def load_model(use_box=False): | |
| """ | |
| 加载计数模型 | |
| Args: | |
| use_box: 是否使用边界框 | |
| Returns: | |
| model: 加载的模型 | |
| device: 设备 | |
| """ | |
| global MODEL, DEVICE | |
| try: | |
| print("🔄 Loading counting model...") | |
| # 初始化模型 | |
| MODEL = CountingModule(use_box=use_box) | |
| # 从 Hugging Face Hub 下载权重 | |
| ckpt_path = hf_hub_download( | |
| repo_id="phoebe777777/111", | |
| filename="microscopy_matching_cnt.pth", | |
| token=None, | |
| force_download=False | |
| ) | |
| print(f"✅ Checkpoint downloaded: {ckpt_path}") | |
| # 加载权重 | |
| MODEL.load_state_dict( | |
| torch.load(ckpt_path, map_location="cpu"), | |
| strict=True | |
| ) | |
| MODEL.eval() | |
| if torch.cuda.is_available(): | |
| DEVICE = torch.device("cuda") | |
| MODEL.move_to_device(DEVICE) | |
| print("✅ Model moved to CUDA") | |
| else: | |
| DEVICE = torch.device("cpu") | |
| MODEL.move_to_device(DEVICE) | |
| print("✅ Model on CPU") | |
| print("✅ Counting model loaded successfully") | |
| return MODEL, DEVICE | |
| except Exception as e: | |
| print(f"❌ Error loading counting model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, torch.device("cpu") | |
| def run(model, img_path, box=None, device="cpu", visualize=True): | |
| """ | |
| 运行计数推理 | |
| Args: | |
| model: 计数模型 | |
| img_path: 图像路径 | |
| box: 边界框 [[x1, y1, x2, y2], ...] 或 None | |
| device: 设备 | |
| visualize: 是否生成可视化 | |
| Returns: | |
| result_dict: { | |
| 'density_map': numpy array, | |
| 'count': float, | |
| 'visualized_path': str (如果 visualize=True) | |
| } | |
| """ | |
| print("DEVICE:", device) | |
| model.move_to_device(device) | |
| model.eval() | |
| if box is not None: | |
| use_box = True | |
| else: | |
| use_box = False | |
| model.use_box = use_box | |
| if model is None: | |
| return { | |
| 'density_map': None, | |
| 'count': 0, | |
| 'visualized_path': None, | |
| 'error': 'Model not loaded' | |
| } | |
| try: | |
| print(f"🔄 Running counting inference on {img_path}") | |
| # 运行推理 (调用你的模型的 forward 方法) | |
| with torch.no_grad(): | |
| density_map, count = model(img_path, box) | |
| print(f"✅ Counting result: {count:.1f} objects") | |
| result = { | |
| 'density_map': density_map, | |
| 'count': count, | |
| 'visualized_path': None | |
| } | |
| # 可视化 | |
| # if visualize: | |
| # viz_path = visualize_result(img_path, density_map, count) | |
| # result['visualized_path'] = viz_path | |
| return result | |
| except Exception as e: | |
| print(f"❌ Counting inference error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| 'density_map': None, | |
| 'count': 0, | |
| 'visualized_path': None, | |
| 'error': str(e) | |
| } | |
| def visualize_result(image_path, density_map, count): | |
| """ | |
| 可视化计数结果 (与你原来的可视化代码一致) | |
| Args: | |
| image_path: 原始图像路径 | |
| density_map: 密度图 (numpy array) | |
| count: 计数值 | |
| Returns: | |
| output_path: 可视化结果的临时文件路径 | |
| """ | |
| try: | |
| import skimage.io as io | |
| # 读取原始图像 | |
| img = io.imread(image_path) | |
| # 处理不同格式的图像 | |
| if len(img.shape) == 3 and img.shape[2] > 3: | |
| img = img[:, :, :3] | |
| if len(img.shape) == 2: | |
| img = np.stack([img]*3, axis=-1) | |
| # 归一化显示 | |
| img_show = img.squeeze() | |
| density_map_show = density_map.squeeze() | |
| # 归一化图像 | |
| img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8) | |
| # 创建可视化 (与你原来的代码一致) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| # 右图: 密度图叠加 | |
| ax.imshow(img_show) | |
| ax.imshow(density_map_show, cmap='jet', alpha=0.5) | |
| ax.axis('off') | |
| # ax.set_title(f"Predicted density map, count: {count:.1f}") | |
| plt.tight_layout() | |
| # 保存到临时文件 | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
| plt.savefig(temp_file.name, dpi=300) | |
| plt.close() | |
| print(f"✅ Visualization saved to {temp_file.name}") | |
| return temp_file.name | |
| except Exception as e: | |
| print(f"❌ Visualization error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return image_path | |
| # ===== 测试代码 ===== | |
| if __name__ == "__main__": | |
| print("="*60) | |
| print("Testing Counting Model") | |
| print("="*60) | |
| # 测试模型加载 | |
| model, device = load_model(use_box=False) | |
| if model is not None: | |
| print("\n" + "="*60) | |
| print("Model loaded successfully, testing inference...") | |
| print("="*60) | |
| # 测试推理 | |
| test_image = "example_imgs/1977_Well_F-5_Field_1.png" | |
| if os.path.exists(test_image): | |
| result = run( | |
| model, | |
| test_image, | |
| box=None, | |
| device=device, | |
| visualize=True | |
| ) | |
| if 'error' not in result: | |
| print("\n" + "="*60) | |
| print("Inference Results:") | |
| print("="*60) | |
| print(f"Count: {result['count']:.1f}") | |
| print(f"Density map shape: {result['density_map'].shape}") | |
| if result['visualized_path']: | |
| print(f"Visualization saved to: {result['visualized_path']}") | |
| else: | |
| print(f"\n❌ Inference failed: {result['error']}") | |
| else: | |
| print(f"\n⚠️ Test image not found: {test_image}") | |
| else: | |
| print("\n❌ Model loading failed") | |