|
|
"""
|
|
|
单步推理脚本 - 超声提示多标签分类模型
|
|
|
Single Case Inference for TransMIL + Query2Label Hybrid Model
|
|
|
|
|
|
用法:
|
|
|
# 指定多个图像文件
|
|
|
python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
|
|
|
|
|
|
# 指定图像文件夹
|
|
|
python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import argparse
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
|
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
|
|
|
|
|
|
|
|
TARGET_CLASSES = [
|
|
|
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
|
|
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
|
|
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
|
|
|
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
|
|
]
|
|
|
|
|
|
|
|
|
def load_model(checkpoint_path: str, device: torch.device):
|
|
|
"""加载预训练模型"""
|
|
|
print(f"Loading model from: {checkpoint_path}")
|
|
|
|
|
|
|
|
|
model = TransMIL_Query2Label_E2E(
|
|
|
num_class=17,
|
|
|
hidden_dim=512,
|
|
|
nheads=8,
|
|
|
num_decoder_layers=2,
|
|
|
pretrained_resnet=False,
|
|
|
use_checkpointing=False,
|
|
|
use_ppeg=False
|
|
|
)
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
|
|
|
|
|
|
|
new_state_dict = {}
|
|
|
for k, v in state_dict.items():
|
|
|
name = k.replace("module.", "")
|
|
|
new_state_dict[name] = v
|
|
|
model.load_state_dict(new_state_dict)
|
|
|
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
print("Model loaded successfully!")
|
|
|
return model
|
|
|
|
|
|
|
|
|
def preprocess_images(image_paths: list, img_size: int = 224):
|
|
|
"""预处理图像"""
|
|
|
transform = transforms.Compose([
|
|
|
transforms.Resize((img_size, img_size)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225])
|
|
|
])
|
|
|
|
|
|
images = []
|
|
|
valid_paths = []
|
|
|
|
|
|
for path in image_paths:
|
|
|
if not os.path.exists(path):
|
|
|
print(f"Warning: Image not found: {path}")
|
|
|
continue
|
|
|
try:
|
|
|
img = Image.open(path).convert('RGB')
|
|
|
img_tensor = transform(img)
|
|
|
images.append(img_tensor)
|
|
|
valid_paths.append(path)
|
|
|
except Exception as e:
|
|
|
print(f"Warning: Failed to load image {path}: {e}")
|
|
|
continue
|
|
|
|
|
|
if len(images) == 0:
|
|
|
raise ValueError("No valid images found!")
|
|
|
|
|
|
|
|
|
images_batch = torch.stack(images, dim=0)
|
|
|
|
|
|
return images_batch, valid_paths
|
|
|
|
|
|
|
|
|
def predict(model, images_batch: torch.Tensor, num_images: int,
|
|
|
device: torch.device, threshold: float = 0.5):
|
|
|
"""执行推理"""
|
|
|
images_batch = images_batch.to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
logits = model(images_batch, [num_images])
|
|
|
probs = torch.sigmoid(logits).cpu().numpy()[0]
|
|
|
|
|
|
|
|
|
predictions = (probs >= threshold).astype(int)
|
|
|
|
|
|
return probs, predictions
|
|
|
|
|
|
|
|
|
def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float):
|
|
|
"""格式化输出结果"""
|
|
|
print("\n" + "=" * 60)
|
|
|
print(" 超声提示多标签分类结果")
|
|
|
print("=" * 60)
|
|
|
print(f" 阈值 (Threshold): {threshold}")
|
|
|
print("-" * 60)
|
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(probs)[::-1]
|
|
|
|
|
|
print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}")
|
|
|
print("-" * 40)
|
|
|
|
|
|
predicted_labels = []
|
|
|
for idx in sorted_indices:
|
|
|
class_name = TARGET_CLASSES[idx]
|
|
|
prob = probs[idx]
|
|
|
pred = "✓" if predictions[idx] == 1 else ""
|
|
|
|
|
|
|
|
|
try:
|
|
|
display_width = len(class_name.encode('gbk'))
|
|
|
except:
|
|
|
display_width = len(class_name) * 2
|
|
|
|
|
|
padding = 20 - display_width
|
|
|
aligned_name = class_name + " " * max(0, padding)
|
|
|
|
|
|
print(f"{aligned_name} {prob:>10.4f} {pred:>8}")
|
|
|
|
|
|
if predictions[idx] == 1:
|
|
|
predicted_labels.append(class_name)
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print(" 预测标签汇总")
|
|
|
print("=" * 60)
|
|
|
|
|
|
if predicted_labels:
|
|
|
for label in predicted_labels:
|
|
|
print(f" • {label}")
|
|
|
else:
|
|
|
print(" 无预测标签(所有类别概率均低于阈值)")
|
|
|
|
|
|
print("=" * 60 + "\n")
|
|
|
|
|
|
return predicted_labels
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理')
|
|
|
parser.add_argument('--images', nargs='*', default=None,
|
|
|
help='图像路径列表 (支持多个图像)')
|
|
|
parser.add_argument('--image_dir', type=str, default=None,
|
|
|
help='图像文件夹路径 (自动加载文件夹内所有图像)')
|
|
|
parser.add_argument('--checkpoint', type=str,
|
|
|
default='checkpoints/checkpoint_best.pth',
|
|
|
help='模型权重路径')
|
|
|
parser.add_argument('--threshold', type=float, default=0.5,
|
|
|
help='分类阈值 (default: 0.5)')
|
|
|
parser.add_argument('--device', type=str, default='auto',
|
|
|
help='设备: auto, cuda, cpu')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
image_paths = []
|
|
|
|
|
|
|
|
|
if args.images:
|
|
|
image_paths.extend(args.images)
|
|
|
|
|
|
|
|
|
if args.image_dir:
|
|
|
if not os.path.isdir(args.image_dir):
|
|
|
print(f"Error: Image directory not found: {args.image_dir}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
|
|
|
|
|
|
for filename in sorted(os.listdir(args.image_dir)):
|
|
|
ext = os.path.splitext(filename)[1].lower()
|
|
|
if ext in image_extensions:
|
|
|
image_paths.append(os.path.join(args.image_dir, filename))
|
|
|
|
|
|
print(f"Found {len(image_paths)} images in {args.image_dir}")
|
|
|
|
|
|
|
|
|
if not image_paths:
|
|
|
print("Error: No images specified. Use --images or --image_dir")
|
|
|
parser.print_help()
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
if args.device == 'auto':
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
else:
|
|
|
device = torch.device(args.device)
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
checkpoint_path = args.checkpoint
|
|
|
if not os.path.isabs(checkpoint_path):
|
|
|
checkpoint_path = os.path.join(script_dir, checkpoint_path)
|
|
|
|
|
|
|
|
|
model = load_model(checkpoint_path, device)
|
|
|
|
|
|
|
|
|
print(f"\nProcessing {len(image_paths)} image(s)...")
|
|
|
images_batch, valid_paths = preprocess_images(image_paths)
|
|
|
print(f"Successfully loaded {len(valid_paths)} image(s)")
|
|
|
|
|
|
|
|
|
probs, predictions = predict(model, images_batch, len(valid_paths),
|
|
|
device, args.threshold)
|
|
|
|
|
|
|
|
|
predicted_labels = format_results(probs, predictions, args.threshold)
|
|
|
|
|
|
|
|
|
return predicted_labels
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|