HintsPredictionModel / infer_single_case.py
Doul0414's picture
Initial upload: HintsPrediction
343e05c verified
"""
单步推理脚本 - 超声提示多标签分类模型
Single Case Inference for TransMIL + Query2Label Hybrid Model
用法:
# 指定多个图像文件
python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
# 指定图像文件夹
python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
"""
import os
import sys
import argparse
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
# 添加当前目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.transmil_q2l import TransMIL_Query2Label_E2E
# 17类标签定义
TARGET_CLASSES = [
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
]
def load_model(checkpoint_path: str, device: torch.device):
"""加载预训练模型"""
print(f"Loading model from: {checkpoint_path}")
# 初始化模型
model = TransMIL_Query2Label_E2E(
num_class=17,
hidden_dim=512,
nheads=8,
num_decoder_layers=2,
pretrained_resnet=False, # 推理时不需要下载预训练权重
use_checkpointing=False, # 推理时不需要 checkpointing
use_ppeg=False
)
# 加载权重
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint['model_state_dict']
# 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
print("Model loaded successfully!")
return model
def preprocess_images(image_paths: list, img_size: int = 224):
"""预处理图像"""
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
images = []
valid_paths = []
for path in image_paths:
if not os.path.exists(path):
print(f"Warning: Image not found: {path}")
continue
try:
img = Image.open(path).convert('RGB')
img_tensor = transform(img)
images.append(img_tensor)
valid_paths.append(path)
except Exception as e:
print(f"Warning: Failed to load image {path}: {e}")
continue
if len(images) == 0:
raise ValueError("No valid images found!")
# Stack to batch: [N, C, H, W] - 模型期望直接的图像堆叠,不需要额外的batch维度
images_batch = torch.stack(images, dim=0)
return images_batch, valid_paths
def predict(model, images_batch: torch.Tensor, num_images: int,
device: torch.device, threshold: float = 0.5):
"""执行推理"""
images_batch = images_batch.to(device)
with torch.no_grad():
# Forward pass
logits = model(images_batch, [num_images])
probs = torch.sigmoid(logits).cpu().numpy()[0] # [num_class]
# 根据阈值获取预测标签
predictions = (probs >= threshold).astype(int)
return probs, predictions
def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float):
"""格式化输出结果"""
print("\n" + "=" * 60)
print(" 超声提示多标签分类结果")
print("=" * 60)
print(f" 阈值 (Threshold): {threshold}")
print("-" * 60)
# 按概率排序
sorted_indices = np.argsort(probs)[::-1]
print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}")
print("-" * 40)
predicted_labels = []
for idx in sorted_indices:
class_name = TARGET_CLASSES[idx]
prob = probs[idx]
pred = "✓" if predictions[idx] == 1 else ""
# 使用 GBK 编码计算显示宽度
try:
display_width = len(class_name.encode('gbk'))
except:
display_width = len(class_name) * 2
padding = 20 - display_width
aligned_name = class_name + " " * max(0, padding)
print(f"{aligned_name} {prob:>10.4f} {pred:>8}")
if predictions[idx] == 1:
predicted_labels.append(class_name)
print("\n" + "=" * 60)
print(" 预测标签汇总")
print("=" * 60)
if predicted_labels:
for label in predicted_labels:
print(f" • {label}")
else:
print(" 无预测标签(所有类别概率均低于阈值)")
print("=" * 60 + "\n")
return predicted_labels
def main():
parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理')
parser.add_argument('--images', nargs='*', default=None,
help='图像路径列表 (支持多个图像)')
parser.add_argument('--image_dir', type=str, default=None,
help='图像文件夹路径 (自动加载文件夹内所有图像)')
parser.add_argument('--checkpoint', type=str,
default='checkpoints/checkpoint_best.pth',
help='模型权重路径')
parser.add_argument('--threshold', type=float, default=0.5,
help='分类阈值 (default: 0.5)')
parser.add_argument('--device', type=str, default='auto',
help='设备: auto, cuda, cpu')
args = parser.parse_args()
# 收集图像路径
image_paths = []
# 从 --images 参数收集
if args.images:
image_paths.extend(args.images)
# 从 --image_dir 参数收集
if args.image_dir:
if not os.path.isdir(args.image_dir):
print(f"Error: Image directory not found: {args.image_dir}")
sys.exit(1)
# 支持的图像格式
image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
for filename in sorted(os.listdir(args.image_dir)):
ext = os.path.splitext(filename)[1].lower()
if ext in image_extensions:
image_paths.append(os.path.join(args.image_dir, filename))
print(f"Found {len(image_paths)} images in {args.image_dir}")
# 检查是否有图像输入
if not image_paths:
print("Error: No images specified. Use --images or --image_dir")
parser.print_help()
sys.exit(1)
# 设置设备
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# 处理相对路径
script_dir = os.path.dirname(os.path.abspath(__file__))
checkpoint_path = args.checkpoint
if not os.path.isabs(checkpoint_path):
checkpoint_path = os.path.join(script_dir, checkpoint_path)
# 加载模型
model = load_model(checkpoint_path, device)
# 预处理图像
print(f"\nProcessing {len(image_paths)} image(s)...")
images_batch, valid_paths = preprocess_images(image_paths)
print(f"Successfully loaded {len(valid_paths)} image(s)")
# 推理
probs, predictions = predict(model, images_batch, len(valid_paths),
device, args.threshold)
# 输出结果
predicted_labels = format_results(probs, predictions, args.threshold)
# 返回预测标签列表(供程序调用)
return predicted_labels
if __name__ == "__main__":
main()