""" 单步推理脚本 - 超声提示多标签分类模型 Single Case Inference for TransMIL + Query2Label Hybrid Model 用法: # 指定多个图像文件 python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5 # 指定图像文件夹 python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5 """ import os import sys import argparse import torch import numpy as np from PIL import Image from torchvision import transforms # 添加当前目录到路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from models.transmil_q2l import TransMIL_Query2Label_E2E # 17类标签定义 TARGET_CLASSES = [ "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级", "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级", "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性" ] def load_model(checkpoint_path: str, device: torch.device): """加载预训练模型""" print(f"Loading model from: {checkpoint_path}") # 初始化模型 model = TransMIL_Query2Label_E2E( num_class=17, hidden_dim=512, nheads=8, num_decoder_layers=2, pretrained_resnet=False, # 推理时不需要下载预训练权重 use_checkpointing=False, # 推理时不需要 checkpointing use_ppeg=False ) # 加载权重 checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) state_dict = checkpoint['model_state_dict'] # 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀) new_state_dict = {} for k, v in state_dict.items(): name = k.replace("module.", "") new_state_dict[name] = v model.load_state_dict(new_state_dict) model.to(device) model.eval() print("Model loaded successfully!") return model def preprocess_images(image_paths: list, img_size: int = 224): """预处理图像""" transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) images = [] valid_paths = [] for path in image_paths: if not os.path.exists(path): print(f"Warning: Image not found: {path}") continue try: img = Image.open(path).convert('RGB') img_tensor = transform(img) images.append(img_tensor) valid_paths.append(path) except Exception as e: print(f"Warning: Failed to load image {path}: {e}") continue if len(images) == 0: raise ValueError("No valid images found!") # Stack to batch: [N, C, H, W] - 模型期望直接的图像堆叠,不需要额外的batch维度 images_batch = torch.stack(images, dim=0) return images_batch, valid_paths def predict(model, images_batch: torch.Tensor, num_images: int, device: torch.device, threshold: float = 0.5): """执行推理""" images_batch = images_batch.to(device) with torch.no_grad(): # Forward pass logits = model(images_batch, [num_images]) probs = torch.sigmoid(logits).cpu().numpy()[0] # [num_class] # 根据阈值获取预测标签 predictions = (probs >= threshold).astype(int) return probs, predictions def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float): """格式化输出结果""" print("\n" + "=" * 60) print(" 超声提示多标签分类结果") print("=" * 60) print(f" 阈值 (Threshold): {threshold}") print("-" * 60) # 按概率排序 sorted_indices = np.argsort(probs)[::-1] print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}") print("-" * 40) predicted_labels = [] for idx in sorted_indices: class_name = TARGET_CLASSES[idx] prob = probs[idx] pred = "✓" if predictions[idx] == 1 else "" # 使用 GBK 编码计算显示宽度 try: display_width = len(class_name.encode('gbk')) except: display_width = len(class_name) * 2 padding = 20 - display_width aligned_name = class_name + " " * max(0, padding) print(f"{aligned_name} {prob:>10.4f} {pred:>8}") if predictions[idx] == 1: predicted_labels.append(class_name) print("\n" + "=" * 60) print(" 预测标签汇总") print("=" * 60) if predicted_labels: for label in predicted_labels: print(f" • {label}") else: print(" 无预测标签(所有类别概率均低于阈值)") print("=" * 60 + "\n") return predicted_labels def main(): parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理') parser.add_argument('--images', nargs='*', default=None, help='图像路径列表 (支持多个图像)') parser.add_argument('--image_dir', type=str, default=None, help='图像文件夹路径 (自动加载文件夹内所有图像)') parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_best.pth', help='模型权重路径') parser.add_argument('--threshold', type=float, default=0.5, help='分类阈值 (default: 0.5)') parser.add_argument('--device', type=str, default='auto', help='设备: auto, cuda, cpu') args = parser.parse_args() # 收集图像路径 image_paths = [] # 从 --images 参数收集 if args.images: image_paths.extend(args.images) # 从 --image_dir 参数收集 if args.image_dir: if not os.path.isdir(args.image_dir): print(f"Error: Image directory not found: {args.image_dir}") sys.exit(1) # 支持的图像格式 image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'} for filename in sorted(os.listdir(args.image_dir)): ext = os.path.splitext(filename)[1].lower() if ext in image_extensions: image_paths.append(os.path.join(args.image_dir, filename)) print(f"Found {len(image_paths)} images in {args.image_dir}") # 检查是否有图像输入 if not image_paths: print("Error: No images specified. Use --images or --image_dir") parser.print_help() sys.exit(1) # 设置设备 if args.device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) print(f"Using device: {device}") # 处理相对路径 script_dir = os.path.dirname(os.path.abspath(__file__)) checkpoint_path = args.checkpoint if not os.path.isabs(checkpoint_path): checkpoint_path = os.path.join(script_dir, checkpoint_path) # 加载模型 model = load_model(checkpoint_path, device) # 预处理图像 print(f"\nProcessing {len(image_paths)} image(s)...") images_batch, valid_paths = preprocess_images(image_paths) print(f"Successfully loaded {len(valid_paths)} image(s)") # 推理 probs, predictions = predict(model, images_batch, len(valid_paths), device, args.threshold) # 输出结果 predicted_labels = format_results(probs, predictions, args.threshold) # 返回预测标签列表(供程序调用) return predicted_labels if __name__ == "__main__": main()