Doul0414 commited on
Commit
343e05c
·
verified ·
1 Parent(s): facde7c

Initial upload: HintsPrediction

Browse files
README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 超声提示多标签分类模型
2
+
3
+ 基于 **TransMIL + Query2Label** 混合架构的甲状腺超声图像多标签分类模型。
4
+
5
+ ## 模型架构
6
+
7
+ - **Backbone**: ResNet-50 (预训练)
8
+ - **特征聚合**: TransMIL (Nystrom Attention)
9
+ - **多标签分类**: Query2Label (Transformer Decoder)
10
+ - **损失函数**: Asymmetric Loss (处理类别不平衡)
11
+
12
+ ## 17类标签
13
+
14
+ | 序号 | 标签 | 序号 | 标签 |
15
+ |:---:|:---|:---:|:---|
16
+ | 1 | TI-RADS 1级 | 10 | 囊肿 |
17
+ | 2 | TI-RADS 2级 | 11 | 淋巴结 |
18
+ | 3 | TI-RADS 3级 | 12 | 胶质潴留 |
19
+ | 4 | TI-RADS 4a级 | 13 | 弥漫性病变 |
20
+ | 5 | TI-RADS 4b级 | 14 | 结节性甲状腺肿 |
21
+ | 6 | TI-RADS 4c级 | 15 | 桥本氏甲状腺炎 |
22
+ | 7 | TI-RADS 5级 | 16 | 反应性 |
23
+ | 8 | 钙化 | 17 | 转移性 |
24
+ | 9 | 甲亢 | | |
25
+
26
+ ## 目录结构
27
+
28
+ ```
29
+ HintsPrediction/
30
+ ├── README.md # 本文件
31
+ ├── requirements.txt # 依赖列表
32
+ ├── config.yaml # 配置文件(需修改路径)
33
+ ├── models/ # 模型代码
34
+ │ ├── __init__.py
35
+ │ ├── transmil_q2l.py # 主模型架构
36
+ │ ├── transformer.py # Transformer 组件
37
+ │ └── aslloss.py # 损失函数
38
+ ├── checkpoints/
39
+ │ └── checkpoint_best.pth # 最佳模型权重
40
+ ├── scripts/ # 懒人脚本
41
+ │ ├── train.sh
42
+ │ ├── evaluate.sh
43
+ │ └── infer_single.sh
44
+ ├── train_hybrid.py # 训练代码
45
+ ├── evaluate.py # 评估代码
46
+ ├── thyroid_dataset.py # 数据集加载
47
+ └── infer_single_case.py # 单步推理代码
48
+ ```
49
+
50
+ ## 快速开始
51
+
52
+ ### 1. 环境配置
53
+
54
+ ```bash
55
+ # 安装依赖
56
+ pip install -r requirements.txt
57
+ ```
58
+
59
+ ### 2. 单步推理
60
+
61
+ ```bash
62
+ 用法:
63
+ # 指定多个图像文件
64
+ python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
65
+
66
+ # 指定图像文件夹
67
+ python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
68
+
69
+ # 或使用脚本
70
+ 用法1: bash scripts/infer_single.sh /path/to/image1.png /path/to/image2.png ...
71
+ 用法2: bash scripts/infer_single.sh --image_dir /path/to/case_folder/
72
+ ```
73
+
74
+ ### 3. 评估模型
75
+
76
+ ```bash
77
+ # 先修改 config.yaml 中的数据路径
78
+ # 然后运行评估
79
+ python evaluate.py
80
+ # 或
81
+ bash scripts/evaluate.sh
82
+ ```
83
+
84
+ ### 4. 训练模型
85
+
86
+ ```bash
87
+ # 先修改 config.yaml 中的数据路径
88
+ python train_hybrid.py --config config.yaml
89
+ # 或
90
+ bash scripts/train.sh
91
+ ```
92
+
93
+ ## 配置说明
94
+
95
+ 使用前请修改 `config.yaml` 中的数据路径:
96
+
97
+ ```yaml
98
+ data:
99
+ data_root: "/path/to/your/ReportData_ROI/"
100
+ annotation_csv: "/path/to/your/thyroid_multilabel_annotations.csv"
101
+ val_json: "/path/to/your/classification_val_set_single.json"
102
+ test_json: "/path/to/your/classification_test_set_single.json"
103
+ ```
104
+
105
+ ## 性能指标
106
+
107
+ 在测试集上的性能(请参考 `checkpoints/evaluation_report.csv`)
108
+
109
+ ## 注意事项
110
+
111
+ 1. 推理时需要 GPU(推荐),CPU 也可运行但较慢
112
+ 2. 单病例可输入多张图像,模型会自动聚合特征
113
+ 3. 默认阈值为 0.5,可根据需要调整
checkpoints/checkpoint_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcbcfb29005eaf6d02bcb8c94be860d885f6fecf9cdcea40be38e7d96050f44b
3
+ size 416241673
classification_test_set_single.json ADDED
The diff for this file is too large to render. See raw diff
 
