File size: 8,144 Bytes
343e05c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
"""
单步推理脚本 - 超声提示多标签分类模型
Single Case Inference for TransMIL + Query2Label Hybrid Model
用法:
# 指定多个图像文件
python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
# 指定图像文件夹
python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
"""
import os
import sys
import argparse
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
# 添加当前目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.transmil_q2l import TransMIL_Query2Label_E2E
# 17类标签定义
TARGET_CLASSES = [
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
]
def load_model(checkpoint_path: str, device: torch.device):
"""加载预训练模型"""
print(f"Loading model from: {checkpoint_path}")
# 初始化模型
model = TransMIL_Query2Label_E2E(
num_class=17,
hidden_dim=512,
nheads=8,
num_decoder_layers=2,
pretrained_resnet=False, # 推理时不需要下载预训练权重
use_checkpointing=False, # 推理时不需要 checkpointing
use_ppeg=False
)
# 加载权重
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint['model_state_dict']
# 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
print("Model loaded successfully!")
return model
def preprocess_images(image_paths: list, img_size: int = 224):
"""预处理图像"""
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
images = []
valid_paths = []
for path in image_paths:
if not os.path.exists(path):
print(f"Warning: Image not found: {path}")
continue
try:
img = Image.open(path).convert('RGB')
img_tensor = transform(img)
images.append(img_tensor)
valid_paths.append(path)
except Exception as e:
print(f"Warning: Failed to load image {path}: {e}")
continue
if len(images) == 0:
raise ValueError("No valid images found!")
# Stack to batch: [N, C, H, W] - 模型期望直接的图像堆叠,不需要额外的batch维度
images_batch = torch.stack(images, dim=0)
return images_batch, valid_paths
def predict(model, images_batch: torch.Tensor, num_images: int,
device: torch.device, threshold: float = 0.5):
"""执行推理"""
images_batch = images_batch.to(device)
with torch.no_grad():
# Forward pass
logits = model(images_batch, [num_images])
probs = torch.sigmoid(logits).cpu().numpy()[0] # [num_class]
# 根据阈值获取预测标签
predictions = (probs >= threshold).astype(int)
return probs, predictions
def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float):
"""格式化输出结果"""
print("\n" + "=" * 60)
print(" 超声提示多标签分类结果")
print("=" * 60)
print(f" 阈值 (Threshold): {threshold}")
print("-" * 60)
# 按概率排序
sorted_indices = np.argsort(probs)[::-1]
print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}")
print("-" * 40)
predicted_labels = []
for idx in sorted_indices:
class_name = TARGET_CLASSES[idx]
prob = probs[idx]
pred = "✓" if predictions[idx] == 1 else ""
# 使用 GBK 编码计算显示宽度
try:
display_width = len(class_name.encode('gbk'))
except:
display_width = len(class_name) * 2
padding = 20 - display_width
aligned_name = class_name + " " * max(0, padding)
print(f"{aligned_name} {prob:>10.4f} {pred:>8}")
if predictions[idx] == 1:
predicted_labels.append(class_name)
print("\n" + "=" * 60)
print(" 预测标签汇总")
print("=" * 60)
if predicted_labels:
for label in predicted_labels:
print(f" • {label}")
else:
print(" 无预测标签(所有类别概率均低于阈值)")
print("=" * 60 + "\n")
return predicted_labels
def main():
parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理')
parser.add_argument('--images', nargs='*', default=None,
help='图像路径列表 (支持多个图像)')
parser.add_argument('--image_dir', type=str, default=None,
help='图像文件夹路径 (自动加载文件夹内所有图像)')
parser.add_argument('--checkpoint', type=str,
default='checkpoints/checkpoint_best.pth',
help='模型权重路径')
parser.add_argument('--threshold', type=float, default=0.5,
help='分类阈值 (default: 0.5)')
parser.add_argument('--device', type=str, default='auto',
help='设备: auto, cuda, cpu')
args = parser.parse_args()
# 收集图像路径
image_paths = []
# 从 --images 参数收集
if args.images:
image_paths.extend(args.images)
# 从 --image_dir 参数收集
if args.image_dir:
if not os.path.isdir(args.image_dir):
print(f"Error: Image directory not found: {args.image_dir}")
sys.exit(1)
# 支持的图像格式
image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
for filename in sorted(os.listdir(args.image_dir)):
ext = os.path.splitext(filename)[1].lower()
if ext in image_extensions:
image_paths.append(os.path.join(args.image_dir, filename))
print(f"Found {len(image_paths)} images in {args.image_dir}")
# 检查是否有图像输入
if not image_paths:
print("Error: No images specified. Use --images or --image_dir")
parser.print_help()
sys.exit(1)
# 设置设备
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# 处理相对路径
script_dir = os.path.dirname(os.path.abspath(__file__))
checkpoint_path = args.checkpoint
if not os.path.isabs(checkpoint_path):
checkpoint_path = os.path.join(script_dir, checkpoint_path)
# 加载模型
model = load_model(checkpoint_path, device)
# 预处理图像
print(f"\nProcessing {len(image_paths)} image(s)...")
images_batch, valid_paths = preprocess_images(image_paths)
print(f"Successfully loaded {len(valid_paths)} image(s)")
# 推理
probs, predictions = predict(model, images_batch, len(valid_paths),
device, args.threshold)
# 输出结果
predicted_labels = format_results(probs, predictions, args.threshold)
# 返回预测标签列表(供程序调用)
return predicted_labels
if __name__ == "__main__":
main()
|