File size: 8,144 Bytes
343e05c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""

单步推理脚本 - 超声提示多标签分类模型

Single Case Inference for TransMIL + Query2Label Hybrid Model



用法:

    # 指定多个图像文件

    python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5

    

    # 指定图像文件夹

    python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5

"""

import os
import sys
import argparse
import torch
import numpy as np
from PIL import Image
from torchvision import transforms

# 添加当前目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from models.transmil_q2l import TransMIL_Query2Label_E2E

# 17类标签定义
TARGET_CLASSES = [
    "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级", 
    "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
    "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", 
    "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
]


def load_model(checkpoint_path: str, device: torch.device):
    """加载预训练模型"""
    print(f"Loading model from: {checkpoint_path}")
    
    # 初始化模型
    model = TransMIL_Query2Label_E2E(
        num_class=17,
        hidden_dim=512,
        nheads=8,
        num_decoder_layers=2,
        pretrained_resnet=False,  # 推理时不需要下载预训练权重
        use_checkpointing=False,  # 推理时不需要 checkpointing
        use_ppeg=False
    )
    
    # 加载权重
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    state_dict = checkpoint['model_state_dict']
    
    # 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("module.", "") 
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    
    model.to(device)
    model.eval()
    print("Model loaded successfully!")
    return model


def preprocess_images(image_paths: list, img_size: int = 224):
    """预处理图像"""
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    images = []
    valid_paths = []
    
    for path in image_paths:
        if not os.path.exists(path):
            print(f"Warning: Image not found: {path}")
            continue
        try:
            img = Image.open(path).convert('RGB')
            img_tensor = transform(img)
            images.append(img_tensor)
            valid_paths.append(path)
        except Exception as e:
            print(f"Warning: Failed to load image {path}: {e}")
            continue
    
    if len(images) == 0:
        raise ValueError("No valid images found!")
    
    # Stack to batch: [N, C, H, W] - 模型期望直接的图像堆叠,不需要额外的batch维度
    images_batch = torch.stack(images, dim=0)
    
    return images_batch, valid_paths


def predict(model, images_batch: torch.Tensor, num_images: int, 

            device: torch.device, threshold: float = 0.5):
    """执行推理"""
    images_batch = images_batch.to(device)
    
    with torch.no_grad():
        # Forward pass
        logits = model(images_batch, [num_images])
        probs = torch.sigmoid(logits).cpu().numpy()[0]  # [num_class]
    
    # 根据阈值获取预测标签
    predictions = (probs >= threshold).astype(int)
    
    return probs, predictions


def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float):
    """格式化输出结果"""
    print("\n" + "=" * 60)
    print(" 超声提示多标签分类结果")
    print("=" * 60)
    print(f" 阈值 (Threshold): {threshold}")
    print("-" * 60)
    
    # 按概率排序
    sorted_indices = np.argsort(probs)[::-1]
    
    print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}")
    print("-" * 40)
    
    predicted_labels = []
    for idx in sorted_indices:
        class_name = TARGET_CLASSES[idx]
        prob = probs[idx]
        pred = "✓" if predictions[idx] == 1 else ""
        
        # 使用 GBK 编码计算显示宽度
        try:
            display_width = len(class_name.encode('gbk'))
        except:
            display_width = len(class_name) * 2
        
        padding = 20 - display_width
        aligned_name = class_name + " " * max(0, padding)
        
        print(f"{aligned_name} {prob:>10.4f} {pred:>8}")
        
        if predictions[idx] == 1:
            predicted_labels.append(class_name)
    
    print("\n" + "=" * 60)
    print(" 预测标签汇总")
    print("=" * 60)
    
    if predicted_labels:
        for label in predicted_labels:
            print(f"  • {label}")
    else:
        print("  无预测标签(所有类别概率均低于阈值)")
    
    print("=" * 60 + "\n")
    
    return predicted_labels


def main():
    parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理')
    parser.add_argument('--images', nargs='*', default=None,
                        help='图像路径列表 (支持多个图像)')
    parser.add_argument('--image_dir', type=str, default=None,
                        help='图像文件夹路径 (自动加载文件夹内所有图像)')
    parser.add_argument('--checkpoint', type=str, 
                        default='checkpoints/checkpoint_best.pth',
                        help='模型权重路径')
    parser.add_argument('--threshold', type=float, default=0.5,
                        help='分类阈值 (default: 0.5)')
    parser.add_argument('--device', type=str, default='auto',
                        help='设备: auto, cuda, cpu')
    
    args = parser.parse_args()
    
    # 收集图像路径
    image_paths = []
    
    # 从 --images 参数收集
    if args.images:
        image_paths.extend(args.images)
    
    # 从 --image_dir 参数收集
    if args.image_dir:
        if not os.path.isdir(args.image_dir):
            print(f"Error: Image directory not found: {args.image_dir}")
            sys.exit(1)
        
        # 支持的图像格式
        image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
        
        for filename in sorted(os.listdir(args.image_dir)):
            ext = os.path.splitext(filename)[1].lower()
            if ext in image_extensions:
                image_paths.append(os.path.join(args.image_dir, filename))
        
        print(f"Found {len(image_paths)} images in {args.image_dir}")
    
    # 检查是否有图像输入
    if not image_paths:
        print("Error: No images specified. Use --images or --image_dir")
        parser.print_help()
        sys.exit(1)
    
    # 设置设备
    if args.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)
    print(f"Using device: {device}")
    
    # 处理相对路径
    script_dir = os.path.dirname(os.path.abspath(__file__))
    checkpoint_path = args.checkpoint
    if not os.path.isabs(checkpoint_path):
        checkpoint_path = os.path.join(script_dir, checkpoint_path)
    
    # 加载模型
    model = load_model(checkpoint_path, device)
    
    # 预处理图像
    print(f"\nProcessing {len(image_paths)} image(s)...")
    images_batch, valid_paths = preprocess_images(image_paths)
    print(f"Successfully loaded {len(valid_paths)} image(s)")
    
    # 推理
    probs, predictions = predict(model, images_batch, len(valid_paths), 
                                  device, args.threshold)
    
    # 输出结果
    predicted_labels = format_results(probs, predictions, args.threshold)
    
    # 返回预测标签列表(供程序调用)
    return predicted_labels


if __name__ == "__main__":
    main()