classification_val_set_single.json ADDED
@@ -0,0 +1,1688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "type": "single",
4
+ "id": "Batch10_20250506_P323",
5
+ "rel_path": "Batch10/20250506_P323",
6
+ "score": 1.0
7
+ },
8
+ {
9
+ "type": "single",
10
+ "id": "Batch10_20250506_P424",
11
+ "rel_path": "Batch10/20250506_P424",
12
+ "score": 1.0
13
+ },
14
+ {
15
+ "type": "single",
16
+ "id": "Batch10_20250504_P199",
17
+ "rel_path": "Batch10/20250504_P199",
18
+ "score": 1.0
19
+ },
20
+ {
21
+ "type": "single",
22
+ "id": "Batch10_20250508_P829",
23
+ "rel_path": "Batch10/20250508_P829",
24
+ "score": 1.0
25
+ },
26
+ {
27
+ "type": "single",
28
+ "id": "Batch10_20250501_P34",
29
+ "rel_path": "Batch10/20250501_P34",
30
+ "score": 1.0
31
+ },
32
+ {
33
+ "type": "single",
34
+ "id": "Batch10_20250512_P1380",
35
+ "rel_path": "Batch10/20250512_P1380",
36
+ "score": 1.0
37
+ },
38
+ {
39
+ "type": "single",
40
+ "id": "Batch10_20250506_P405",
41
+ "rel_path": "Batch10/20250506_P405",
42
+ "score": 1.0
43
+ },
44
+ {
45
+ "type": "single",
46
+ "id": "Batch10_20250502_P48",
47
+ "rel_path": "Batch10/20250502_P48",
48
+ "score": 1.0
49
+ },
50
+ {
51
+ "type": "single",
52
+ "id": "Batch10_20250510_P1172",
53
+ "rel_path": "Batch10/20250510_P1172",
54
+ "score": 1.0
55
+ },
56
+ {
57
+ "type": "single",
58
+ "id": "Batch10_20250504_P119",
59
+ "rel_path": "Batch10/20250504_P119",
60
+ "score": 1.0
61
+ },
62
+ {
63
+ "type": "single",
64
+ "id": "Batch10_20250504_P140",
65
+ "rel_path": "Batch10/20250504_P140",
66
+ "score": 1.0
67
+ },
68
+ {
69
+ "type": "single",
70
+ "id": "Batch10_20250509_P1013",
71
+ "rel_path": "Batch10/20250509_P1013",
72
+ "score": 1.0
73
+ },
74
+ {
75
+ "type": "single",
76
+ "id": "Batch10_20250508_P727",
77
+ "rel_path": "Batch10/20250508_P727",
78
+ "score": 1.0
79
+ },
80
+ {
81
+ "type": "single",
82
+ "id": "Batch10_20250507_P656",
83
+ "rel_path": "Batch10/20250507_P656",
84
+ "score": 1.0
85
+ },
86
+ {
87
+ "type": "single",
88
+ "id": "Batch10_20250508_P720",
89
+ "rel_path": "Batch10/20250508_P720",
90
+ "score": 1.0
91
+ },
92
+ {
93
+ "type": "single",
94
+ "id": "Batch10_20250507_P647",
95
+ "rel_path": "Batch10/20250507_P647",
96
+ "score": 1.0
97
+ },
98
+ {
99
+ "type": "single",
100
+ "id": "Batch10_20250510_P1122",
101
+ "rel_path": "Batch10/20250510_P1122",
102
+ "score": 1.0
103
+ },
104
+ {
105
+ "type": "single",
106
+ "id": "Batch10_20250510_P1182",
107
+ "rel_path": "Batch10/20250510_P1182",
108
+ "score": 1.0
109
+ },
110
+ {
111
+ "type": "single",
112
+ "id": "Batch10_20250501_P40",
113
+ "rel_path": "Batch10/20250501_P40",
114
+ "score": 1.0
115
+ },
116
+ {
117
+ "type": "single",
118
+ "id": "Batch10_20250505_P254",
119
+ "rel_path": "Batch10/20250505_P254",
120
+ "score": 1.0
121
+ },
122
+ {
123
+ "type": "single",
124
+ "id": "Batch10_20250508_P810",
125
+ "rel_path": "Batch10/20250508_P810",
126
+ "score": 1.0
127
+ },
128
+ {
129
+ "type": "single",
130
+ "id": "Batch10_20250507_P636",
131
+ "rel_path": "Batch10/20250507_P636",
132
+ "score": 1.0
133
+ },
134
+ {
135
+ "type": "single",
136
+ "id": "Batch10_20250507_P557",
137
+ "rel_path": "Batch10/20250507_P557",
138
+ "score": 1.0
139
+ },
140
+ {
141
+ "type": "single",
142
+ "id": "Batch10_20250512_P1272",
143
+ "rel_path": "Batch10/20250512_P1272",
144
+ "score": 1.0
145
+ },
146
+ {
147
+ "type": "single",
148
+ "id": "Batch10_20250512_P1306",
149
+ "rel_path": "Batch10/20250512_P1306",
150
+ "score": 1.0
151
+ },
152
+ {
153
+ "type": "single",
154
+ "id": "Batch10_20250506_P311",
155
+ "rel_path": "Batch10/20250506_P311",
156
+ "score": 1.0
157
+ },
158
+ {
159
+ "type": "single",
160
+ "id": "Batch10_20250505_P258",
161
+ "rel_path": "Batch10/20250505_P258",
162
+ "score": 1.0
163
+ },
164
+ {
165
+ "type": "single",
166
+ "id": "Batch10_20250501_P25",
167
+ "rel_path": "Batch10/20250501_P25",
168
+ "score": 1.0
169
+ },
170
+ {
171
+ "type": "single",
172
+ "id": "Batch10_20250512_P1273",
173
+ "rel_path": "Batch10/20250512_P1273",
174
+ "score": 1.0
175
+ },
176
+ {
177
+ "type": "single",
178
+ "id": "Batch10_20250512_P1277",
179
+ "rel_path": "Batch10/20250512_P1277",
180
+ "score": 1.0
181
+ },
182
+ {
183
+ "type": "single",
184
+ "id": "Batch10_20250506_P350",
185
+ "rel_path": "Batch10/20250506_P350",
186
+ "score": 1.0
187
+ },
188
+ {
189
+ "type": "single",
190
+ "id": "Batch10_20250512_P1289",
191
+ "rel_path": "Batch10/20250512_P1289",
192
+ "score": 1.0
193
+ },
194
+ {
195
+ "type": "single",
196
+ "id": "Batch10_20250512_P1386",
197
+ "rel_path": "Batch10/20250512_P1386",
198
+ "score": 1.0
199
+ },
200
+ {
201
+ "type": "single",
202
+ "id": "Batch10_20250509_P933",
203
+ "rel_path": "Batch10/20250509_P933",
204
+ "score": 1.0
205
+ },
206
+ {
207
+ "type": "single",
208
+ "id": "Batch10_20250508_P842",
209
+ "rel_path": "Batch10/20250508_P842",
210
+ "score": 1.0
211
+ },
212
+ {
213
+ "type": "single",
214
+ "id": "Batch10_20250509_P974",
215
+ "rel_path": "Batch10/20250509_P974",
216
+ "score": 1.0
217
+ },
218
+ {
219
+ "type": "single",
220
+ "id": "Batch10_20250511_P1235",
221
+ "rel_path": "Batch10/20250511_P1235",
222
+ "score": 1.0
223
+ },
224
+ {
225
+ "type": "single",
226
+ "id": "Batch10_20250501_P28",
227
+ "rel_path": "Batch10/20250501_P28",
228
+ "score": 1.0
229
+ },
230
+ {
231
+ "type": "single",
232
+ "id": "Batch10_20250511_P1230",
233
+ "rel_path": "Batch10/20250511_P1230",
234
+ "score": 1.0
235
+ },
236
+ {
237
+ "type": "single",
238
+ "id": "Batch10_20250505_P279",
239
+ "rel_path": "Batch10/20250505_P279",
240
+ "score": 1.0
241
+ },
242
+ {
243
+ "type": "single",
244
+ "id": "Batch10_20250511_P1228",
245
+ "rel_path": "Batch10/20250511_P1228",
246
+ "score": 1.0
247
+ },
248
+ {
249
+ "type": "single",
250
+ "id": "Batch10_20250506_P351",
251
+ "rel_path": "Batch10/20250506_P351",
252
+ "score": 1.0
253
+ },
254
+ {
255
+ "type": "single",
256
+ "id": "Batch10_20250504_P195",
257
+ "rel_path": "Batch10/20250504_P195",
258
+ "score": 1.0
259
+ },
260
+ {
261
+ "type": "single",
262
+ "id": "Batch10_20250508_P825",
263
+ "rel_path": "Batch10/20250508_P825",
264
+ "score": 1.0
265
+ },
266
+ {
267
+ "type": "single",
268
+ "id": "Batch10_20250507_P596",
269
+ "rel_path": "Batch10/20250507_P596",
270
+ "score": 1.0
271
+ },
272
+ {
273
+ "type": "single",
274
+ "id": "Batch10_20250507_P539",
275
+ "rel_path": "Batch10/20250507_P539",
276
+ "score": 1.0
277
+ },
278
+ {
279
+ "type": "single",
280
+ "id": "Batch10_20250512_P1333",
281
+ "rel_path": "Batch10/20250512_P1333",
282
+ "score": 1.0
283
+ },
284
+ {
285
+ "type": "single",
286
+ "id": "Batch10_20250508_P706",
287
+ "rel_path": "Batch10/20250508_P706",
288
+ "score": 1.0
289
+ },
290
+ {
291
+ "type": "single",
292
+ "id": "Batch10_20250509_P902",
293
+ "rel_path": "Batch10/20250509_P902",
294
+ "score": 1.0
295
+ },
296
+ {
297
+ "type": "single",
298
+ "id": "Batch10_20250512_P1303",
299
+ "rel_path": "Batch10/20250512_P1303",
300
+ "score": 1.0
301
+ },
302
+ {
303
+ "type": "single",
304
+ "id": "Batch10_20250509_P913",
305
+ "rel_path": "Batch10/20250509_P913",
306
+ "score": 1.0
307
+ },
308
+ {
309
+ "type": "single",
310
+ "id": "Batch10_20250506_P300",
311
+ "rel_path": "Batch10/20250506_P300",
312
+ "score": 1.0
313
+ },
314
+ {
315
+ "type": "single",
316
+ "id": "Batch10_20250502_P66",
317
+ "rel_path": "Batch10/20250502_P66",
318
+ "score": 1.0
319
+ },
320
+ {
321
+ "type": "single",
322
+ "id": "Batch10_20250508_P726",
323
+ "rel_path": "Batch10/20250508_P726",
324
+ "score": 1.0
325
+ },
326
+ {
327
+ "type": "single",
328
+ "id": "Batch10_20250505_P222",
329
+ "rel_path": "Batch10/20250505_P222",
330
+ "score": 1.0
331
+ },
332
+ {
333
+ "type": "single",
334
+ "id": "Batch10_20250510_P1165",
335
+ "rel_path": "Batch10/20250510_P1165",
336
+ "score": 1.0
337
+ },
338
+ {
339
+ "type": "single",
340
+ "id": "Batch10_20250506_P348",
341
+ "rel_path": "Batch10/20250506_P348",
342
+ "score": 1.0
343
+ },
344
+ {
345
+ "type": "single",
346
+ "id": "Batch10_20250507_P678",
347
+ "rel_path": "Batch10/20250507_P678",
348
+ "score": 1.0
349
+ },
350
+ {
351
+ "type": "single",
352
+ "id": "Batch10_20250504_P169",
353
+ "rel_path": "Batch10/20250504_P169",
354
+ "score": 1.0
355
+ },
356
+ {
357
+ "type": "single",
358
+ "id": "Batch10_20250506_P473",
359
+ "rel_path": "Batch10/20250506_P473",
360
+ "score": 1.0
361
+ },
362
+ {
363
+ "type": "single",
364
+ "id": "Batch10_20250506_P427",
365
+ "rel_path": "Batch10/20250506_P427",
366
+ "score": 1.0
367
+ },
368
+ {
369
+ "type": "single",
370
+ "id": "Batch10_20250508_P711",
371
+ "rel_path": "Batch10/20250508_P711",
372
+ "score": 1.0
373
+ },
374
+ {
375
+ "type": "single",
376
+ "id": "Batch10_20250508_P738",
377
+ "rel_path": "Batch10/20250508_P738",
378
+ "score": 1.0
379
+ },
380
+ {
381
+ "type": "single",
382
+ "id": "Batch10_20250507_P655",
383
+ "rel_path": "Batch10/20250507_P655",
384
+ "score": 1.0
385
+ },
386
+ {
387
+ "type": "single",
388
+ "id": "Batch10_20250508_P765",
389
+ "rel_path": "Batch10/20250508_P765",
390
+ "score": 1.0
391
+ },
392
+ {
393
+ "type": "single",
394
+ "id": "Batch10_20250507_P658",
395
+ "rel_path": "Batch10/20250507_P658",
396
+ "score": 1.0
397
+ },
398
+ {
399
+ "type": "single",
400
+ "id": "Batch10_20250510_P1109",
401
+ "rel_path": "Batch10/20250510_P1109",
402
+ "score": 1.0
403
+ },
404
+ {
405
+ "type": "single",
406
+ "id": "Batch10_20250507_P629",
407
+ "rel_path": "Batch10/20250507_P629",
408
+ "score": 1.0
409
+ },
410
+ {
411
+ "type": "single",
412
+ "id": "Batch10_20250512_P1377",
413
+ "rel_path": "Batch10/20250512_P1377",
414
+ "score": 1.0
415
+ },
416
+ {
417
+ "type": "single",
418
+ "id": "Batch10_20250512_P1401",
419
+ "rel_path": "Batch10/20250512_P1401",
420
+ "score": 1.0
421
+ },
422
+ {
423
+ "type": "single",
424
+ "id": "Batch10_20250502_P62",
425
+ "rel_path": "Batch10/20250502_P62",
426
+ "score": 1.0
427
+ },
428
+ {
429
+ "type": "single",
430
+ "id": "Batch10_20250507_P659",
431
+ "rel_path": "Batch10/20250507_P659",
432
+ "score": 1.0
433
+ },
434
+ {
435
+ "type": "single",
436
+ "id": "Batch10_20250505_P275",
437
+ "rel_path": "Batch10/20250505_P275",
438
+ "score": 1.0
439
+ },
440
+ {
441
+ "type": "single",
442
+ "id": "Batch10_20250501_P21",
443
+ "rel_path": "Batch10/20250501_P21",
444
+ "score": 1.0
445
+ },
446
+ {
447
+ "type": "single",
448
+ "id": "Batch10_20250512_P1412",
449
+ "rel_path": "Batch10/20250512_P1412",
450
+ "score": 1.0
451
+ },
452
+ {
453
+ "type": "single",
454
+ "id": "Batch10_20250507_P514",
455
+ "rel_path": "Batch10/20250507_P514",
456
+ "score": 1.0
457
+ },
458
+ {
459
+ "type": "single",
460
+ "id": "Batch10_20250504_P179",
461
+ "rel_path": "Batch10/20250504_P179",
462
+ "score": 1.0
463
+ },
464
+ {
465
+ "type": "single",
466
+ "id": "Batch10_20250504_P188",
467
+ "rel_path": "Batch10/20250504_P188",
468
+ "score": 1.0
469
+ },
470
+ {
471
+ "type": "single",
472
+ "id": "Batch10_20250512_P1249",
473
+ "rel_path": "Batch10/20250512_P1249",
474
+ "score": 1.0
475
+ },
476
+ {
477
+ "type": "single",
478
+ "id": "Batch10_20250509_P1069",
479
+ "rel_path": "Batch10/20250509_P1069",
480
+ "score": 1.0
481
+ },
482
+ {
483
+ "type": "single",
484
+ "id": "Batch10_20250512_P1391",
485
+ "rel_path": "Batch10/20250512_P1391",
486
+ "score": 1.0
487
+ },
488
+ {
489
+ "type": "single",
490
+ "id": "Batch10_20250506_P480",
491
+ "rel_path": "Batch10/20250506_P480",
492
+ "score": 1.0
493
+ },
494
+ {
495
+ "type": "single",
496
+ "id": "Batch10_20250504_P167",
497
+ "rel_path": "Batch10/20250504_P167",
498
+ "score": 1.0
499
+ },
500
+ {
501
+ "type": "single",
502
+ "id": "Batch10_20250503_P74",
503
+ "rel_path": "Batch10/20250503_P74",
504
+ "score": 1.0
505
+ },
506
+ {
507
+ "type": "single",
508
+ "id": "Batch10_20250505_P223",
509
+ "rel_path": "Batch10/20250505_P223",
510
+ "score": 1.0
511
+ },
512
+ {
513
+ "type": "single",
514
+ "id": "Batch10_20250508_P833",
515
+ "rel_path": "Batch10/20250508_P833",
516
+ "score": 1.0
517
+ },
518
+ {
519
+ "type": "single",
520
+ "id": "Batch10_20250512_P1336",
521
+ "rel_path": "Batch10/20250512_P1336",
522
+ "score": 1.0
523
+ },
524
+ {
525
+ "type": "single",
526
+ "id": "Batch10_20250508_P775",
527
+ "rel_path": "Batch10/20250508_P775",
528
+ "score": 1.0
529
+ },
530
+ {
531
+ "type": "single",
532
+ "id": "Batch10_20250507_P553",
533
+ "rel_path": "Batch10/20250507_P553",
534
+ "score": 1.0
535
+ },
536
+ {
537
+ "type": "single",
538
+ "id": "Batch10_20250502_P59",
539
+ "rel_path": "Batch10/20250502_P59",
540
+ "score": 1.0
541
+ },
542
+ {
543
+ "type": "single",
544
+ "id": "Batch10_20250512_P1295",
545
+ "rel_path": "Batch10/20250512_P1295",
546
+ "score": 1.0
547
+ },
548
+ {
549
+ "type": "single",
550
+ "id": "Batch10_20250505_P220",
551
+ "rel_path": "Batch10/20250505_P220",
552
+ "score": 1.0
553
+ },
554
+ {
555
+ "type": "single",
556
+ "id": "Batch10_20250503_P77",
557
+ "rel_path": "Batch10/20250503_P77",
558
+ "score": 1.0
559
+ },
560
+ {
561
+ "type": "single",
562
+ "id": "Batch10_20250511_P1231",
563
+ "rel_path": "Batch10/20250511_P1231",
564
+ "score": 1.0
565
+ },
566
+ {
567
+ "type": "single",
568
+ "id": "Batch10_20250506_P389",
569
+ "rel_path": "Batch10/20250506_P389",
570
+ "score": 1.0
571
+ },
572
+ {
573
+ "type": "single",
574
+ "id": "Batch10_20250507_P573",
575
+ "rel_path": "Batch10/20250507_P573",
576
+ "score": 1.0
577
+ },
578
+ {
579
+ "type": "single",
580
+ "id": "Batch10_20250506_P385",
581
+ "rel_path": "Batch10/20250506_P385",
582
+ "score": 1.0
583
+ },
584
+ {
585
+ "type": "single",
586
+ "id": "Batch10_20250508_P782",
587
+ "rel_path": "Batch10/20250508_P782",
588
+ "score": 1.0
589
+ },
590
+ {
591
+ "type": "single",
592
+ "id": "Batch10_20250508_P851",
593
+ "rel_path": "Batch10/20250508_P851",
594
+ "score": 1.0
595
+ },
596
+ {
597
+ "type": "single",
598
+ "id": "Batch10_20250506_P481",
599
+ "rel_path": "Batch10/20250506_P481",
600
+ "score": 1.0
601
+ },
602
+ {
603
+ "type": "single",
604
+ "id": "Batch10_20250511_P1243",
605
+ "rel_path": "Batch10/20250511_P1243",
606
+ "score": 1.0
607
+ },
608
+ {
609
+ "type": "single",
610
+ "id": "Batch10_20250509_P956",
611
+ "rel_path": "Batch10/20250509_P956",
612
+ "score": 1.0
613
+ },
614
+ {
615
+ "type": "single",
616
+ "id": "Batch10_20250506_P341",
617
+ "rel_path": "Batch10/20250506_P341",
618
+ "score": 1.0
619
+ },
620
+ {
621
+ "type": "single",
622
+ "id": "Batch10_20250508_P752",
623
+ "rel_path": "Batch10/20250508_P752",
624
+ "score": 1.0
625
+ },
626
+ {
627
+ "type": "single",
628
+ "id": "Batch10_20250509_P861",
629
+ "rel_path": "Batch10/20250509_P861",
630
+ "score": 1.0
631
+ },
632
+ {
633
+ "type": "single",
634
+ "id": "Batch10_20250512_P1320",
635
+ "rel_path": "Batch10/20250512_P1320",
636
+ "score": 1.0
637
+ },
638
+ {
639
+ "type": "single",
640
+ "id": "Batch10_20250501_P12",
641
+ "rel_path": "Batch10/20250501_P12",
642
+ "score": 1.0
643
+ },
644
+ {
645
+ "type": "single",
646
+ "id": "Batch10_20250512_P1312",
647
+ "rel_path": "Batch10/20250512_P1312",
648
+ "score": 1.0
649
+ },
650
+ {
651
+ "type": "single",
652
+ "id": "Batch10_20250507_P588",
653
+ "rel_path": "Batch10/20250507_P588",
654
+ "score": 1.0
655
+ },
656
+ {
657
+ "type": "single",
658
+ "id": "Batch10_20250509_P904",
659
+ "rel_path": "Batch10/20250509_P904",
660
+ "score": 1.0
661
+ },
662
+ {
663
+ "type": "single",
664
+ "id": "Batch10_20250512_P1370",
665
+ "rel_path": "Batch10/20250512_P1370",
666
+ "score": 1.0
667
+ },
668
+ {
669
+ "type": "single",
670
+ "id": "Batch10_20250507_P565",
671
+ "rel_path": "Batch10/20250507_P565",
672
+ "score": 1.0
673
+ },
674
+ {
675
+ "type": "single",
676
+ "id": "Batch10_20250506_P412",
677
+ "rel_path": "Batch10/20250506_P412",
678
+ "score": 1.0
679
+ },
680
+ {
681
+ "type": "single",
682
+ "id": "Batch10_20250505_P267",
683
+ "rel_path": "Batch10/20250505_P267",
684
+ "score": 1.0
685
+ },
686
+ {
687
+ "type": "single",
688
+ "id": "Batch10_20250508_P849",
689
+ "rel_path": "Batch10/20250508_P849",
690
+ "score": 1.0
691
+ },
692
+ {
693
+ "type": "single",
694
+ "id": "Batch10_20250509_P921",
695
+ "rel_path": "Batch10/20250509_P921",
696
+ "score": 1.0
697
+ },
698
+ {
699
+ "type": "single",
700
+ "id": "Batch10_20250505_P257",
701
+ "rel_path": "Batch10/20250505_P257",
702
+ "score": 1.0
703
+ },
704
+ {
705
+ "type": "single",
706
+ "id": "Batch10_20250511_P1242",
707
+ "rel_path": "Batch10/20250511_P1242",
708
+ "score": 1.0
709
+ },
710
+ {
711
+ "type": "single",
712
+ "id": "Batch10_20250506_P453",
713
+ "rel_path": "Batch10/20250506_P453",
714
+ "score": 1.0
715
+ },
716
+ {
717
+ "type": "single",
718
+ "id": "Batch10_20250509_P892",
719
+ "rel_path": "Batch10/20250509_P892",
720
+ "score": 1.0
721
+ },
722
+ {
723
+ "type": "single",
724
+ "id": "Batch10_20250504_P116",
725
+ "rel_path": "Batch10/20250504_P116",
726
+ "score": 1.0
727
+ },
728
+ {
729
+ "type": "single",
730
+ "id": "Batch10_20250507_P531",
731
+ "rel_path": "Batch10/20250507_P531",
732
+ "score": 1.0
733
+ },
734
+ {
735
+ "type": "single",
736
+ "id": "Batch10_20250503_P98",
737
+ "rel_path": "Batch10/20250503_P98",
738
+ "score": 1.0
739
+ },
740
+ {
741
+ "type": "single",
742
+ "id": "Batch10_20250509_P926",
743
+ "rel_path": "Batch10/20250509_P926",
744
+ "score": 1.0
745
+ },
746
+ {
747
+ "type": "single",
748
+ "id": "Batch10_20250510_P1119",
749
+ "rel_path": "Batch10/20250510_P1119",
750
+ "score": 1.0
751
+ },
752
+ {
753
+ "type": "single",
754
+ "id": "Batch10_20250505_P240",
755
+ "rel_path": "Batch10/20250505_P240",
756
+ "score": 1.0
757
+ },
758
+ {
759
+ "type": "single",
760
+ "id": "Batch10_20250504_P129",
761
+ "rel_path": "Batch10/20250504_P129",
762
+ "score": 1.0
763
+ },
764
+ {
765
+ "type": "single",
766
+ "id": "Batch10_20250508_P710",
767
+ "rel_path": "Batch10/20250508_P710",
768
+ "score": 1.0
769
+ },
770
+ {
771
+ "type": "single",
772
+ "id": "Batch10_20250511_P1220",
773
+ "rel_path": "Batch10/20250511_P1220",
774
+ "score": 1.0
775
+ },
776
+ {
777
+ "type": "single",
778
+ "id": "Batch10_20250503_P103",
779
+ "rel_path": "Batch10/20250503_P103",
780
+ "score": 1.0
781
+ },
782
+ {
783
+ "type": "single",
784
+ "id": "Batch10_20250509_P878",
785
+ "rel_path": "Batch10/20250509_P878",
786
+ "score": 1.0
787
+ },
788
+ {
789
+ "type": "single",
790
+ "id": "Batch10_20250504_P172",
791
+ "rel_path": "Batch10/20250504_P172",
792
+ "score": 1.0
793
+ },
794
+ {
795
+ "type": "single",
796
+ "id": "Batch10_20250512_P1318",
797
+ "rel_path": "Batch10/20250512_P1318",
798
+ "score": 1.0
799
+ },
800
+ {
801
+ "type": "single",
802
+ "id": "Batch10_20250507_P530",
803
+ "rel_path": "Batch10/20250507_P530",
804
+ "score": 1.0
805
+ },
806
+ {
807
+ "type": "single",
808
+ "id": "Batch10_20250511_P1247",
809
+ "rel_path": "Batch10/20250511_P1247",
810
+ "score": 1.0
811
+ },
812
+ {
813
+ "type": "single",
814
+ "id": "Batch10_20250507_P677",
815
+ "rel_path": "Batch10/20250507_P677",
816
+ "score": 1.0
817
+ },
818
+ {
819
+ "type": "single",
820
+ "id": "Batch10_20250508_P694",
821
+ "rel_path": "Batch10/20250508_P694",
822
+ "score": 1.0
823
+ },
824
+ {
825
+ "type": "single",
826
+ "id": "Batch10_20250505_P246",
827
+ "rel_path": "Batch10/20250505_P246",
828
+ "score": 1.0
829
+ },
830
+ {
831
+ "type": "single",
832
+ "id": "Batch10_20250512_P1339",
833
+ "rel_path": "Batch10/20250512_P1339",
834
+ "score": 1.0
835
+ },
836
+ {
837
+ "type": "single",
838
+ "id": "Batch10_20250507_P666",
839
+ "rel_path": "Batch10/20250507_P666",
840
+ "score": 1.0
841
+ },
842
+ {
843
+ "type": "single",
844
+ "id": "Batch10_20250506_P433",
845
+ "rel_path": "Batch10/20250506_P433",
846
+ "score": 1.0
847
+ },
848
+ {
849
+ "type": "single",
850
+ "id": "Batch10_20250506_P467",
851
+ "rel_path": "Batch10/20250506_P467",
852
+ "score": 1.0
853
+ },
854
+ {
855
+ "type": "single",
856
+ "id": "Batch10_20250506_P451",
857
+ "rel_path": "Batch10/20250506_P451",
858
+ "score": 1.0
859
+ },
860
+ {
861
+ "type": "single",
862
+ "id": "Batch10_20250507_P552",
863
+ "rel_path": "Batch10/20250507_P552",
864
+ "score": 1.0
865
+ },
866
+ {
867
+ "type": "single",
868
+ "id": "Batch10_20250506_P431",
869
+ "rel_path": "Batch10/20250506_P431",
870
+ "score": 1.0
871
+ },
872
+ {
873
+ "type": "single",
874
+ "id": "Batch10_20250506_P291",
875
+ "rel_path": "Batch10/20250506_P291",
876
+ "score": 1.0
877
+ },
878
+ {
879
+ "type": "single",
880
+ "id": "Batch10_20250509_P987",
881
+ "rel_path": "Batch10/20250509_P987",
882
+ "score": 1.0
883
+ },
884
+ {
885
+ "type": "single",
886
+ "id": "Batch10_20250510_P1198",
887
+ "rel_path": "Batch10/20250510_P1198",
888
+ "score": 1.0
889
+ },
890
+ {
891
+ "type": "single",
892
+ "id": "Batch10_20250509_P1040",
893
+ "rel_path": "Batch10/20250509_P1040",
894
+ "score": 1.0
895
+ },
896
+ {
897
+ "type": "single",
898
+ "id": "Batch10_20250506_P483",
899
+ "rel_path": "Batch10/20250506_P483",
900
+ "score": 1.0
901
+ },
902
+ {
903
+ "type": "single",
904
+ "id": "Batch10_20250504_P183",
905
+ "rel_path": "Batch10/20250504_P183",
906
+ "score": 1.0
907
+ },
908
+ {
909
+ "type": "single",
910
+ "id": "Batch10_20250507_P569",
911
+ "rel_path": "Batch10/20250507_P569",
912
+ "score": 1.0
913
+ },
914
+ {
915
+ "type": "single",
916
+ "id": "Batch10_20250506_P485",
917
+ "rel_path": "Batch10/20250506_P485",
918
+ "score": 1.0
919
+ },
920
+ {
921
+ "type": "single",
922
+ "id": "Batch10_20250506_P363",
923
+ "rel_path": "Batch10/20250506_P363",
924
+ "score": 1.0
925
+ },
926
+ {
927
+ "type": "single",
928
+ "id": "Batch10_20250509_P945",
929
+ "rel_path": "Batch10/20250509_P945",
930
+ "score": 1.0
931
+ },
932
+ {
933
+ "type": "single",
934
+ "id": "Batch10_20250510_P1161",
935
+ "rel_path": "Batch10/20250510_P1161",
936
+ "score": 1.0
937
+ },
938
+ {
939
+ "type": "single",
940
+ "id": "Batch10_20250509_P976",
941
+ "rel_path": "Batch10/20250509_P976",
942
+ "score": 1.0
943
+ },
944
+ {
945
+ "type": "single",
946
+ "id": "Batch10_20250508_P809",
947
+ "rel_path": "Batch10/20250508_P809",
948
+ "score": 1.0
949
+ },
950
+ {
951
+ "type": "single",
952
+ "id": "Batch10_20250506_P489",
953
+ "rel_path": "Batch10/20250506_P489",
954
+ "score": 1.0
955
+ },
956
+ {
957
+ "type": "single",
958
+ "id": "Batch10_20250508_P687",
959
+ "rel_path": "Batch10/20250508_P687",
960
+ "score": 1.0
961
+ },
962
+ {
963
+ "type": "single",
964
+ "id": "Batch10_20250507_P603",
965
+ "rel_path": "Batch10/20250507_P603",
966
+ "score": 1.0
967
+ },
968
+ {
969
+ "type": "single",
970
+ "id": "Batch10_20250502_P69",
971
+ "rel_path": "Batch10/20250502_P69",
972
+ "score": 1.0
973
+ },
974
+ {
975
+ "type": "single",
976
+ "id": "Batch10_20250509_P1062",
977
+ "rel_path": "Batch10/20250509_P1062",
978
+ "score": 1.0
979
+ },
980
+ {
981
+ "type": "single",
982
+ "id": "Batch10_20250506_P306",
983
+ "rel_path": "Batch10/20250506_P306",
984
+ "score": 1.0
985
+ },
986
+ {
987
+ "type": "single",
988
+ "id": "Batch10_20250507_P543",
989
+ "rel_path": "Batch10/20250507_P543",
990
+ "score": 1.0
991
+ },
992
+ {
993
+ "type": "single",
994
+ "id": "Batch10_20250507_P623",
995
+ "rel_path": "Batch10/20250507_P623",
996
+ "score": 1.0
997
+ },
998
+ {
999
+ "type": "single",
1000
+ "id": "Batch10_20250506_P395",
1001
+ "rel_path": "Batch10/20250506_P395",
1002
+ "score": 1.0
1003
+ },
1004
+ {
1005
+ "type": "single",
1006
+ "id": "Batch10_20250508_P795",
1007
+ "rel_path": "Batch10/20250508_P795",
1008
+ "score": 1.0
1009
+ },
1010
+ {
1011
+ "type": "single",
1012
+ "id": "Batch10_20250509_P857",
1013
+ "rel_path": "Batch10/20250509_P857",
1014
+ "score": 1.0
1015
+ },
1016
+ {
1017
+ "type": "single",
1018
+ "id": "Batch10_20250503_P106",
1019
+ "rel_path": "Batch10/20250503_P106",
1020
+ "score": 1.0
1021
+ },
1022
+ {
1023
+ "type": "single",
1024
+ "id": "Batch10_20250506_P512",
1025
+ "rel_path": "Batch10/20250506_P512",
1026
+ "score": 1.0
1027
+ },
1028
+ {
1029
+ "type": "single",
1030
+ "id": "Batch10_20250512_P1507",
1031
+ "rel_path": "Batch10/20250512_P1507",
1032
+ "score": 1.0
1033
+ },
1034
+ {
1035
+ "type": "single",
1036
+ "id": "Batch10_20250512_P1459",
1037
+ "rel_path": "Batch10/20250512_P1459",
1038
+ "score": 1.0
1039
+ },
1040
+ {
1041
+ "type": "single",
1042
+ "id": "Batch10_20250507_P682",
1043
+ "rel_path": "Batch10/20250507_P682",
1044
+ "score": 1.0
1045
+ },
1046
+ {
1047
+ "type": "single",
1048
+ "id": "Batch10_20250508_P764",
1049
+ "rel_path": "Batch10/20250508_P764",
1050
+ "score": 1.0
1051
+ },
1052
+ {
1053
+ "type": "single",
1054
+ "id": "Batch10_20250502_P57",
1055
+ "rel_path": "Batch10/20250502_P57",
1056
+ "score": 1.0
1057
+ },
1058
+ {
1059
+ "type": "single",
1060
+ "id": "Batch10_20250510_P1189",
1061
+ "rel_path": "Batch10/20250510_P1189",
1062
+ "score": 1.0
1063
+ },
1064
+ {
1065
+ "type": "single",
1066
+ "id": "Batch10_20250506_P321",
1067
+ "rel_path": "Batch10/20250506_P321",
1068
+ "score": 1.0
1069
+ },
1070
+ {
1071
+ "type": "single",
1072
+ "id": "Batch10_20250511_P1225",
1073
+ "rel_path": "Batch10/20250511_P1225",
1074
+ "score": 1.0
1075
+ },
1076
+ {
1077
+ "type": "single",
1078
+ "id": "Batch10_20250508_P838",
1079
+ "rel_path": "Batch10/20250508_P838",
1080
+ "score": 1.0
1081
+ },
1082
+ {
1083
+ "type": "single",
1084
+ "id": "Batch10_20250502_P58",
1085
+ "rel_path": "Batch10/20250502_P58",
1086
+ "score": 1.0
1087
+ },
1088
+ {
1089
+ "type": "single",
1090
+ "id": "Batch10_20250511_P1245",
1091
+ "rel_path": "Batch10/20250511_P1245",
1092
+ "score": 1.0
1093
+ },
1094
+ {
1095
+ "type": "single",
1096
+ "id": "Batch10_20250501_P35",
1097
+ "rel_path": "Batch10/20250501_P35",
1098
+ "score": 1.0
1099
+ },
1100
+ {
1101
+ "type": "single",
1102
+ "id": "Batch10_20250504_P196",
1103
+ "rel_path": "Batch10/20250504_P196",
1104
+ "score": 1.0
1105
+ },
1106
+ {
1107
+ "type": "single",
1108
+ "id": "Batch10_20250512_P1497",
1109
+ "rel_path": "Batch10/20250512_P1497",
1110
+ "score": 1.0
1111
+ },
1112
+ {
1113
+ "type": "single",
1114
+ "id": "Batch10_20250508_P816",
1115
+ "rel_path": "Batch10/20250508_P816",
1116
+ "score": 1.0
1117
+ },
1118
+ {
1119
+ "type": "single",
1120
+ "id": "Batch10_20250512_P1420",
1121
+ "rel_path": "Batch10/20250512_P1420",
1122
+ "score": 1.0
1123
+ },
1124
+ {
1125
+ "type": "single",
1126
+ "id": "Batch10_20250506_P286",
1127
+ "rel_path": "Batch10/20250506_P286",
1128
+ "score": 1.0
1129
+ },
1130
+ {
1131
+ "type": "single",
1132
+ "id": "Batch10_20250507_P580",
1133
+ "rel_path": "Batch10/20250507_P580",
1134
+ "score": 1.0
1135
+ },
1136
+ {
1137
+ "type": "single",
1138
+ "id": "Batch10_20250502_P47",
1139
+ "rel_path": "Batch10/20250502_P47",
1140
+ "score": 1.0
1141
+ },
1142
+ {
1143
+ "type": "single",
1144
+ "id": "Batch10_20250505_P211",
1145
+ "rel_path": "Batch10/20250505_P211",
1146
+ "score": 1.0
1147
+ },
1148
+ {
1149
+ "type": "single",
1150
+ "id": "Batch10_20250509_P1031",
1151
+ "rel_path": "Batch10/20250509_P1031",
1152
+ "score": 1.0
1153
+ },
1154
+ {
1155
+ "type": "single",
1156
+ "id": "Batch10_20250502_P44",
1157
+ "rel_path": "Batch10/20250502_P44",
1158
+ "score": 1.0
1159
+ },
1160
+ {
1161
+ "type": "single",
1162
+ "id": "Batch10_20250506_P288",
1163
+ "rel_path": "Batch10/20250506_P288",
1164
+ "score": 1.0
1165
+ },
1166
+ {
1167
+ "type": "single",
1168
+ "id": "Batch10_20250509_P928",
1169
+ "rel_path": "Batch10/20250509_P928",
1170
+ "score": 1.0
1171
+ },
1172
+ {
1173
+ "type": "single",
1174
+ "id": "Batch10_20250503_P83",
1175
+ "rel_path": "Batch10/20250503_P83",
1176
+ "score": 1.0
1177
+ },
1178
+ {
1179
+ "type": "single",
1180
+ "id": "Batch10_20250506_P362",
1181
+ "rel_path": "Batch10/20250506_P362",
1182
+ "score": 1.0
1183
+ },
1184
+ {
1185
+ "type": "single",
1186
+ "id": "Batch10_20250503_P82",
1187
+ "rel_path": "Batch10/20250503_P82",
1188
+ "score": 1.0
1189
+ },
1190
+ {
1191
+ "type": "single",
1192
+ "id": "Batch10_20250506_P511",
1193
+ "rel_path": "Batch10/20250506_P511",
1194
+ "score": 1.0
1195
+ },
1196
+ {
1197
+ "type": "single",
1198
+ "id": "Batch10_20250507_P525",
1199
+ "rel_path": "Batch10/20250507_P525",
1200
+ "score": 1.0
1201
+ },
1202
+ {
1203
+ "type": "single",
1204
+ "id": "Batch10_20250512_P1302",
1205
+ "rel_path": "Batch10/20250512_P1302",
1206
+ "score": 1.0
1207
+ },
1208
+ {
1209
+ "type": "single",
1210
+ "id": "Batch10_20250506_P283",
1211
+ "rel_path": "Batch10/20250506_P283",
1212
+ "score": 1.0
1213
+ },
1214
+ {
1215
+ "type": "single",
1216
+ "id": "Batch10_20250508_P733",
1217
+ "rel_path": "Batch10/20250508_P733",
1218
+ "score": 1.0
1219
+ },
1220
+ {
1221
+ "type": "single",
1222
+ "id": "Batch10_20250506_P462",
1223
+ "rel_path": "Batch10/20250506_P462",
1224
+ "score": 1.0
1225
+ },
1226
+ {
1227
+ "type": "single",
1228
+ "id": "Batch10_20250511_P1215",
1229
+ "rel_path": "Batch10/20250511_P1215",
1230
+ "score": 1.0
1231
+ },
1232
+ {
1233
+ "type": "single",
1234
+ "id": "Batch10_20250504_P197",
1235
+ "rel_path": "Batch10/20250504_P197",
1236
+ "score": 1.0
1237
+ },
1238
+ {
1239
+ "type": "single",
1240
+ "id": "Batch10_20250501_P13",
1241
+ "rel_path": "Batch10/20250501_P13",
1242
+ "score": 1.0
1243
+ },
1244
+ {
1245
+ "type": "single",
1246
+ "id": "Batch10_20250506_P407",
1247
+ "rel_path": "Batch10/20250506_P407",
1248
+ "score": 1.0
1249
+ },
1250
+ {
1251
+ "type": "single",
1252
+ "id": "Batch10_20250506_P400",
1253
+ "rel_path": "Batch10/20250506_P400",
1254
+ "score": 1.0
1255
+ },
1256
+ {
1257
+ "type": "single",
1258
+ "id": "Batch10_20250508_P836",
1259
+ "rel_path": "Batch10/20250508_P836",
1260
+ "score": 1.0
1261
+ },
1262
+ {
1263
+ "type": "single",
1264
+ "id": "Batch10_20250506_P369",
1265
+ "rel_path": "Batch10/20250506_P369",
1266
+ "score": 1.0
1267
+ },
1268
+ {
1269
+ "type": "single",
1270
+ "id": "Batch10_20250507_P570",
1271
+ "rel_path": "Batch10/20250507_P570",
1272
+ "score": 1.0
1273
+ },
1274
+ {
1275
+ "type": "single",
1276
+ "id": "Batch10_20250505_P269",
1277
+ "rel_path": "Batch10/20250505_P269",
1278
+ "score": 1.0
1279
+ },
1280
+ {
1281
+ "type": "single",
1282
+ "id": "Batch10_20250505_P248",
1283
+ "rel_path": "Batch10/20250505_P248",
1284
+ "score": 1.0
1285
+ },
1286
+ {
1287
+ "type": "single",
1288
+ "id": "Batch10_20250501_P37",
1289
+ "rel_path": "Batch10/20250501_P37",
1290
+ "score": 1.0
1291
+ },
1292
+ {
1293
+ "type": "single",
1294
+ "id": "Batch10_20250512_P1492",
1295
+ "rel_path": "Batch10/20250512_P1492",
1296
+ "score": 1.0
1297
+ },
1298
+ {
1299
+ "type": "single",
1300
+ "id": "Batch10_20250507_P679",
1301
+ "rel_path": "Batch10/20250507_P679",
1302
+ "score": 1.0
1303
+ },
1304
+ {
1305
+ "type": "single",
1306
+ "id": "Batch10_20250506_P491",
1307
+ "rel_path": "Batch10/20250506_P491",
1308
+ "score": 1.0
1309
+ },
1310
+ {
1311
+ "type": "single",
1312
+ "id": "Batch10_20250506_P289",
1313
+ "rel_path": "Batch10/20250506_P289",
1314
+ "score": 1.0
1315
+ },
1316
+ {
1317
+ "type": "single",
1318
+ "id": "Batch10_20250509_P1001",
1319
+ "rel_path": "Batch10/20250509_P1001",
1320
+ "score": 1.0
1321
+ },
1322
+ {
1323
+ "type": "single",
1324
+ "id": "Batch10_20250507_P522",
1325
+ "rel_path": "Batch10/20250507_P522",
1326
+ "score": 1.0
1327
+ },
1328
+ {
1329
+ "type": "single",
1330
+ "id": "Batch10_20250509_P1034",
1331
+ "rel_path": "Batch10/20250509_P1034",
1332
+ "score": 1.0
1333
+ },
1334
+ {
1335
+ "type": "single",
1336
+ "id": "Batch10_20250505_P280",
1337
+ "rel_path": "Batch10/20250505_P280",
1338
+ "score": 1.0
1339
+ },
1340
+ {
1341
+ "type": "single",
1342
+ "id": "Batch10_20250501_P30",
1343
+ "rel_path": "Batch10/20250501_P30",
1344
+ "score": 1.0
1345
+ },
1346
+ {
1347
+ "type": "single",
1348
+ "id": "Batch10_20250512_P1411",
1349
+ "rel_path": "Batch10/20250512_P1411",
1350
+ "score": 1.0
1351
+ },
1352
+ {
1353
+ "type": "single",
1354
+ "id": "Batch10_20250503_P78",
1355
+ "rel_path": "Batch10/20250503_P78",
1356
+ "score": 1.0
1357
+ },
1358
+ {
1359
+ "type": "single",
1360
+ "id": "Batch10_20250512_P1404",
1361
+ "rel_path": "Batch10/20250512_P1404",
1362
+ "score": 1.0
1363
+ },
1364
+ {
1365
+ "type": "single",
1366
+ "id": "Batch10_20250501_P11",
1367
+ "rel_path": "Batch10/20250501_P11",
1368
+ "score": 1.0
1369
+ },
1370
+ {
1371
+ "type": "single",
1372
+ "id": "Batch10_20250512_P1363",
1373
+ "rel_path": "Batch10/20250512_P1363",
1374
+ "score": 1.0
1375
+ },
1376
+ {
1377
+ "type": "single",
1378
+ "id": "Batch10_20250502_P55",
1379
+ "rel_path": "Batch10/20250502_P55",
1380
+ "score": 1.0
1381
+ },
1382
+ {
1383
+ "type": "single",
1384
+ "id": "Batch10_20250506_P409",
1385
+ "rel_path": "Batch10/20250506_P409",
1386
+ "score": 1.0
1387
+ },
1388
+ {
1389
+ "type": "single",
1390
+ "id": "Batch10_20250502_P53",
1391
+ "rel_path": "Batch10/20250502_P53",
1392
+ "score": 1.0
1393
+ },
1394
+ {
1395
+ "type": "single",
1396
+ "id": "Batch10_20250510_P1197",
1397
+ "rel_path": "Batch10/20250510_P1197",
1398
+ "score": 1.0
1399
+ },
1400
+ {
1401
+ "type": "single",
1402
+ "id": "Batch10_20250512_P1454",
1403
+ "rel_path": "Batch10/20250512_P1454",
1404
+ "score": 1.0
1405
+ },
1406
+ {
1407
+ "type": "single",
1408
+ "id": "Batch10_20250512_P1406",
1409
+ "rel_path": "Batch10/20250512_P1406",
1410
+ "score": 1.0
1411
+ },
1412
+ {
1413
+ "type": "single",
1414
+ "id": "Batch10_20250504_P148",
1415
+ "rel_path": "Batch10/20250504_P148",
1416
+ "score": 1.0
1417
+ },
1418
+ {
1419
+ "type": "single",
1420
+ "id": "Batch10_20250507_P660",
1421
+ "rel_path": "Batch10/20250507_P660",
1422
+ "score": 1.0
1423
+ },
1424
+ {
1425
+ "type": "single",
1426
+ "id": "Batch10_20250509_P856",
1427
+ "rel_path": "Batch10/20250509_P856",
1428
+ "score": 1.0
1429
+ },
1430
+ {
1431
+ "type": "single",
1432
+ "id": "Batch10_20250501_P26",
1433
+ "rel_path": "Batch10/20250501_P26",
1434
+ "score": 1.0
1435
+ },
1436
+ {
1437
+ "type": "single",
1438
+ "id": "Batch10_20250506_P387",
1439
+ "rel_path": "Batch10/20250506_P387",
1440
+ "score": 1.0
1441
+ },
1442
+ {
1443
+ "type": "single",
1444
+ "id": "Batch10_20250508_P686",
1445
+ "rel_path": "Batch10/20250508_P686",
1446
+ "score": 1.0
1447
+ },
1448
+ {
1449
+ "type": "single",
1450
+ "id": "Batch10_20250509_P1003",
1451
+ "rel_path": "Batch10/20250509_P1003",
1452
+ "score": 1.0
1453
+ },
1454
+ {
1455
+ "type": "single",
1456
+ "id": "Batch10_20250512_P1496",
1457
+ "rel_path": "Batch10/20250512_P1496",
1458
+ "score": 1.0
1459
+ },
1460
+ {
1461
+ "type": "single",
1462
+ "id": "Batch10_20250508_P744",
1463
+ "rel_path": "Batch10/20250508_P744",
1464
+ "score": 1.0
1465
+ },
1466
+ {
1467
+ "type": "single",
1468
+ "id": "Batch10_20250512_P1512",
1469
+ "rel_path": "Batch10/20250512_P1512",
1470
+ "score": 1.0
1471
+ },
1472
+ {
1473
+ "type": "single",
1474
+ "id": "Batch10_20250504_P117",
1475
+ "rel_path": "Batch10/20250504_P117",
1476
+ "score": 1.0
1477
+ },
1478
+ {
1479
+ "type": "single",
1480
+ "id": "Batch10_20250505_P228",
1481
+ "rel_path": "Batch10/20250505_P228",
1482
+ "score": 1.0
1483
+ },
1484
+ {
1485
+ "type": "single",
1486
+ "id": "Batch10_20250512_P1509",
1487
+ "rel_path": "Batch10/20250512_P1509",
1488
+ "score": 1.0
1489
+ },
1490
+ {
1491
+ "type": "single",
1492
+ "id": "Batch10_20250512_P1327",
1493
+ "rel_path": "Batch10/20250512_P1327",
1494
+ "score": 1.0
1495
+ },
1496
+ {
1497
+ "type": "single",
1498
+ "id": "Batch10_20250511_P1223",
1499
+ "rel_path": "Batch10/20250511_P1223",
1500
+ "score": 1.0
1501
+ },
1502
+ {
1503
+ "type": "single",
1504
+ "id": "Batch10_20250507_P575",
1505
+ "rel_path": "Batch10/20250507_P575",
1506
+ "score": 1.0
1507
+ },
1508
+ {
1509
+ "type": "single",
1510
+ "id": "Batch10_20250509_P1032",
1511
+ "rel_path": "Batch10/20250509_P1032",
1512
+ "score": 1.0
1513
+ },
1514
+ {
1515
+ "type": "single",
1516
+ "id": "Batch10_20250511_P1217",
1517
+ "rel_path": "Batch10/20250511_P1217",
1518
+ "score": 1.0
1519
+ },
1520
+ {
1521
+ "type": "single",
1522
+ "id": "Batch10_20250506_P406",
1523
+ "rel_path": "Batch10/20250506_P406",
1524
+ "score": 1.0
1525
+ },
1526
+ {
1527
+ "type": "single",
1528
+ "id": "Batch10_20250509_P992",
1529
+ "rel_path": "Batch10/20250509_P992",
1530
+ "score": 1.0
1531
+ },
1532
+ {
1533
+ "type": "single",
1534
+ "id": "Batch10_20250501_P07",
1535
+ "rel_path": "Batch10/20250501_P07",
1536
+ "score": 1.0
1537
+ },
1538
+ {
1539
+ "type": "single",
1540
+ "id": "Batch10_20250509_P888",
1541
+ "rel_path": "Batch10/20250509_P888",
1542
+ "score": 1.0
1543
+ },
1544
+ {
1545
+ "type": "single",
1546
+ "id": "Batch10_20250509_P894",
1547
+ "rel_path": "Batch10/20250509_P894",
1548
+ "score": 1.0
1549
+ },
1550
+ {
1551
+ "type": "single",
1552
+ "id": "Batch10_20250507_P513",
1553
+ "rel_path": "Batch10/20250507_P513",
1554
+ "score": 1.0
1555
+ },
1556
+ {
1557
+ "type": "single",
1558
+ "id": "Batch10_20250508_P811",
1559
+ "rel_path": "Batch10/20250508_P811",
1560
+ "score": 1.0
1561
+ },
1562
+ {
1563
+ "type": "single",
1564
+ "id": "Batch10_20250508_P763",
1565
+ "rel_path": "Batch10/20250508_P763",
1566
+ "score": 1.0
1567
+ },
1568
+ {
1569
+ "type": "single",
1570
+ "id": "Batch10_20250507_P548",
1571
+ "rel_path": "Batch10/20250507_P548",
1572
+ "score": 1.0
1573
+ },
1574
+ {
1575
+ "type": "single",
1576
+ "id": "Batch10_20250507_P652",
1577
+ "rel_path": "Batch10/20250507_P652",
1578
+ "score": 1.0
1579
+ },
1580
+ {
1581
+ "type": "single",
1582
+ "id": "Batch10_20250512_P1403",
1583
+ "rel_path": "Batch10/20250512_P1403",
1584
+ "score": 1.0
1585
+ },
1586
+ {
1587
+ "type": "single",
1588
+ "id": "Batch10_20250506_P461",
1589
+ "rel_path": "Batch10/20250506_P461",
1590
+ "score": 1.0
1591
+ },
1592
+ {
1593
+ "type": "single",
1594
+ "id": "Batch10_20250511_P1224",
1595
+ "rel_path": "Batch10/20250511_P1224",
1596
+ "score": 1.0
1597
+ },
1598
+ {
1599
+ "type": "single",
1600
+ "id": "Batch10_20250512_P1465",
1601
+ "rel_path": "Batch10/20250512_P1465",
1602
+ "score": 1.0
1603
+ },
1604
+ {
1605
+ "type": "single",
1606
+ "id": "Batch10_20250504_P193",
1607
+ "rel_path": "Batch10/20250504_P193",
1608
+ "score": 1.0
1609
+ },
1610
+ {
1611
+ "type": "single",
1612
+ "id": "Batch10_20250507_P626",
1613
+ "rel_path": "Batch10/20250507_P626",
1614
+ "score": 1.0
1615
+ },
1616
+ {
1617
+ "type": "single",
1618
+ "id": "Batch10_20250506_P327",
1619
+ "rel_path": "Batch10/20250506_P327",
1620
+ "score": 1.0
1621
+ },
1622
+ {
1623
+ "type": "single",
1624
+ "id": "Batch10_20250508_P837",
1625
+ "rel_path": "Batch10/20250508_P837",
1626
+ "score": 1.0
1627
+ },
1628
+ {
1629
+ "type": "single",
1630
+ "id": "Batch10_20250507_P563",
1631
+ "rel_path": "Batch10/20250507_P563",
1632
+ "score": 1.0
1633
+ },
1634
+ {
1635
+ "type": "single",
1636
+ "id": "Batch10_20250512_P1344",
1637
+ "rel_path": "Batch10/20250512_P1344",
1638
+ "score": 1.0
1639
+ },
1640
+ {
1641
+ "type": "single",
1642
+ "id": "Batch10_20250501_P41",
1643
+ "rel_path": "Batch10/20250501_P41",
1644
+ "score": 1.0
1645
+ },
1646
+ {
1647
+ "type": "single",
1648
+ "id": "Batch10_20250508_P698",
1649
+ "rel_path": "Batch10/20250508_P698",
1650
+ "score": 1.0
1651
+ },
1652
+ {
1653
+ "type": "single",
1654
+ "id": "Batch10_20250504_P163",
1655
+ "rel_path": "Batch10/20250504_P163",
1656
+ "score": 1.0
1657
+ },
1658
+ {
1659
+ "type": "single",
1660
+ "id": "Batch10_20250506_P484",
1661
+ "rel_path": "Batch10/20250506_P484",
1662
+ "score": 1.0
1663
+ },
1664
+ {
1665
+ "type": "single",
1666
+ "id": "Batch10_20250512_P1396",
1667
+ "rel_path": "Batch10/20250512_P1396",
1668
+ "score": 1.0
1669
+ },
1670
+ {
1671
+ "type": "single",
1672
+ "id": "Batch10_20250510_P1173",
1673
+ "rel_path": "Batch10/20250510_P1173",
1674
+ "score": 1.0
1675
+ },
1676
+ {
1677
+ "type": "single",
1678
+ "id": "Batch10_20250507_P583",
1679
+ "rel_path": "Batch10/20250507_P583",
1680
+ "score": 1.0
1681
+ },
1682
+ {
1683
+ "type": "single",
1684
+ "id": "Batch10_20250503_P108",
1685
+ "rel_path": "Batch10/20250503_P108",
1686
+ "score": 1.0
1687
+ }
1688
+ ]
config.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 超声提示多标签分类模型配置文件
2
+ # TransMIL + Query2Label Hybrid Model
3
+
4
+ data:
5
+ # 【需要修改】数据根目录(包含 Report_XXX 文件夹的目录)
6
+ data_root: "/path/to/your/ReportData_ROI/"
7
+ # 【需要修改】多标签注释 CSV 文件路径
8
+ annotation_csv: "/path/to/your/ReportData_ROI/thyroid_multilabel_annotations.csv"
9
+ # 【需要修改】验证集 JSON 文件路径
10
+ val_json: "/path/to/your/ReportData_ROI/classification_val_set_single.json"
11
+ # 【需要修改】测试集 JSON 文件路径
12
+ test_json: "/path/to/your/ReportData_ROI/classification_test_set_single.json"
13
+ img_size: 224
14
+ max_images_per_case: 20
15
+ num_workers: 8
16
+
17
+ model:
18
+ num_class: 17 # 17类标签(已删除"切除术后")
19
+ hidden_dim: 512
20
+ nheads: 8
21
+ num_decoder_layers: 2
22
+ pretrained_resnet: True
23
+ use_ppeg: False
24
+
25
+ training:
26
+ batch_size: 4
27
+ epochs: 50
28
+ lr: 0.0001
29
+ weight_decay: 0.0001
30
+ optimizer: "AdamW"
31
+
32
+ # Asymmetric Loss 参数(处理多标签不平衡)
33
+ gamma_neg: 4
34
+ gamma_pos: 1
35
+ clip: 0.05
36
+
37
+ # 内存优化策略
38
+ use_amp: true # 混合精度训练
39
+ gradient_accumulation_steps: 4 # 有效 batch_size = 4 * 4 = 16
40
+ gradient_checkpointing: true
41
+
42
+ # 学习率调度器
43
+ scheduler: "cosine"
44
+ warmup_epochs: 5
45
+
46
+ # 模型保存
47
+ save_dir: "checkpoints/"
48
+ save_freq: 5
evaluate.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ from sklearn.metrics import (
8
+ average_precision_score,
9
+ roc_auc_score,
10
+ f1_score,
11
+ precision_score,
12
+ recall_score,
13
+ accuracy_score
14
+ )
15
+
16
+ # 引入你的模型和数据加载器
17
+ from models.transmil_q2l import TransMIL_Query2Label_E2E
18
+ from thyroid_dataset import create_dataloaders, TARGET_CLASSES
19
+ '''
20
+ # 18类标签定义 (与训练时保持一致)
21
+ TARGET_CLASSES = [
22
+ "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
23
+ "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
24
+ "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
25
+ "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
26
+ ]
27
+ '''
28
+
29
+ def get_best_checkpoint_path(save_dir):
30
+ """自动寻找 best checkpoint"""
31
+ best_path = os.path.join(save_dir, 'checkpoint_best.pth')
32
+ if os.path.exists(best_path):
33
+ return best_path
34
+ # 如果没找到 best,找 latest
35
+ latest_path = os.path.join(save_dir, 'checkpoint_latest.pth')
36
+ if os.path.exists(latest_path):
37
+ print(f"Warning: 'checkpoint_best.pth' not found. Using '{latest_path}' instead.")
38
+ return latest_path
39
+ raise FileNotFoundError(f"No checkpoints found in {save_dir}")
40
+
41
+ def compute_metrics(y_true, y_pred_probs, threshold=0.5):
42
+ """
43
+ 计算全面的多标签指标
44
+ y_true: [N, num_classes] (0 or 1)
45
+ y_pred_probs: [N, num_classes] (0.0 ~ 1.0)
46
+ """
47
+ metrics = {}
48
+
49
+ # 1. 二值化预测
50
+ y_pred_binary = (y_pred_probs >= threshold).astype(int)
51
+
52
+ # 2. 全局指标 (Global Metrics)
53
+ # mAP (mean Average Precision) - 最重要的多标签指标
54
+ metrics['mAP'] = average_precision_score(y_true, y_pred_probs, average='macro')
55
+ metrics['weighted_mAP'] = average_precision_score(y_true, y_pred_probs, average='weighted')
56
+
57
+ # AUROC (Macro & Micro)
58
+ try:
59
+ metrics['macro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='macro')
60
+ metrics['micro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='micro')
61
+ except ValueError:
62
+ metrics['macro_auroc'] = 0.0
63
+ metrics['micro_auroc'] = 0.0
64
+
65
+ # F1 Score
66
+ metrics['micro_f1'] = f1_score(y_true, y_pred_binary, average='micro')
67
+ metrics['macro_f1'] = f1_score(y_true, y_pred_binary, average='macro')
68
+
69
+ # Exact Match Ratio (Subset Accuracy) - 全对才算对
70
+ metrics['subset_accuracy'] = accuracy_score(y_true, y_pred_binary)
71
+
72
+ # 3. 每类详细指标 (Per-class Metrics)
73
+ class_metrics = []
74
+ for i, class_name in enumerate(TARGET_CLASSES):
75
+ # 提取当前类的真实标签和预测概率
76
+ yt = y_true[:, i]
77
+ yp = y_pred_probs[:, i]
78
+ yb = y_pred_binary[:, i]
79
+
80
+ # 样本数
81
+ support = int(yt.sum())
82
+
83
+ # 如果该类没有正样本,部分指标无法计算
84
+ if support > 0:
85
+ ap = average_precision_score(yt, yp)
86
+ try:
87
+ auroc = roc_auc_score(yt, yp)
88
+ except ValueError:
89
+ auroc = 0.5 # 只有一个类别存在时无法计算AUC
90
+
91
+ f1 = f1_score(yt, yb)
92
+ rec = recall_score(yt, yb)
93
+ prec = precision_score(yt, yb, zero_division=0)
94
+ else:
95
+ ap, auroc, f1, rec, prec = 0.0, 0.5, 0.0, 0.0, 0.0
96
+
97
+ class_metrics.append({
98
+ "Class": class_name,
99
+ "Support": support,
100
+ "AP": ap,
101
+ "AUROC": auroc,
102
+ "F1": f1,
103
+ "Precision": prec,
104
+ "Recall": rec
105
+ })
106
+
107
+ return metrics, pd.DataFrame(class_metrics)
108
+
109
+ def main():
110
+ # 1. 加载配置
111
+ config_path = 'config.yaml' # 确保这里路径正确
112
+ with open(config_path, 'r') as f:
113
+ config = yaml.safe_load(f)
114
+
115
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
+ print(f"Evaluating on {device}")
117
+
118
+ # 2. 准备数据加载器
119
+ print("Loading Test Data...")
120
+ _, _, test_loader = create_dataloaders(config)
121
+
122
+ # 3. 初始化模型
123
+ print("Initializing Model...")
124
+ model = TransMIL_Query2Label_E2E(
125
+ num_class=config['model']['num_class'],
126
+ hidden_dim=config['model']['hidden_dim'],
127
+ nheads=config['model']['nheads'],
128
+ num_decoder_layers=config['model']['num_decoder_layers'],
129
+ pretrained_resnet=False, # 推理时不需要下载预训练权重,直接加载我们自己的权重
130
+ use_checkpointing=False, # 推理时不需要 checkpointing
131
+ use_ppeg=config['model'].get('use_ppeg', False)
132
+ )
133
+
134
+ # 4. 加载权重
135
+ ckpt_path = get_best_checkpoint_path(config['training']['save_dir'])
136
+ print(f"Loading checkpoint from: {ckpt_path}")
137
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
138
+
139
+ # 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
140
+ state_dict = checkpoint['model_state_dict']
141
+ new_state_dict = {}
142
+ for k, v in state_dict.items():
143
+ name = k.replace("module.", "")
144
+ new_state_dict[name] = v
145
+ model.load_state_dict(new_state_dict)
146
+
147
+ model.to(device)
148
+ model.eval()
149
+
150
+ # 5. 推理循环
151
+ print("Running Inference...")
152
+ all_preds = []
153
+ all_targets = []
154
+
155
+ with torch.no_grad():
156
+ for batch in tqdm(test_loader):
157
+ images = batch['images'].to(device)
158
+ num_instances = batch['num_instances_per_case']
159
+ labels = batch['labels'].numpy() # CPU numpy
160
+
161
+ # Forward
162
+ logits = model(images, num_instances)
163
+ probs = torch.sigmoid(logits).cpu().numpy()
164
+
165
+ all_preds.append(probs)
166
+ all_targets.append(labels)
167
+
168
+ # 拼接
169
+ y_pred_probs = np.concatenate(all_preds, axis=0)
170
+ y_true = np.concatenate(all_targets, axis=0)
171
+
172
+ # 6. 计算指标
173
+ print("\nComputing Metrics...")
174
+ global_metrics, class_df = compute_metrics(y_true, y_pred_probs)
175
+
176
+ # 7. 打印结果
177
+ print("\n" + "="*60)
178
+ print(" GLOBAL PERFORMANCE SUMMARY ")
179
+ print("="*60)
180
+ print(f" mAP (Macro) : {global_metrics['mAP']:.4f}")
181
+ print(f" mAP (Weighted): {global_metrics['weighted_mAP']:.4f}")
182
+ print(f" AUROC (Macro) : {global_metrics['macro_auroc']:.4f}")
183
+ print(f" AUROC (Micro) : {global_metrics['micro_auroc']:.4f}")
184
+ print(f" F1 (Micro) : {global_metrics['micro_f1']:.4f}")
185
+ print(f" F1 (Macro) : {global_metrics['macro_f1']:.4f}")
186
+ print(f" Subset Acc : {global_metrics['subset_accuracy']:.4f}")
187
+ print("-" * 60)
188
+
189
+ print("\n" + "="*100)
190
+ print(" PER-CLASS PERFORMANCE DETAILS (Sorted by Support) ")
191
+ print("="*100)
192
+
193
+ # 按样本数量排序
194
+ class_df = class_df.sort_values(by='Support', ascending=False)
195
+
196
+ # --- 开始修改:手动格式化打印 ---
197
+ # 定义表头
198
+ # 中文字符宽度处理技巧:给 Class 列预留足够大的空间 (比如30)
199
+ # {:<N} 左对齐, {:>N} 右对齐
200
+
201
+ headers = ["Class", "Support", "AP", "AUROC", "F1", "Precision", "Recall"]
202
+
203
+ # 打印表头
204
+ # {0:<24} 表示第一列左对齐占24格
205
+ head_fmt = "{:<24} {:>8} {:>10} {:>10} {:>10} {:>12} {:>10}"
206
+ print(head_fmt.format(*headers))
207
+ print("-" * 100)
208
+
209
+ # 打印每一行
210
+ row_fmt = "{:<24} {:>8d} {:>10.4f} {:>10.4f} {:>10.4f} {:>12.4f} {:>10.4f}"
211
+
212
+ for _, row in class_df.iterrows():
213
+ cls_name = row['Class']
214
+
215
+
216
+ display_width = len(cls_name.encode('gbk'))
217
+
218
+ # 计算需要填充的空格数
219
+ # 目标宽度 24 - 实际显示宽度
220
+ target_width = 24
221
+ padding = target_width - display_width
222
+
223
+ # 构造对齐后的字符串
224
+ aligned_name = cls_name + " " * padding
225
+
226
+ print(f"{aligned_name} {int(row['Support']):>8d} {row['AP']:>10.4f} {row['AUROC']:>10.4f} {row['F1']:>10.4f} {row['Precision']:>12.4f} {row['Recall']:>10.4f}")
227
+
228
+ print("="*100)
229
+
230
+ # 保存结果到 CSV
231
+ result_csv = os.path.join(config['training']['save_dir'], 'evaluation_report.csv')
232
+ class_df.to_csv(result_csv, index=False, encoding='utf-8-sig')
233
+ print(f"\nDetailed report saved to: {result_csv}")
234
+
235
+ if __name__ == "__main__":
236
+ main()
infer_single_case.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 单步推理脚本 - 超声提示多标签分类模型
3
+ Single Case Inference for TransMIL + Query2Label Hybrid Model
4
+
5
+ 用法:
6
+ # 指定多个图像文件
7
+ python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
8
+
9
+ # 指定图像文件夹
10
+ python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import argparse
16
+ import torch
17
+ import numpy as np
18
+ from PIL import Image
19
+ from torchvision import transforms
20
+
21
+ # 添加当前目录到路径
22
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ from models.transmil_q2l import TransMIL_Query2Label_E2E
25
+
26
+ # 17类标签定义
27
+ TARGET_CLASSES = [
28
+ "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
29
+ "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
30
+ "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
31
+ "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
32
+ ]
33
+
34
+
35
+ def load_model(checkpoint_path: str, device: torch.device):
36
+ """加载预训练模型"""
37
+ print(f"Loading model from: {checkpoint_path}")
38
+
39
+ # 初始化模型
40
+ model = TransMIL_Query2Label_E2E(
41
+ num_class=17,
42
+ hidden_dim=512,
43
+ nheads=8,
44
+ num_decoder_layers=2,
45
+ pretrained_resnet=False, # 推理时不需要下载预训练权重
46
+ use_checkpointing=False, # 推理时不需要 checkpointing
47
+ use_ppeg=False
48
+ )
49
+
50
+ # 加载权重
51
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
52
+ state_dict = checkpoint['model_state_dict']
53
+
54
+ # 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
55
+ new_state_dict = {}
56
+ for k, v in state_dict.items():
57
+ name = k.replace("module.", "")
58
+ new_state_dict[name] = v
59
+ model.load_state_dict(new_state_dict)
60
+
61
+ model.to(device)
62
+ model.eval()
63
+ print("Model loaded successfully!")
64
+ return model
65
+
66
+
67
+ def preprocess_images(image_paths: list, img_size: int = 224):
68
+ """预处理图像"""
69
+ transform = transforms.Compose([
70
+ transforms.Resize((img_size, img_size)),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
73
+ std=[0.229, 0.224, 0.225])
74
+ ])
75
+
76
+ images = []
77
+ valid_paths = []
78
+
79
+ for path in image_paths:
80
+ if not os.path.exists(path):
81
+ print(f"Warning: Image not found: {path}")
82
+ continue
83
+ try:
84
+ img = Image.open(path).convert('RGB')
85
+ img_tensor = transform(img)
86
+ images.append(img_tensor)
87
+ valid_paths.append(path)
88
+ except Exception as e:
89
+ print(f"Warning: Failed to load image {path}: {e}")
90
+ continue
91
+
92
+ if len(images) == 0:
93
+ raise ValueError("No valid images found!")
94
+
95
+ # Stack to batch: [N, C, H, W] - 模型期望直接的图像堆叠,不需要额外的batch维度
96
+ images_batch = torch.stack(images, dim=0)
97
+
98
+ return images_batch, valid_paths
99
+
100
+
101
+ def predict(model, images_batch: torch.Tensor, num_images: int,
102
+ device: torch.device, threshold: float = 0.5):
103
+ """执行推理"""
104
+ images_batch = images_batch.to(device)
105
+
106
+ with torch.no_grad():
107
+ # Forward pass
108
+ logits = model(images_batch, [num_images])
109
+ probs = torch.sigmoid(logits).cpu().numpy()[0] # [num_class]
110
+
111
+ # 根据阈值获取预测标签
112
+ predictions = (probs >= threshold).astype(int)
113
+
114
+ return probs, predictions
115
+
116
+
117
+ def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float):
118
+ """格式化输出结果"""
119
+ print("\n" + "=" * 60)
120
+ print(" 超声提示多标签分类结果")
121
+ print("=" * 60)
122
+ print(f" 阈值 (Threshold): {threshold}")
123
+ print("-" * 60)
124
+
125
+ # 按概率排序
126
+ sorted_indices = np.argsort(probs)[::-1]
127
+
128
+ print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}")
129
+ print("-" * 40)
130
+
131
+ predicted_labels = []
132
+ for idx in sorted_indices:
133
+ class_name = TARGET_CLASSES[idx]
134
+ prob = probs[idx]
135
+ pred = "✓" if predictions[idx] == 1 else ""
136
+
137
+ # 使用 GBK 编码计算显示宽度
138
+ try:
139
+ display_width = len(class_name.encode('gbk'))
140
+ except:
141
+ display_width = len(class_name) * 2
142
+
143
+ padding = 20 - display_width
144
+ aligned_name = class_name + " " * max(0, padding)
145
+
146
+ print(f"{aligned_name} {prob:>10.4f} {pred:>8}")
147
+
148
+ if predictions[idx] == 1:
149
+ predicted_labels.append(class_name)
150
+
151
+ print("\n" + "=" * 60)
152
+ print(" 预测标签汇总")
153
+ print("=" * 60)
154
+
155
+ if predicted_labels:
156
+ for label in predicted_labels:
157
+ print(f" • {label}")
158
+ else:
159
+ print(" 无预测标签(所有类别概率均低于阈值)")
160
+
161
+ print("=" * 60 + "\n")
162
+
163
+ return predicted_labels
164
+
165
+
166
+ def main():
167
+ parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理')
168
+ parser.add_argument('--images', nargs='*', default=None,
169
+ help='图像路径列表 (支持多个图像)')
170
+ parser.add_argument('--image_dir', type=str, default=None,
171
+ help='图像文件夹路径 (自动加载文件夹内所有图像)')
172
+ parser.add_argument('--checkpoint', type=str,
173
+ default='checkpoints/checkpoint_best.pth',
174
+ help='模型权重路径')
175
+ parser.add_argument('--threshold', type=float, default=0.5,
176
+ help='分类阈值 (default: 0.5)')
177
+ parser.add_argument('--device', type=str, default='auto',
178
+ help='设备: auto, cuda, cpu')
179
+
180
+ args = parser.parse_args()
181
+
182
+ # 收集图像路径
183
+ image_paths = []
184
+
185
+ # 从 --images 参数收集
186
+ if args.images:
187
+ image_paths.extend(args.images)
188
+
189
+ # 从 --image_dir 参数收集
190
+ if args.image_dir:
191
+ if not os.path.isdir(args.image_dir):
192
+ print(f"Error: Image directory not found: {args.image_dir}")
193
+ sys.exit(1)
194
+
195
+ # 支持的图像格式
196
+ image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
197
+
198
+ for filename in sorted(os.listdir(args.image_dir)):
199
+ ext = os.path.splitext(filename)[1].lower()
200
+ if ext in image_extensions:
201
+ image_paths.append(os.path.join(args.image_dir, filename))
202
+
203
+ print(f"Found {len(image_paths)} images in {args.image_dir}")
204
+
205
+ # 检查是否有图像输入
206
+ if not image_paths:
207
+ print("Error: No images specified. Use --images or --image_dir")
208
+ parser.print_help()
209
+ sys.exit(1)
210
+
211
+ # 设置设备
212
+ if args.device == 'auto':
213
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
214
+ else:
215
+ device = torch.device(args.device)
216
+ print(f"Using device: {device}")
217
+
218
+ # 处理相对路径
219
+ script_dir = os.path.dirname(os.path.abspath(__file__))
220
+ checkpoint_path = args.checkpoint
221
+ if not os.path.isabs(checkpoint_path):
222
+ checkpoint_path = os.path.join(script_dir, checkpoint_path)
223
+
224
+ # 加载模型
225
+ model = load_model(checkpoint_path, device)
226
+
227
+ # 预处理图像
228
+ print(f"\nProcessing {len(image_paths)} image(s)...")
229
+ images_batch, valid_paths = preprocess_images(image_paths)
230
+ print(f"Successfully loaded {len(valid_paths)} image(s)")
231
+
232
+ # 推理
233
+ probs, predictions = predict(model, images_batch, len(valid_paths),
234
+ device, args.threshold)
235
+
236
+ # 输出结果
237
+ predicted_labels = format_results(probs, predictions, args.threshold)
238
+
239
+ # 返回预测标签列表(供程序调用)
240
+ return predicted_labels
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
models/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Thyroid Ultrasound Hint Multi-Label Classification
2
+
3
+ from models.transmil_q2l import TransMIL_Query2Label_E2E
4
+ from models.aslloss import AsymmetricLoss, AsymmetricLossOptimized
5
+ from models.transformer import TransformerDecoder, TransformerDecoderLayer
6
+
7
+ __all__ = [
8
+ 'TransMIL_Query2Label_E2E',
9
+ 'AsymmetricLoss',
10
+ 'AsymmetricLossOptimized',
11
+ 'TransformerDecoder',
12
+ 'TransformerDecoderLayer',
13
+ ]
models/aslloss.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Most borrow from: https://github.com/Alibaba-MIIL/ASL
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class AsymmetricLoss(nn.Module):
9
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
10
+ super(AsymmetricLoss, self).__init__()
11
+
12
+ self.gamma_neg = gamma_neg
13
+ self.gamma_pos = gamma_pos
14
+ self.clip = clip
15
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
16
+ self.eps = eps
17
+
18
+ def forward(self, x, y):
19
+ """"
20
+ Parameters
21
+ ----------
22
+ x: input logits
23
+ y: targets (multi-label binarized vector)
24
+ """
25
+
26
+ # Calculating Probabilities
27
+ x_sigmoid = torch.sigmoid(x)
28
+ xs_pos = x_sigmoid
29
+ xs_neg = 1 - x_sigmoid
30
+
31
+ # Asymmetric Clipping
32
+ if self.clip is not None and self.clip > 0:
33
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
34
+
35
+ # Basic CE calculation
36
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps, max=1-self.eps))
37
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps, max=1-self.eps))
38
+ loss = los_pos + los_neg
39
+
40
+ # Asymmetric Focusing
41
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
42
+ if self.disable_torch_grad_focal_loss:
43
+ torch._C.set_grad_enabled(False)
44
+ pt0 = xs_pos * y
45
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
46
+ pt = pt0 + pt1
47
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
48
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
49
+ if self.disable_torch_grad_focal_loss:
50
+ torch._C.set_grad_enabled(True)
51
+ loss *= one_sided_w
52
+
53
+ return -loss.sum()
54
+
55
+ class AsymmetricLossOptimized(nn.Module):
56
+ ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
57
+ favors inplace operations'''
58
+
59
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-5, disable_torch_grad_focal_loss=False):
60
+ super(AsymmetricLossOptimized, self).__init__()
61
+
62
+ self.gamma_neg = gamma_neg
63
+ self.gamma_pos = gamma_pos
64
+ self.clip = clip
65
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
66
+ self.eps = eps
67
+
68
+ self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
69
+
70
+ def forward(self, x, y):
71
+ """"
72
+ Parameters
73
+ ----------
74
+ x: input logits
75
+ y: targets (multi-label binarized vector)
76
+ """
77
+
78
+ self.targets = y
79
+ self.anti_targets = 1 - y
80
+
81
+ # Calculating Probabilities
82
+ self.xs_pos = torch.sigmoid(x)
83
+ self.xs_neg = 1.0 - self.xs_pos
84
+
85
+ # Asymmetric Clipping
86
+ if self.clip is not None and self.clip > 0:
87
+ self.xs_neg.add_(self.clip).clamp_(max=1)
88
+
89
+ # Basic CE calculation
90
+ self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
91
+ self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
92
+
93
+ # Asymmetric Focusing
94
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
95
+ if self.disable_torch_grad_focal_loss:
96
+ with torch.no_grad():
97
+ # if self.disable_torch_grad_focal_loss:
98
+ # torch._C.set_grad_enabled(False)
99
+ self.xs_pos = self.xs_pos * self.targets
100
+ self.xs_neg = self.xs_neg * self.anti_targets
101
+ self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
102
+ self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
103
+ # if self.disable_torch_grad_focal_loss:
104
+ # torch._C.set_grad_enabled(True)
105
+ self.loss *= self.asymmetric_w
106
+ else:
107
+ self.xs_pos = self.xs_pos * self.targets
108
+ self.xs_neg = self.xs_neg * self.anti_targets
109
+ self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
110
+ self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
111
+ self.loss *= self.asymmetric_w
112
+ _loss = - self.loss.sum() / x.size(0)
113
+ _loss = _loss / y.size(1) * 1000
114
+
115
+ return _loss
models/transformer.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Q2L Transformer class.
4
+
5
+ Most borrow from DETR except:
6
+ * remove self-attention by default.
7
+
8
+ Copy-paste from torch.nn.Transformer with modifications:
9
+ * positional encodings are passed in MHattention
10
+ * extra LN at the end of encoder is removed
11
+ * decoder returns a stack of activations from all decoding layers
12
+ * using modified multihead attention from nn_multiheadattention.py
13
+ """
14
+ import copy
15
+ from typing import Optional, List
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn, Tensor
20
+ from torch.nn import MultiheadAttention
21
+
22
+
23
+
24
+ class Transformer(nn.Module):
25
+
26
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
27
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
28
+ activation="relu", normalize_before=False,
29
+ return_intermediate_dec=False,
30
+ rm_self_attn_dec=True, rm_first_self_attn=True,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.num_encoder_layers = num_encoder_layers
35
+ if num_decoder_layers > 0:
36
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
37
+ dropout, activation, normalize_before)
38
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
39
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
40
+
41
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
42
+ dropout, activation, normalize_before)
43
+ decoder_norm = nn.LayerNorm(d_model)
44
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
45
+ return_intermediate=return_intermediate_dec)
46
+
47
+ self._reset_parameters()
48
+
49
+ self.d_model = d_model
50
+ self.nhead = nhead
51
+ self.rm_self_attn_dec = rm_self_attn_dec
52
+ self.rm_first_self_attn = rm_first_self_attn
53
+
54
+ if self.rm_self_attn_dec or self.rm_first_self_attn:
55
+ self.rm_self_attn_dec_func()
56
+
57
+ # self.debug_mode = False
58
+ # self.set_debug_mode(self.debug_mode)
59
+
60
+ def rm_self_attn_dec_func(self):
61
+ total_modifie_layer_num = 0
62
+ rm_list = []
63
+ for idx, layer in enumerate(self.decoder.layers):
64
+ if idx == 0 and not self.rm_first_self_attn:
65
+ continue
66
+ if idx != 0 and not self.rm_self_attn_dec:
67
+ continue
68
+
69
+ layer.omit_selfattn = True
70
+ del layer.self_attn
71
+ del layer.dropout1
72
+ del layer.norm1
73
+
74
+ total_modifie_layer_num += 1
75
+ rm_list.append(idx)
76
+ # remove some self-attention layer
77
+ # print("rm {} layer: {}".format(total_modifie_layer_num, rm_list))
78
+
79
+ def set_debug_mode(self, status):
80
+ print("set debug mode to {}!!!".format(status))
81
+ self.debug_mode = status
82
+ if hasattr(self, 'encoder'):
83
+ for idx, layer in enumerate(self.encoder.layers):
84
+ layer.debug_mode = status
85
+ layer.debug_name = str(idx)
86
+ if hasattr(self, 'decoder'):
87
+ for idx, layer in enumerate(self.decoder.layers):
88
+ layer.debug_mode = status
89
+ layer.debug_name = str(idx)
90
+
91
+
92
+ def _reset_parameters(self):
93
+ for p in self.parameters():
94
+ if p.dim() > 1:
95
+ nn.init.xavier_uniform_(p)
96
+
97
+ def forward(self, src, query_embed, pos_embed, mask=None):
98
+ # flatten NxCxHxW to HWxNxC
99
+ bs, c, h, w = src.shape
100
+ src = src.flatten(2).permute(2, 0, 1)
101
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
102
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
103
+ if mask is not None:
104
+ mask = mask.flatten(1)
105
+
106
+
107
+ if self.num_encoder_layers > 0:
108
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
109
+ else:
110
+ memory = src
111
+
112
+ tgt = torch.zeros_like(query_embed)
113
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
114
+ pos=pos_embed, query_pos=query_embed)
115
+
116
+ return hs.transpose(1, 2), memory[:h*w].permute(1, 2, 0).view(bs, c, h, w)
117
+
118
+
119
+ class TransformerEncoder(nn.Module):
120
+
121
+ def __init__(self, encoder_layer, num_layers, norm=None):
122
+ super().__init__()
123
+ self.layers = _get_clones(encoder_layer, num_layers)
124
+ self.num_layers = num_layers
125
+ self.norm = norm
126
+
127
+ def forward(self, src,
128
+ mask: Optional[Tensor] = None,
129
+ src_key_padding_mask: Optional[Tensor] = None,
130
+ pos: Optional[Tensor] = None):
131
+ output = src
132
+
133
+ for layer in self.layers:
134
+ output = layer(output, src_mask=mask,
135
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
136
+
137
+ if self.norm is not None:
138
+ output = self.norm(output)
139
+
140
+ return output
141
+
142
+
143
+ class TransformerDecoder(nn.Module):
144
+
145
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
146
+ super().__init__()
147
+ self.layers = _get_clones(decoder_layer, num_layers)
148
+ self.num_layers = num_layers
149
+ self.norm = norm
150
+ self.return_intermediate = return_intermediate
151
+
152
+ def forward(self, tgt, memory,
153
+ tgt_mask: Optional[Tensor] = None,
154
+ memory_mask: Optional[Tensor] = None,
155
+ tgt_key_padding_mask: Optional[Tensor] = None,
156
+ memory_key_padding_mask: Optional[Tensor] = None,
157
+ pos: Optional[Tensor] = None,
158
+ query_pos: Optional[Tensor] = None):
159
+ output = tgt
160
+
161
+ intermediate = []
162
+
163
+ for layer in self.layers:
164
+ output = layer(output, memory, tgt_mask=tgt_mask,
165
+ memory_mask=memory_mask,
166
+ tgt_key_padding_mask=tgt_key_padding_mask,
167
+ memory_key_padding_mask=memory_key_padding_mask,
168
+ pos=pos, query_pos=query_pos)
169
+ if self.return_intermediate:
170
+ intermediate.append(self.norm(output))
171
+
172
+ if self.norm is not None:
173
+ output = self.norm(output)
174
+ if self.return_intermediate:
175
+ intermediate.pop()
176
+ intermediate.append(output)
177
+
178
+ if self.return_intermediate:
179
+ return torch.stack(intermediate)
180
+
181
+ return output.unsqueeze(0)
182
+
183
+
184
+ class TransformerEncoderLayer(nn.Module):
185
+
186
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
187
+ activation="relu", normalize_before=False):
188
+ super().__init__()
189
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
190
+ # Implementation of Feedforward model
191
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
192
+ self.dropout = nn.Dropout(dropout)
193
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
194
+
195
+ self.norm1 = nn.LayerNorm(d_model)
196
+ self.norm2 = nn.LayerNorm(d_model)
197
+ self.dropout1 = nn.Dropout(dropout)
198
+ self.dropout2 = nn.Dropout(dropout)
199
+
200
+ self.activation = _get_activation_fn(activation)
201
+ self.normalize_before = normalize_before
202
+
203
+ self.debug_mode = False
204
+ self.debug_name = None
205
+
206
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
207
+ return tensor if pos is None else tensor + pos
208
+
209
+ def forward_post(self,
210
+ src,
211
+ src_mask: Optional[Tensor] = None,
212
+ src_key_padding_mask: Optional[Tensor] = None,
213
+ pos: Optional[Tensor] = None):
214
+ q = k = self.with_pos_embed(src, pos)
215
+ src2, corr = self.self_attn(q, k, value=src, attn_mask=src_mask,
216
+ key_padding_mask=src_key_padding_mask)
217
+
218
+
219
+ src = src + self.dropout1(src2)
220
+ src = self.norm1(src)
221
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
222
+ src = src + self.dropout2(src2)
223
+ src = self.norm2(src)
224
+ return src
225
+
226
+ def forward_pre(self, src,
227
+ src_mask: Optional[Tensor] = None,
228
+ src_key_padding_mask: Optional[Tensor] = None,
229
+ pos: Optional[Tensor] = None):
230
+ src2 = self.norm1(src)
231
+ q = k = self.with_pos_embed(src2, pos)
232
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
233
+ key_padding_mask=src_key_padding_mask)[0]
234
+
235
+ src = src + self.dropout1(src2)
236
+ src2 = self.norm2(src)
237
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
238
+ src = src + self.dropout2(src2)
239
+ return src
240
+
241
+ def forward(self, src,
242
+ src_mask: Optional[Tensor] = None,
243
+ src_key_padding_mask: Optional[Tensor] = None,
244
+ pos: Optional[Tensor] = None):
245
+ if self.normalize_before:
246
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
247
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
248
+
249
+
250
+ class TransformerDecoderLayer(nn.Module):
251
+
252
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
253
+ activation="relu", normalize_before=False):
254
+ super().__init__()
255
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
256
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
257
+ # Implementation of Feedforward model
258
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
259
+ self.dropout = nn.Dropout(dropout)
260
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
261
+
262
+ self.norm1 = nn.LayerNorm(d_model)
263
+ self.norm2 = nn.LayerNorm(d_model)
264
+ self.norm3 = nn.LayerNorm(d_model)
265
+ self.dropout1 = nn.Dropout(dropout)
266
+ self.dropout2 = nn.Dropout(dropout)
267
+ self.dropout3 = nn.Dropout(dropout)
268
+
269
+ self.activation = _get_activation_fn(activation)
270
+ self.normalize_before = normalize_before
271
+
272
+ self.debug_mode = False
273
+ self.debug_name = None
274
+ self.omit_selfattn = False
275
+
276
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
277
+ return tensor if pos is None else tensor + pos
278
+
279
+ def forward_post(self, tgt, memory,
280
+ tgt_mask: Optional[Tensor] = None,
281
+ memory_mask: Optional[Tensor] = None,
282
+ tgt_key_padding_mask: Optional[Tensor] = None,
283
+ memory_key_padding_mask: Optional[Tensor] = None,
284
+ pos: Optional[Tensor] = None,
285
+ query_pos: Optional[Tensor] = None):
286
+ q = k = self.with_pos_embed(tgt, query_pos)
287
+
288
+ if not self.omit_selfattn:
289
+ tgt2, sim_mat_1 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
290
+ key_padding_mask=tgt_key_padding_mask)
291
+
292
+ tgt = tgt + self.dropout1(tgt2)
293
+ tgt = self.norm1(tgt)
294
+
295
+ tgt2, sim_mat_2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
296
+ key=self.with_pos_embed(memory, pos),
297
+ value=memory, attn_mask=memory_mask,
298
+ key_padding_mask=memory_key_padding_mask)
299
+
300
+ tgt = tgt + self.dropout2(tgt2)
301
+ tgt = self.norm2(tgt)
302
+
303
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
304
+ tgt = tgt + self.dropout3(tgt2)
305
+ tgt = self.norm3(tgt)
306
+ return tgt
307
+
308
+ def forward_pre(self, tgt, memory,
309
+ tgt_mask: Optional[Tensor] = None,
310
+ memory_mask: Optional[Tensor] = None,
311
+ tgt_key_padding_mask: Optional[Tensor] = None,
312
+ memory_key_padding_mask: Optional[Tensor] = None,
313
+ pos: Optional[Tensor] = None,
314
+ query_pos: Optional[Tensor] = None):
315
+ tgt2 = self.norm1(tgt)
316
+ q = k = self.with_pos_embed(tgt2, query_pos)
317
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
318
+ key_padding_mask=tgt_key_padding_mask)[0]
319
+
320
+ tgt = tgt + self.dropout1(tgt2)
321
+ tgt2 = self.norm2(tgt)
322
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
323
+ key=self.with_pos_embed(memory, pos),
324
+ value=memory, attn_mask=memory_mask,
325
+ key_padding_mask=memory_key_padding_mask)[0]
326
+
327
+ tgt = tgt + self.dropout2(tgt2)
328
+ tgt2 = self.norm3(tgt)
329
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
330
+ tgt = tgt + self.dropout3(tgt2)
331
+ return tgt
332
+
333
+ def forward(self, tgt, memory,
334
+ tgt_mask: Optional[Tensor] = None,
335
+ memory_mask: Optional[Tensor] = None,
336
+ tgt_key_padding_mask: Optional[Tensor] = None,
337
+ memory_key_padding_mask: Optional[Tensor] = None,
338
+ pos: Optional[Tensor] = None,
339
+ query_pos: Optional[Tensor] = None):
340
+ if self.normalize_before:
341
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
342
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
343
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
344
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
345
+
346
+
347
+ def _get_clones(module, N):
348
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
349
+
350
+
351
+ def build_transformer(args):
352
+ return Transformer(
353
+ d_model=args.hidden_dim,
354
+ dropout=args.dropout,
355
+ nhead=args.nheads,
356
+ dim_feedforward=args.dim_feedforward,
357
+ num_encoder_layers=args.enc_layers,
358
+ num_decoder_layers=args.dec_layers,
359
+ normalize_before=args.pre_norm,
360
+ return_intermediate_dec=False,
361
+ rm_self_attn_dec=not args.keep_other_self_attn_dec,
362
+ rm_first_self_attn=not args.keep_first_self_attn_dec,
363
+ )
364
+
365
+
366
+ def _get_activation_fn(activation):
367
+ """Return an activation function given a string"""
368
+ if activation == "relu":
369
+ return F.relu
370
+ if activation == "gelu":
371
+ return F.gelu
372
+ if activation == "glu":
373
+ return F.glu
374
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
models/transmil_q2l.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid TransMIL + Query2Label Architecture
3
+
4
+ Combines:
5
+ - TransMIL's instance-level feature aggregation (with Nystrom attention)
6
+ - Query2Label's learnable label queries with cross-attention decoder
7
+ - End-to-end training with ResNet-50 backbone
8
+
9
+ Key Innovation: Extract sequence features from TransMIL BEFORE CLS aggregation,
10
+ allowing Q2L label queries to cross-attend across all ultrasound images per case.
11
+ """
12
+
13
+ import os
14
+ import sys
15
+
16
+ # Add models directory to path for local imports
17
+ _models_dir = os.path.dirname(os.path.abspath(__file__))
18
+ if _models_dir not in sys.path:
19
+ sys.path.insert(0, _models_dir)
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torchvision
25
+ import numpy as np
26
+ from torch.utils.checkpoint import checkpoint_sequential
27
+
28
+ # Import TransMIL components (from nystrom-attention package)
29
+ from nystrom_attention import NystromAttention
30
+
31
+ # Import Q2L Transformer components from local transformer.py
32
+ try:
33
+ from models.transformer import TransformerDecoder, TransformerDecoderLayer
34
+ except ImportError:
35
+ try:
36
+ from transformer import TransformerDecoder, TransformerDecoderLayer
37
+ except ImportError:
38
+ print("Warning: Could not import Q2L Transformer components.")
39
+
40
+ # ============================================================================
41
+ # TransMIL Components (Modified)
42
+ # ============================================================================
43
+
44
+ class TransLayer(nn.Module):
45
+ """Transformer layer with Nystrom attention (from TransMIL)"""
46
+
47
+ def __init__(self, norm_layer=nn.LayerNorm, dim=512):
48
+ super().__init__()
49
+ self.norm = norm_layer(dim)
50
+ self.attn = NystromAttention(
51
+ dim=dim,
52
+ dim_head=dim // 8,
53
+ heads=8,
54
+ num_landmarks=dim // 2,
55
+ pinv_iterations=6,
56
+ residual=True,
57
+ dropout=0.1
58
+ )
59
+
60
+ def forward(self, x):
61
+ x = x + self.attn(self.norm(x))
62
+ return x
63
+
64
+
65
+ class TransMILFeatureExtractor(nn.Module):
66
+ """
67
+ Modified TransMIL that outputs sequence features instead of aggregated CLS token.
68
+
69
+ Based on TransMIL.py but extracts features BEFORE CLS aggregation (line 83 output).
70
+ Uses learned 1D position encoding instead of PPEG for simplicity.
71
+
72
+ Args:
73
+ input_dim: Dimension of input features (2048 for ResNet-50)
74
+ hidden_dim: Dimension of hidden features (512 default)
75
+ use_ppeg: Whether to use PPEG (2D positional encoding) or learned 1D encoding
76
+ max_seq_len: Maximum sequence length for position encoding
77
+ """
78
+
79
+ def __init__(self, input_dim=2048, hidden_dim=512, use_ppeg=False, max_seq_len=1024):
80
+ super().__init__()
81
+
82
+ # Feature projection (TransMIL line 50)
83
+ self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU())
84
+
85
+ # Learnable CLS token (TransMIL line 51)
86
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
87
+
88
+ # Transformer layers (TransMIL lines 53-54)
89
+ self.layer1 = TransLayer(dim=hidden_dim)
90
+ self.layer2 = TransLayer(dim=hidden_dim)
91
+
92
+ # LayerNorm (TransMIL line 55)
93
+ self.norm = nn.LayerNorm(hidden_dim)
94
+
95
+ # Position encoding
96
+ self.use_ppeg = use_ppeg
97
+ if not use_ppeg:
98
+ # Learned 1D position encoding (simpler than PPEG)
99
+ self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
100
+ else:
101
+ # PPEG: Position-aware Patch Embedding Generator (requires 2D reshaping)
102
+ self.pos_layer = PPEG(dim=hidden_dim)
103
+
104
+ def forward(self, features, mask=None):
105
+ """
106
+ Args:
107
+ features: [B, N, input_dim] - Instance features (e.g., from ResNet-50)
108
+ mask: [B, N] - Padding mask (True = valid instance, False = padded)
109
+
110
+ Returns:
111
+ seq_features: [B, 1+N, hidden_dim] - Sequence features (CLS + instances)
112
+ attn_mask: [B, 1+N] - Attention mask for decoder
113
+ """
114
+ B, N, _ = features.shape
115
+
116
+ # Project features (TransMIL line 63)
117
+ h = self.fc1(features) # [B, N, hidden_dim]
118
+
119
+ # Handle PPEG padding if needed
120
+ if self.use_ppeg:
121
+ # Pad to nearest square for PPEG (TransMIL lines 65-69)
122
+ H = h.shape[1]
123
+ _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
124
+ add_length = _H * _W - H
125
+ if add_length > 0:
126
+ h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N_padded, hidden_dim]
127
+
128
+ # Update mask
129
+ if mask is not None:
130
+ pad_mask = torch.zeros(B, add_length, dtype=torch.bool, device=mask.device)
131
+ mask = torch.cat([mask, pad_mask], dim=1)
132
+
133
+ # Add CLS token (TransMIL lines 72-74)
134
+ cls_tokens = self.cls_token.expand(B, -1, -1)
135
+ h = torch.cat([cls_tokens, h], dim=1) # [B, 1+N, hidden_dim]
136
+
137
+ # Update mask to include CLS (always valid)
138
+ if mask is not None:
139
+ cls_mask = torch.ones(B, 1, dtype=torch.bool, device=mask.device)
140
+ attn_mask = torch.cat([cls_mask, mask], dim=1) # [B, 1+N]
141
+ else:
142
+ attn_mask = torch.ones(B, h.shape[1], dtype=torch.bool, device=h.device)
143
+
144
+ # TransLayer 1 (TransMIL line 77)
145
+ h = self.layer1(h) # [B, 1+N, hidden_dim]
146
+
147
+ # Position encoding
148
+ if self.use_ppeg:
149
+ # PPEG (TransMIL line 80)
150
+ h = self.pos_layer(h, _H, _W)
151
+ else:
152
+ # Learned 1D position encoding
153
+ seq_len = h.shape[1]
154
+ h = h + self.pos_embedding[:, :seq_len, :]
155
+
156
+ # TransLayer 2 (TransMIL line 83)
157
+ h = self.layer2(h) # [B, 1+N, hidden_dim]
158
+
159
+ # LayerNorm (TransMIL line 86, but keep full sequence)
160
+ h = self.norm(h) # [B, 1+N, hidden_dim]
161
+
162
+ # CRITICAL: Return full sequence, not just CLS token
163
+ return h, attn_mask
164
+
165
+
166
+ class PPEG(nn.Module):
167
+ """
168
+ Position-aware Patch Embedding Generator (from TransMIL)
169
+ Uses 2D depthwise convolutions to inject spatial positional information.
170
+ """
171
+
172
+ def __init__(self, dim=512):
173
+ super().__init__()
174
+ self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim)
175
+ self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim)
176
+ self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim)
177
+
178
+ def forward(self, x, H, W):
179
+ """
180
+ Args:
181
+ x: [B, 1+N, C] - Token sequence (CLS + instances)
182
+ H, W: Grid dimensions (H * W >= N)
183
+ """
184
+ B, _, C = x.shape
185
+
186
+ # Separate CLS token and feature tokens
187
+ cls_token, feat_token = x[:, 0], x[:, 1:]
188
+
189
+ # Reshape to 2D grid
190
+ cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
191
+
192
+ # Apply 2D convolutions
193
+ x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat)
194
+
195
+ # Flatten back to sequence
196
+ x = x.flatten(2).transpose(1, 2)
197
+
198
+ # Concatenate CLS token back
199
+ x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
200
+
201
+ return x
202
+
203
+
204
+ # ============================================================================
205
+ # Query2Label Components (Adapted for Sequences)
206
+ # ============================================================================
207
+
208
+ class GroupWiseLinear(nn.Module):
209
+ """
210
+ Group-wise linear layer for per-class classification (from Q2L).
211
+ Applies a separate linear transformation for each class.
212
+ """
213
+
214
+ def __init__(self, num_class, hidden_dim, bias=True):
215
+ super().__init__()
216
+ self.num_class = num_class
217
+ self.hidden_dim = hidden_dim
218
+ self.bias = bias
219
+
220
+ self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
221
+ if bias:
222
+ self.b = nn.Parameter(torch.Tensor(1, num_class))
223
+ self.reset_parameters()
224
+
225
+ def reset_parameters(self):
226
+ import math
227
+ stdv = 1. / math.sqrt(self.W.size(2))
228
+ for i in range(self.num_class):
229
+ self.W[0][i].data.uniform_(-stdv, stdv)
230
+ if self.bias:
231
+ for i in range(self.num_class):
232
+ self.b[0][i].data.uniform_(-stdv, stdv)
233
+
234
+ def forward(self, x):
235
+ """
236
+ Args:
237
+ x: [B, num_class, hidden_dim]
238
+ Returns:
239
+ logits: [B, num_class]
240
+ """
241
+ # Element-wise multiplication and sum over hidden_dim
242
+ x = (self.W * x).sum(-1) # [B, num_class]
243
+ if self.bias:
244
+ x = x + self.b
245
+ return x
246
+
247
+
248
+ class HybridQuery2Label(nn.Module):
249
+ """
250
+ Query2Label decoder adapted for sequence inputs (not spatial features).
251
+
252
+ Uses learnable label queries to cross-attend to instance sequence from TransMIL.
253
+ Based on query2label.py but modified to accept [B, 1+N, hidden_dim] sequences
254
+ instead of [B, C, H, W] spatial features.
255
+
256
+ Args:
257
+ num_class: Number of label classes
258
+ hidden_dim: Dimension of features (512)
259
+ nheads: Number of attention heads
260
+ num_decoder_layers: Number of transformer decoder layers
261
+ dim_feedforward: Dimension of feedforward network
262
+ dropout: Dropout rate
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ num_class,
268
+ hidden_dim=512,
269
+ nheads=8,
270
+ num_decoder_layers=2,
271
+ dim_feedforward=2048,
272
+ dropout=0.1,
273
+ normalize_before=False
274
+ ):
275
+ super().__init__()
276
+ self.num_class = num_class
277
+ self.hidden_dim = hidden_dim
278
+
279
+ # Label query embeddings (Q2L line 68)
280
+ self.query_embed = nn.Embedding(num_class, hidden_dim)
281
+
282
+ # Transformer decoder (Q2L uses transformer.py)
283
+ decoder_layer = TransformerDecoderLayer(
284
+ d_model=hidden_dim,
285
+ nhead=nheads,
286
+ dim_feedforward=dim_feedforward,
287
+ dropout=dropout,
288
+ normalize_before=normalize_before
289
+ )
290
+ decoder_norm = nn.LayerNorm(hidden_dim)
291
+ self.decoder = TransformerDecoder(
292
+ decoder_layer,
293
+ num_decoder_layers,
294
+ decoder_norm,
295
+ return_intermediate=False
296
+ )
297
+
298
+ # Group-wise linear classifier (Q2L line 69)
299
+ self.fc = GroupWiseLinear(num_class, hidden_dim, bias=True)
300
+
301
+ def forward(self, sequence_features, memory_key_padding_mask=None):
302
+ """
303
+ Args:
304
+ sequence_features: [B, 1+N, hidden_dim] - Sequence from TransMIL
305
+ memory_key_padding_mask: [B, 1+N] - Padding mask (True = ignore, False = valid)
306
+ NOTE: PyTorch convention is inverted!
307
+
308
+ Returns:
309
+ logits: [B, num_class] - Multi-label classification logits
310
+ """
311
+ B = sequence_features.shape[0]
312
+
313
+ # Transpose for decoder: expects [seq_len, B, hidden_dim]
314
+ memory = sequence_features.permute(1, 0, 2) # [1+N, B, hidden_dim]
315
+
316
+ # Label queries (Q2L line 77)
317
+ query_embed = self.query_embed.weight # [num_class, hidden_dim]
318
+ query_embed = query_embed.unsqueeze(1).repeat(1, B, 1) # [num_class, B, hidden_dim]
319
+
320
+ # Initialize target (zero tensor)
321
+ tgt = torch.zeros_like(query_embed) # [num_class, B, hidden_dim]
322
+
323
+ # Cross-attention decoder (Q2L line 78)
324
+ # Queries attend to instance sequence via cross-attention
325
+ hs = self.decoder(
326
+ tgt=tgt,
327
+ memory=memory,
328
+ memory_key_padding_mask=memory_key_padding_mask,
329
+ pos=None, # No positional encoding (already in TransMIL)
330
+ query_pos=query_embed
331
+ ) # [1, num_class, B, hidden_dim] if return_intermediate=False
332
+
333
+ # Handle output shape
334
+ if hs.dim() == 4:
335
+ hs = hs[-1] # Take last layer: [num_class, B, hidden_dim]
336
+
337
+ # Transpose to [B, num_class, hidden_dim]
338
+ hs = hs.permute(1, 0, 2) # [B, num_class, hidden_dim]
339
+
340
+ # Group-wise linear classification (Q2L line 79)
341
+ logits = self.fc(hs) # [B, num_class]
342
+
343
+ return logits
344
+
345
+
346
+ # ============================================================================
347
+ # ResNet-50 Backbone
348
+ # ============================================================================
349
+
350
+ class ResNet50Backbone(nn.Module):
351
+ """
352
+ ResNet-50 feature extractor with Global Average Pooling.
353
+
354
+ Extracts 2048-dimensional features from images for TransMIL input.
355
+ Supports gradient checkpointing for memory efficiency.
356
+
357
+ Args:
358
+ pretrained: Use ImageNet pre-trained weights
359
+ use_checkpointing: Enable gradient checkpointing (saves memory)
360
+ """
361
+
362
+ def __init__(self, pretrained=True, use_checkpointing=False):
363
+ super().__init__()
364
+
365
+ # Load ResNet-50
366
+ resnet = torchvision.models.resnet50(pretrained=pretrained)
367
+
368
+ # Remove final FC layer and avgpool
369
+ # Output of layer4: [B, 2048, 7, 7] for 224x224 input
370
+ self.features = nn.Sequential(*list(resnet.children())[:-2])
371
+
372
+ # Global Average Pooling
373
+ self.gap = nn.AdaptiveAvgPool2d(1)
374
+
375
+ self.use_checkpointing = use_checkpointing
376
+
377
+ def forward(self, images):
378
+ """
379
+ Args:
380
+ images: [B*N, 3, H, W] - Batch of images (flattened across cases)
381
+
382
+ Returns:
383
+ features: [B*N, 2048] - Instance features
384
+ """
385
+ if self.training and self.use_checkpointing:
386
+ # Gradient checkpointing: segment backbone into chunks
387
+ # Trades compute for memory (recomputes activations during backward)
388
+ x = checkpoint_sequential(self.features, segments=4, input=images)
389
+ else:
390
+ x = self.features(images) # [B*N, 2048, 7, 7]
391
+
392
+ x = self.gap(x) # [B*N, 2048, 1, 1]
393
+ x = x.flatten(1) # [B*N, 2048]
394
+
395
+ return x
396
+
397
+
398
+ # ============================================================================
399
+ # Complete End-to-End Model
400
+ # ============================================================================
401
+
402
+ class TransMIL_Query2Label_E2E(nn.Module):
403
+ """
404
+ Complete end-to-end model: Images → ResNet-50 → TransMIL → Q2L → Logits
405
+
406
+ Pipeline:
407
+ 1. ResNet-50 extracts features from each ultrasound image
408
+ 2. TransMIL aggregates variable-length instance sequences with attention
409
+ 3. Query2Label decoder performs multi-label classification via cross-attention
410
+
411
+ Args:
412
+ num_class: Number of label classes (default 30)
413
+ hidden_dim: Hidden dimension for TransMIL and Q2L (default 512)
414
+ nheads: Number of attention heads in Q2L decoder
415
+ num_decoder_layers: Number of Q2L decoder layers
416
+ pretrained_resnet: Use ImageNet pre-trained ResNet-50
417
+ use_checkpointing: Enable gradient checkpointing for ResNet-50
418
+ use_ppeg: Use PPEG position encoding (vs learned 1D)
419
+ """
420
+
421
+ def __init__(
422
+ self,
423
+ num_class=30,
424
+ hidden_dim=512,
425
+ nheads=8,
426
+ num_decoder_layers=2,
427
+ pretrained_resnet=True,
428
+ use_checkpointing=False,
429
+ use_ppeg=False
430
+ ):
431
+ super().__init__()
432
+
433
+ # ResNet-50 backbone
434
+ self.backbone = ResNet50Backbone(
435
+ pretrained=pretrained_resnet,
436
+ use_checkpointing=use_checkpointing
437
+ )
438
+
439
+ # TransMIL feature extractor (no PPEG by default, learned 1D position encoding)
440
+ self.feature_extractor = TransMILFeatureExtractor(
441
+ input_dim=2048,
442
+ hidden_dim=hidden_dim,
443
+ use_ppeg=use_ppeg
444
+ )
445
+
446
+ # Query2Label decoder
447
+ self.q2l_decoder = HybridQuery2Label(
448
+ num_class=num_class,
449
+ hidden_dim=hidden_dim,
450
+ nheads=nheads,
451
+ num_decoder_layers=num_decoder_layers
452
+ )
453
+
454
+ def forward(self, images, num_instances_per_case):
455
+ """
456
+ Args:
457
+ images: [B*N_total, 3, H, W] - All images flattened across batch
458
+ num_instances_per_case: [B] or list - Number of images per case
459
+
460
+ Returns:
461
+ logits: [B, num_class] - Multi-label classification logits
462
+ """
463
+ # Convert to tensor if list
464
+ if isinstance(num_instances_per_case, list):
465
+ num_instances_per_case = torch.tensor(num_instances_per_case, device=images.device)
466
+
467
+ B = len(num_instances_per_case)
468
+
469
+ # Step 1: Extract features from all images
470
+ all_features = self.backbone(images) # [B*N_total, 2048]
471
+
472
+ # Step 2: Reshape to [B, max_N, 2048] with padding
473
+ max_N = int(num_instances_per_case.max().item())
474
+ features_padded = torch.zeros(B, max_N, 2048, device=images.device)
475
+ masks = torch.zeros(B, max_N, dtype=torch.bool, device=images.device)
476
+
477
+ idx = 0
478
+ for i, n in enumerate(num_instances_per_case):
479
+ n = int(n.item()) if torch.is_tensor(n) else int(n)
480
+ features_padded[i, :n] = all_features[idx:idx+n]
481
+ masks[i, :n] = True # True = valid instance
482
+ idx += n
483
+
484
+ # Step 3: TransMIL sequence features
485
+ seq_features, attn_mask = self.feature_extractor(features_padded, masks)
486
+ # seq_features: [B, 1+max_N, 512]
487
+ # attn_mask: [B, 1+max_N] where True = valid, False = padded
488
+
489
+ # Step 4: Q2L decoder
490
+ # IMPORTANT: PyTorch MultiheadAttention uses inverted mask convention!
491
+ # memory_key_padding_mask: True = ignore, False = attend
492
+ # So we need to invert our mask
493
+ decoder_mask = ~attn_mask # Invert: True = padded (ignore)
494
+ logits = self.q2l_decoder(seq_features, memory_key_padding_mask=decoder_mask)
495
+ # logits: [B, num_class]
496
+
497
+ return logits
498
+
499
+ def freeze_backbone(self):
500
+ """Freeze ResNet-50 backbone for training only TransMIL+Q2L"""
501
+ for param in self.backbone.parameters():
502
+ param.requires_grad = False
503
+
504
+ def unfreeze_backbone(self):
505
+ """Unfreeze ResNet-50 for end-to-end fine-tuning"""
506
+ for param in self.backbone.parameters():
507
+ param.requires_grad = True
508
+
509
+
510
+ # ============================================================================
511
+ # Testing
512
+ # ============================================================================
513
+
514
+ if __name__ == "__main__":
515
+ print("Testing TransMIL_Query2Label_E2E model...")
516
+
517
+ # Model config
518
+ num_class = 30
519
+ batch_size = 2
520
+ num_instances = [8, 12] # Variable N per case
521
+ img_size = 224
522
+
523
+ # Create model
524
+ model = TransMIL_Query2Label_E2E(
525
+ num_class=num_class,
526
+ hidden_dim=512,
527
+ nheads=8,
528
+ num_decoder_layers=2,
529
+ pretrained_resnet=False, # Faster for testing
530
+ use_checkpointing=False,
531
+ use_ppeg=False
532
+ )
533
+
534
+ # Create dummy data
535
+ total_images = sum(num_instances)
536
+ images = torch.randn(total_images, 3, img_size, img_size)
537
+
538
+ print(f"\nInput shapes:")
539
+ print(f" Images: {images.shape}")
540
+ print(f" Num instances per case: {num_instances}")
541
+
542
+ # Forward pass
543
+ model.eval()
544
+ with torch.no_grad():
545
+ logits = model(images, num_instances)
546
+
547
+ print(f"\nOutput shape:")
548
+ print(f" Logits: {logits.shape}")
549
+ print(f" Expected: [{batch_size}, {num_class}]")
550
+
551
+ assert logits.shape == (batch_size, num_class), "Output shape mismatch!"
552
+ print("\n✓ Model test passed!")
553
+
554
+ # Test individual components
555
+ print("\n" + "="*60)
556
+ print("Testing individual components...")
557
+ print("="*60)
558
+
559
+ # Test TransMILFeatureExtractor
560
+ print("\n1. TransMILFeatureExtractor")
561
+ feature_extractor = TransMILFeatureExtractor(input_dim=2048, hidden_dim=512)
562
+ features = torch.randn(2, 10, 2048)
563
+ mask = torch.ones(2, 10, dtype=torch.bool)
564
+ seq_features, attn_mask = feature_extractor(features, mask)
565
+ print(f" Input: {features.shape}, Output: {seq_features.shape}")
566
+ assert seq_features.shape == (2, 11, 512) # 1 CLS + 10 instances
567
+ print(" ✓ Passed")
568
+
569
+ # Test HybridQuery2Label
570
+ print("\n2. HybridQuery2Label")
571
+ decoder = HybridQuery2Label(num_class=30, hidden_dim=512)
572
+ seq_features = torch.randn(2, 11, 512)
573
+ logits = decoder(seq_features)
574
+ print(f" Input: {seq_features.shape}, Output: {logits.shape}")
575
+ assert logits.shape == (2, 30)
576
+ print(" ✓ Passed")
577
+
578
+ # Test ResNet50Backbone
579
+ print("\n3. ResNet50Backbone")
580
+ backbone = ResNet50Backbone(pretrained=False)
581
+ images = torch.randn(20, 3, 224, 224)
582
+ features = backbone(images)
583
+ print(f" Input: {images.shape}, Output: {features.shape}")
584
+ assert features.shape == (20, 2048)
585
+ print(" ✓ Passed")
586
+
587
+ print("\n" + "="*60)
588
+ print("All tests passed! ✓")
589
+ print("="*60)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 超声提示多标签分类模型 - 依赖列表
2
+ # TransMIL + Query2Label Hybrid Model Requirements
3
+
4
+ torch>=1.10.0
5
+ torchvision>=0.11.0
6
+ numpy>=1.19.0
7
+ pandas>=1.3.0
8
+ Pillow>=8.0.0
9
+ scikit-learn>=0.24.0
10
+ tqdm>=4.60.0
11
+ PyYAML>=5.4.0
12
+ nystrom-attention>=0.0.11
scripts/evaluate.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 评估脚本 - 在测试集上评估模型性能
3
+ # 请先修改 config.yaml 中的数据路径
4
+
5
+ cd "$(dirname "$0")/.."
6
+ python evaluate.py
scripts/infer_single.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 单步推理脚本 - 对单个病例进行推理
3
+ # 用法1: ./infer_single.sh /path/to/image1.png /path/to/image2.png ...
4
+ # 用法2: ./infer_single.sh --image_dir /path/to/case_folder/
5
+
6
+ cd "$(dirname "$0")/.."
7
+ python infer_single_case.py "$@" --threshold 0.5
scripts/train.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 训练脚本 - TransMIL + Query2Label Hybrid Model
3
+ # 请先修改 config.yaml 中的数据路径
4
+
5
+ python train_hybrid.py --config config.yaml
thyroid_dataset.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import pandas as pd
5
+ import numpy as np
6
+ from PIL import Image
7
+ from pathlib import Path
8
+ from typing import List, Dict, Optional
9
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
10
+ from torchvision import transforms
11
+
12
+ # 18类标签定义 (必须与CSV列顺序严格一致)
13
+ '''
14
+ TARGET_CLASSES = [
15
+ "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
16
+ "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
17
+ "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
18
+ "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
19
+ ]
20
+ '''
21
+ #17类标签定义,去除切除术后
22
+ TARGET_CLASSES = [
23
+ "TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
24
+ "TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
25
+ "钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
26
+ "弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
27
+ ]
28
+
29
+ # 定义稀有/困难类别索引 (用于重采样)
30
+ # 对应: 4b(4), 4c(5), 5(6), 切除(12), 转移(17)
31
+ #RARE_CLASS_INDICES = [4, 5, 6, 12, 17]
32
+
33
+ RARE_CLASS_INDICES = [4, 5, 6, 16] #17类标签
34
+
35
+ class ThyroidMultiLabelDataset(Dataset):
36
+ def __init__(self,
37
+ data_root: str,
38
+ annotation_csv: str,
39
+ split_json: Optional[str] = None,
40
+ split_type: str = 'train', # 'train', 'val', 'test'
41
+ val_json_path: Optional[str] = None, # 仅当 split_type='train' 时需要,用于排除验证集
42
+ test_json_path: Optional[str] = None, # 仅当 split_type='train' 时需要,用于排除测试集
43
+ img_size: int = 224,
44
+ max_images_per_case: int = 20,
45
+ transform=None):
46
+
47
+ self.data_root = Path(data_root)
48
+ self.img_size = img_size
49
+ self.max_images_per_case = max_images_per_case
50
+ self.split_type = split_type
51
+
52
+ # 1. 读取所有标签
53
+ self.df_labels = pd.read_csv(annotation_csv)
54
+ # 将 case_path 设为索引,方便查询
55
+ self.df_labels.set_index('case_path', inplace=True)
56
+
57
+ # 2. 确定数据集包含的 case_list
58
+ self.case_list = self._get_split_cases(split_json, val_json_path, test_json_path)
59
+
60
+ # 3. 定义数据增强
61
+ if transform:
62
+ self.transform = transform
63
+ elif split_type == 'train':
64
+ self.transform = transforms.Compose([
65
+ transforms.Resize((img_size, img_size)),
66
+ transforms.RandomHorizontalFlip(p=0.5),
67
+ transforms.RandomVerticalFlip(p=0.5), # 超声可以上下翻转
68
+ transforms.RandomRotation(15),
69
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
72
+ ])
73
+ else:
74
+ self.transform = transforms.Compose([
75
+ transforms.Resize((img_size, img_size)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
78
+ ])
79
+
80
+ print(f"[{split_type.upper()}] Loaded {len(self.case_list)} cases.")
81
+
82
+ def _get_split_cases(self, split_json, val_json_path, test_json_path):
83
+ """
84
+ 根据 JSON 文件划分数据集
85
+ """
86
+ all_cases_in_csv = set(self.df_labels.index.tolist())
87
+
88
+ # 读取指定的 split json (如果是 val 或 test)
89
+ target_cases = []
90
+ if split_json:
91
+ with open(split_json, 'r') as f:
92
+ data = json.load(f)
93
+ # JSON 里的 rel_path 对应 CSV 里的 case_path
94
+ target_cases = [item['rel_path'] for item in data]
95
+
96
+ # 过滤掉 CSV 里没有的 case (以防万一)
97
+ valid_cases = [c for c in target_cases if c in all_cases_in_csv]
98
+ return valid_cases
99
+
100
+ # 如果是 Train,逻辑是:所有 CSV 里的 case 减去 Val 和 Test 的 case
101
+ elif self.split_type == 'train':
102
+ exclude_cases = set()
103
+
104
+ if val_json_path:
105
+ with open(val_json_path, 'r') as f:
106
+ exclude_cases.update([item['rel_path'] for item in json.load(f)])
107
+
108
+ if test_json_path:
109
+ with open(test_json_path, 'r') as f:
110
+ exclude_cases.update([item['rel_path'] for item in json.load(f)])
111
+
112
+ train_cases = list(all_cases_in_csv - exclude_cases)
113
+ return sorted(train_cases) # 排序保证确定性
114
+
115
+ else:
116
+ return []
117
+
118
+ def __len__(self):
119
+ return len(self.case_list)
120
+
121
+ def __getitem__(self, idx):
122
+ case_rel_path = self.case_list[idx]
123
+
124
+ # 1. 拼接图片目录路径: data_root / BatchX/CaseID / Images
125
+ img_dir = self.data_root / case_rel_path / "Images"
126
+
127
+ # 2. 获取标签
128
+ # df.loc[index] 返回 Series,转 numpy
129
+ try:
130
+ label_vec = self.df_labels.loc[case_rel_path, TARGET_CLASSES].values.astype(np.float32)
131
+ label_tensor = torch.tensor(label_vec)
132
+ except KeyError:
133
+ print(f"Warning: Label for {case_rel_path} not found in CSV. Using zeros.")
134
+ label_tensor = torch.zeros(len(TARGET_CLASSES))
135
+
136
+ # 3. 读取图片
137
+ image_files = sorted(list(img_dir.glob("*.jpg")) + list(img_dir.glob("*.png")) + list(img_dir.glob("*.bmp")))
138
+
139
+ # 采样逻辑 (Train: 随机采; Val/Test: 取前N张)
140
+ if self.max_images_per_case and len(image_files) > self.max_images_per_case:
141
+ if self.split_type == 'train':
142
+ # 训练时随机采样,增加多样性
143
+ image_files = np.random.choice(image_files, self.max_images_per_case, replace=False)
144
+ else:
145
+ image_files = image_files[:self.max_images_per_case]
146
+
147
+ images = []
148
+ for img_path in image_files:
149
+ try:
150
+ img = Image.open(img_path).convert('RGB')
151
+ if self.transform:
152
+ img = self.transform(img)
153
+ images.append(img)
154
+ except Exception as e:
155
+ pass
156
+
157
+ if len(images) == 0:
158
+ # 异常处理:生成全黑图
159
+ images = [torch.zeros(3, self.img_size, self.img_size)]
160
+
161
+ images_stack = torch.stack(images) # [N, 3, H, W]
162
+
163
+ return {
164
+ 'images': images_stack,
165
+ 'labels': label_tensor,
166
+ 'num_images': len(images),
167
+ 'case_id': case_rel_path
168
+ }
169
+
170
+ def get_sampler_weights(self):
171
+ """
172
+ 计算采样权重:包含稀有类别的样本权重 = 10,其他 = 1
173
+ """
174
+ weights = []
175
+ for case_rel_path in self.case_list:
176
+ label_vec = self.df_labels.loc[case_rel_path, TARGET_CLASSES].values
177
+
178
+ # 检查是否有稀有类别
179
+ is_rare = False
180
+ for idx in RARE_CLASS_INDICES:
181
+ if label_vec[idx] == 1:
182
+ is_rare = True
183
+ break
184
+
185
+ if is_rare:
186
+ weights.append(10.0) # 稀有样本采样概率翻10倍
187
+ else:
188
+ weights.append(1.0)
189
+
190
+ return torch.tensor(weights, dtype=torch.double)
191
+
192
+ def collate_fn(batch):
193
+ images_list = []
194
+ labels_list = []
195
+ num_instances_list = []
196
+ case_ids = []
197
+
198
+ for item in batch:
199
+ images_list.append(item['images'])
200
+ labels_list.append(item['labels'])
201
+ num_instances_list.append(item['num_images'])
202
+ case_ids.append(item['case_id'])
203
+
204
+ all_images = torch.cat(images_list, dim=0)
205
+ labels = torch.stack(labels_list)
206
+ num_instances_per_case = torch.tensor(num_instances_list, dtype=torch.long)
207
+
208
+ return {
209
+ 'images': all_images,
210
+ 'labels': labels,
211
+ 'num_instances_per_case': num_instances_per_case,
212
+ 'case_ids': case_ids
213
+ }
214
+
215
+ def create_dataloaders(config):
216
+ data_root = config['data']['data_root']
217
+ csv_path = config['data']['annotation_csv']
218
+ val_json = config['data']['val_json']
219
+ test_json = config['data']['test_json']
220
+
221
+ # Train Dataset
222
+ train_dataset = ThyroidMultiLabelDataset(
223
+ data_root=data_root,
224
+ annotation_csv=csv_path,
225
+ split_type='train',
226
+ val_json_path=val_json,
227
+ test_json_path=test_json,
228
+ img_size=config['data']['img_size'],
229
+ max_images_per_case=config['data']['max_images_per_case']
230
+ )
231
+
232
+ # 计算采样权重并创建 Sampler
233
+ print("Calculating sampler weights for class balance...")
234
+ train_weights = train_dataset.get_sampler_weights()
235
+ sampler = WeightedRandomSampler(train_weights, len(train_weights))
236
+
237
+ train_loader = DataLoader(
238
+ train_dataset,
239
+ batch_size=config['training']['batch_size'],
240
+ sampler=sampler, # 使用 sampler 时不要 shuffle=True
241
+ num_workers=config['data']['num_workers'],
242
+ collate_fn=collate_fn,
243
+ pin_memory=True,
244
+ drop_last=True
245
+ )
246
+
247
+ # Val Dataset
248
+ val_dataset = ThyroidMultiLabelDataset(
249
+ data_root=data_root,
250
+ annotation_csv=csv_path,
251
+ split_type='val',
252
+ split_json=val_json,
253
+ img_size=config['data']['img_size'],
254
+ max_images_per_case=config['data']['max_images_per_case']
255
+ )
256
+
257
+ val_loader = DataLoader(
258
+ val_dataset,
259
+ batch_size=config['training']['batch_size'],
260
+ shuffle=False,
261
+ num_workers=config['data']['num_workers'],
262
+ collate_fn=collate_fn,
263
+ pin_memory=True
264
+ )
265
+
266
+ # Test Dataset
267
+ test_dataset = ThyroidMultiLabelDataset(
268
+ data_root=data_root,
269
+ annotation_csv=csv_path,
270
+ split_type='test',
271
+ split_json=test_json,
272
+ img_size=config['data']['img_size'],
273
+ max_images_per_case=None # 测试时尽可能用所有图
274
+ )
275
+
276
+ test_loader = DataLoader(
277
+ test_dataset,
278
+ batch_size=config['training']['batch_size'],
279
+ shuffle=False,
280
+ num_workers=config['data']['num_workers'],
281
+ collate_fn=collate_fn,
282
+ pin_memory=True
283
+ )
284
+
285
+ return train_loader, val_loader, test_loader
thyroid_multilabel_annotations.csv ADDED
The diff for this file is too large to render. See raw diff
 
train_hybrid.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Script for TransMIL + Query2Label Hybrid Model
3
+
4
+ Supports:
5
+ - End-to-end training with ResNet-50 backbone
6
+ - Mixed precision training (AMP) for memory efficiency
7
+ - Gradient accumulation for larger effective batch size
8
+ - Gradient checkpointing for ResNet-50
9
+ - AsymmetricLoss for multi-label imbalance
10
+ - Multi-label evaluation metrics (mAP, per-class AP, F1)
11
+ """
12
+
13
+ import sys
14
+ #sys.path.append('query2labels/lib/models')
15
+ #sys.path.append('/XYFS01/HDD_POOL/sysu_gbli2/sysu_gbli2xy_1/chenshiyu/ThyroidAgent/ThyroidRegion/HintsVer3/query2labels/lib/')
16
+
17
+ import os
18
+ import argparse
19
+ import yaml
20
+ from pathlib import Path
21
+ from datetime import datetime
22
+ import json
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.optim as optim
27
+ from torch.cuda.amp import autocast, GradScaler
28
+ from torch.utils.tensorboard import SummaryWriter
29
+ import numpy as np
30
+ from tqdm import tqdm
31
+ from sklearn.metrics import average_precision_score, f1_score
32
+
33
+ # Import model and dataset
34
+ from models.transmil_q2l import TransMIL_Query2Label_E2E
35
+ from thyroid_dataset import create_dataloaders
36
+
37
+ # Import AsymmetricLoss
38
+ try:
39
+ from models.aslloss import AsymmetricLossOptimized
40
+ except ImportError:
41
+ print("Warning: Could not import AsymmetricLoss.")
42
+ AsymmetricLossOptimized = None
43
+ '''
44
+ try:
45
+ #from aslloss import AsymmetricLossOptimized
46
+ from models.aslloss import AsymmetricLossOptimized
47
+ except ImportError:
48
+ print("Warning: Could not import AsymmetricLoss from query2labels.")
49
+ print("Make sure query2labels/lib/models/aslloss.py is in Python path.")
50
+ AsymmetricLossOptimized = None
51
+
52
+ '''
53
+ # ============================================================================
54
+ # Metrics
55
+ # ============================================================================
56
+
57
+ def compute_multilabel_metrics(preds, targets, threshold=0.5):
58
+ """
59
+ Compute multi-label classification metrics.
60
+
61
+ Args:
62
+ preds: [N, num_class] numpy array of probabilities
63
+ targets: [N, num_class] numpy array of binary labels
64
+ threshold: Classification threshold for F1 score
65
+
66
+ Returns:
67
+ dict with mAP, per-class AP, F1 scores
68
+ """
69
+ metrics = {}
70
+
71
+ # Mean Average Precision (mAP)
72
+ aps = []
73
+ for i in range(targets.shape[1]):
74
+ if targets[:, i].sum() > 0: # Skip classes with no positive samples
75
+ ap = average_precision_score(targets[:, i], preds[:, i])
76
+ aps.append(ap)
77
+ else:
78
+ aps.append(np.nan)
79
+
80
+ metrics['mAP'] = np.nanmean(aps)
81
+ metrics['per_class_AP'] = aps
82
+
83
+ # F1 Score at threshold
84
+ preds_binary = (preds >= threshold).astype(int)
85
+ f1_micro = f1_score(targets, preds_binary, average='micro', zero_division=0)
86
+ f1_macro = f1_score(targets, preds_binary, average='macro', zero_division=0)
87
+ f1_samples = f1_score(targets, preds_binary, average='samples', zero_division=0)
88
+
89
+ metrics['F1_micro'] = f1_micro
90
+ metrics['F1_macro'] = f1_macro
91
+ metrics['F1_samples'] = f1_samples
92
+
93
+ return metrics
94
+
95
+
96
+ # ============================================================================
97
+ # Training Functions
98
+ # ============================================================================
99
+
100
+ def train_epoch(model, dataloader, criterion, optimizer, scaler, device, config, epoch):
101
+ """
102
+ Train for one epoch with gradient accumulation and mixed precision.
103
+
104
+ Args:
105
+ model: TransMIL_Query2Label_E2E model
106
+ dataloader: Training dataloader
107
+ criterion: AsymmetricLoss
108
+ optimizer: AdamW optimizer
109
+ scaler: GradScaler for AMP
110
+ device: torch.device
111
+ config: Config dict
112
+ epoch: Current epoch number
113
+
114
+ Returns:
115
+ Average loss for epoch
116
+ """
117
+ model.train()
118
+
119
+ total_loss = 0.0
120
+ accumulation_steps = config['training']['gradient_accumulation_steps']
121
+ use_amp = config['training']['use_amp']
122
+
123
+ # Progress bar
124
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
125
+
126
+ optimizer.zero_grad()
127
+
128
+ for i, batch in enumerate(pbar):
129
+ images = batch['images'].to(device) # [B*N_total, 3, H, W]
130
+ labels = batch['labels'].to(device) # [B, num_class]
131
+ num_instances_per_case = batch['num_instances_per_case'] # [B]
132
+
133
+ # Mixed precision forward pass
134
+ if use_amp:
135
+ with autocast():
136
+ logits = model(images, num_instances_per_case)
137
+ loss = criterion(logits, labels)
138
+ loss = loss / accumulation_steps # Scale loss for accumulation
139
+ else:
140
+ logits = model(images, num_instances_per_case)
141
+ loss = criterion(logits, labels)
142
+ loss = loss / accumulation_steps
143
+
144
+ # Backward pass
145
+ if use_amp:
146
+ scaler.scale(loss).backward()
147
+ else:
148
+ loss.backward()
149
+
150
+ # Optimizer step every accumulation_steps
151
+ if (i + 1) % accumulation_steps == 0:
152
+ if use_amp:
153
+ scaler.step(optimizer)
154
+ scaler.update()
155
+ else:
156
+ optimizer.step()
157
+ optimizer.zero_grad()
158
+
159
+ # Track loss
160
+ total_loss += loss.item() * accumulation_steps
161
+ pbar.set_postfix({'loss': loss.item() * accumulation_steps})
162
+
163
+ return total_loss / len(dataloader)
164
+
165
+
166
+ @torch.no_grad()
167
+ def validate(model, dataloader, criterion, device, config):
168
+ """
169
+ Validate model with multi-label metrics.
170
+
171
+ Args:
172
+ model: TransMIL_Query2Label_E2E model
173
+ dataloader: Validation dataloader
174
+ criterion: AsymmetricLoss
175
+ device: torch.device
176
+ config: Config dict
177
+
178
+ Returns:
179
+ dict with loss and metrics (mAP, F1, etc.)
180
+ """
181
+ model.eval()
182
+
183
+ total_loss = 0.0
184
+ all_preds = []
185
+ all_targets = []
186
+
187
+ for batch in tqdm(dataloader, desc="Validating"):
188
+ images = batch['images'].to(device)
189
+ labels = batch['labels'].to(device)
190
+ num_instances_per_case = batch['num_instances_per_case']
191
+
192
+ # Forward pass
193
+ logits = model(images, num_instances_per_case)
194
+ loss = criterion(logits, labels)
195
+
196
+ # Sigmoid for multi-label probabilities
197
+ preds = torch.sigmoid(logits)
198
+
199
+ # Store predictions and targets
200
+ all_preds.append(preds.cpu().numpy())
201
+ all_targets.append(labels.cpu().numpy())
202
+
203
+ total_loss += loss.item()
204
+
205
+ # Concatenate all batches
206
+ all_preds = np.concatenate(all_preds, axis=0)
207
+ all_targets = np.concatenate(all_targets, axis=0)
208
+
209
+ # Compute metrics
210
+ metrics = compute_multilabel_metrics(all_preds, all_targets)
211
+ metrics['loss'] = total_loss / len(dataloader)
212
+
213
+ return metrics
214
+
215
+
216
+ # ============================================================================
217
+ # Main Training Loop
218
+ # ============================================================================
219
+
220
+ def train(config, resume_from=None):
221
+ """
222
+ Main training function.
223
+
224
+ Args:
225
+ config: Config dictionary from YAML
226
+ resume_from: Optional checkpoint path to resume training
227
+ """
228
+ # Setup device
229
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
230
+ print(f"\nUsing device: {device}")
231
+ if torch.cuda.is_available():
232
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
233
+ print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
234
+
235
+ # Create save directory
236
+ save_dir = Path(config['training']['save_dir'])
237
+ save_dir.mkdir(parents=True, exist_ok=True)
238
+
239
+ # Create tensorboard writer
240
+ log_dir = save_dir / 'logs' / datetime.now().strftime('%Y%m%d_%H%M%S')
241
+ writer = SummaryWriter(log_dir)
242
+
243
+ # Save config
244
+ with open(save_dir / 'config.yaml', 'w') as f:
245
+ yaml.dump(config, f)
246
+
247
+ # Create dataloaders
248
+ print("\nCreating dataloaders...")
249
+ train_loader, val_loader, test_loader = create_dataloaders(config)
250
+
251
+ # Create model
252
+ print("\nCreating model...")
253
+ model = TransMIL_Query2Label_E2E(
254
+ num_class=config['model']['num_class'],
255
+ hidden_dim=config['model']['hidden_dim'],
256
+ nheads=config['model']['nheads'],
257
+ num_decoder_layers=config['model']['num_decoder_layers'],
258
+ pretrained_resnet=config['model']['pretrained_resnet'],
259
+ use_checkpointing=config['training']['gradient_checkpointing'],
260
+ use_ppeg=config['model'].get('use_ppeg', False)
261
+ )
262
+ model = model.to(device)
263
+
264
+ # Print model stats
265
+ total_params = sum(p.numel() for p in model.parameters())
266
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
267
+ print(f"Total parameters: {total_params:,}")
268
+ print(f"Trainable parameters: {trainable_params:,}")
269
+
270
+ # Create optimizer
271
+ optimizer = optim.AdamW(
272
+ model.parameters(),
273
+ lr=config['training']['lr'],
274
+ weight_decay=config['training']['weight_decay']
275
+ )
276
+
277
+ # Create scheduler
278
+ scheduler_type = config['training'].get('scheduler', 'cosine')
279
+ if scheduler_type == 'cosine':
280
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
281
+ optimizer,
282
+ T_max=config['training']['epochs'],
283
+ eta_min=1e-6
284
+ )
285
+ elif scheduler_type == 'onecycle':
286
+ scheduler = optim.lr_scheduler.OneCycleLR(
287
+ optimizer,
288
+ max_lr=config['training']['lr'],
289
+ epochs=config['training']['epochs'],
290
+ steps_per_epoch=len(train_loader)
291
+ )
292
+ else:
293
+ scheduler = None
294
+
295
+ # Create loss function
296
+ if AsymmetricLossOptimized is not None:
297
+ criterion = AsymmetricLossOptimized(
298
+ gamma_neg=config['training']['gamma_neg'],
299
+ gamma_pos=config['training']['gamma_pos'],
300
+ clip=config['training']['clip'],
301
+ eps=1e-5
302
+ )
303
+ else:
304
+ # Fallback to BCEWithLogitsLoss
305
+ print("Warning: Using BCEWithLogitsLoss instead of AsymmetricLoss")
306
+ criterion = nn.BCEWithLogitsLoss()
307
+
308
+ # Mixed precision scaler
309
+ scaler = GradScaler() if config['training']['use_amp'] else None
310
+
311
+ # Resume from checkpoint if specified
312
+ start_epoch = 0
313
+ best_map = 0.0
314
+
315
+ if resume_from is not None and Path(resume_from).exists():
316
+ print(f"\nResuming from {resume_from}")
317
+ checkpoint = torch.load(resume_from, map_location=device)
318
+ model.load_state_dict(checkpoint['model_state_dict'])
319
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
320
+ start_epoch = checkpoint['epoch'] + 1
321
+ best_map = checkpoint.get('best_map', 0.0)
322
+ if scheduler is not None and 'scheduler_state_dict' in checkpoint:
323
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
324
+ print(f"Resumed from epoch {start_epoch}, best mAP: {best_map:.4f}")
325
+
326
+ # Training loop
327
+ print(f"\nStarting training for {config['training']['epochs']} epochs...")
328
+ print("="*80)
329
+
330
+ for epoch in range(start_epoch, config['training']['epochs']):
331
+ # Train
332
+ train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, config, epoch)
333
+
334
+ # Validate
335
+ val_metrics = validate(model, val_loader, criterion, device, config)
336
+
337
+ # Update scheduler
338
+ if scheduler is not None:
339
+ if scheduler_type == 'onecycle':
340
+ pass # OneCycleLR updates per step, not per epoch
341
+ else:
342
+ scheduler.step()
343
+
344
+ # Log metrics
345
+ current_lr = optimizer.param_groups[0]['lr']
346
+ writer.add_scalar('Loss/train', train_loss, epoch)
347
+ writer.add_scalar('Loss/val', val_metrics['loss'], epoch)
348
+ writer.add_scalar('Metrics/mAP', val_metrics['mAP'], epoch)
349
+ writer.add_scalar('Metrics/F1_micro', val_metrics['F1_micro'], epoch)
350
+ writer.add_scalar('Metrics/F1_macro', val_metrics['F1_macro'], epoch)
351
+ writer.add_scalar('LR', current_lr, epoch)
352
+
353
+ # Print epoch summary
354
+ print(f"\nEpoch {epoch}/{config['training']['epochs']}")
355
+ print(f" Train Loss: {train_loss:.4f}")
356
+ print(f" Val Loss: {val_metrics['loss']:.4f}")
357
+ print(f" mAP: {val_metrics['mAP']:.4f}")
358
+ print(f" F1 (micro): {val_metrics['F1_micro']:.4f}")
359
+ print(f" F1 (macro): {val_metrics['F1_macro']:.4f}")
360
+ print(f" LR: {current_lr:.6f}")
361
+
362
+ # Save checkpoint
363
+ is_best = val_metrics['mAP'] > best_map
364
+ if is_best:
365
+ best_map = val_metrics['mAP']
366
+
367
+ if (epoch + 1) % config['training']['save_freq'] == 0 or is_best:
368
+ checkpoint = {
369
+ 'epoch': epoch,
370
+ 'model_state_dict': model.state_dict(),
371
+ 'optimizer_state_dict': optimizer.state_dict(),
372
+ 'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
373
+ 'train_loss': train_loss,
374
+ 'val_metrics': val_metrics,
375
+ 'best_map': best_map,
376
+ 'config': config
377
+ }
378
+
379
+ # Save latest checkpoint
380
+ torch.save(checkpoint, save_dir / 'checkpoint_latest.pth')
381
+
382
+ # Save best checkpoint
383
+ if is_best:
384
+ torch.save(checkpoint, save_dir / 'checkpoint_best.pth')
385
+ print(f" ✓ Saved best model (mAP: {best_map:.4f})")
386
+
387
+ # Save periodic checkpoint
388
+ if (epoch + 1) % config['training']['save_freq'] == 0:
389
+ torch.save(checkpoint, save_dir / f'checkpoint_epoch_{epoch}.pth')
390
+
391
+ print("\n" + "="*80)
392
+ print(f"Training completed! Best mAP: {best_map:.4f}")
393
+ print(f"Checkpoints saved to: {save_dir}")
394
+
395
+ writer.close()
396
+
397
+ # Final test evaluation
398
+ print("\nEvaluating on test set...")
399
+ test_metrics = validate(model, test_loader, criterion, device, config)
400
+ print(f"\nTest Results:")
401
+ print(f" mAP: {test_metrics['mAP']:.4f}")
402
+ print(f" F1 (micro): {test_metrics['F1_micro']:.4f}")
403
+ print(f" F1 (macro): {test_metrics['F1_macro']:.4f}")
404
+
405
+ # Save test results
406
+ with open(save_dir / 'test_results.json', 'w') as f:
407
+ json.dump({k: float(v) if not isinstance(v, list) else v
408
+ for k, v in test_metrics.items()}, f, indent=2)
409
+
410
+
411
+ # ============================================================================
412
+ # Main
413
+ # ============================================================================
414
+
415
+ def main():
416
+ parser = argparse.ArgumentParser(description='Train TransMIL + Query2Label Hybrid Model')
417
+ parser.add_argument('--config', type=str, default='hybrid_model/config.yaml',
418
+ help='Path to config file')
419
+ parser.add_argument('--resume', type=str, default=None,
420
+ help='Path to checkpoint to resume from')
421
+ args = parser.parse_args()
422
+
423
+ # Load config
424
+ with open(args.config, 'r') as f:
425
+ config = yaml.safe_load(f)
426
+
427
+ print("="*80)
428
+ print("TransMIL + Query2Label Hybrid Model Training")
429
+ print("="*80)
430
+ print(f"\nConfig: {args.config}")
431
+ if args.resume:
432
+ print(f"Resume from: {args.resume}")
433
+
434
+ # Train
435
+ train(config, resume_from=args.resume)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()