Initial upload: HintsPrediction
Browse files- README.md +113 -0
- checkpoints/checkpoint_best.pth +3 -0
- classification_test_set_single.json +0 -0
- classification_val_set_single.json +1688 -0
- config.yaml +48 -0
- evaluate.py +236 -0
- infer_single_case.py +244 -0
- models/__init__.py +13 -0
- models/aslloss.py +115 -0
- models/transformer.py +374 -0
- models/transmil_q2l.py +589 -0
- requirements.txt +12 -0
- scripts/evaluate.sh +6 -0
- scripts/infer_single.sh +7 -0
- scripts/train.sh +5 -0
- thyroid_dataset.py +285 -0
- thyroid_multilabel_annotations.csv +0 -0
- train_hybrid.py +439 -0
README.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 超声提示多标签分类模型
|
| 2 |
+
|
| 3 |
+
基于 **TransMIL + Query2Label** 混合架构的甲状腺超声图像多标签分类模型。
|
| 4 |
+
|
| 5 |
+
## 模型架构
|
| 6 |
+
|
| 7 |
+
- **Backbone**: ResNet-50 (预训练)
|
| 8 |
+
- **特征聚合**: TransMIL (Nystrom Attention)
|
| 9 |
+
- **多标签分类**: Query2Label (Transformer Decoder)
|
| 10 |
+
- **损失函数**: Asymmetric Loss (处理类别不平衡)
|
| 11 |
+
|
| 12 |
+
## 17类标签
|
| 13 |
+
|
| 14 |
+
| 序号 | 标签 | 序号 | 标签 |
|
| 15 |
+
|:---:|:---|:---:|:---|
|
| 16 |
+
| 1 | TI-RADS 1级 | 10 | 囊肿 |
|
| 17 |
+
| 2 | TI-RADS 2级 | 11 | 淋巴结 |
|
| 18 |
+
| 3 | TI-RADS 3级 | 12 | 胶质潴留 |
|
| 19 |
+
| 4 | TI-RADS 4a级 | 13 | 弥漫性病变 |
|
| 20 |
+
| 5 | TI-RADS 4b级 | 14 | 结节性甲状腺肿 |
|
| 21 |
+
| 6 | TI-RADS 4c级 | 15 | 桥本氏甲状腺炎 |
|
| 22 |
+
| 7 | TI-RADS 5级 | 16 | 反应性 |
|
| 23 |
+
| 8 | 钙化 | 17 | 转移性 |
|
| 24 |
+
| 9 | 甲亢 | | |
|
| 25 |
+
|
| 26 |
+
## 目录结构
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
HintsPrediction/
|
| 30 |
+
├── README.md # 本文件
|
| 31 |
+
├── requirements.txt # 依赖列表
|
| 32 |
+
├── config.yaml # 配置文件(需修改路径)
|
| 33 |
+
├── models/ # 模型代码
|
| 34 |
+
│ ├── __init__.py
|
| 35 |
+
│ ├── transmil_q2l.py # 主模型架构
|
| 36 |
+
│ ├── transformer.py # Transformer 组件
|
| 37 |
+
│ └── aslloss.py # 损失函数
|
| 38 |
+
├── checkpoints/
|
| 39 |
+
│ └── checkpoint_best.pth # 最佳模型权重
|
| 40 |
+
├── scripts/ # 懒人脚本
|
| 41 |
+
│ ├── train.sh
|
| 42 |
+
│ ├── evaluate.sh
|
| 43 |
+
│ └── infer_single.sh
|
| 44 |
+
├── train_hybrid.py # 训练代码
|
| 45 |
+
├── evaluate.py # 评估代码
|
| 46 |
+
├── thyroid_dataset.py # 数据集加载
|
| 47 |
+
└── infer_single_case.py # 单步推理代码
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## 快速开始
|
| 51 |
+
|
| 52 |
+
### 1. 环境配置
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
# 安装依赖
|
| 56 |
+
pip install -r requirements.txt
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### 2. 单步推理
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
用法:
|
| 63 |
+
# 指定多个图像文件
|
| 64 |
+
python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
|
| 65 |
+
|
| 66 |
+
# 指定图像文件夹
|
| 67 |
+
python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
|
| 68 |
+
|
| 69 |
+
# 或使用脚本
|
| 70 |
+
用法1: bash scripts/infer_single.sh /path/to/image1.png /path/to/image2.png ...
|
| 71 |
+
用法2: bash scripts/infer_single.sh --image_dir /path/to/case_folder/
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### 3. 评估模型
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# 先修改 config.yaml 中的数据路径
|
| 78 |
+
# 然后运行评估
|
| 79 |
+
python evaluate.py
|
| 80 |
+
# 或
|
| 81 |
+
bash scripts/evaluate.sh
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### 4. 训练模型
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
# 先修改 config.yaml 中的数据路径
|
| 88 |
+
python train_hybrid.py --config config.yaml
|
| 89 |
+
# 或
|
| 90 |
+
bash scripts/train.sh
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 配置说明
|
| 94 |
+
|
| 95 |
+
使用前请修改 `config.yaml` 中的数据路径:
|
| 96 |
+
|
| 97 |
+
```yaml
|
| 98 |
+
data:
|
| 99 |
+
data_root: "/path/to/your/ReportData_ROI/"
|
| 100 |
+
annotation_csv: "/path/to/your/thyroid_multilabel_annotations.csv"
|
| 101 |
+
val_json: "/path/to/your/classification_val_set_single.json"
|
| 102 |
+
test_json: "/path/to/your/classification_test_set_single.json"
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## 性能指标
|
| 106 |
+
|
| 107 |
+
在测试集上的性能(请参考 `checkpoints/evaluation_report.csv`)
|
| 108 |
+
|
| 109 |
+
## 注意事项
|
| 110 |
+
|
| 111 |
+
1. 推理时需要 GPU(推荐),CPU 也可运行但较慢
|
| 112 |
+
2. 单病例可输入多张图像,模型会自动聚合特征
|
| 113 |
+
3. 默认阈值为 0.5,可根据需要调整
|
checkpoints/checkpoint_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fcbcfb29005eaf6d02bcb8c94be860d885f6fecf9cdcea40be38e7d96050f44b
|
| 3 |
+
size 416241673
|
classification_test_set_single.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
classification_val_set_single.json
ADDED
|
@@ -0,0 +1,1688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"type": "single",
|
| 4 |
+
"id": "Batch10_20250506_P323",
|
| 5 |
+
"rel_path": "Batch10/20250506_P323",
|
| 6 |
+
"score": 1.0
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"type": "single",
|
| 10 |
+
"id": "Batch10_20250506_P424",
|
| 11 |
+
"rel_path": "Batch10/20250506_P424",
|
| 12 |
+
"score": 1.0
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"type": "single",
|
| 16 |
+
"id": "Batch10_20250504_P199",
|
| 17 |
+
"rel_path": "Batch10/20250504_P199",
|
| 18 |
+
"score": 1.0
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"type": "single",
|
| 22 |
+
"id": "Batch10_20250508_P829",
|
| 23 |
+
"rel_path": "Batch10/20250508_P829",
|
| 24 |
+
"score": 1.0
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"type": "single",
|
| 28 |
+
"id": "Batch10_20250501_P34",
|
| 29 |
+
"rel_path": "Batch10/20250501_P34",
|
| 30 |
+
"score": 1.0
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"type": "single",
|
| 34 |
+
"id": "Batch10_20250512_P1380",
|
| 35 |
+
"rel_path": "Batch10/20250512_P1380",
|
| 36 |
+
"score": 1.0
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"type": "single",
|
| 40 |
+
"id": "Batch10_20250506_P405",
|
| 41 |
+
"rel_path": "Batch10/20250506_P405",
|
| 42 |
+
"score": 1.0
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"type": "single",
|
| 46 |
+
"id": "Batch10_20250502_P48",
|
| 47 |
+
"rel_path": "Batch10/20250502_P48",
|
| 48 |
+
"score": 1.0
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"type": "single",
|
| 52 |
+
"id": "Batch10_20250510_P1172",
|
| 53 |
+
"rel_path": "Batch10/20250510_P1172",
|
| 54 |
+
"score": 1.0
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"type": "single",
|
| 58 |
+
"id": "Batch10_20250504_P119",
|
| 59 |
+
"rel_path": "Batch10/20250504_P119",
|
| 60 |
+
"score": 1.0
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"type": "single",
|
| 64 |
+
"id": "Batch10_20250504_P140",
|
| 65 |
+
"rel_path": "Batch10/20250504_P140",
|
| 66 |
+
"score": 1.0
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"type": "single",
|
| 70 |
+
"id": "Batch10_20250509_P1013",
|
| 71 |
+
"rel_path": "Batch10/20250509_P1013",
|
| 72 |
+
"score": 1.0
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"type": "single",
|
| 76 |
+
"id": "Batch10_20250508_P727",
|
| 77 |
+
"rel_path": "Batch10/20250508_P727",
|
| 78 |
+
"score": 1.0
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"type": "single",
|
| 82 |
+
"id": "Batch10_20250507_P656",
|
| 83 |
+
"rel_path": "Batch10/20250507_P656",
|
| 84 |
+
"score": 1.0
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"type": "single",
|
| 88 |
+
"id": "Batch10_20250508_P720",
|
| 89 |
+
"rel_path": "Batch10/20250508_P720",
|
| 90 |
+
"score": 1.0
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"type": "single",
|
| 94 |
+
"id": "Batch10_20250507_P647",
|
| 95 |
+
"rel_path": "Batch10/20250507_P647",
|
| 96 |
+
"score": 1.0
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"type": "single",
|
| 100 |
+
"id": "Batch10_20250510_P1122",
|
| 101 |
+
"rel_path": "Batch10/20250510_P1122",
|
| 102 |
+
"score": 1.0
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"type": "single",
|
| 106 |
+
"id": "Batch10_20250510_P1182",
|
| 107 |
+
"rel_path": "Batch10/20250510_P1182",
|
| 108 |
+
"score": 1.0
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"type": "single",
|
| 112 |
+
"id": "Batch10_20250501_P40",
|
| 113 |
+
"rel_path": "Batch10/20250501_P40",
|
| 114 |
+
"score": 1.0
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"type": "single",
|
| 118 |
+
"id": "Batch10_20250505_P254",
|
| 119 |
+
"rel_path": "Batch10/20250505_P254",
|
| 120 |
+
"score": 1.0
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"type": "single",
|
| 124 |
+
"id": "Batch10_20250508_P810",
|
| 125 |
+
"rel_path": "Batch10/20250508_P810",
|
| 126 |
+
"score": 1.0
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"type": "single",
|
| 130 |
+
"id": "Batch10_20250507_P636",
|
| 131 |
+
"rel_path": "Batch10/20250507_P636",
|
| 132 |
+
"score": 1.0
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"type": "single",
|
| 136 |
+
"id": "Batch10_20250507_P557",
|
| 137 |
+
"rel_path": "Batch10/20250507_P557",
|
| 138 |
+
"score": 1.0
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"type": "single",
|
| 142 |
+
"id": "Batch10_20250512_P1272",
|
| 143 |
+
"rel_path": "Batch10/20250512_P1272",
|
| 144 |
+
"score": 1.0
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"type": "single",
|
| 148 |
+
"id": "Batch10_20250512_P1306",
|
| 149 |
+
"rel_path": "Batch10/20250512_P1306",
|
| 150 |
+
"score": 1.0
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"type": "single",
|
| 154 |
+
"id": "Batch10_20250506_P311",
|
| 155 |
+
"rel_path": "Batch10/20250506_P311",
|
| 156 |
+
"score": 1.0
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"type": "single",
|
| 160 |
+
"id": "Batch10_20250505_P258",
|
| 161 |
+
"rel_path": "Batch10/20250505_P258",
|
| 162 |
+
"score": 1.0
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"type": "single",
|
| 166 |
+
"id": "Batch10_20250501_P25",
|
| 167 |
+
"rel_path": "Batch10/20250501_P25",
|
| 168 |
+
"score": 1.0
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"type": "single",
|
| 172 |
+
"id": "Batch10_20250512_P1273",
|
| 173 |
+
"rel_path": "Batch10/20250512_P1273",
|
| 174 |
+
"score": 1.0
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"type": "single",
|
| 178 |
+
"id": "Batch10_20250512_P1277",
|
| 179 |
+
"rel_path": "Batch10/20250512_P1277",
|
| 180 |
+
"score": 1.0
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"type": "single",
|
| 184 |
+
"id": "Batch10_20250506_P350",
|
| 185 |
+
"rel_path": "Batch10/20250506_P350",
|
| 186 |
+
"score": 1.0
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"type": "single",
|
| 190 |
+
"id": "Batch10_20250512_P1289",
|
| 191 |
+
"rel_path": "Batch10/20250512_P1289",
|
| 192 |
+
"score": 1.0
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"type": "single",
|
| 196 |
+
"id": "Batch10_20250512_P1386",
|
| 197 |
+
"rel_path": "Batch10/20250512_P1386",
|
| 198 |
+
"score": 1.0
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"type": "single",
|
| 202 |
+
"id": "Batch10_20250509_P933",
|
| 203 |
+
"rel_path": "Batch10/20250509_P933",
|
| 204 |
+
"score": 1.0
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"type": "single",
|
| 208 |
+
"id": "Batch10_20250508_P842",
|
| 209 |
+
"rel_path": "Batch10/20250508_P842",
|
| 210 |
+
"score": 1.0
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"type": "single",
|
| 214 |
+
"id": "Batch10_20250509_P974",
|
| 215 |
+
"rel_path": "Batch10/20250509_P974",
|
| 216 |
+
"score": 1.0
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"type": "single",
|
| 220 |
+
"id": "Batch10_20250511_P1235",
|
| 221 |
+
"rel_path": "Batch10/20250511_P1235",
|
| 222 |
+
"score": 1.0
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"type": "single",
|
| 226 |
+
"id": "Batch10_20250501_P28",
|
| 227 |
+
"rel_path": "Batch10/20250501_P28",
|
| 228 |
+
"score": 1.0
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"type": "single",
|
| 232 |
+
"id": "Batch10_20250511_P1230",
|
| 233 |
+
"rel_path": "Batch10/20250511_P1230",
|
| 234 |
+
"score": 1.0
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"type": "single",
|
| 238 |
+
"id": "Batch10_20250505_P279",
|
| 239 |
+
"rel_path": "Batch10/20250505_P279",
|
| 240 |
+
"score": 1.0
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"type": "single",
|
| 244 |
+
"id": "Batch10_20250511_P1228",
|
| 245 |
+
"rel_path": "Batch10/20250511_P1228",
|
| 246 |
+
"score": 1.0
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"type": "single",
|
| 250 |
+
"id": "Batch10_20250506_P351",
|
| 251 |
+
"rel_path": "Batch10/20250506_P351",
|
| 252 |
+
"score": 1.0
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"type": "single",
|
| 256 |
+
"id": "Batch10_20250504_P195",
|
| 257 |
+
"rel_path": "Batch10/20250504_P195",
|
| 258 |
+
"score": 1.0
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"type": "single",
|
| 262 |
+
"id": "Batch10_20250508_P825",
|
| 263 |
+
"rel_path": "Batch10/20250508_P825",
|
| 264 |
+
"score": 1.0
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"type": "single",
|
| 268 |
+
"id": "Batch10_20250507_P596",
|
| 269 |
+
"rel_path": "Batch10/20250507_P596",
|
| 270 |
+
"score": 1.0
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"type": "single",
|
| 274 |
+
"id": "Batch10_20250507_P539",
|
| 275 |
+
"rel_path": "Batch10/20250507_P539",
|
| 276 |
+
"score": 1.0
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"type": "single",
|
| 280 |
+
"id": "Batch10_20250512_P1333",
|
| 281 |
+
"rel_path": "Batch10/20250512_P1333",
|
| 282 |
+
"score": 1.0
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"type": "single",
|
| 286 |
+
"id": "Batch10_20250508_P706",
|
| 287 |
+
"rel_path": "Batch10/20250508_P706",
|
| 288 |
+
"score": 1.0
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"type": "single",
|
| 292 |
+
"id": "Batch10_20250509_P902",
|
| 293 |
+
"rel_path": "Batch10/20250509_P902",
|
| 294 |
+
"score": 1.0
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"type": "single",
|
| 298 |
+
"id": "Batch10_20250512_P1303",
|
| 299 |
+
"rel_path": "Batch10/20250512_P1303",
|
| 300 |
+
"score": 1.0
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"type": "single",
|
| 304 |
+
"id": "Batch10_20250509_P913",
|
| 305 |
+
"rel_path": "Batch10/20250509_P913",
|
| 306 |
+
"score": 1.0
|
| 307 |
+
},
|
| 308 |
+
{
|
| 309 |
+
"type": "single",
|
| 310 |
+
"id": "Batch10_20250506_P300",
|
| 311 |
+
"rel_path": "Batch10/20250506_P300",
|
| 312 |
+
"score": 1.0
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"type": "single",
|
| 316 |
+
"id": "Batch10_20250502_P66",
|
| 317 |
+
"rel_path": "Batch10/20250502_P66",
|
| 318 |
+
"score": 1.0
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"type": "single",
|
| 322 |
+
"id": "Batch10_20250508_P726",
|
| 323 |
+
"rel_path": "Batch10/20250508_P726",
|
| 324 |
+
"score": 1.0
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"type": "single",
|
| 328 |
+
"id": "Batch10_20250505_P222",
|
| 329 |
+
"rel_path": "Batch10/20250505_P222",
|
| 330 |
+
"score": 1.0
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"type": "single",
|
| 334 |
+
"id": "Batch10_20250510_P1165",
|
| 335 |
+
"rel_path": "Batch10/20250510_P1165",
|
| 336 |
+
"score": 1.0
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"type": "single",
|
| 340 |
+
"id": "Batch10_20250506_P348",
|
| 341 |
+
"rel_path": "Batch10/20250506_P348",
|
| 342 |
+
"score": 1.0
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"type": "single",
|
| 346 |
+
"id": "Batch10_20250507_P678",
|
| 347 |
+
"rel_path": "Batch10/20250507_P678",
|
| 348 |
+
"score": 1.0
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"type": "single",
|
| 352 |
+
"id": "Batch10_20250504_P169",
|
| 353 |
+
"rel_path": "Batch10/20250504_P169",
|
| 354 |
+
"score": 1.0
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"type": "single",
|
| 358 |
+
"id": "Batch10_20250506_P473",
|
| 359 |
+
"rel_path": "Batch10/20250506_P473",
|
| 360 |
+
"score": 1.0
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"type": "single",
|
| 364 |
+
"id": "Batch10_20250506_P427",
|
| 365 |
+
"rel_path": "Batch10/20250506_P427",
|
| 366 |
+
"score": 1.0
|
| 367 |
+
},
|
| 368 |
+
{
|
| 369 |
+
"type": "single",
|
| 370 |
+
"id": "Batch10_20250508_P711",
|
| 371 |
+
"rel_path": "Batch10/20250508_P711",
|
| 372 |
+
"score": 1.0
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"type": "single",
|
| 376 |
+
"id": "Batch10_20250508_P738",
|
| 377 |
+
"rel_path": "Batch10/20250508_P738",
|
| 378 |
+
"score": 1.0
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"type": "single",
|
| 382 |
+
"id": "Batch10_20250507_P655",
|
| 383 |
+
"rel_path": "Batch10/20250507_P655",
|
| 384 |
+
"score": 1.0
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"type": "single",
|
| 388 |
+
"id": "Batch10_20250508_P765",
|
| 389 |
+
"rel_path": "Batch10/20250508_P765",
|
| 390 |
+
"score": 1.0
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"type": "single",
|
| 394 |
+
"id": "Batch10_20250507_P658",
|
| 395 |
+
"rel_path": "Batch10/20250507_P658",
|
| 396 |
+
"score": 1.0
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"type": "single",
|
| 400 |
+
"id": "Batch10_20250510_P1109",
|
| 401 |
+
"rel_path": "Batch10/20250510_P1109",
|
| 402 |
+
"score": 1.0
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"type": "single",
|
| 406 |
+
"id": "Batch10_20250507_P629",
|
| 407 |
+
"rel_path": "Batch10/20250507_P629",
|
| 408 |
+
"score": 1.0
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"type": "single",
|
| 412 |
+
"id": "Batch10_20250512_P1377",
|
| 413 |
+
"rel_path": "Batch10/20250512_P1377",
|
| 414 |
+
"score": 1.0
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"type": "single",
|
| 418 |
+
"id": "Batch10_20250512_P1401",
|
| 419 |
+
"rel_path": "Batch10/20250512_P1401",
|
| 420 |
+
"score": 1.0
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"type": "single",
|
| 424 |
+
"id": "Batch10_20250502_P62",
|
| 425 |
+
"rel_path": "Batch10/20250502_P62",
|
| 426 |
+
"score": 1.0
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"type": "single",
|
| 430 |
+
"id": "Batch10_20250507_P659",
|
| 431 |
+
"rel_path": "Batch10/20250507_P659",
|
| 432 |
+
"score": 1.0
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"type": "single",
|
| 436 |
+
"id": "Batch10_20250505_P275",
|
| 437 |
+
"rel_path": "Batch10/20250505_P275",
|
| 438 |
+
"score": 1.0
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"type": "single",
|
| 442 |
+
"id": "Batch10_20250501_P21",
|
| 443 |
+
"rel_path": "Batch10/20250501_P21",
|
| 444 |
+
"score": 1.0
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"type": "single",
|
| 448 |
+
"id": "Batch10_20250512_P1412",
|
| 449 |
+
"rel_path": "Batch10/20250512_P1412",
|
| 450 |
+
"score": 1.0
|
| 451 |
+
},
|
| 452 |
+
{
|
| 453 |
+
"type": "single",
|
| 454 |
+
"id": "Batch10_20250507_P514",
|
| 455 |
+
"rel_path": "Batch10/20250507_P514",
|
| 456 |
+
"score": 1.0
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"type": "single",
|
| 460 |
+
"id": "Batch10_20250504_P179",
|
| 461 |
+
"rel_path": "Batch10/20250504_P179",
|
| 462 |
+
"score": 1.0
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"type": "single",
|
| 466 |
+
"id": "Batch10_20250504_P188",
|
| 467 |
+
"rel_path": "Batch10/20250504_P188",
|
| 468 |
+
"score": 1.0
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"type": "single",
|
| 472 |
+
"id": "Batch10_20250512_P1249",
|
| 473 |
+
"rel_path": "Batch10/20250512_P1249",
|
| 474 |
+
"score": 1.0
|
| 475 |
+
},
|
| 476 |
+
{
|
| 477 |
+
"type": "single",
|
| 478 |
+
"id": "Batch10_20250509_P1069",
|
| 479 |
+
"rel_path": "Batch10/20250509_P1069",
|
| 480 |
+
"score": 1.0
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"type": "single",
|
| 484 |
+
"id": "Batch10_20250512_P1391",
|
| 485 |
+
"rel_path": "Batch10/20250512_P1391",
|
| 486 |
+
"score": 1.0
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"type": "single",
|
| 490 |
+
"id": "Batch10_20250506_P480",
|
| 491 |
+
"rel_path": "Batch10/20250506_P480",
|
| 492 |
+
"score": 1.0
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
"type": "single",
|
| 496 |
+
"id": "Batch10_20250504_P167",
|
| 497 |
+
"rel_path": "Batch10/20250504_P167",
|
| 498 |
+
"score": 1.0
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"type": "single",
|
| 502 |
+
"id": "Batch10_20250503_P74",
|
| 503 |
+
"rel_path": "Batch10/20250503_P74",
|
| 504 |
+
"score": 1.0
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"type": "single",
|
| 508 |
+
"id": "Batch10_20250505_P223",
|
| 509 |
+
"rel_path": "Batch10/20250505_P223",
|
| 510 |
+
"score": 1.0
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"type": "single",
|
| 514 |
+
"id": "Batch10_20250508_P833",
|
| 515 |
+
"rel_path": "Batch10/20250508_P833",
|
| 516 |
+
"score": 1.0
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"type": "single",
|
| 520 |
+
"id": "Batch10_20250512_P1336",
|
| 521 |
+
"rel_path": "Batch10/20250512_P1336",
|
| 522 |
+
"score": 1.0
|
| 523 |
+
},
|
| 524 |
+
{
|
| 525 |
+
"type": "single",
|
| 526 |
+
"id": "Batch10_20250508_P775",
|
| 527 |
+
"rel_path": "Batch10/20250508_P775",
|
| 528 |
+
"score": 1.0
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"type": "single",
|
| 532 |
+
"id": "Batch10_20250507_P553",
|
| 533 |
+
"rel_path": "Batch10/20250507_P553",
|
| 534 |
+
"score": 1.0
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
"type": "single",
|
| 538 |
+
"id": "Batch10_20250502_P59",
|
| 539 |
+
"rel_path": "Batch10/20250502_P59",
|
| 540 |
+
"score": 1.0
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"type": "single",
|
| 544 |
+
"id": "Batch10_20250512_P1295",
|
| 545 |
+
"rel_path": "Batch10/20250512_P1295",
|
| 546 |
+
"score": 1.0
|
| 547 |
+
},
|
| 548 |
+
{
|
| 549 |
+
"type": "single",
|
| 550 |
+
"id": "Batch10_20250505_P220",
|
| 551 |
+
"rel_path": "Batch10/20250505_P220",
|
| 552 |
+
"score": 1.0
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"type": "single",
|
| 556 |
+
"id": "Batch10_20250503_P77",
|
| 557 |
+
"rel_path": "Batch10/20250503_P77",
|
| 558 |
+
"score": 1.0
|
| 559 |
+
},
|
| 560 |
+
{
|
| 561 |
+
"type": "single",
|
| 562 |
+
"id": "Batch10_20250511_P1231",
|
| 563 |
+
"rel_path": "Batch10/20250511_P1231",
|
| 564 |
+
"score": 1.0
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
"type": "single",
|
| 568 |
+
"id": "Batch10_20250506_P389",
|
| 569 |
+
"rel_path": "Batch10/20250506_P389",
|
| 570 |
+
"score": 1.0
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"type": "single",
|
| 574 |
+
"id": "Batch10_20250507_P573",
|
| 575 |
+
"rel_path": "Batch10/20250507_P573",
|
| 576 |
+
"score": 1.0
|
| 577 |
+
},
|
| 578 |
+
{
|
| 579 |
+
"type": "single",
|
| 580 |
+
"id": "Batch10_20250506_P385",
|
| 581 |
+
"rel_path": "Batch10/20250506_P385",
|
| 582 |
+
"score": 1.0
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"type": "single",
|
| 586 |
+
"id": "Batch10_20250508_P782",
|
| 587 |
+
"rel_path": "Batch10/20250508_P782",
|
| 588 |
+
"score": 1.0
|
| 589 |
+
},
|
| 590 |
+
{
|
| 591 |
+
"type": "single",
|
| 592 |
+
"id": "Batch10_20250508_P851",
|
| 593 |
+
"rel_path": "Batch10/20250508_P851",
|
| 594 |
+
"score": 1.0
|
| 595 |
+
},
|
| 596 |
+
{
|
| 597 |
+
"type": "single",
|
| 598 |
+
"id": "Batch10_20250506_P481",
|
| 599 |
+
"rel_path": "Batch10/20250506_P481",
|
| 600 |
+
"score": 1.0
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"type": "single",
|
| 604 |
+
"id": "Batch10_20250511_P1243",
|
| 605 |
+
"rel_path": "Batch10/20250511_P1243",
|
| 606 |
+
"score": 1.0
|
| 607 |
+
},
|
| 608 |
+
{
|
| 609 |
+
"type": "single",
|
| 610 |
+
"id": "Batch10_20250509_P956",
|
| 611 |
+
"rel_path": "Batch10/20250509_P956",
|
| 612 |
+
"score": 1.0
|
| 613 |
+
},
|
| 614 |
+
{
|
| 615 |
+
"type": "single",
|
| 616 |
+
"id": "Batch10_20250506_P341",
|
| 617 |
+
"rel_path": "Batch10/20250506_P341",
|
| 618 |
+
"score": 1.0
|
| 619 |
+
},
|
| 620 |
+
{
|
| 621 |
+
"type": "single",
|
| 622 |
+
"id": "Batch10_20250508_P752",
|
| 623 |
+
"rel_path": "Batch10/20250508_P752",
|
| 624 |
+
"score": 1.0
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"type": "single",
|
| 628 |
+
"id": "Batch10_20250509_P861",
|
| 629 |
+
"rel_path": "Batch10/20250509_P861",
|
| 630 |
+
"score": 1.0
|
| 631 |
+
},
|
| 632 |
+
{
|
| 633 |
+
"type": "single",
|
| 634 |
+
"id": "Batch10_20250512_P1320",
|
| 635 |
+
"rel_path": "Batch10/20250512_P1320",
|
| 636 |
+
"score": 1.0
|
| 637 |
+
},
|
| 638 |
+
{
|
| 639 |
+
"type": "single",
|
| 640 |
+
"id": "Batch10_20250501_P12",
|
| 641 |
+
"rel_path": "Batch10/20250501_P12",
|
| 642 |
+
"score": 1.0
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
"type": "single",
|
| 646 |
+
"id": "Batch10_20250512_P1312",
|
| 647 |
+
"rel_path": "Batch10/20250512_P1312",
|
| 648 |
+
"score": 1.0
|
| 649 |
+
},
|
| 650 |
+
{
|
| 651 |
+
"type": "single",
|
| 652 |
+
"id": "Batch10_20250507_P588",
|
| 653 |
+
"rel_path": "Batch10/20250507_P588",
|
| 654 |
+
"score": 1.0
|
| 655 |
+
},
|
| 656 |
+
{
|
| 657 |
+
"type": "single",
|
| 658 |
+
"id": "Batch10_20250509_P904",
|
| 659 |
+
"rel_path": "Batch10/20250509_P904",
|
| 660 |
+
"score": 1.0
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"type": "single",
|
| 664 |
+
"id": "Batch10_20250512_P1370",
|
| 665 |
+
"rel_path": "Batch10/20250512_P1370",
|
| 666 |
+
"score": 1.0
|
| 667 |
+
},
|
| 668 |
+
{
|
| 669 |
+
"type": "single",
|
| 670 |
+
"id": "Batch10_20250507_P565",
|
| 671 |
+
"rel_path": "Batch10/20250507_P565",
|
| 672 |
+
"score": 1.0
|
| 673 |
+
},
|
| 674 |
+
{
|
| 675 |
+
"type": "single",
|
| 676 |
+
"id": "Batch10_20250506_P412",
|
| 677 |
+
"rel_path": "Batch10/20250506_P412",
|
| 678 |
+
"score": 1.0
|
| 679 |
+
},
|
| 680 |
+
{
|
| 681 |
+
"type": "single",
|
| 682 |
+
"id": "Batch10_20250505_P267",
|
| 683 |
+
"rel_path": "Batch10/20250505_P267",
|
| 684 |
+
"score": 1.0
|
| 685 |
+
},
|
| 686 |
+
{
|
| 687 |
+
"type": "single",
|
| 688 |
+
"id": "Batch10_20250508_P849",
|
| 689 |
+
"rel_path": "Batch10/20250508_P849",
|
| 690 |
+
"score": 1.0
|
| 691 |
+
},
|
| 692 |
+
{
|
| 693 |
+
"type": "single",
|
| 694 |
+
"id": "Batch10_20250509_P921",
|
| 695 |
+
"rel_path": "Batch10/20250509_P921",
|
| 696 |
+
"score": 1.0
|
| 697 |
+
},
|
| 698 |
+
{
|
| 699 |
+
"type": "single",
|
| 700 |
+
"id": "Batch10_20250505_P257",
|
| 701 |
+
"rel_path": "Batch10/20250505_P257",
|
| 702 |
+
"score": 1.0
|
| 703 |
+
},
|
| 704 |
+
{
|
| 705 |
+
"type": "single",
|
| 706 |
+
"id": "Batch10_20250511_P1242",
|
| 707 |
+
"rel_path": "Batch10/20250511_P1242",
|
| 708 |
+
"score": 1.0
|
| 709 |
+
},
|
| 710 |
+
{
|
| 711 |
+
"type": "single",
|
| 712 |
+
"id": "Batch10_20250506_P453",
|
| 713 |
+
"rel_path": "Batch10/20250506_P453",
|
| 714 |
+
"score": 1.0
|
| 715 |
+
},
|
| 716 |
+
{
|
| 717 |
+
"type": "single",
|
| 718 |
+
"id": "Batch10_20250509_P892",
|
| 719 |
+
"rel_path": "Batch10/20250509_P892",
|
| 720 |
+
"score": 1.0
|
| 721 |
+
},
|
| 722 |
+
{
|
| 723 |
+
"type": "single",
|
| 724 |
+
"id": "Batch10_20250504_P116",
|
| 725 |
+
"rel_path": "Batch10/20250504_P116",
|
| 726 |
+
"score": 1.0
|
| 727 |
+
},
|
| 728 |
+
{
|
| 729 |
+
"type": "single",
|
| 730 |
+
"id": "Batch10_20250507_P531",
|
| 731 |
+
"rel_path": "Batch10/20250507_P531",
|
| 732 |
+
"score": 1.0
|
| 733 |
+
},
|
| 734 |
+
{
|
| 735 |
+
"type": "single",
|
| 736 |
+
"id": "Batch10_20250503_P98",
|
| 737 |
+
"rel_path": "Batch10/20250503_P98",
|
| 738 |
+
"score": 1.0
|
| 739 |
+
},
|
| 740 |
+
{
|
| 741 |
+
"type": "single",
|
| 742 |
+
"id": "Batch10_20250509_P926",
|
| 743 |
+
"rel_path": "Batch10/20250509_P926",
|
| 744 |
+
"score": 1.0
|
| 745 |
+
},
|
| 746 |
+
{
|
| 747 |
+
"type": "single",
|
| 748 |
+
"id": "Batch10_20250510_P1119",
|
| 749 |
+
"rel_path": "Batch10/20250510_P1119",
|
| 750 |
+
"score": 1.0
|
| 751 |
+
},
|
| 752 |
+
{
|
| 753 |
+
"type": "single",
|
| 754 |
+
"id": "Batch10_20250505_P240",
|
| 755 |
+
"rel_path": "Batch10/20250505_P240",
|
| 756 |
+
"score": 1.0
|
| 757 |
+
},
|
| 758 |
+
{
|
| 759 |
+
"type": "single",
|
| 760 |
+
"id": "Batch10_20250504_P129",
|
| 761 |
+
"rel_path": "Batch10/20250504_P129",
|
| 762 |
+
"score": 1.0
|
| 763 |
+
},
|
| 764 |
+
{
|
| 765 |
+
"type": "single",
|
| 766 |
+
"id": "Batch10_20250508_P710",
|
| 767 |
+
"rel_path": "Batch10/20250508_P710",
|
| 768 |
+
"score": 1.0
|
| 769 |
+
},
|
| 770 |
+
{
|
| 771 |
+
"type": "single",
|
| 772 |
+
"id": "Batch10_20250511_P1220",
|
| 773 |
+
"rel_path": "Batch10/20250511_P1220",
|
| 774 |
+
"score": 1.0
|
| 775 |
+
},
|
| 776 |
+
{
|
| 777 |
+
"type": "single",
|
| 778 |
+
"id": "Batch10_20250503_P103",
|
| 779 |
+
"rel_path": "Batch10/20250503_P103",
|
| 780 |
+
"score": 1.0
|
| 781 |
+
},
|
| 782 |
+
{
|
| 783 |
+
"type": "single",
|
| 784 |
+
"id": "Batch10_20250509_P878",
|
| 785 |
+
"rel_path": "Batch10/20250509_P878",
|
| 786 |
+
"score": 1.0
|
| 787 |
+
},
|
| 788 |
+
{
|
| 789 |
+
"type": "single",
|
| 790 |
+
"id": "Batch10_20250504_P172",
|
| 791 |
+
"rel_path": "Batch10/20250504_P172",
|
| 792 |
+
"score": 1.0
|
| 793 |
+
},
|
| 794 |
+
{
|
| 795 |
+
"type": "single",
|
| 796 |
+
"id": "Batch10_20250512_P1318",
|
| 797 |
+
"rel_path": "Batch10/20250512_P1318",
|
| 798 |
+
"score": 1.0
|
| 799 |
+
},
|
| 800 |
+
{
|
| 801 |
+
"type": "single",
|
| 802 |
+
"id": "Batch10_20250507_P530",
|
| 803 |
+
"rel_path": "Batch10/20250507_P530",
|
| 804 |
+
"score": 1.0
|
| 805 |
+
},
|
| 806 |
+
{
|
| 807 |
+
"type": "single",
|
| 808 |
+
"id": "Batch10_20250511_P1247",
|
| 809 |
+
"rel_path": "Batch10/20250511_P1247",
|
| 810 |
+
"score": 1.0
|
| 811 |
+
},
|
| 812 |
+
{
|
| 813 |
+
"type": "single",
|
| 814 |
+
"id": "Batch10_20250507_P677",
|
| 815 |
+
"rel_path": "Batch10/20250507_P677",
|
| 816 |
+
"score": 1.0
|
| 817 |
+
},
|
| 818 |
+
{
|
| 819 |
+
"type": "single",
|
| 820 |
+
"id": "Batch10_20250508_P694",
|
| 821 |
+
"rel_path": "Batch10/20250508_P694",
|
| 822 |
+
"score": 1.0
|
| 823 |
+
},
|
| 824 |
+
{
|
| 825 |
+
"type": "single",
|
| 826 |
+
"id": "Batch10_20250505_P246",
|
| 827 |
+
"rel_path": "Batch10/20250505_P246",
|
| 828 |
+
"score": 1.0
|
| 829 |
+
},
|
| 830 |
+
{
|
| 831 |
+
"type": "single",
|
| 832 |
+
"id": "Batch10_20250512_P1339",
|
| 833 |
+
"rel_path": "Batch10/20250512_P1339",
|
| 834 |
+
"score": 1.0
|
| 835 |
+
},
|
| 836 |
+
{
|
| 837 |
+
"type": "single",
|
| 838 |
+
"id": "Batch10_20250507_P666",
|
| 839 |
+
"rel_path": "Batch10/20250507_P666",
|
| 840 |
+
"score": 1.0
|
| 841 |
+
},
|
| 842 |
+
{
|
| 843 |
+
"type": "single",
|
| 844 |
+
"id": "Batch10_20250506_P433",
|
| 845 |
+
"rel_path": "Batch10/20250506_P433",
|
| 846 |
+
"score": 1.0
|
| 847 |
+
},
|
| 848 |
+
{
|
| 849 |
+
"type": "single",
|
| 850 |
+
"id": "Batch10_20250506_P467",
|
| 851 |
+
"rel_path": "Batch10/20250506_P467",
|
| 852 |
+
"score": 1.0
|
| 853 |
+
},
|
| 854 |
+
{
|
| 855 |
+
"type": "single",
|
| 856 |
+
"id": "Batch10_20250506_P451",
|
| 857 |
+
"rel_path": "Batch10/20250506_P451",
|
| 858 |
+
"score": 1.0
|
| 859 |
+
},
|
| 860 |
+
{
|
| 861 |
+
"type": "single",
|
| 862 |
+
"id": "Batch10_20250507_P552",
|
| 863 |
+
"rel_path": "Batch10/20250507_P552",
|
| 864 |
+
"score": 1.0
|
| 865 |
+
},
|
| 866 |
+
{
|
| 867 |
+
"type": "single",
|
| 868 |
+
"id": "Batch10_20250506_P431",
|
| 869 |
+
"rel_path": "Batch10/20250506_P431",
|
| 870 |
+
"score": 1.0
|
| 871 |
+
},
|
| 872 |
+
{
|
| 873 |
+
"type": "single",
|
| 874 |
+
"id": "Batch10_20250506_P291",
|
| 875 |
+
"rel_path": "Batch10/20250506_P291",
|
| 876 |
+
"score": 1.0
|
| 877 |
+
},
|
| 878 |
+
{
|
| 879 |
+
"type": "single",
|
| 880 |
+
"id": "Batch10_20250509_P987",
|
| 881 |
+
"rel_path": "Batch10/20250509_P987",
|
| 882 |
+
"score": 1.0
|
| 883 |
+
},
|
| 884 |
+
{
|
| 885 |
+
"type": "single",
|
| 886 |
+
"id": "Batch10_20250510_P1198",
|
| 887 |
+
"rel_path": "Batch10/20250510_P1198",
|
| 888 |
+
"score": 1.0
|
| 889 |
+
},
|
| 890 |
+
{
|
| 891 |
+
"type": "single",
|
| 892 |
+
"id": "Batch10_20250509_P1040",
|
| 893 |
+
"rel_path": "Batch10/20250509_P1040",
|
| 894 |
+
"score": 1.0
|
| 895 |
+
},
|
| 896 |
+
{
|
| 897 |
+
"type": "single",
|
| 898 |
+
"id": "Batch10_20250506_P483",
|
| 899 |
+
"rel_path": "Batch10/20250506_P483",
|
| 900 |
+
"score": 1.0
|
| 901 |
+
},
|
| 902 |
+
{
|
| 903 |
+
"type": "single",
|
| 904 |
+
"id": "Batch10_20250504_P183",
|
| 905 |
+
"rel_path": "Batch10/20250504_P183",
|
| 906 |
+
"score": 1.0
|
| 907 |
+
},
|
| 908 |
+
{
|
| 909 |
+
"type": "single",
|
| 910 |
+
"id": "Batch10_20250507_P569",
|
| 911 |
+
"rel_path": "Batch10/20250507_P569",
|
| 912 |
+
"score": 1.0
|
| 913 |
+
},
|
| 914 |
+
{
|
| 915 |
+
"type": "single",
|
| 916 |
+
"id": "Batch10_20250506_P485",
|
| 917 |
+
"rel_path": "Batch10/20250506_P485",
|
| 918 |
+
"score": 1.0
|
| 919 |
+
},
|
| 920 |
+
{
|
| 921 |
+
"type": "single",
|
| 922 |
+
"id": "Batch10_20250506_P363",
|
| 923 |
+
"rel_path": "Batch10/20250506_P363",
|
| 924 |
+
"score": 1.0
|
| 925 |
+
},
|
| 926 |
+
{
|
| 927 |
+
"type": "single",
|
| 928 |
+
"id": "Batch10_20250509_P945",
|
| 929 |
+
"rel_path": "Batch10/20250509_P945",
|
| 930 |
+
"score": 1.0
|
| 931 |
+
},
|
| 932 |
+
{
|
| 933 |
+
"type": "single",
|
| 934 |
+
"id": "Batch10_20250510_P1161",
|
| 935 |
+
"rel_path": "Batch10/20250510_P1161",
|
| 936 |
+
"score": 1.0
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"type": "single",
|
| 940 |
+
"id": "Batch10_20250509_P976",
|
| 941 |
+
"rel_path": "Batch10/20250509_P976",
|
| 942 |
+
"score": 1.0
|
| 943 |
+
},
|
| 944 |
+
{
|
| 945 |
+
"type": "single",
|
| 946 |
+
"id": "Batch10_20250508_P809",
|
| 947 |
+
"rel_path": "Batch10/20250508_P809",
|
| 948 |
+
"score": 1.0
|
| 949 |
+
},
|
| 950 |
+
{
|
| 951 |
+
"type": "single",
|
| 952 |
+
"id": "Batch10_20250506_P489",
|
| 953 |
+
"rel_path": "Batch10/20250506_P489",
|
| 954 |
+
"score": 1.0
|
| 955 |
+
},
|
| 956 |
+
{
|
| 957 |
+
"type": "single",
|
| 958 |
+
"id": "Batch10_20250508_P687",
|
| 959 |
+
"rel_path": "Batch10/20250508_P687",
|
| 960 |
+
"score": 1.0
|
| 961 |
+
},
|
| 962 |
+
{
|
| 963 |
+
"type": "single",
|
| 964 |
+
"id": "Batch10_20250507_P603",
|
| 965 |
+
"rel_path": "Batch10/20250507_P603",
|
| 966 |
+
"score": 1.0
|
| 967 |
+
},
|
| 968 |
+
{
|
| 969 |
+
"type": "single",
|
| 970 |
+
"id": "Batch10_20250502_P69",
|
| 971 |
+
"rel_path": "Batch10/20250502_P69",
|
| 972 |
+
"score": 1.0
|
| 973 |
+
},
|
| 974 |
+
{
|
| 975 |
+
"type": "single",
|
| 976 |
+
"id": "Batch10_20250509_P1062",
|
| 977 |
+
"rel_path": "Batch10/20250509_P1062",
|
| 978 |
+
"score": 1.0
|
| 979 |
+
},
|
| 980 |
+
{
|
| 981 |
+
"type": "single",
|
| 982 |
+
"id": "Batch10_20250506_P306",
|
| 983 |
+
"rel_path": "Batch10/20250506_P306",
|
| 984 |
+
"score": 1.0
|
| 985 |
+
},
|
| 986 |
+
{
|
| 987 |
+
"type": "single",
|
| 988 |
+
"id": "Batch10_20250507_P543",
|
| 989 |
+
"rel_path": "Batch10/20250507_P543",
|
| 990 |
+
"score": 1.0
|
| 991 |
+
},
|
| 992 |
+
{
|
| 993 |
+
"type": "single",
|
| 994 |
+
"id": "Batch10_20250507_P623",
|
| 995 |
+
"rel_path": "Batch10/20250507_P623",
|
| 996 |
+
"score": 1.0
|
| 997 |
+
},
|
| 998 |
+
{
|
| 999 |
+
"type": "single",
|
| 1000 |
+
"id": "Batch10_20250506_P395",
|
| 1001 |
+
"rel_path": "Batch10/20250506_P395",
|
| 1002 |
+
"score": 1.0
|
| 1003 |
+
},
|
| 1004 |
+
{
|
| 1005 |
+
"type": "single",
|
| 1006 |
+
"id": "Batch10_20250508_P795",
|
| 1007 |
+
"rel_path": "Batch10/20250508_P795",
|
| 1008 |
+
"score": 1.0
|
| 1009 |
+
},
|
| 1010 |
+
{
|
| 1011 |
+
"type": "single",
|
| 1012 |
+
"id": "Batch10_20250509_P857",
|
| 1013 |
+
"rel_path": "Batch10/20250509_P857",
|
| 1014 |
+
"score": 1.0
|
| 1015 |
+
},
|
| 1016 |
+
{
|
| 1017 |
+
"type": "single",
|
| 1018 |
+
"id": "Batch10_20250503_P106",
|
| 1019 |
+
"rel_path": "Batch10/20250503_P106",
|
| 1020 |
+
"score": 1.0
|
| 1021 |
+
},
|
| 1022 |
+
{
|
| 1023 |
+
"type": "single",
|
| 1024 |
+
"id": "Batch10_20250506_P512",
|
| 1025 |
+
"rel_path": "Batch10/20250506_P512",
|
| 1026 |
+
"score": 1.0
|
| 1027 |
+
},
|
| 1028 |
+
{
|
| 1029 |
+
"type": "single",
|
| 1030 |
+
"id": "Batch10_20250512_P1507",
|
| 1031 |
+
"rel_path": "Batch10/20250512_P1507",
|
| 1032 |
+
"score": 1.0
|
| 1033 |
+
},
|
| 1034 |
+
{
|
| 1035 |
+
"type": "single",
|
| 1036 |
+
"id": "Batch10_20250512_P1459",
|
| 1037 |
+
"rel_path": "Batch10/20250512_P1459",
|
| 1038 |
+
"score": 1.0
|
| 1039 |
+
},
|
| 1040 |
+
{
|
| 1041 |
+
"type": "single",
|
| 1042 |
+
"id": "Batch10_20250507_P682",
|
| 1043 |
+
"rel_path": "Batch10/20250507_P682",
|
| 1044 |
+
"score": 1.0
|
| 1045 |
+
},
|
| 1046 |
+
{
|
| 1047 |
+
"type": "single",
|
| 1048 |
+
"id": "Batch10_20250508_P764",
|
| 1049 |
+
"rel_path": "Batch10/20250508_P764",
|
| 1050 |
+
"score": 1.0
|
| 1051 |
+
},
|
| 1052 |
+
{
|
| 1053 |
+
"type": "single",
|
| 1054 |
+
"id": "Batch10_20250502_P57",
|
| 1055 |
+
"rel_path": "Batch10/20250502_P57",
|
| 1056 |
+
"score": 1.0
|
| 1057 |
+
},
|
| 1058 |
+
{
|
| 1059 |
+
"type": "single",
|
| 1060 |
+
"id": "Batch10_20250510_P1189",
|
| 1061 |
+
"rel_path": "Batch10/20250510_P1189",
|
| 1062 |
+
"score": 1.0
|
| 1063 |
+
},
|
| 1064 |
+
{
|
| 1065 |
+
"type": "single",
|
| 1066 |
+
"id": "Batch10_20250506_P321",
|
| 1067 |
+
"rel_path": "Batch10/20250506_P321",
|
| 1068 |
+
"score": 1.0
|
| 1069 |
+
},
|
| 1070 |
+
{
|
| 1071 |
+
"type": "single",
|
| 1072 |
+
"id": "Batch10_20250511_P1225",
|
| 1073 |
+
"rel_path": "Batch10/20250511_P1225",
|
| 1074 |
+
"score": 1.0
|
| 1075 |
+
},
|
| 1076 |
+
{
|
| 1077 |
+
"type": "single",
|
| 1078 |
+
"id": "Batch10_20250508_P838",
|
| 1079 |
+
"rel_path": "Batch10/20250508_P838",
|
| 1080 |
+
"score": 1.0
|
| 1081 |
+
},
|
| 1082 |
+
{
|
| 1083 |
+
"type": "single",
|
| 1084 |
+
"id": "Batch10_20250502_P58",
|
| 1085 |
+
"rel_path": "Batch10/20250502_P58",
|
| 1086 |
+
"score": 1.0
|
| 1087 |
+
},
|
| 1088 |
+
{
|
| 1089 |
+
"type": "single",
|
| 1090 |
+
"id": "Batch10_20250511_P1245",
|
| 1091 |
+
"rel_path": "Batch10/20250511_P1245",
|
| 1092 |
+
"score": 1.0
|
| 1093 |
+
},
|
| 1094 |
+
{
|
| 1095 |
+
"type": "single",
|
| 1096 |
+
"id": "Batch10_20250501_P35",
|
| 1097 |
+
"rel_path": "Batch10/20250501_P35",
|
| 1098 |
+
"score": 1.0
|
| 1099 |
+
},
|
| 1100 |
+
{
|
| 1101 |
+
"type": "single",
|
| 1102 |
+
"id": "Batch10_20250504_P196",
|
| 1103 |
+
"rel_path": "Batch10/20250504_P196",
|
| 1104 |
+
"score": 1.0
|
| 1105 |
+
},
|
| 1106 |
+
{
|
| 1107 |
+
"type": "single",
|
| 1108 |
+
"id": "Batch10_20250512_P1497",
|
| 1109 |
+
"rel_path": "Batch10/20250512_P1497",
|
| 1110 |
+
"score": 1.0
|
| 1111 |
+
},
|
| 1112 |
+
{
|
| 1113 |
+
"type": "single",
|
| 1114 |
+
"id": "Batch10_20250508_P816",
|
| 1115 |
+
"rel_path": "Batch10/20250508_P816",
|
| 1116 |
+
"score": 1.0
|
| 1117 |
+
},
|
| 1118 |
+
{
|
| 1119 |
+
"type": "single",
|
| 1120 |
+
"id": "Batch10_20250512_P1420",
|
| 1121 |
+
"rel_path": "Batch10/20250512_P1420",
|
| 1122 |
+
"score": 1.0
|
| 1123 |
+
},
|
| 1124 |
+
{
|
| 1125 |
+
"type": "single",
|
| 1126 |
+
"id": "Batch10_20250506_P286",
|
| 1127 |
+
"rel_path": "Batch10/20250506_P286",
|
| 1128 |
+
"score": 1.0
|
| 1129 |
+
},
|
| 1130 |
+
{
|
| 1131 |
+
"type": "single",
|
| 1132 |
+
"id": "Batch10_20250507_P580",
|
| 1133 |
+
"rel_path": "Batch10/20250507_P580",
|
| 1134 |
+
"score": 1.0
|
| 1135 |
+
},
|
| 1136 |
+
{
|
| 1137 |
+
"type": "single",
|
| 1138 |
+
"id": "Batch10_20250502_P47",
|
| 1139 |
+
"rel_path": "Batch10/20250502_P47",
|
| 1140 |
+
"score": 1.0
|
| 1141 |
+
},
|
| 1142 |
+
{
|
| 1143 |
+
"type": "single",
|
| 1144 |
+
"id": "Batch10_20250505_P211",
|
| 1145 |
+
"rel_path": "Batch10/20250505_P211",
|
| 1146 |
+
"score": 1.0
|
| 1147 |
+
},
|
| 1148 |
+
{
|
| 1149 |
+
"type": "single",
|
| 1150 |
+
"id": "Batch10_20250509_P1031",
|
| 1151 |
+
"rel_path": "Batch10/20250509_P1031",
|
| 1152 |
+
"score": 1.0
|
| 1153 |
+
},
|
| 1154 |
+
{
|
| 1155 |
+
"type": "single",
|
| 1156 |
+
"id": "Batch10_20250502_P44",
|
| 1157 |
+
"rel_path": "Batch10/20250502_P44",
|
| 1158 |
+
"score": 1.0
|
| 1159 |
+
},
|
| 1160 |
+
{
|
| 1161 |
+
"type": "single",
|
| 1162 |
+
"id": "Batch10_20250506_P288",
|
| 1163 |
+
"rel_path": "Batch10/20250506_P288",
|
| 1164 |
+
"score": 1.0
|
| 1165 |
+
},
|
| 1166 |
+
{
|
| 1167 |
+
"type": "single",
|
| 1168 |
+
"id": "Batch10_20250509_P928",
|
| 1169 |
+
"rel_path": "Batch10/20250509_P928",
|
| 1170 |
+
"score": 1.0
|
| 1171 |
+
},
|
| 1172 |
+
{
|
| 1173 |
+
"type": "single",
|
| 1174 |
+
"id": "Batch10_20250503_P83",
|
| 1175 |
+
"rel_path": "Batch10/20250503_P83",
|
| 1176 |
+
"score": 1.0
|
| 1177 |
+
},
|
| 1178 |
+
{
|
| 1179 |
+
"type": "single",
|
| 1180 |
+
"id": "Batch10_20250506_P362",
|
| 1181 |
+
"rel_path": "Batch10/20250506_P362",
|
| 1182 |
+
"score": 1.0
|
| 1183 |
+
},
|
| 1184 |
+
{
|
| 1185 |
+
"type": "single",
|
| 1186 |
+
"id": "Batch10_20250503_P82",
|
| 1187 |
+
"rel_path": "Batch10/20250503_P82",
|
| 1188 |
+
"score": 1.0
|
| 1189 |
+
},
|
| 1190 |
+
{
|
| 1191 |
+
"type": "single",
|
| 1192 |
+
"id": "Batch10_20250506_P511",
|
| 1193 |
+
"rel_path": "Batch10/20250506_P511",
|
| 1194 |
+
"score": 1.0
|
| 1195 |
+
},
|
| 1196 |
+
{
|
| 1197 |
+
"type": "single",
|
| 1198 |
+
"id": "Batch10_20250507_P525",
|
| 1199 |
+
"rel_path": "Batch10/20250507_P525",
|
| 1200 |
+
"score": 1.0
|
| 1201 |
+
},
|
| 1202 |
+
{
|
| 1203 |
+
"type": "single",
|
| 1204 |
+
"id": "Batch10_20250512_P1302",
|
| 1205 |
+
"rel_path": "Batch10/20250512_P1302",
|
| 1206 |
+
"score": 1.0
|
| 1207 |
+
},
|
| 1208 |
+
{
|
| 1209 |
+
"type": "single",
|
| 1210 |
+
"id": "Batch10_20250506_P283",
|
| 1211 |
+
"rel_path": "Batch10/20250506_P283",
|
| 1212 |
+
"score": 1.0
|
| 1213 |
+
},
|
| 1214 |
+
{
|
| 1215 |
+
"type": "single",
|
| 1216 |
+
"id": "Batch10_20250508_P733",
|
| 1217 |
+
"rel_path": "Batch10/20250508_P733",
|
| 1218 |
+
"score": 1.0
|
| 1219 |
+
},
|
| 1220 |
+
{
|
| 1221 |
+
"type": "single",
|
| 1222 |
+
"id": "Batch10_20250506_P462",
|
| 1223 |
+
"rel_path": "Batch10/20250506_P462",
|
| 1224 |
+
"score": 1.0
|
| 1225 |
+
},
|
| 1226 |
+
{
|
| 1227 |
+
"type": "single",
|
| 1228 |
+
"id": "Batch10_20250511_P1215",
|
| 1229 |
+
"rel_path": "Batch10/20250511_P1215",
|
| 1230 |
+
"score": 1.0
|
| 1231 |
+
},
|
| 1232 |
+
{
|
| 1233 |
+
"type": "single",
|
| 1234 |
+
"id": "Batch10_20250504_P197",
|
| 1235 |
+
"rel_path": "Batch10/20250504_P197",
|
| 1236 |
+
"score": 1.0
|
| 1237 |
+
},
|
| 1238 |
+
{
|
| 1239 |
+
"type": "single",
|
| 1240 |
+
"id": "Batch10_20250501_P13",
|
| 1241 |
+
"rel_path": "Batch10/20250501_P13",
|
| 1242 |
+
"score": 1.0
|
| 1243 |
+
},
|
| 1244 |
+
{
|
| 1245 |
+
"type": "single",
|
| 1246 |
+
"id": "Batch10_20250506_P407",
|
| 1247 |
+
"rel_path": "Batch10/20250506_P407",
|
| 1248 |
+
"score": 1.0
|
| 1249 |
+
},
|
| 1250 |
+
{
|
| 1251 |
+
"type": "single",
|
| 1252 |
+
"id": "Batch10_20250506_P400",
|
| 1253 |
+
"rel_path": "Batch10/20250506_P400",
|
| 1254 |
+
"score": 1.0
|
| 1255 |
+
},
|
| 1256 |
+
{
|
| 1257 |
+
"type": "single",
|
| 1258 |
+
"id": "Batch10_20250508_P836",
|
| 1259 |
+
"rel_path": "Batch10/20250508_P836",
|
| 1260 |
+
"score": 1.0
|
| 1261 |
+
},
|
| 1262 |
+
{
|
| 1263 |
+
"type": "single",
|
| 1264 |
+
"id": "Batch10_20250506_P369",
|
| 1265 |
+
"rel_path": "Batch10/20250506_P369",
|
| 1266 |
+
"score": 1.0
|
| 1267 |
+
},
|
| 1268 |
+
{
|
| 1269 |
+
"type": "single",
|
| 1270 |
+
"id": "Batch10_20250507_P570",
|
| 1271 |
+
"rel_path": "Batch10/20250507_P570",
|
| 1272 |
+
"score": 1.0
|
| 1273 |
+
},
|
| 1274 |
+
{
|
| 1275 |
+
"type": "single",
|
| 1276 |
+
"id": "Batch10_20250505_P269",
|
| 1277 |
+
"rel_path": "Batch10/20250505_P269",
|
| 1278 |
+
"score": 1.0
|
| 1279 |
+
},
|
| 1280 |
+
{
|
| 1281 |
+
"type": "single",
|
| 1282 |
+
"id": "Batch10_20250505_P248",
|
| 1283 |
+
"rel_path": "Batch10/20250505_P248",
|
| 1284 |
+
"score": 1.0
|
| 1285 |
+
},
|
| 1286 |
+
{
|
| 1287 |
+
"type": "single",
|
| 1288 |
+
"id": "Batch10_20250501_P37",
|
| 1289 |
+
"rel_path": "Batch10/20250501_P37",
|
| 1290 |
+
"score": 1.0
|
| 1291 |
+
},
|
| 1292 |
+
{
|
| 1293 |
+
"type": "single",
|
| 1294 |
+
"id": "Batch10_20250512_P1492",
|
| 1295 |
+
"rel_path": "Batch10/20250512_P1492",
|
| 1296 |
+
"score": 1.0
|
| 1297 |
+
},
|
| 1298 |
+
{
|
| 1299 |
+
"type": "single",
|
| 1300 |
+
"id": "Batch10_20250507_P679",
|
| 1301 |
+
"rel_path": "Batch10/20250507_P679",
|
| 1302 |
+
"score": 1.0
|
| 1303 |
+
},
|
| 1304 |
+
{
|
| 1305 |
+
"type": "single",
|
| 1306 |
+
"id": "Batch10_20250506_P491",
|
| 1307 |
+
"rel_path": "Batch10/20250506_P491",
|
| 1308 |
+
"score": 1.0
|
| 1309 |
+
},
|
| 1310 |
+
{
|
| 1311 |
+
"type": "single",
|
| 1312 |
+
"id": "Batch10_20250506_P289",
|
| 1313 |
+
"rel_path": "Batch10/20250506_P289",
|
| 1314 |
+
"score": 1.0
|
| 1315 |
+
},
|
| 1316 |
+
{
|
| 1317 |
+
"type": "single",
|
| 1318 |
+
"id": "Batch10_20250509_P1001",
|
| 1319 |
+
"rel_path": "Batch10/20250509_P1001",
|
| 1320 |
+
"score": 1.0
|
| 1321 |
+
},
|
| 1322 |
+
{
|
| 1323 |
+
"type": "single",
|
| 1324 |
+
"id": "Batch10_20250507_P522",
|
| 1325 |
+
"rel_path": "Batch10/20250507_P522",
|
| 1326 |
+
"score": 1.0
|
| 1327 |
+
},
|
| 1328 |
+
{
|
| 1329 |
+
"type": "single",
|
| 1330 |
+
"id": "Batch10_20250509_P1034",
|
| 1331 |
+
"rel_path": "Batch10/20250509_P1034",
|
| 1332 |
+
"score": 1.0
|
| 1333 |
+
},
|
| 1334 |
+
{
|
| 1335 |
+
"type": "single",
|
| 1336 |
+
"id": "Batch10_20250505_P280",
|
| 1337 |
+
"rel_path": "Batch10/20250505_P280",
|
| 1338 |
+
"score": 1.0
|
| 1339 |
+
},
|
| 1340 |
+
{
|
| 1341 |
+
"type": "single",
|
| 1342 |
+
"id": "Batch10_20250501_P30",
|
| 1343 |
+
"rel_path": "Batch10/20250501_P30",
|
| 1344 |
+
"score": 1.0
|
| 1345 |
+
},
|
| 1346 |
+
{
|
| 1347 |
+
"type": "single",
|
| 1348 |
+
"id": "Batch10_20250512_P1411",
|
| 1349 |
+
"rel_path": "Batch10/20250512_P1411",
|
| 1350 |
+
"score": 1.0
|
| 1351 |
+
},
|
| 1352 |
+
{
|
| 1353 |
+
"type": "single",
|
| 1354 |
+
"id": "Batch10_20250503_P78",
|
| 1355 |
+
"rel_path": "Batch10/20250503_P78",
|
| 1356 |
+
"score": 1.0
|
| 1357 |
+
},
|
| 1358 |
+
{
|
| 1359 |
+
"type": "single",
|
| 1360 |
+
"id": "Batch10_20250512_P1404",
|
| 1361 |
+
"rel_path": "Batch10/20250512_P1404",
|
| 1362 |
+
"score": 1.0
|
| 1363 |
+
},
|
| 1364 |
+
{
|
| 1365 |
+
"type": "single",
|
| 1366 |
+
"id": "Batch10_20250501_P11",
|
| 1367 |
+
"rel_path": "Batch10/20250501_P11",
|
| 1368 |
+
"score": 1.0
|
| 1369 |
+
},
|
| 1370 |
+
{
|
| 1371 |
+
"type": "single",
|
| 1372 |
+
"id": "Batch10_20250512_P1363",
|
| 1373 |
+
"rel_path": "Batch10/20250512_P1363",
|
| 1374 |
+
"score": 1.0
|
| 1375 |
+
},
|
| 1376 |
+
{
|
| 1377 |
+
"type": "single",
|
| 1378 |
+
"id": "Batch10_20250502_P55",
|
| 1379 |
+
"rel_path": "Batch10/20250502_P55",
|
| 1380 |
+
"score": 1.0
|
| 1381 |
+
},
|
| 1382 |
+
{
|
| 1383 |
+
"type": "single",
|
| 1384 |
+
"id": "Batch10_20250506_P409",
|
| 1385 |
+
"rel_path": "Batch10/20250506_P409",
|
| 1386 |
+
"score": 1.0
|
| 1387 |
+
},
|
| 1388 |
+
{
|
| 1389 |
+
"type": "single",
|
| 1390 |
+
"id": "Batch10_20250502_P53",
|
| 1391 |
+
"rel_path": "Batch10/20250502_P53",
|
| 1392 |
+
"score": 1.0
|
| 1393 |
+
},
|
| 1394 |
+
{
|
| 1395 |
+
"type": "single",
|
| 1396 |
+
"id": "Batch10_20250510_P1197",
|
| 1397 |
+
"rel_path": "Batch10/20250510_P1197",
|
| 1398 |
+
"score": 1.0
|
| 1399 |
+
},
|
| 1400 |
+
{
|
| 1401 |
+
"type": "single",
|
| 1402 |
+
"id": "Batch10_20250512_P1454",
|
| 1403 |
+
"rel_path": "Batch10/20250512_P1454",
|
| 1404 |
+
"score": 1.0
|
| 1405 |
+
},
|
| 1406 |
+
{
|
| 1407 |
+
"type": "single",
|
| 1408 |
+
"id": "Batch10_20250512_P1406",
|
| 1409 |
+
"rel_path": "Batch10/20250512_P1406",
|
| 1410 |
+
"score": 1.0
|
| 1411 |
+
},
|
| 1412 |
+
{
|
| 1413 |
+
"type": "single",
|
| 1414 |
+
"id": "Batch10_20250504_P148",
|
| 1415 |
+
"rel_path": "Batch10/20250504_P148",
|
| 1416 |
+
"score": 1.0
|
| 1417 |
+
},
|
| 1418 |
+
{
|
| 1419 |
+
"type": "single",
|
| 1420 |
+
"id": "Batch10_20250507_P660",
|
| 1421 |
+
"rel_path": "Batch10/20250507_P660",
|
| 1422 |
+
"score": 1.0
|
| 1423 |
+
},
|
| 1424 |
+
{
|
| 1425 |
+
"type": "single",
|
| 1426 |
+
"id": "Batch10_20250509_P856",
|
| 1427 |
+
"rel_path": "Batch10/20250509_P856",
|
| 1428 |
+
"score": 1.0
|
| 1429 |
+
},
|
| 1430 |
+
{
|
| 1431 |
+
"type": "single",
|
| 1432 |
+
"id": "Batch10_20250501_P26",
|
| 1433 |
+
"rel_path": "Batch10/20250501_P26",
|
| 1434 |
+
"score": 1.0
|
| 1435 |
+
},
|
| 1436 |
+
{
|
| 1437 |
+
"type": "single",
|
| 1438 |
+
"id": "Batch10_20250506_P387",
|
| 1439 |
+
"rel_path": "Batch10/20250506_P387",
|
| 1440 |
+
"score": 1.0
|
| 1441 |
+
},
|
| 1442 |
+
{
|
| 1443 |
+
"type": "single",
|
| 1444 |
+
"id": "Batch10_20250508_P686",
|
| 1445 |
+
"rel_path": "Batch10/20250508_P686",
|
| 1446 |
+
"score": 1.0
|
| 1447 |
+
},
|
| 1448 |
+
{
|
| 1449 |
+
"type": "single",
|
| 1450 |
+
"id": "Batch10_20250509_P1003",
|
| 1451 |
+
"rel_path": "Batch10/20250509_P1003",
|
| 1452 |
+
"score": 1.0
|
| 1453 |
+
},
|
| 1454 |
+
{
|
| 1455 |
+
"type": "single",
|
| 1456 |
+
"id": "Batch10_20250512_P1496",
|
| 1457 |
+
"rel_path": "Batch10/20250512_P1496",
|
| 1458 |
+
"score": 1.0
|
| 1459 |
+
},
|
| 1460 |
+
{
|
| 1461 |
+
"type": "single",
|
| 1462 |
+
"id": "Batch10_20250508_P744",
|
| 1463 |
+
"rel_path": "Batch10/20250508_P744",
|
| 1464 |
+
"score": 1.0
|
| 1465 |
+
},
|
| 1466 |
+
{
|
| 1467 |
+
"type": "single",
|
| 1468 |
+
"id": "Batch10_20250512_P1512",
|
| 1469 |
+
"rel_path": "Batch10/20250512_P1512",
|
| 1470 |
+
"score": 1.0
|
| 1471 |
+
},
|
| 1472 |
+
{
|
| 1473 |
+
"type": "single",
|
| 1474 |
+
"id": "Batch10_20250504_P117",
|
| 1475 |
+
"rel_path": "Batch10/20250504_P117",
|
| 1476 |
+
"score": 1.0
|
| 1477 |
+
},
|
| 1478 |
+
{
|
| 1479 |
+
"type": "single",
|
| 1480 |
+
"id": "Batch10_20250505_P228",
|
| 1481 |
+
"rel_path": "Batch10/20250505_P228",
|
| 1482 |
+
"score": 1.0
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"type": "single",
|
| 1486 |
+
"id": "Batch10_20250512_P1509",
|
| 1487 |
+
"rel_path": "Batch10/20250512_P1509",
|
| 1488 |
+
"score": 1.0
|
| 1489 |
+
},
|
| 1490 |
+
{
|
| 1491 |
+
"type": "single",
|
| 1492 |
+
"id": "Batch10_20250512_P1327",
|
| 1493 |
+
"rel_path": "Batch10/20250512_P1327",
|
| 1494 |
+
"score": 1.0
|
| 1495 |
+
},
|
| 1496 |
+
{
|
| 1497 |
+
"type": "single",
|
| 1498 |
+
"id": "Batch10_20250511_P1223",
|
| 1499 |
+
"rel_path": "Batch10/20250511_P1223",
|
| 1500 |
+
"score": 1.0
|
| 1501 |
+
},
|
| 1502 |
+
{
|
| 1503 |
+
"type": "single",
|
| 1504 |
+
"id": "Batch10_20250507_P575",
|
| 1505 |
+
"rel_path": "Batch10/20250507_P575",
|
| 1506 |
+
"score": 1.0
|
| 1507 |
+
},
|
| 1508 |
+
{
|
| 1509 |
+
"type": "single",
|
| 1510 |
+
"id": "Batch10_20250509_P1032",
|
| 1511 |
+
"rel_path": "Batch10/20250509_P1032",
|
| 1512 |
+
"score": 1.0
|
| 1513 |
+
},
|
| 1514 |
+
{
|
| 1515 |
+
"type": "single",
|
| 1516 |
+
"id": "Batch10_20250511_P1217",
|
| 1517 |
+
"rel_path": "Batch10/20250511_P1217",
|
| 1518 |
+
"score": 1.0
|
| 1519 |
+
},
|
| 1520 |
+
{
|
| 1521 |
+
"type": "single",
|
| 1522 |
+
"id": "Batch10_20250506_P406",
|
| 1523 |
+
"rel_path": "Batch10/20250506_P406",
|
| 1524 |
+
"score": 1.0
|
| 1525 |
+
},
|
| 1526 |
+
{
|
| 1527 |
+
"type": "single",
|
| 1528 |
+
"id": "Batch10_20250509_P992",
|
| 1529 |
+
"rel_path": "Batch10/20250509_P992",
|
| 1530 |
+
"score": 1.0
|
| 1531 |
+
},
|
| 1532 |
+
{
|
| 1533 |
+
"type": "single",
|
| 1534 |
+
"id": "Batch10_20250501_P07",
|
| 1535 |
+
"rel_path": "Batch10/20250501_P07",
|
| 1536 |
+
"score": 1.0
|
| 1537 |
+
},
|
| 1538 |
+
{
|
| 1539 |
+
"type": "single",
|
| 1540 |
+
"id": "Batch10_20250509_P888",
|
| 1541 |
+
"rel_path": "Batch10/20250509_P888",
|
| 1542 |
+
"score": 1.0
|
| 1543 |
+
},
|
| 1544 |
+
{
|
| 1545 |
+
"type": "single",
|
| 1546 |
+
"id": "Batch10_20250509_P894",
|
| 1547 |
+
"rel_path": "Batch10/20250509_P894",
|
| 1548 |
+
"score": 1.0
|
| 1549 |
+
},
|
| 1550 |
+
{
|
| 1551 |
+
"type": "single",
|
| 1552 |
+
"id": "Batch10_20250507_P513",
|
| 1553 |
+
"rel_path": "Batch10/20250507_P513",
|
| 1554 |
+
"score": 1.0
|
| 1555 |
+
},
|
| 1556 |
+
{
|
| 1557 |
+
"type": "single",
|
| 1558 |
+
"id": "Batch10_20250508_P811",
|
| 1559 |
+
"rel_path": "Batch10/20250508_P811",
|
| 1560 |
+
"score": 1.0
|
| 1561 |
+
},
|
| 1562 |
+
{
|
| 1563 |
+
"type": "single",
|
| 1564 |
+
"id": "Batch10_20250508_P763",
|
| 1565 |
+
"rel_path": "Batch10/20250508_P763",
|
| 1566 |
+
"score": 1.0
|
| 1567 |
+
},
|
| 1568 |
+
{
|
| 1569 |
+
"type": "single",
|
| 1570 |
+
"id": "Batch10_20250507_P548",
|
| 1571 |
+
"rel_path": "Batch10/20250507_P548",
|
| 1572 |
+
"score": 1.0
|
| 1573 |
+
},
|
| 1574 |
+
{
|
| 1575 |
+
"type": "single",
|
| 1576 |
+
"id": "Batch10_20250507_P652",
|
| 1577 |
+
"rel_path": "Batch10/20250507_P652",
|
| 1578 |
+
"score": 1.0
|
| 1579 |
+
},
|
| 1580 |
+
{
|
| 1581 |
+
"type": "single",
|
| 1582 |
+
"id": "Batch10_20250512_P1403",
|
| 1583 |
+
"rel_path": "Batch10/20250512_P1403",
|
| 1584 |
+
"score": 1.0
|
| 1585 |
+
},
|
| 1586 |
+
{
|
| 1587 |
+
"type": "single",
|
| 1588 |
+
"id": "Batch10_20250506_P461",
|
| 1589 |
+
"rel_path": "Batch10/20250506_P461",
|
| 1590 |
+
"score": 1.0
|
| 1591 |
+
},
|
| 1592 |
+
{
|
| 1593 |
+
"type": "single",
|
| 1594 |
+
"id": "Batch10_20250511_P1224",
|
| 1595 |
+
"rel_path": "Batch10/20250511_P1224",
|
| 1596 |
+
"score": 1.0
|
| 1597 |
+
},
|
| 1598 |
+
{
|
| 1599 |
+
"type": "single",
|
| 1600 |
+
"id": "Batch10_20250512_P1465",
|
| 1601 |
+
"rel_path": "Batch10/20250512_P1465",
|
| 1602 |
+
"score": 1.0
|
| 1603 |
+
},
|
| 1604 |
+
{
|
| 1605 |
+
"type": "single",
|
| 1606 |
+
"id": "Batch10_20250504_P193",
|
| 1607 |
+
"rel_path": "Batch10/20250504_P193",
|
| 1608 |
+
"score": 1.0
|
| 1609 |
+
},
|
| 1610 |
+
{
|
| 1611 |
+
"type": "single",
|
| 1612 |
+
"id": "Batch10_20250507_P626",
|
| 1613 |
+
"rel_path": "Batch10/20250507_P626",
|
| 1614 |
+
"score": 1.0
|
| 1615 |
+
},
|
| 1616 |
+
{
|
| 1617 |
+
"type": "single",
|
| 1618 |
+
"id": "Batch10_20250506_P327",
|
| 1619 |
+
"rel_path": "Batch10/20250506_P327",
|
| 1620 |
+
"score": 1.0
|
| 1621 |
+
},
|
| 1622 |
+
{
|
| 1623 |
+
"type": "single",
|
| 1624 |
+
"id": "Batch10_20250508_P837",
|
| 1625 |
+
"rel_path": "Batch10/20250508_P837",
|
| 1626 |
+
"score": 1.0
|
| 1627 |
+
},
|
| 1628 |
+
{
|
| 1629 |
+
"type": "single",
|
| 1630 |
+
"id": "Batch10_20250507_P563",
|
| 1631 |
+
"rel_path": "Batch10/20250507_P563",
|
| 1632 |
+
"score": 1.0
|
| 1633 |
+
},
|
| 1634 |
+
{
|
| 1635 |
+
"type": "single",
|
| 1636 |
+
"id": "Batch10_20250512_P1344",
|
| 1637 |
+
"rel_path": "Batch10/20250512_P1344",
|
| 1638 |
+
"score": 1.0
|
| 1639 |
+
},
|
| 1640 |
+
{
|
| 1641 |
+
"type": "single",
|
| 1642 |
+
"id": "Batch10_20250501_P41",
|
| 1643 |
+
"rel_path": "Batch10/20250501_P41",
|
| 1644 |
+
"score": 1.0
|
| 1645 |
+
},
|
| 1646 |
+
{
|
| 1647 |
+
"type": "single",
|
| 1648 |
+
"id": "Batch10_20250508_P698",
|
| 1649 |
+
"rel_path": "Batch10/20250508_P698",
|
| 1650 |
+
"score": 1.0
|
| 1651 |
+
},
|
| 1652 |
+
{
|
| 1653 |
+
"type": "single",
|
| 1654 |
+
"id": "Batch10_20250504_P163",
|
| 1655 |
+
"rel_path": "Batch10/20250504_P163",
|
| 1656 |
+
"score": 1.0
|
| 1657 |
+
},
|
| 1658 |
+
{
|
| 1659 |
+
"type": "single",
|
| 1660 |
+
"id": "Batch10_20250506_P484",
|
| 1661 |
+
"rel_path": "Batch10/20250506_P484",
|
| 1662 |
+
"score": 1.0
|
| 1663 |
+
},
|
| 1664 |
+
{
|
| 1665 |
+
"type": "single",
|
| 1666 |
+
"id": "Batch10_20250512_P1396",
|
| 1667 |
+
"rel_path": "Batch10/20250512_P1396",
|
| 1668 |
+
"score": 1.0
|
| 1669 |
+
},
|
| 1670 |
+
{
|
| 1671 |
+
"type": "single",
|
| 1672 |
+
"id": "Batch10_20250510_P1173",
|
| 1673 |
+
"rel_path": "Batch10/20250510_P1173",
|
| 1674 |
+
"score": 1.0
|
| 1675 |
+
},
|
| 1676 |
+
{
|
| 1677 |
+
"type": "single",
|
| 1678 |
+
"id": "Batch10_20250507_P583",
|
| 1679 |
+
"rel_path": "Batch10/20250507_P583",
|
| 1680 |
+
"score": 1.0
|
| 1681 |
+
},
|
| 1682 |
+
{
|
| 1683 |
+
"type": "single",
|
| 1684 |
+
"id": "Batch10_20250503_P108",
|
| 1685 |
+
"rel_path": "Batch10/20250503_P108",
|
| 1686 |
+
"score": 1.0
|
| 1687 |
+
}
|
| 1688 |
+
]
|
config.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 超声提示多标签分类模型配置文件
|
| 2 |
+
# TransMIL + Query2Label Hybrid Model
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
# 【需要修改】数据根目录(包含 Report_XXX 文件夹的目录)
|
| 6 |
+
data_root: "/path/to/your/ReportData_ROI/"
|
| 7 |
+
# 【需要修改】多标签注释 CSV 文件路径
|
| 8 |
+
annotation_csv: "/path/to/your/ReportData_ROI/thyroid_multilabel_annotations.csv"
|
| 9 |
+
# 【需要修改】验证集 JSON 文件路径
|
| 10 |
+
val_json: "/path/to/your/ReportData_ROI/classification_val_set_single.json"
|
| 11 |
+
# 【需要修改】测试集 JSON 文件路径
|
| 12 |
+
test_json: "/path/to/your/ReportData_ROI/classification_test_set_single.json"
|
| 13 |
+
img_size: 224
|
| 14 |
+
max_images_per_case: 20
|
| 15 |
+
num_workers: 8
|
| 16 |
+
|
| 17 |
+
model:
|
| 18 |
+
num_class: 17 # 17类标签(已删除"切除术后")
|
| 19 |
+
hidden_dim: 512
|
| 20 |
+
nheads: 8
|
| 21 |
+
num_decoder_layers: 2
|
| 22 |
+
pretrained_resnet: True
|
| 23 |
+
use_ppeg: False
|
| 24 |
+
|
| 25 |
+
training:
|
| 26 |
+
batch_size: 4
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr: 0.0001
|
| 29 |
+
weight_decay: 0.0001
|
| 30 |
+
optimizer: "AdamW"
|
| 31 |
+
|
| 32 |
+
# Asymmetric Loss 参数(处理多标签不平衡)
|
| 33 |
+
gamma_neg: 4
|
| 34 |
+
gamma_pos: 1
|
| 35 |
+
clip: 0.05
|
| 36 |
+
|
| 37 |
+
# 内存优化策略
|
| 38 |
+
use_amp: true # 混合精度训练
|
| 39 |
+
gradient_accumulation_steps: 4 # 有效 batch_size = 4 * 4 = 16
|
| 40 |
+
gradient_checkpointing: true
|
| 41 |
+
|
| 42 |
+
# 学习率调度器
|
| 43 |
+
scheduler: "cosine"
|
| 44 |
+
warmup_epochs: 5
|
| 45 |
+
|
| 46 |
+
# 模型保存
|
| 47 |
+
save_dir: "checkpoints/"
|
| 48 |
+
save_freq: 5
|
evaluate.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from sklearn.metrics import (
|
| 8 |
+
average_precision_score,
|
| 9 |
+
roc_auc_score,
|
| 10 |
+
f1_score,
|
| 11 |
+
precision_score,
|
| 12 |
+
recall_score,
|
| 13 |
+
accuracy_score
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# 引入你的模型和数据加载器
|
| 17 |
+
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
| 18 |
+
from thyroid_dataset import create_dataloaders, TARGET_CLASSES
|
| 19 |
+
'''
|
| 20 |
+
# 18类标签定义 (与训练时保持一致)
|
| 21 |
+
TARGET_CLASSES = [
|
| 22 |
+
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
| 23 |
+
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
| 24 |
+
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
|
| 25 |
+
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
| 26 |
+
]
|
| 27 |
+
'''
|
| 28 |
+
|
| 29 |
+
def get_best_checkpoint_path(save_dir):
|
| 30 |
+
"""自动寻找 best checkpoint"""
|
| 31 |
+
best_path = os.path.join(save_dir, 'checkpoint_best.pth')
|
| 32 |
+
if os.path.exists(best_path):
|
| 33 |
+
return best_path
|
| 34 |
+
# 如果没找到 best,找 latest
|
| 35 |
+
latest_path = os.path.join(save_dir, 'checkpoint_latest.pth')
|
| 36 |
+
if os.path.exists(latest_path):
|
| 37 |
+
print(f"Warning: 'checkpoint_best.pth' not found. Using '{latest_path}' instead.")
|
| 38 |
+
return latest_path
|
| 39 |
+
raise FileNotFoundError(f"No checkpoints found in {save_dir}")
|
| 40 |
+
|
| 41 |
+
def compute_metrics(y_true, y_pred_probs, threshold=0.5):
|
| 42 |
+
"""
|
| 43 |
+
计算全面的多标签指标
|
| 44 |
+
y_true: [N, num_classes] (0 or 1)
|
| 45 |
+
y_pred_probs: [N, num_classes] (0.0 ~ 1.0)
|
| 46 |
+
"""
|
| 47 |
+
metrics = {}
|
| 48 |
+
|
| 49 |
+
# 1. 二值化预测
|
| 50 |
+
y_pred_binary = (y_pred_probs >= threshold).astype(int)
|
| 51 |
+
|
| 52 |
+
# 2. 全局指标 (Global Metrics)
|
| 53 |
+
# mAP (mean Average Precision) - 最重要的多标签指标
|
| 54 |
+
metrics['mAP'] = average_precision_score(y_true, y_pred_probs, average='macro')
|
| 55 |
+
metrics['weighted_mAP'] = average_precision_score(y_true, y_pred_probs, average='weighted')
|
| 56 |
+
|
| 57 |
+
# AUROC (Macro & Micro)
|
| 58 |
+
try:
|
| 59 |
+
metrics['macro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='macro')
|
| 60 |
+
metrics['micro_auroc'] = roc_auc_score(y_true, y_pred_probs, average='micro')
|
| 61 |
+
except ValueError:
|
| 62 |
+
metrics['macro_auroc'] = 0.0
|
| 63 |
+
metrics['micro_auroc'] = 0.0
|
| 64 |
+
|
| 65 |
+
# F1 Score
|
| 66 |
+
metrics['micro_f1'] = f1_score(y_true, y_pred_binary, average='micro')
|
| 67 |
+
metrics['macro_f1'] = f1_score(y_true, y_pred_binary, average='macro')
|
| 68 |
+
|
| 69 |
+
# Exact Match Ratio (Subset Accuracy) - 全对才算对
|
| 70 |
+
metrics['subset_accuracy'] = accuracy_score(y_true, y_pred_binary)
|
| 71 |
+
|
| 72 |
+
# 3. 每类详细指标 (Per-class Metrics)
|
| 73 |
+
class_metrics = []
|
| 74 |
+
for i, class_name in enumerate(TARGET_CLASSES):
|
| 75 |
+
# 提取当前类的真实标签和预测概率
|
| 76 |
+
yt = y_true[:, i]
|
| 77 |
+
yp = y_pred_probs[:, i]
|
| 78 |
+
yb = y_pred_binary[:, i]
|
| 79 |
+
|
| 80 |
+
# 样本数
|
| 81 |
+
support = int(yt.sum())
|
| 82 |
+
|
| 83 |
+
# 如果该类没有正样本,部分指标无法计算
|
| 84 |
+
if support > 0:
|
| 85 |
+
ap = average_precision_score(yt, yp)
|
| 86 |
+
try:
|
| 87 |
+
auroc = roc_auc_score(yt, yp)
|
| 88 |
+
except ValueError:
|
| 89 |
+
auroc = 0.5 # 只有一个类别存在时无法计算AUC
|
| 90 |
+
|
| 91 |
+
f1 = f1_score(yt, yb)
|
| 92 |
+
rec = recall_score(yt, yb)
|
| 93 |
+
prec = precision_score(yt, yb, zero_division=0)
|
| 94 |
+
else:
|
| 95 |
+
ap, auroc, f1, rec, prec = 0.0, 0.5, 0.0, 0.0, 0.0
|
| 96 |
+
|
| 97 |
+
class_metrics.append({
|
| 98 |
+
"Class": class_name,
|
| 99 |
+
"Support": support,
|
| 100 |
+
"AP": ap,
|
| 101 |
+
"AUROC": auroc,
|
| 102 |
+
"F1": f1,
|
| 103 |
+
"Precision": prec,
|
| 104 |
+
"Recall": rec
|
| 105 |
+
})
|
| 106 |
+
|
| 107 |
+
return metrics, pd.DataFrame(class_metrics)
|
| 108 |
+
|
| 109 |
+
def main():
|
| 110 |
+
# 1. 加载配置
|
| 111 |
+
config_path = 'config.yaml' # 确保这里路径正确
|
| 112 |
+
with open(config_path, 'r') as f:
|
| 113 |
+
config = yaml.safe_load(f)
|
| 114 |
+
|
| 115 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 116 |
+
print(f"Evaluating on {device}")
|
| 117 |
+
|
| 118 |
+
# 2. 准备数据加载器
|
| 119 |
+
print("Loading Test Data...")
|
| 120 |
+
_, _, test_loader = create_dataloaders(config)
|
| 121 |
+
|
| 122 |
+
# 3. 初始化模型
|
| 123 |
+
print("Initializing Model...")
|
| 124 |
+
model = TransMIL_Query2Label_E2E(
|
| 125 |
+
num_class=config['model']['num_class'],
|
| 126 |
+
hidden_dim=config['model']['hidden_dim'],
|
| 127 |
+
nheads=config['model']['nheads'],
|
| 128 |
+
num_decoder_layers=config['model']['num_decoder_layers'],
|
| 129 |
+
pretrained_resnet=False, # 推理时不需要下载预训练权重,直接加载我们自己的权重
|
| 130 |
+
use_checkpointing=False, # 推理时不需要 checkpointing
|
| 131 |
+
use_ppeg=config['model'].get('use_ppeg', False)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# 4. 加载权重
|
| 135 |
+
ckpt_path = get_best_checkpoint_path(config['training']['save_dir'])
|
| 136 |
+
print(f"Loading checkpoint from: {ckpt_path}")
|
| 137 |
+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 138 |
+
|
| 139 |
+
# 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
|
| 140 |
+
state_dict = checkpoint['model_state_dict']
|
| 141 |
+
new_state_dict = {}
|
| 142 |
+
for k, v in state_dict.items():
|
| 143 |
+
name = k.replace("module.", "")
|
| 144 |
+
new_state_dict[name] = v
|
| 145 |
+
model.load_state_dict(new_state_dict)
|
| 146 |
+
|
| 147 |
+
model.to(device)
|
| 148 |
+
model.eval()
|
| 149 |
+
|
| 150 |
+
# 5. 推理循环
|
| 151 |
+
print("Running Inference...")
|
| 152 |
+
all_preds = []
|
| 153 |
+
all_targets = []
|
| 154 |
+
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
for batch in tqdm(test_loader):
|
| 157 |
+
images = batch['images'].to(device)
|
| 158 |
+
num_instances = batch['num_instances_per_case']
|
| 159 |
+
labels = batch['labels'].numpy() # CPU numpy
|
| 160 |
+
|
| 161 |
+
# Forward
|
| 162 |
+
logits = model(images, num_instances)
|
| 163 |
+
probs = torch.sigmoid(logits).cpu().numpy()
|
| 164 |
+
|
| 165 |
+
all_preds.append(probs)
|
| 166 |
+
all_targets.append(labels)
|
| 167 |
+
|
| 168 |
+
# 拼接
|
| 169 |
+
y_pred_probs = np.concatenate(all_preds, axis=0)
|
| 170 |
+
y_true = np.concatenate(all_targets, axis=0)
|
| 171 |
+
|
| 172 |
+
# 6. 计算指标
|
| 173 |
+
print("\nComputing Metrics...")
|
| 174 |
+
global_metrics, class_df = compute_metrics(y_true, y_pred_probs)
|
| 175 |
+
|
| 176 |
+
# 7. 打印结果
|
| 177 |
+
print("\n" + "="*60)
|
| 178 |
+
print(" GLOBAL PERFORMANCE SUMMARY ")
|
| 179 |
+
print("="*60)
|
| 180 |
+
print(f" mAP (Macro) : {global_metrics['mAP']:.4f}")
|
| 181 |
+
print(f" mAP (Weighted): {global_metrics['weighted_mAP']:.4f}")
|
| 182 |
+
print(f" AUROC (Macro) : {global_metrics['macro_auroc']:.4f}")
|
| 183 |
+
print(f" AUROC (Micro) : {global_metrics['micro_auroc']:.4f}")
|
| 184 |
+
print(f" F1 (Micro) : {global_metrics['micro_f1']:.4f}")
|
| 185 |
+
print(f" F1 (Macro) : {global_metrics['macro_f1']:.4f}")
|
| 186 |
+
print(f" Subset Acc : {global_metrics['subset_accuracy']:.4f}")
|
| 187 |
+
print("-" * 60)
|
| 188 |
+
|
| 189 |
+
print("\n" + "="*100)
|
| 190 |
+
print(" PER-CLASS PERFORMANCE DETAILS (Sorted by Support) ")
|
| 191 |
+
print("="*100)
|
| 192 |
+
|
| 193 |
+
# 按样本数量排序
|
| 194 |
+
class_df = class_df.sort_values(by='Support', ascending=False)
|
| 195 |
+
|
| 196 |
+
# --- 开始修改:手动格式化打印 ---
|
| 197 |
+
# 定义表头
|
| 198 |
+
# 中文字符宽度处理技巧:给 Class 列预留足够大的空间 (比如30)
|
| 199 |
+
# {:<N} 左对齐, {:>N} 右对齐
|
| 200 |
+
|
| 201 |
+
headers = ["Class", "Support", "AP", "AUROC", "F1", "Precision", "Recall"]
|
| 202 |
+
|
| 203 |
+
# 打印表头
|
| 204 |
+
# {0:<24} 表示第一列左对齐占24格
|
| 205 |
+
head_fmt = "{:<24} {:>8} {:>10} {:>10} {:>10} {:>12} {:>10}"
|
| 206 |
+
print(head_fmt.format(*headers))
|
| 207 |
+
print("-" * 100)
|
| 208 |
+
|
| 209 |
+
# 打印每一行
|
| 210 |
+
row_fmt = "{:<24} {:>8d} {:>10.4f} {:>10.4f} {:>10.4f} {:>12.4f} {:>10.4f}"
|
| 211 |
+
|
| 212 |
+
for _, row in class_df.iterrows():
|
| 213 |
+
cls_name = row['Class']
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
display_width = len(cls_name.encode('gbk'))
|
| 217 |
+
|
| 218 |
+
# 计算需要填充的空格数
|
| 219 |
+
# 目标宽度 24 - 实际显示宽度
|
| 220 |
+
target_width = 24
|
| 221 |
+
padding = target_width - display_width
|
| 222 |
+
|
| 223 |
+
# 构造对齐后的字符串
|
| 224 |
+
aligned_name = cls_name + " " * padding
|
| 225 |
+
|
| 226 |
+
print(f"{aligned_name} {int(row['Support']):>8d} {row['AP']:>10.4f} {row['AUROC']:>10.4f} {row['F1']:>10.4f} {row['Precision']:>12.4f} {row['Recall']:>10.4f}")
|
| 227 |
+
|
| 228 |
+
print("="*100)
|
| 229 |
+
|
| 230 |
+
# 保存结果到 CSV
|
| 231 |
+
result_csv = os.path.join(config['training']['save_dir'], 'evaluation_report.csv')
|
| 232 |
+
class_df.to_csv(result_csv, index=False, encoding='utf-8-sig')
|
| 233 |
+
print(f"\nDetailed report saved to: {result_csv}")
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
main()
|
infer_single_case.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
单步推理脚本 - 超声提示多标签分类模型
|
| 3 |
+
Single Case Inference for TransMIL + Query2Label Hybrid Model
|
| 4 |
+
|
| 5 |
+
用法:
|
| 6 |
+
# 指定多个图像文件
|
| 7 |
+
python infer_single_case.py --images /path/to/img1.png /path/to/img2.png --threshold 0.5
|
| 8 |
+
|
| 9 |
+
# 指定图像文件夹
|
| 10 |
+
python infer_single_case.py --image_dir /path/to/case_folder/ --threshold 0.5
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import argparse
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
|
| 21 |
+
# 添加当前目录到路径
|
| 22 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 23 |
+
|
| 24 |
+
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
| 25 |
+
|
| 26 |
+
# 17类标签定义
|
| 27 |
+
TARGET_CLASSES = [
|
| 28 |
+
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
| 29 |
+
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
| 30 |
+
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
|
| 31 |
+
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_model(checkpoint_path: str, device: torch.device):
|
| 36 |
+
"""加载预训练模型"""
|
| 37 |
+
print(f"Loading model from: {checkpoint_path}")
|
| 38 |
+
|
| 39 |
+
# 初始化模型
|
| 40 |
+
model = TransMIL_Query2Label_E2E(
|
| 41 |
+
num_class=17,
|
| 42 |
+
hidden_dim=512,
|
| 43 |
+
nheads=8,
|
| 44 |
+
num_decoder_layers=2,
|
| 45 |
+
pretrained_resnet=False, # 推理时不需要下载预训练权重
|
| 46 |
+
use_checkpointing=False, # 推理时不需要 checkpointing
|
| 47 |
+
use_ppeg=False
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# 加载权重
|
| 51 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 52 |
+
state_dict = checkpoint['model_state_dict']
|
| 53 |
+
|
| 54 |
+
# 处理 state_dict 键名可能不匹配的问题 (如 module. 前缀)
|
| 55 |
+
new_state_dict = {}
|
| 56 |
+
for k, v in state_dict.items():
|
| 57 |
+
name = k.replace("module.", "")
|
| 58 |
+
new_state_dict[name] = v
|
| 59 |
+
model.load_state_dict(new_state_dict)
|
| 60 |
+
|
| 61 |
+
model.to(device)
|
| 62 |
+
model.eval()
|
| 63 |
+
print("Model loaded successfully!")
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def preprocess_images(image_paths: list, img_size: int = 224):
|
| 68 |
+
"""预处理图像"""
|
| 69 |
+
transform = transforms.Compose([
|
| 70 |
+
transforms.Resize((img_size, img_size)),
|
| 71 |
+
transforms.ToTensor(),
|
| 72 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 73 |
+
std=[0.229, 0.224, 0.225])
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
images = []
|
| 77 |
+
valid_paths = []
|
| 78 |
+
|
| 79 |
+
for path in image_paths:
|
| 80 |
+
if not os.path.exists(path):
|
| 81 |
+
print(f"Warning: Image not found: {path}")
|
| 82 |
+
continue
|
| 83 |
+
try:
|
| 84 |
+
img = Image.open(path).convert('RGB')
|
| 85 |
+
img_tensor = transform(img)
|
| 86 |
+
images.append(img_tensor)
|
| 87 |
+
valid_paths.append(path)
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"Warning: Failed to load image {path}: {e}")
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
if len(images) == 0:
|
| 93 |
+
raise ValueError("No valid images found!")
|
| 94 |
+
|
| 95 |
+
# Stack to batch: [N, C, H, W] - 模型期望直接的图像堆叠,不需要额外的batch维度
|
| 96 |
+
images_batch = torch.stack(images, dim=0)
|
| 97 |
+
|
| 98 |
+
return images_batch, valid_paths
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def predict(model, images_batch: torch.Tensor, num_images: int,
|
| 102 |
+
device: torch.device, threshold: float = 0.5):
|
| 103 |
+
"""执行推理"""
|
| 104 |
+
images_batch = images_batch.to(device)
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
# Forward pass
|
| 108 |
+
logits = model(images_batch, [num_images])
|
| 109 |
+
probs = torch.sigmoid(logits).cpu().numpy()[0] # [num_class]
|
| 110 |
+
|
| 111 |
+
# 根据阈值获取预测标签
|
| 112 |
+
predictions = (probs >= threshold).astype(int)
|
| 113 |
+
|
| 114 |
+
return probs, predictions
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def format_results(probs: np.ndarray, predictions: np.ndarray, threshold: float):
|
| 118 |
+
"""格式化输出结果"""
|
| 119 |
+
print("\n" + "=" * 60)
|
| 120 |
+
print(" 超声提示多标签分类结果")
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
print(f" 阈值 (Threshold): {threshold}")
|
| 123 |
+
print("-" * 60)
|
| 124 |
+
|
| 125 |
+
# 按概率排序
|
| 126 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 127 |
+
|
| 128 |
+
print(f"\n{'类别':<20} {'概率':>10} {'预测':>8}")
|
| 129 |
+
print("-" * 40)
|
| 130 |
+
|
| 131 |
+
predicted_labels = []
|
| 132 |
+
for idx in sorted_indices:
|
| 133 |
+
class_name = TARGET_CLASSES[idx]
|
| 134 |
+
prob = probs[idx]
|
| 135 |
+
pred = "✓" if predictions[idx] == 1 else ""
|
| 136 |
+
|
| 137 |
+
# 使用 GBK 编码计算显示宽度
|
| 138 |
+
try:
|
| 139 |
+
display_width = len(class_name.encode('gbk'))
|
| 140 |
+
except:
|
| 141 |
+
display_width = len(class_name) * 2
|
| 142 |
+
|
| 143 |
+
padding = 20 - display_width
|
| 144 |
+
aligned_name = class_name + " " * max(0, padding)
|
| 145 |
+
|
| 146 |
+
print(f"{aligned_name} {prob:>10.4f} {pred:>8}")
|
| 147 |
+
|
| 148 |
+
if predictions[idx] == 1:
|
| 149 |
+
predicted_labels.append(class_name)
|
| 150 |
+
|
| 151 |
+
print("\n" + "=" * 60)
|
| 152 |
+
print(" 预测标签汇总")
|
| 153 |
+
print("=" * 60)
|
| 154 |
+
|
| 155 |
+
if predicted_labels:
|
| 156 |
+
for label in predicted_labels:
|
| 157 |
+
print(f" • {label}")
|
| 158 |
+
else:
|
| 159 |
+
print(" 无预测标签(所有类别概率均低于阈值)")
|
| 160 |
+
|
| 161 |
+
print("=" * 60 + "\n")
|
| 162 |
+
|
| 163 |
+
return predicted_labels
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def main():
|
| 167 |
+
parser = argparse.ArgumentParser(description='超声提示多标签分类 - 单步推理')
|
| 168 |
+
parser.add_argument('--images', nargs='*', default=None,
|
| 169 |
+
help='图像路径列表 (支持多个图像)')
|
| 170 |
+
parser.add_argument('--image_dir', type=str, default=None,
|
| 171 |
+
help='图像文件夹路径 (自动加载文件夹内所有图像)')
|
| 172 |
+
parser.add_argument('--checkpoint', type=str,
|
| 173 |
+
default='checkpoints/checkpoint_best.pth',
|
| 174 |
+
help='模型权重路径')
|
| 175 |
+
parser.add_argument('--threshold', type=float, default=0.5,
|
| 176 |
+
help='分类阈值 (default: 0.5)')
|
| 177 |
+
parser.add_argument('--device', type=str, default='auto',
|
| 178 |
+
help='设备: auto, cuda, cpu')
|
| 179 |
+
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
|
| 182 |
+
# 收集图像路径
|
| 183 |
+
image_paths = []
|
| 184 |
+
|
| 185 |
+
# 从 --images 参数收集
|
| 186 |
+
if args.images:
|
| 187 |
+
image_paths.extend(args.images)
|
| 188 |
+
|
| 189 |
+
# 从 --image_dir 参数收集
|
| 190 |
+
if args.image_dir:
|
| 191 |
+
if not os.path.isdir(args.image_dir):
|
| 192 |
+
print(f"Error: Image directory not found: {args.image_dir}")
|
| 193 |
+
sys.exit(1)
|
| 194 |
+
|
| 195 |
+
# 支持的图像格式
|
| 196 |
+
image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
|
| 197 |
+
|
| 198 |
+
for filename in sorted(os.listdir(args.image_dir)):
|
| 199 |
+
ext = os.path.splitext(filename)[1].lower()
|
| 200 |
+
if ext in image_extensions:
|
| 201 |
+
image_paths.append(os.path.join(args.image_dir, filename))
|
| 202 |
+
|
| 203 |
+
print(f"Found {len(image_paths)} images in {args.image_dir}")
|
| 204 |
+
|
| 205 |
+
# 检查是否有图像输入
|
| 206 |
+
if not image_paths:
|
| 207 |
+
print("Error: No images specified. Use --images or --image_dir")
|
| 208 |
+
parser.print_help()
|
| 209 |
+
sys.exit(1)
|
| 210 |
+
|
| 211 |
+
# 设置设备
|
| 212 |
+
if args.device == 'auto':
|
| 213 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 214 |
+
else:
|
| 215 |
+
device = torch.device(args.device)
|
| 216 |
+
print(f"Using device: {device}")
|
| 217 |
+
|
| 218 |
+
# 处理相对路径
|
| 219 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 220 |
+
checkpoint_path = args.checkpoint
|
| 221 |
+
if not os.path.isabs(checkpoint_path):
|
| 222 |
+
checkpoint_path = os.path.join(script_dir, checkpoint_path)
|
| 223 |
+
|
| 224 |
+
# 加载模型
|
| 225 |
+
model = load_model(checkpoint_path, device)
|
| 226 |
+
|
| 227 |
+
# 预处理图像
|
| 228 |
+
print(f"\nProcessing {len(image_paths)} image(s)...")
|
| 229 |
+
images_batch, valid_paths = preprocess_images(image_paths)
|
| 230 |
+
print(f"Successfully loaded {len(valid_paths)} image(s)")
|
| 231 |
+
|
| 232 |
+
# 推理
|
| 233 |
+
probs, predictions = predict(model, images_batch, len(valid_paths),
|
| 234 |
+
device, args.threshold)
|
| 235 |
+
|
| 236 |
+
# 输出结果
|
| 237 |
+
predicted_labels = format_results(probs, predictions, args.threshold)
|
| 238 |
+
|
| 239 |
+
# 返回预测标签列表(供程序调用)
|
| 240 |
+
return predicted_labels
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Thyroid Ultrasound Hint Multi-Label Classification
|
| 2 |
+
|
| 3 |
+
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
| 4 |
+
from models.aslloss import AsymmetricLoss, AsymmetricLossOptimized
|
| 5 |
+
from models.transformer import TransformerDecoder, TransformerDecoderLayer
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'TransMIL_Query2Label_E2E',
|
| 9 |
+
'AsymmetricLoss',
|
| 10 |
+
'AsymmetricLossOptimized',
|
| 11 |
+
'TransformerDecoder',
|
| 12 |
+
'TransformerDecoderLayer',
|
| 13 |
+
]
|
models/aslloss.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Most borrow from: https://github.com/Alibaba-MIIL/ASL
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AsymmetricLoss(nn.Module):
|
| 9 |
+
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
|
| 10 |
+
super(AsymmetricLoss, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.gamma_neg = gamma_neg
|
| 13 |
+
self.gamma_pos = gamma_pos
|
| 14 |
+
self.clip = clip
|
| 15 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
| 16 |
+
self.eps = eps
|
| 17 |
+
|
| 18 |
+
def forward(self, x, y):
|
| 19 |
+
""""
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
x: input logits
|
| 23 |
+
y: targets (multi-label binarized vector)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# Calculating Probabilities
|
| 27 |
+
x_sigmoid = torch.sigmoid(x)
|
| 28 |
+
xs_pos = x_sigmoid
|
| 29 |
+
xs_neg = 1 - x_sigmoid
|
| 30 |
+
|
| 31 |
+
# Asymmetric Clipping
|
| 32 |
+
if self.clip is not None and self.clip > 0:
|
| 33 |
+
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
| 34 |
+
|
| 35 |
+
# Basic CE calculation
|
| 36 |
+
los_pos = y * torch.log(xs_pos.clamp(min=self.eps, max=1-self.eps))
|
| 37 |
+
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps, max=1-self.eps))
|
| 38 |
+
loss = los_pos + los_neg
|
| 39 |
+
|
| 40 |
+
# Asymmetric Focusing
|
| 41 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
| 42 |
+
if self.disable_torch_grad_focal_loss:
|
| 43 |
+
torch._C.set_grad_enabled(False)
|
| 44 |
+
pt0 = xs_pos * y
|
| 45 |
+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
| 46 |
+
pt = pt0 + pt1
|
| 47 |
+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
| 48 |
+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
| 49 |
+
if self.disable_torch_grad_focal_loss:
|
| 50 |
+
torch._C.set_grad_enabled(True)
|
| 51 |
+
loss *= one_sided_w
|
| 52 |
+
|
| 53 |
+
return -loss.sum()
|
| 54 |
+
|
| 55 |
+
class AsymmetricLossOptimized(nn.Module):
|
| 56 |
+
''' Notice - optimized version, minimizes memory allocation and gpu uploading,
|
| 57 |
+
favors inplace operations'''
|
| 58 |
+
|
| 59 |
+
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-5, disable_torch_grad_focal_loss=False):
|
| 60 |
+
super(AsymmetricLossOptimized, self).__init__()
|
| 61 |
+
|
| 62 |
+
self.gamma_neg = gamma_neg
|
| 63 |
+
self.gamma_pos = gamma_pos
|
| 64 |
+
self.clip = clip
|
| 65 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
| 66 |
+
self.eps = eps
|
| 67 |
+
|
| 68 |
+
self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
|
| 69 |
+
|
| 70 |
+
def forward(self, x, y):
|
| 71 |
+
""""
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
x: input logits
|
| 75 |
+
y: targets (multi-label binarized vector)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
self.targets = y
|
| 79 |
+
self.anti_targets = 1 - y
|
| 80 |
+
|
| 81 |
+
# Calculating Probabilities
|
| 82 |
+
self.xs_pos = torch.sigmoid(x)
|
| 83 |
+
self.xs_neg = 1.0 - self.xs_pos
|
| 84 |
+
|
| 85 |
+
# Asymmetric Clipping
|
| 86 |
+
if self.clip is not None and self.clip > 0:
|
| 87 |
+
self.xs_neg.add_(self.clip).clamp_(max=1)
|
| 88 |
+
|
| 89 |
+
# Basic CE calculation
|
| 90 |
+
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
|
| 91 |
+
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
|
| 92 |
+
|
| 93 |
+
# Asymmetric Focusing
|
| 94 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
| 95 |
+
if self.disable_torch_grad_focal_loss:
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
# if self.disable_torch_grad_focal_loss:
|
| 98 |
+
# torch._C.set_grad_enabled(False)
|
| 99 |
+
self.xs_pos = self.xs_pos * self.targets
|
| 100 |
+
self.xs_neg = self.xs_neg * self.anti_targets
|
| 101 |
+
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
|
| 102 |
+
self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
|
| 103 |
+
# if self.disable_torch_grad_focal_loss:
|
| 104 |
+
# torch._C.set_grad_enabled(True)
|
| 105 |
+
self.loss *= self.asymmetric_w
|
| 106 |
+
else:
|
| 107 |
+
self.xs_pos = self.xs_pos * self.targets
|
| 108 |
+
self.xs_neg = self.xs_neg * self.anti_targets
|
| 109 |
+
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
|
| 110 |
+
self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
|
| 111 |
+
self.loss *= self.asymmetric_w
|
| 112 |
+
_loss = - self.loss.sum() / x.size(0)
|
| 113 |
+
_loss = _loss / y.size(1) * 1000
|
| 114 |
+
|
| 115 |
+
return _loss
|
models/transformer.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Q2L Transformer class.
|
| 4 |
+
|
| 5 |
+
Most borrow from DETR except:
|
| 6 |
+
* remove self-attention by default.
|
| 7 |
+
|
| 8 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
| 9 |
+
* positional encodings are passed in MHattention
|
| 10 |
+
* extra LN at the end of encoder is removed
|
| 11 |
+
* decoder returns a stack of activations from all decoding layers
|
| 12 |
+
* using modified multihead attention from nn_multiheadattention.py
|
| 13 |
+
"""
|
| 14 |
+
import copy
|
| 15 |
+
from typing import Optional, List
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch import nn, Tensor
|
| 20 |
+
from torch.nn import MultiheadAttention
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Transformer(nn.Module):
|
| 25 |
+
|
| 26 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
| 27 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
| 28 |
+
activation="relu", normalize_before=False,
|
| 29 |
+
return_intermediate_dec=False,
|
| 30 |
+
rm_self_attn_dec=True, rm_first_self_attn=True,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.num_encoder_layers = num_encoder_layers
|
| 35 |
+
if num_decoder_layers > 0:
|
| 36 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
| 37 |
+
dropout, activation, normalize_before)
|
| 38 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 39 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 40 |
+
|
| 41 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
| 42 |
+
dropout, activation, normalize_before)
|
| 43 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 44 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
| 45 |
+
return_intermediate=return_intermediate_dec)
|
| 46 |
+
|
| 47 |
+
self._reset_parameters()
|
| 48 |
+
|
| 49 |
+
self.d_model = d_model
|
| 50 |
+
self.nhead = nhead
|
| 51 |
+
self.rm_self_attn_dec = rm_self_attn_dec
|
| 52 |
+
self.rm_first_self_attn = rm_first_self_attn
|
| 53 |
+
|
| 54 |
+
if self.rm_self_attn_dec or self.rm_first_self_attn:
|
| 55 |
+
self.rm_self_attn_dec_func()
|
| 56 |
+
|
| 57 |
+
# self.debug_mode = False
|
| 58 |
+
# self.set_debug_mode(self.debug_mode)
|
| 59 |
+
|
| 60 |
+
def rm_self_attn_dec_func(self):
|
| 61 |
+
total_modifie_layer_num = 0
|
| 62 |
+
rm_list = []
|
| 63 |
+
for idx, layer in enumerate(self.decoder.layers):
|
| 64 |
+
if idx == 0 and not self.rm_first_self_attn:
|
| 65 |
+
continue
|
| 66 |
+
if idx != 0 and not self.rm_self_attn_dec:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
layer.omit_selfattn = True
|
| 70 |
+
del layer.self_attn
|
| 71 |
+
del layer.dropout1
|
| 72 |
+
del layer.norm1
|
| 73 |
+
|
| 74 |
+
total_modifie_layer_num += 1
|
| 75 |
+
rm_list.append(idx)
|
| 76 |
+
# remove some self-attention layer
|
| 77 |
+
# print("rm {} layer: {}".format(total_modifie_layer_num, rm_list))
|
| 78 |
+
|
| 79 |
+
def set_debug_mode(self, status):
|
| 80 |
+
print("set debug mode to {}!!!".format(status))
|
| 81 |
+
self.debug_mode = status
|
| 82 |
+
if hasattr(self, 'encoder'):
|
| 83 |
+
for idx, layer in enumerate(self.encoder.layers):
|
| 84 |
+
layer.debug_mode = status
|
| 85 |
+
layer.debug_name = str(idx)
|
| 86 |
+
if hasattr(self, 'decoder'):
|
| 87 |
+
for idx, layer in enumerate(self.decoder.layers):
|
| 88 |
+
layer.debug_mode = status
|
| 89 |
+
layer.debug_name = str(idx)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _reset_parameters(self):
|
| 93 |
+
for p in self.parameters():
|
| 94 |
+
if p.dim() > 1:
|
| 95 |
+
nn.init.xavier_uniform_(p)
|
| 96 |
+
|
| 97 |
+
def forward(self, src, query_embed, pos_embed, mask=None):
|
| 98 |
+
# flatten NxCxHxW to HWxNxC
|
| 99 |
+
bs, c, h, w = src.shape
|
| 100 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 101 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
| 102 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
| 103 |
+
if mask is not None:
|
| 104 |
+
mask = mask.flatten(1)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if self.num_encoder_layers > 0:
|
| 108 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
| 109 |
+
else:
|
| 110 |
+
memory = src
|
| 111 |
+
|
| 112 |
+
tgt = torch.zeros_like(query_embed)
|
| 113 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
| 114 |
+
pos=pos_embed, query_pos=query_embed)
|
| 115 |
+
|
| 116 |
+
return hs.transpose(1, 2), memory[:h*w].permute(1, 2, 0).view(bs, c, h, w)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TransformerEncoder(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 124 |
+
self.num_layers = num_layers
|
| 125 |
+
self.norm = norm
|
| 126 |
+
|
| 127 |
+
def forward(self, src,
|
| 128 |
+
mask: Optional[Tensor] = None,
|
| 129 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 130 |
+
pos: Optional[Tensor] = None):
|
| 131 |
+
output = src
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
output = layer(output, src_mask=mask,
|
| 135 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
| 136 |
+
|
| 137 |
+
if self.norm is not None:
|
| 138 |
+
output = self.norm(output)
|
| 139 |
+
|
| 140 |
+
return output
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class TransformerDecoder(nn.Module):
|
| 144 |
+
|
| 145 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 148 |
+
self.num_layers = num_layers
|
| 149 |
+
self.norm = norm
|
| 150 |
+
self.return_intermediate = return_intermediate
|
| 151 |
+
|
| 152 |
+
def forward(self, tgt, memory,
|
| 153 |
+
tgt_mask: Optional[Tensor] = None,
|
| 154 |
+
memory_mask: Optional[Tensor] = None,
|
| 155 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 156 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 157 |
+
pos: Optional[Tensor] = None,
|
| 158 |
+
query_pos: Optional[Tensor] = None):
|
| 159 |
+
output = tgt
|
| 160 |
+
|
| 161 |
+
intermediate = []
|
| 162 |
+
|
| 163 |
+
for layer in self.layers:
|
| 164 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
| 165 |
+
memory_mask=memory_mask,
|
| 166 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 167 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 168 |
+
pos=pos, query_pos=query_pos)
|
| 169 |
+
if self.return_intermediate:
|
| 170 |
+
intermediate.append(self.norm(output))
|
| 171 |
+
|
| 172 |
+
if self.norm is not None:
|
| 173 |
+
output = self.norm(output)
|
| 174 |
+
if self.return_intermediate:
|
| 175 |
+
intermediate.pop()
|
| 176 |
+
intermediate.append(output)
|
| 177 |
+
|
| 178 |
+
if self.return_intermediate:
|
| 179 |
+
return torch.stack(intermediate)
|
| 180 |
+
|
| 181 |
+
return output.unsqueeze(0)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class TransformerEncoderLayer(nn.Module):
|
| 185 |
+
|
| 186 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 187 |
+
activation="relu", normalize_before=False):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 190 |
+
# Implementation of Feedforward model
|
| 191 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 192 |
+
self.dropout = nn.Dropout(dropout)
|
| 193 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 194 |
+
|
| 195 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 196 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 197 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 198 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 199 |
+
|
| 200 |
+
self.activation = _get_activation_fn(activation)
|
| 201 |
+
self.normalize_before = normalize_before
|
| 202 |
+
|
| 203 |
+
self.debug_mode = False
|
| 204 |
+
self.debug_name = None
|
| 205 |
+
|
| 206 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 207 |
+
return tensor if pos is None else tensor + pos
|
| 208 |
+
|
| 209 |
+
def forward_post(self,
|
| 210 |
+
src,
|
| 211 |
+
src_mask: Optional[Tensor] = None,
|
| 212 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 213 |
+
pos: Optional[Tensor] = None):
|
| 214 |
+
q = k = self.with_pos_embed(src, pos)
|
| 215 |
+
src2, corr = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
| 216 |
+
key_padding_mask=src_key_padding_mask)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
src = src + self.dropout1(src2)
|
| 220 |
+
src = self.norm1(src)
|
| 221 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 222 |
+
src = src + self.dropout2(src2)
|
| 223 |
+
src = self.norm2(src)
|
| 224 |
+
return src
|
| 225 |
+
|
| 226 |
+
def forward_pre(self, src,
|
| 227 |
+
src_mask: Optional[Tensor] = None,
|
| 228 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 229 |
+
pos: Optional[Tensor] = None):
|
| 230 |
+
src2 = self.norm1(src)
|
| 231 |
+
q = k = self.with_pos_embed(src2, pos)
|
| 232 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
| 233 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 234 |
+
|
| 235 |
+
src = src + self.dropout1(src2)
|
| 236 |
+
src2 = self.norm2(src)
|
| 237 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 238 |
+
src = src + self.dropout2(src2)
|
| 239 |
+
return src
|
| 240 |
+
|
| 241 |
+
def forward(self, src,
|
| 242 |
+
src_mask: Optional[Tensor] = None,
|
| 243 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 244 |
+
pos: Optional[Tensor] = None):
|
| 245 |
+
if self.normalize_before:
|
| 246 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 247 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class TransformerDecoderLayer(nn.Module):
|
| 251 |
+
|
| 252 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 253 |
+
activation="relu", normalize_before=False):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 256 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 257 |
+
# Implementation of Feedforward model
|
| 258 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 259 |
+
self.dropout = nn.Dropout(dropout)
|
| 260 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 261 |
+
|
| 262 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 263 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 264 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 265 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 266 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 267 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 268 |
+
|
| 269 |
+
self.activation = _get_activation_fn(activation)
|
| 270 |
+
self.normalize_before = normalize_before
|
| 271 |
+
|
| 272 |
+
self.debug_mode = False
|
| 273 |
+
self.debug_name = None
|
| 274 |
+
self.omit_selfattn = False
|
| 275 |
+
|
| 276 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 277 |
+
return tensor if pos is None else tensor + pos
|
| 278 |
+
|
| 279 |
+
def forward_post(self, tgt, memory,
|
| 280 |
+
tgt_mask: Optional[Tensor] = None,
|
| 281 |
+
memory_mask: Optional[Tensor] = None,
|
| 282 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 283 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 284 |
+
pos: Optional[Tensor] = None,
|
| 285 |
+
query_pos: Optional[Tensor] = None):
|
| 286 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 287 |
+
|
| 288 |
+
if not self.omit_selfattn:
|
| 289 |
+
tgt2, sim_mat_1 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
| 290 |
+
key_padding_mask=tgt_key_padding_mask)
|
| 291 |
+
|
| 292 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 293 |
+
tgt = self.norm1(tgt)
|
| 294 |
+
|
| 295 |
+
tgt2, sim_mat_2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
| 296 |
+
key=self.with_pos_embed(memory, pos),
|
| 297 |
+
value=memory, attn_mask=memory_mask,
|
| 298 |
+
key_padding_mask=memory_key_padding_mask)
|
| 299 |
+
|
| 300 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 301 |
+
tgt = self.norm2(tgt)
|
| 302 |
+
|
| 303 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 304 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 305 |
+
tgt = self.norm3(tgt)
|
| 306 |
+
return tgt
|
| 307 |
+
|
| 308 |
+
def forward_pre(self, tgt, memory,
|
| 309 |
+
tgt_mask: Optional[Tensor] = None,
|
| 310 |
+
memory_mask: Optional[Tensor] = None,
|
| 311 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 312 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 313 |
+
pos: Optional[Tensor] = None,
|
| 314 |
+
query_pos: Optional[Tensor] = None):
|
| 315 |
+
tgt2 = self.norm1(tgt)
|
| 316 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 317 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 318 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 319 |
+
|
| 320 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 321 |
+
tgt2 = self.norm2(tgt)
|
| 322 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 323 |
+
key=self.with_pos_embed(memory, pos),
|
| 324 |
+
value=memory, attn_mask=memory_mask,
|
| 325 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 326 |
+
|
| 327 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 328 |
+
tgt2 = self.norm3(tgt)
|
| 329 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 330 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 331 |
+
return tgt
|
| 332 |
+
|
| 333 |
+
def forward(self, tgt, memory,
|
| 334 |
+
tgt_mask: Optional[Tensor] = None,
|
| 335 |
+
memory_mask: Optional[Tensor] = None,
|
| 336 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 337 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 338 |
+
pos: Optional[Tensor] = None,
|
| 339 |
+
query_pos: Optional[Tensor] = None):
|
| 340 |
+
if self.normalize_before:
|
| 341 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
| 342 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 343 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
| 344 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _get_clones(module, N):
|
| 348 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def build_transformer(args):
|
| 352 |
+
return Transformer(
|
| 353 |
+
d_model=args.hidden_dim,
|
| 354 |
+
dropout=args.dropout,
|
| 355 |
+
nhead=args.nheads,
|
| 356 |
+
dim_feedforward=args.dim_feedforward,
|
| 357 |
+
num_encoder_layers=args.enc_layers,
|
| 358 |
+
num_decoder_layers=args.dec_layers,
|
| 359 |
+
normalize_before=args.pre_norm,
|
| 360 |
+
return_intermediate_dec=False,
|
| 361 |
+
rm_self_attn_dec=not args.keep_other_self_attn_dec,
|
| 362 |
+
rm_first_self_attn=not args.keep_first_self_attn_dec,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def _get_activation_fn(activation):
|
| 367 |
+
"""Return an activation function given a string"""
|
| 368 |
+
if activation == "relu":
|
| 369 |
+
return F.relu
|
| 370 |
+
if activation == "gelu":
|
| 371 |
+
return F.gelu
|
| 372 |
+
if activation == "glu":
|
| 373 |
+
return F.glu
|
| 374 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
models/transmil_q2l.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid TransMIL + Query2Label Architecture
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- TransMIL's instance-level feature aggregation (with Nystrom attention)
|
| 6 |
+
- Query2Label's learnable label queries with cross-attention decoder
|
| 7 |
+
- End-to-end training with ResNet-50 backbone
|
| 8 |
+
|
| 9 |
+
Key Innovation: Extract sequence features from TransMIL BEFORE CLS aggregation,
|
| 10 |
+
allowing Q2L label queries to cross-attend across all ultrasound images per case.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
# Add models directory to path for local imports
|
| 17 |
+
_models_dir = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
if _models_dir not in sys.path:
|
| 19 |
+
sys.path.insert(0, _models_dir)
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import torchvision
|
| 25 |
+
import numpy as np
|
| 26 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 27 |
+
|
| 28 |
+
# Import TransMIL components (from nystrom-attention package)
|
| 29 |
+
from nystrom_attention import NystromAttention
|
| 30 |
+
|
| 31 |
+
# Import Q2L Transformer components from local transformer.py
|
| 32 |
+
try:
|
| 33 |
+
from models.transformer import TransformerDecoder, TransformerDecoderLayer
|
| 34 |
+
except ImportError:
|
| 35 |
+
try:
|
| 36 |
+
from transformer import TransformerDecoder, TransformerDecoderLayer
|
| 37 |
+
except ImportError:
|
| 38 |
+
print("Warning: Could not import Q2L Transformer components.")
|
| 39 |
+
|
| 40 |
+
# ============================================================================
|
| 41 |
+
# TransMIL Components (Modified)
|
| 42 |
+
# ============================================================================
|
| 43 |
+
|
| 44 |
+
class TransLayer(nn.Module):
|
| 45 |
+
"""Transformer layer with Nystrom attention (from TransMIL)"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, norm_layer=nn.LayerNorm, dim=512):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.norm = norm_layer(dim)
|
| 50 |
+
self.attn = NystromAttention(
|
| 51 |
+
dim=dim,
|
| 52 |
+
dim_head=dim // 8,
|
| 53 |
+
heads=8,
|
| 54 |
+
num_landmarks=dim // 2,
|
| 55 |
+
pinv_iterations=6,
|
| 56 |
+
residual=True,
|
| 57 |
+
dropout=0.1
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = x + self.attn(self.norm(x))
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TransMILFeatureExtractor(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Modified TransMIL that outputs sequence features instead of aggregated CLS token.
|
| 68 |
+
|
| 69 |
+
Based on TransMIL.py but extracts features BEFORE CLS aggregation (line 83 output).
|
| 70 |
+
Uses learned 1D position encoding instead of PPEG for simplicity.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
input_dim: Dimension of input features (2048 for ResNet-50)
|
| 74 |
+
hidden_dim: Dimension of hidden features (512 default)
|
| 75 |
+
use_ppeg: Whether to use PPEG (2D positional encoding) or learned 1D encoding
|
| 76 |
+
max_seq_len: Maximum sequence length for position encoding
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, input_dim=2048, hidden_dim=512, use_ppeg=False, max_seq_len=1024):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
# Feature projection (TransMIL line 50)
|
| 83 |
+
self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU())
|
| 84 |
+
|
| 85 |
+
# Learnable CLS token (TransMIL line 51)
|
| 86 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
|
| 87 |
+
|
| 88 |
+
# Transformer layers (TransMIL lines 53-54)
|
| 89 |
+
self.layer1 = TransLayer(dim=hidden_dim)
|
| 90 |
+
self.layer2 = TransLayer(dim=hidden_dim)
|
| 91 |
+
|
| 92 |
+
# LayerNorm (TransMIL line 55)
|
| 93 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
| 94 |
+
|
| 95 |
+
# Position encoding
|
| 96 |
+
self.use_ppeg = use_ppeg
|
| 97 |
+
if not use_ppeg:
|
| 98 |
+
# Learned 1D position encoding (simpler than PPEG)
|
| 99 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
|
| 100 |
+
else:
|
| 101 |
+
# PPEG: Position-aware Patch Embedding Generator (requires 2D reshaping)
|
| 102 |
+
self.pos_layer = PPEG(dim=hidden_dim)
|
| 103 |
+
|
| 104 |
+
def forward(self, features, mask=None):
|
| 105 |
+
"""
|
| 106 |
+
Args:
|
| 107 |
+
features: [B, N, input_dim] - Instance features (e.g., from ResNet-50)
|
| 108 |
+
mask: [B, N] - Padding mask (True = valid instance, False = padded)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
seq_features: [B, 1+N, hidden_dim] - Sequence features (CLS + instances)
|
| 112 |
+
attn_mask: [B, 1+N] - Attention mask for decoder
|
| 113 |
+
"""
|
| 114 |
+
B, N, _ = features.shape
|
| 115 |
+
|
| 116 |
+
# Project features (TransMIL line 63)
|
| 117 |
+
h = self.fc1(features) # [B, N, hidden_dim]
|
| 118 |
+
|
| 119 |
+
# Handle PPEG padding if needed
|
| 120 |
+
if self.use_ppeg:
|
| 121 |
+
# Pad to nearest square for PPEG (TransMIL lines 65-69)
|
| 122 |
+
H = h.shape[1]
|
| 123 |
+
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
|
| 124 |
+
add_length = _H * _W - H
|
| 125 |
+
if add_length > 0:
|
| 126 |
+
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N_padded, hidden_dim]
|
| 127 |
+
|
| 128 |
+
# Update mask
|
| 129 |
+
if mask is not None:
|
| 130 |
+
pad_mask = torch.zeros(B, add_length, dtype=torch.bool, device=mask.device)
|
| 131 |
+
mask = torch.cat([mask, pad_mask], dim=1)
|
| 132 |
+
|
| 133 |
+
# Add CLS token (TransMIL lines 72-74)
|
| 134 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 135 |
+
h = torch.cat([cls_tokens, h], dim=1) # [B, 1+N, hidden_dim]
|
| 136 |
+
|
| 137 |
+
# Update mask to include CLS (always valid)
|
| 138 |
+
if mask is not None:
|
| 139 |
+
cls_mask = torch.ones(B, 1, dtype=torch.bool, device=mask.device)
|
| 140 |
+
attn_mask = torch.cat([cls_mask, mask], dim=1) # [B, 1+N]
|
| 141 |
+
else:
|
| 142 |
+
attn_mask = torch.ones(B, h.shape[1], dtype=torch.bool, device=h.device)
|
| 143 |
+
|
| 144 |
+
# TransLayer 1 (TransMIL line 77)
|
| 145 |
+
h = self.layer1(h) # [B, 1+N, hidden_dim]
|
| 146 |
+
|
| 147 |
+
# Position encoding
|
| 148 |
+
if self.use_ppeg:
|
| 149 |
+
# PPEG (TransMIL line 80)
|
| 150 |
+
h = self.pos_layer(h, _H, _W)
|
| 151 |
+
else:
|
| 152 |
+
# Learned 1D position encoding
|
| 153 |
+
seq_len = h.shape[1]
|
| 154 |
+
h = h + self.pos_embedding[:, :seq_len, :]
|
| 155 |
+
|
| 156 |
+
# TransLayer 2 (TransMIL line 83)
|
| 157 |
+
h = self.layer2(h) # [B, 1+N, hidden_dim]
|
| 158 |
+
|
| 159 |
+
# LayerNorm (TransMIL line 86, but keep full sequence)
|
| 160 |
+
h = self.norm(h) # [B, 1+N, hidden_dim]
|
| 161 |
+
|
| 162 |
+
# CRITICAL: Return full sequence, not just CLS token
|
| 163 |
+
return h, attn_mask
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class PPEG(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
Position-aware Patch Embedding Generator (from TransMIL)
|
| 169 |
+
Uses 2D depthwise convolutions to inject spatial positional information.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, dim=512):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim)
|
| 175 |
+
self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim)
|
| 176 |
+
self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim)
|
| 177 |
+
|
| 178 |
+
def forward(self, x, H, W):
|
| 179 |
+
"""
|
| 180 |
+
Args:
|
| 181 |
+
x: [B, 1+N, C] - Token sequence (CLS + instances)
|
| 182 |
+
H, W: Grid dimensions (H * W >= N)
|
| 183 |
+
"""
|
| 184 |
+
B, _, C = x.shape
|
| 185 |
+
|
| 186 |
+
# Separate CLS token and feature tokens
|
| 187 |
+
cls_token, feat_token = x[:, 0], x[:, 1:]
|
| 188 |
+
|
| 189 |
+
# Reshape to 2D grid
|
| 190 |
+
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
|
| 191 |
+
|
| 192 |
+
# Apply 2D convolutions
|
| 193 |
+
x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat)
|
| 194 |
+
|
| 195 |
+
# Flatten back to sequence
|
| 196 |
+
x = x.flatten(2).transpose(1, 2)
|
| 197 |
+
|
| 198 |
+
# Concatenate CLS token back
|
| 199 |
+
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
|
| 200 |
+
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ============================================================================
|
| 205 |
+
# Query2Label Components (Adapted for Sequences)
|
| 206 |
+
# ============================================================================
|
| 207 |
+
|
| 208 |
+
class GroupWiseLinear(nn.Module):
|
| 209 |
+
"""
|
| 210 |
+
Group-wise linear layer for per-class classification (from Q2L).
|
| 211 |
+
Applies a separate linear transformation for each class.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self, num_class, hidden_dim, bias=True):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.num_class = num_class
|
| 217 |
+
self.hidden_dim = hidden_dim
|
| 218 |
+
self.bias = bias
|
| 219 |
+
|
| 220 |
+
self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
|
| 221 |
+
if bias:
|
| 222 |
+
self.b = nn.Parameter(torch.Tensor(1, num_class))
|
| 223 |
+
self.reset_parameters()
|
| 224 |
+
|
| 225 |
+
def reset_parameters(self):
|
| 226 |
+
import math
|
| 227 |
+
stdv = 1. / math.sqrt(self.W.size(2))
|
| 228 |
+
for i in range(self.num_class):
|
| 229 |
+
self.W[0][i].data.uniform_(-stdv, stdv)
|
| 230 |
+
if self.bias:
|
| 231 |
+
for i in range(self.num_class):
|
| 232 |
+
self.b[0][i].data.uniform_(-stdv, stdv)
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
"""
|
| 236 |
+
Args:
|
| 237 |
+
x: [B, num_class, hidden_dim]
|
| 238 |
+
Returns:
|
| 239 |
+
logits: [B, num_class]
|
| 240 |
+
"""
|
| 241 |
+
# Element-wise multiplication and sum over hidden_dim
|
| 242 |
+
x = (self.W * x).sum(-1) # [B, num_class]
|
| 243 |
+
if self.bias:
|
| 244 |
+
x = x + self.b
|
| 245 |
+
return x
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class HybridQuery2Label(nn.Module):
|
| 249 |
+
"""
|
| 250 |
+
Query2Label decoder adapted for sequence inputs (not spatial features).
|
| 251 |
+
|
| 252 |
+
Uses learnable label queries to cross-attend to instance sequence from TransMIL.
|
| 253 |
+
Based on query2label.py but modified to accept [B, 1+N, hidden_dim] sequences
|
| 254 |
+
instead of [B, C, H, W] spatial features.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
num_class: Number of label classes
|
| 258 |
+
hidden_dim: Dimension of features (512)
|
| 259 |
+
nheads: Number of attention heads
|
| 260 |
+
num_decoder_layers: Number of transformer decoder layers
|
| 261 |
+
dim_feedforward: Dimension of feedforward network
|
| 262 |
+
dropout: Dropout rate
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
num_class,
|
| 268 |
+
hidden_dim=512,
|
| 269 |
+
nheads=8,
|
| 270 |
+
num_decoder_layers=2,
|
| 271 |
+
dim_feedforward=2048,
|
| 272 |
+
dropout=0.1,
|
| 273 |
+
normalize_before=False
|
| 274 |
+
):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.num_class = num_class
|
| 277 |
+
self.hidden_dim = hidden_dim
|
| 278 |
+
|
| 279 |
+
# Label query embeddings (Q2L line 68)
|
| 280 |
+
self.query_embed = nn.Embedding(num_class, hidden_dim)
|
| 281 |
+
|
| 282 |
+
# Transformer decoder (Q2L uses transformer.py)
|
| 283 |
+
decoder_layer = TransformerDecoderLayer(
|
| 284 |
+
d_model=hidden_dim,
|
| 285 |
+
nhead=nheads,
|
| 286 |
+
dim_feedforward=dim_feedforward,
|
| 287 |
+
dropout=dropout,
|
| 288 |
+
normalize_before=normalize_before
|
| 289 |
+
)
|
| 290 |
+
decoder_norm = nn.LayerNorm(hidden_dim)
|
| 291 |
+
self.decoder = TransformerDecoder(
|
| 292 |
+
decoder_layer,
|
| 293 |
+
num_decoder_layers,
|
| 294 |
+
decoder_norm,
|
| 295 |
+
return_intermediate=False
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Group-wise linear classifier (Q2L line 69)
|
| 299 |
+
self.fc = GroupWiseLinear(num_class, hidden_dim, bias=True)
|
| 300 |
+
|
| 301 |
+
def forward(self, sequence_features, memory_key_padding_mask=None):
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
sequence_features: [B, 1+N, hidden_dim] - Sequence from TransMIL
|
| 305 |
+
memory_key_padding_mask: [B, 1+N] - Padding mask (True = ignore, False = valid)
|
| 306 |
+
NOTE: PyTorch convention is inverted!
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
logits: [B, num_class] - Multi-label classification logits
|
| 310 |
+
"""
|
| 311 |
+
B = sequence_features.shape[0]
|
| 312 |
+
|
| 313 |
+
# Transpose for decoder: expects [seq_len, B, hidden_dim]
|
| 314 |
+
memory = sequence_features.permute(1, 0, 2) # [1+N, B, hidden_dim]
|
| 315 |
+
|
| 316 |
+
# Label queries (Q2L line 77)
|
| 317 |
+
query_embed = self.query_embed.weight # [num_class, hidden_dim]
|
| 318 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, B, 1) # [num_class, B, hidden_dim]
|
| 319 |
+
|
| 320 |
+
# Initialize target (zero tensor)
|
| 321 |
+
tgt = torch.zeros_like(query_embed) # [num_class, B, hidden_dim]
|
| 322 |
+
|
| 323 |
+
# Cross-attention decoder (Q2L line 78)
|
| 324 |
+
# Queries attend to instance sequence via cross-attention
|
| 325 |
+
hs = self.decoder(
|
| 326 |
+
tgt=tgt,
|
| 327 |
+
memory=memory,
|
| 328 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 329 |
+
pos=None, # No positional encoding (already in TransMIL)
|
| 330 |
+
query_pos=query_embed
|
| 331 |
+
) # [1, num_class, B, hidden_dim] if return_intermediate=False
|
| 332 |
+
|
| 333 |
+
# Handle output shape
|
| 334 |
+
if hs.dim() == 4:
|
| 335 |
+
hs = hs[-1] # Take last layer: [num_class, B, hidden_dim]
|
| 336 |
+
|
| 337 |
+
# Transpose to [B, num_class, hidden_dim]
|
| 338 |
+
hs = hs.permute(1, 0, 2) # [B, num_class, hidden_dim]
|
| 339 |
+
|
| 340 |
+
# Group-wise linear classification (Q2L line 79)
|
| 341 |
+
logits = self.fc(hs) # [B, num_class]
|
| 342 |
+
|
| 343 |
+
return logits
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ============================================================================
|
| 347 |
+
# ResNet-50 Backbone
|
| 348 |
+
# ============================================================================
|
| 349 |
+
|
| 350 |
+
class ResNet50Backbone(nn.Module):
|
| 351 |
+
"""
|
| 352 |
+
ResNet-50 feature extractor with Global Average Pooling.
|
| 353 |
+
|
| 354 |
+
Extracts 2048-dimensional features from images for TransMIL input.
|
| 355 |
+
Supports gradient checkpointing for memory efficiency.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
pretrained: Use ImageNet pre-trained weights
|
| 359 |
+
use_checkpointing: Enable gradient checkpointing (saves memory)
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self, pretrained=True, use_checkpointing=False):
|
| 363 |
+
super().__init__()
|
| 364 |
+
|
| 365 |
+
# Load ResNet-50
|
| 366 |
+
resnet = torchvision.models.resnet50(pretrained=pretrained)
|
| 367 |
+
|
| 368 |
+
# Remove final FC layer and avgpool
|
| 369 |
+
# Output of layer4: [B, 2048, 7, 7] for 224x224 input
|
| 370 |
+
self.features = nn.Sequential(*list(resnet.children())[:-2])
|
| 371 |
+
|
| 372 |
+
# Global Average Pooling
|
| 373 |
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
| 374 |
+
|
| 375 |
+
self.use_checkpointing = use_checkpointing
|
| 376 |
+
|
| 377 |
+
def forward(self, images):
|
| 378 |
+
"""
|
| 379 |
+
Args:
|
| 380 |
+
images: [B*N, 3, H, W] - Batch of images (flattened across cases)
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
features: [B*N, 2048] - Instance features
|
| 384 |
+
"""
|
| 385 |
+
if self.training and self.use_checkpointing:
|
| 386 |
+
# Gradient checkpointing: segment backbone into chunks
|
| 387 |
+
# Trades compute for memory (recomputes activations during backward)
|
| 388 |
+
x = checkpoint_sequential(self.features, segments=4, input=images)
|
| 389 |
+
else:
|
| 390 |
+
x = self.features(images) # [B*N, 2048, 7, 7]
|
| 391 |
+
|
| 392 |
+
x = self.gap(x) # [B*N, 2048, 1, 1]
|
| 393 |
+
x = x.flatten(1) # [B*N, 2048]
|
| 394 |
+
|
| 395 |
+
return x
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# ============================================================================
|
| 399 |
+
# Complete End-to-End Model
|
| 400 |
+
# ============================================================================
|
| 401 |
+
|
| 402 |
+
class TransMIL_Query2Label_E2E(nn.Module):
|
| 403 |
+
"""
|
| 404 |
+
Complete end-to-end model: Images → ResNet-50 → TransMIL → Q2L → Logits
|
| 405 |
+
|
| 406 |
+
Pipeline:
|
| 407 |
+
1. ResNet-50 extracts features from each ultrasound image
|
| 408 |
+
2. TransMIL aggregates variable-length instance sequences with attention
|
| 409 |
+
3. Query2Label decoder performs multi-label classification via cross-attention
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
num_class: Number of label classes (default 30)
|
| 413 |
+
hidden_dim: Hidden dimension for TransMIL and Q2L (default 512)
|
| 414 |
+
nheads: Number of attention heads in Q2L decoder
|
| 415 |
+
num_decoder_layers: Number of Q2L decoder layers
|
| 416 |
+
pretrained_resnet: Use ImageNet pre-trained ResNet-50
|
| 417 |
+
use_checkpointing: Enable gradient checkpointing for ResNet-50
|
| 418 |
+
use_ppeg: Use PPEG position encoding (vs learned 1D)
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
num_class=30,
|
| 424 |
+
hidden_dim=512,
|
| 425 |
+
nheads=8,
|
| 426 |
+
num_decoder_layers=2,
|
| 427 |
+
pretrained_resnet=True,
|
| 428 |
+
use_checkpointing=False,
|
| 429 |
+
use_ppeg=False
|
| 430 |
+
):
|
| 431 |
+
super().__init__()
|
| 432 |
+
|
| 433 |
+
# ResNet-50 backbone
|
| 434 |
+
self.backbone = ResNet50Backbone(
|
| 435 |
+
pretrained=pretrained_resnet,
|
| 436 |
+
use_checkpointing=use_checkpointing
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# TransMIL feature extractor (no PPEG by default, learned 1D position encoding)
|
| 440 |
+
self.feature_extractor = TransMILFeatureExtractor(
|
| 441 |
+
input_dim=2048,
|
| 442 |
+
hidden_dim=hidden_dim,
|
| 443 |
+
use_ppeg=use_ppeg
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Query2Label decoder
|
| 447 |
+
self.q2l_decoder = HybridQuery2Label(
|
| 448 |
+
num_class=num_class,
|
| 449 |
+
hidden_dim=hidden_dim,
|
| 450 |
+
nheads=nheads,
|
| 451 |
+
num_decoder_layers=num_decoder_layers
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
def forward(self, images, num_instances_per_case):
|
| 455 |
+
"""
|
| 456 |
+
Args:
|
| 457 |
+
images: [B*N_total, 3, H, W] - All images flattened across batch
|
| 458 |
+
num_instances_per_case: [B] or list - Number of images per case
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
logits: [B, num_class] - Multi-label classification logits
|
| 462 |
+
"""
|
| 463 |
+
# Convert to tensor if list
|
| 464 |
+
if isinstance(num_instances_per_case, list):
|
| 465 |
+
num_instances_per_case = torch.tensor(num_instances_per_case, device=images.device)
|
| 466 |
+
|
| 467 |
+
B = len(num_instances_per_case)
|
| 468 |
+
|
| 469 |
+
# Step 1: Extract features from all images
|
| 470 |
+
all_features = self.backbone(images) # [B*N_total, 2048]
|
| 471 |
+
|
| 472 |
+
# Step 2: Reshape to [B, max_N, 2048] with padding
|
| 473 |
+
max_N = int(num_instances_per_case.max().item())
|
| 474 |
+
features_padded = torch.zeros(B, max_N, 2048, device=images.device)
|
| 475 |
+
masks = torch.zeros(B, max_N, dtype=torch.bool, device=images.device)
|
| 476 |
+
|
| 477 |
+
idx = 0
|
| 478 |
+
for i, n in enumerate(num_instances_per_case):
|
| 479 |
+
n = int(n.item()) if torch.is_tensor(n) else int(n)
|
| 480 |
+
features_padded[i, :n] = all_features[idx:idx+n]
|
| 481 |
+
masks[i, :n] = True # True = valid instance
|
| 482 |
+
idx += n
|
| 483 |
+
|
| 484 |
+
# Step 3: TransMIL sequence features
|
| 485 |
+
seq_features, attn_mask = self.feature_extractor(features_padded, masks)
|
| 486 |
+
# seq_features: [B, 1+max_N, 512]
|
| 487 |
+
# attn_mask: [B, 1+max_N] where True = valid, False = padded
|
| 488 |
+
|
| 489 |
+
# Step 4: Q2L decoder
|
| 490 |
+
# IMPORTANT: PyTorch MultiheadAttention uses inverted mask convention!
|
| 491 |
+
# memory_key_padding_mask: True = ignore, False = attend
|
| 492 |
+
# So we need to invert our mask
|
| 493 |
+
decoder_mask = ~attn_mask # Invert: True = padded (ignore)
|
| 494 |
+
logits = self.q2l_decoder(seq_features, memory_key_padding_mask=decoder_mask)
|
| 495 |
+
# logits: [B, num_class]
|
| 496 |
+
|
| 497 |
+
return logits
|
| 498 |
+
|
| 499 |
+
def freeze_backbone(self):
|
| 500 |
+
"""Freeze ResNet-50 backbone for training only TransMIL+Q2L"""
|
| 501 |
+
for param in self.backbone.parameters():
|
| 502 |
+
param.requires_grad = False
|
| 503 |
+
|
| 504 |
+
def unfreeze_backbone(self):
|
| 505 |
+
"""Unfreeze ResNet-50 for end-to-end fine-tuning"""
|
| 506 |
+
for param in self.backbone.parameters():
|
| 507 |
+
param.requires_grad = True
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# ============================================================================
|
| 511 |
+
# Testing
|
| 512 |
+
# ============================================================================
|
| 513 |
+
|
| 514 |
+
if __name__ == "__main__":
|
| 515 |
+
print("Testing TransMIL_Query2Label_E2E model...")
|
| 516 |
+
|
| 517 |
+
# Model config
|
| 518 |
+
num_class = 30
|
| 519 |
+
batch_size = 2
|
| 520 |
+
num_instances = [8, 12] # Variable N per case
|
| 521 |
+
img_size = 224
|
| 522 |
+
|
| 523 |
+
# Create model
|
| 524 |
+
model = TransMIL_Query2Label_E2E(
|
| 525 |
+
num_class=num_class,
|
| 526 |
+
hidden_dim=512,
|
| 527 |
+
nheads=8,
|
| 528 |
+
num_decoder_layers=2,
|
| 529 |
+
pretrained_resnet=False, # Faster for testing
|
| 530 |
+
use_checkpointing=False,
|
| 531 |
+
use_ppeg=False
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Create dummy data
|
| 535 |
+
total_images = sum(num_instances)
|
| 536 |
+
images = torch.randn(total_images, 3, img_size, img_size)
|
| 537 |
+
|
| 538 |
+
print(f"\nInput shapes:")
|
| 539 |
+
print(f" Images: {images.shape}")
|
| 540 |
+
print(f" Num instances per case: {num_instances}")
|
| 541 |
+
|
| 542 |
+
# Forward pass
|
| 543 |
+
model.eval()
|
| 544 |
+
with torch.no_grad():
|
| 545 |
+
logits = model(images, num_instances)
|
| 546 |
+
|
| 547 |
+
print(f"\nOutput shape:")
|
| 548 |
+
print(f" Logits: {logits.shape}")
|
| 549 |
+
print(f" Expected: [{batch_size}, {num_class}]")
|
| 550 |
+
|
| 551 |
+
assert logits.shape == (batch_size, num_class), "Output shape mismatch!"
|
| 552 |
+
print("\n✓ Model test passed!")
|
| 553 |
+
|
| 554 |
+
# Test individual components
|
| 555 |
+
print("\n" + "="*60)
|
| 556 |
+
print("Testing individual components...")
|
| 557 |
+
print("="*60)
|
| 558 |
+
|
| 559 |
+
# Test TransMILFeatureExtractor
|
| 560 |
+
print("\n1. TransMILFeatureExtractor")
|
| 561 |
+
feature_extractor = TransMILFeatureExtractor(input_dim=2048, hidden_dim=512)
|
| 562 |
+
features = torch.randn(2, 10, 2048)
|
| 563 |
+
mask = torch.ones(2, 10, dtype=torch.bool)
|
| 564 |
+
seq_features, attn_mask = feature_extractor(features, mask)
|
| 565 |
+
print(f" Input: {features.shape}, Output: {seq_features.shape}")
|
| 566 |
+
assert seq_features.shape == (2, 11, 512) # 1 CLS + 10 instances
|
| 567 |
+
print(" ✓ Passed")
|
| 568 |
+
|
| 569 |
+
# Test HybridQuery2Label
|
| 570 |
+
print("\n2. HybridQuery2Label")
|
| 571 |
+
decoder = HybridQuery2Label(num_class=30, hidden_dim=512)
|
| 572 |
+
seq_features = torch.randn(2, 11, 512)
|
| 573 |
+
logits = decoder(seq_features)
|
| 574 |
+
print(f" Input: {seq_features.shape}, Output: {logits.shape}")
|
| 575 |
+
assert logits.shape == (2, 30)
|
| 576 |
+
print(" ✓ Passed")
|
| 577 |
+
|
| 578 |
+
# Test ResNet50Backbone
|
| 579 |
+
print("\n3. ResNet50Backbone")
|
| 580 |
+
backbone = ResNet50Backbone(pretrained=False)
|
| 581 |
+
images = torch.randn(20, 3, 224, 224)
|
| 582 |
+
features = backbone(images)
|
| 583 |
+
print(f" Input: {images.shape}, Output: {features.shape}")
|
| 584 |
+
assert features.shape == (20, 2048)
|
| 585 |
+
print(" ✓ Passed")
|
| 586 |
+
|
| 587 |
+
print("\n" + "="*60)
|
| 588 |
+
print("All tests passed! ✓")
|
| 589 |
+
print("="*60)
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 超声提示多标签分类模型 - 依赖列表
|
| 2 |
+
# TransMIL + Query2Label Hybrid Model Requirements
|
| 3 |
+
|
| 4 |
+
torch>=1.10.0
|
| 5 |
+
torchvision>=0.11.0
|
| 6 |
+
numpy>=1.19.0
|
| 7 |
+
pandas>=1.3.0
|
| 8 |
+
Pillow>=8.0.0
|
| 9 |
+
scikit-learn>=0.24.0
|
| 10 |
+
tqdm>=4.60.0
|
| 11 |
+
PyYAML>=5.4.0
|
| 12 |
+
nystrom-attention>=0.0.11
|
scripts/evaluate.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 评估脚本 - 在测试集上评估模型性能
|
| 3 |
+
# 请先修改 config.yaml 中的数据路径
|
| 4 |
+
|
| 5 |
+
cd "$(dirname "$0")/.."
|
| 6 |
+
python evaluate.py
|
scripts/infer_single.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 单步推理脚本 - 对单个病例进行推理
|
| 3 |
+
# 用法1: ./infer_single.sh /path/to/image1.png /path/to/image2.png ...
|
| 4 |
+
# 用法2: ./infer_single.sh --image_dir /path/to/case_folder/
|
| 5 |
+
|
| 6 |
+
cd "$(dirname "$0")/.."
|
| 7 |
+
python infer_single_case.py "$@" --threshold 0.5
|
scripts/train.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 训练脚本 - TransMIL + Query2Label Hybrid Model
|
| 3 |
+
# 请先修改 config.yaml 中的数据路径
|
| 4 |
+
|
| 5 |
+
python train_hybrid.py --config config.yaml
|
thyroid_dataset.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Dict, Optional
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
|
| 12 |
+
# 18类标签定义 (必须与CSV列顺序严格一致)
|
| 13 |
+
'''
|
| 14 |
+
TARGET_CLASSES = [
|
| 15 |
+
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
| 16 |
+
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
| 17 |
+
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
|
| 18 |
+
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
| 19 |
+
]
|
| 20 |
+
'''
|
| 21 |
+
#17类标签定义,去除切除术后
|
| 22 |
+
TARGET_CLASSES = [
|
| 23 |
+
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
| 24 |
+
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
| 25 |
+
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
|
| 26 |
+
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# 定义稀有/困难类别索引 (用于重采样)
|
| 30 |
+
# 对应: 4b(4), 4c(5), 5(6), 切除(12), 转移(17)
|
| 31 |
+
#RARE_CLASS_INDICES = [4, 5, 6, 12, 17]
|
| 32 |
+
|
| 33 |
+
RARE_CLASS_INDICES = [4, 5, 6, 16] #17类标签
|
| 34 |
+
|
| 35 |
+
class ThyroidMultiLabelDataset(Dataset):
|
| 36 |
+
def __init__(self,
|
| 37 |
+
data_root: str,
|
| 38 |
+
annotation_csv: str,
|
| 39 |
+
split_json: Optional[str] = None,
|
| 40 |
+
split_type: str = 'train', # 'train', 'val', 'test'
|
| 41 |
+
val_json_path: Optional[str] = None, # 仅当 split_type='train' 时需要,用于排除验证集
|
| 42 |
+
test_json_path: Optional[str] = None, # 仅当 split_type='train' 时需要,用于排除测试集
|
| 43 |
+
img_size: int = 224,
|
| 44 |
+
max_images_per_case: int = 20,
|
| 45 |
+
transform=None):
|
| 46 |
+
|
| 47 |
+
self.data_root = Path(data_root)
|
| 48 |
+
self.img_size = img_size
|
| 49 |
+
self.max_images_per_case = max_images_per_case
|
| 50 |
+
self.split_type = split_type
|
| 51 |
+
|
| 52 |
+
# 1. 读取所有标签
|
| 53 |
+
self.df_labels = pd.read_csv(annotation_csv)
|
| 54 |
+
# 将 case_path 设为索引,方便查询
|
| 55 |
+
self.df_labels.set_index('case_path', inplace=True)
|
| 56 |
+
|
| 57 |
+
# 2. 确定数据集包含的 case_list
|
| 58 |
+
self.case_list = self._get_split_cases(split_json, val_json_path, test_json_path)
|
| 59 |
+
|
| 60 |
+
# 3. 定义数据增强
|
| 61 |
+
if transform:
|
| 62 |
+
self.transform = transform
|
| 63 |
+
elif split_type == 'train':
|
| 64 |
+
self.transform = transforms.Compose([
|
| 65 |
+
transforms.Resize((img_size, img_size)),
|
| 66 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 67 |
+
transforms.RandomVerticalFlip(p=0.5), # 超声可以上下翻转
|
| 68 |
+
transforms.RandomRotation(15),
|
| 69 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
| 70 |
+
transforms.ToTensor(),
|
| 71 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 72 |
+
])
|
| 73 |
+
else:
|
| 74 |
+
self.transform = transforms.Compose([
|
| 75 |
+
transforms.Resize((img_size, img_size)),
|
| 76 |
+
transforms.ToTensor(),
|
| 77 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
print(f"[{split_type.upper()}] Loaded {len(self.case_list)} cases.")
|
| 81 |
+
|
| 82 |
+
def _get_split_cases(self, split_json, val_json_path, test_json_path):
|
| 83 |
+
"""
|
| 84 |
+
根据 JSON 文件划分数据集
|
| 85 |
+
"""
|
| 86 |
+
all_cases_in_csv = set(self.df_labels.index.tolist())
|
| 87 |
+
|
| 88 |
+
# 读取指定的 split json (如果是 val 或 test)
|
| 89 |
+
target_cases = []
|
| 90 |
+
if split_json:
|
| 91 |
+
with open(split_json, 'r') as f:
|
| 92 |
+
data = json.load(f)
|
| 93 |
+
# JSON 里的 rel_path 对应 CSV 里的 case_path
|
| 94 |
+
target_cases = [item['rel_path'] for item in data]
|
| 95 |
+
|
| 96 |
+
# 过滤掉 CSV 里没有的 case (以防万一)
|
| 97 |
+
valid_cases = [c for c in target_cases if c in all_cases_in_csv]
|
| 98 |
+
return valid_cases
|
| 99 |
+
|
| 100 |
+
# 如果是 Train,逻辑是:所有 CSV 里的 case 减去 Val 和 Test 的 case
|
| 101 |
+
elif self.split_type == 'train':
|
| 102 |
+
exclude_cases = set()
|
| 103 |
+
|
| 104 |
+
if val_json_path:
|
| 105 |
+
with open(val_json_path, 'r') as f:
|
| 106 |
+
exclude_cases.update([item['rel_path'] for item in json.load(f)])
|
| 107 |
+
|
| 108 |
+
if test_json_path:
|
| 109 |
+
with open(test_json_path, 'r') as f:
|
| 110 |
+
exclude_cases.update([item['rel_path'] for item in json.load(f)])
|
| 111 |
+
|
| 112 |
+
train_cases = list(all_cases_in_csv - exclude_cases)
|
| 113 |
+
return sorted(train_cases) # 排序保证确定性
|
| 114 |
+
|
| 115 |
+
else:
|
| 116 |
+
return []
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return len(self.case_list)
|
| 120 |
+
|
| 121 |
+
def __getitem__(self, idx):
|
| 122 |
+
case_rel_path = self.case_list[idx]
|
| 123 |
+
|
| 124 |
+
# 1. 拼接图片目录路径: data_root / BatchX/CaseID / Images
|
| 125 |
+
img_dir = self.data_root / case_rel_path / "Images"
|
| 126 |
+
|
| 127 |
+
# 2. 获取标签
|
| 128 |
+
# df.loc[index] 返回 Series,转 numpy
|
| 129 |
+
try:
|
| 130 |
+
label_vec = self.df_labels.loc[case_rel_path, TARGET_CLASSES].values.astype(np.float32)
|
| 131 |
+
label_tensor = torch.tensor(label_vec)
|
| 132 |
+
except KeyError:
|
| 133 |
+
print(f"Warning: Label for {case_rel_path} not found in CSV. Using zeros.")
|
| 134 |
+
label_tensor = torch.zeros(len(TARGET_CLASSES))
|
| 135 |
+
|
| 136 |
+
# 3. 读取图片
|
| 137 |
+
image_files = sorted(list(img_dir.glob("*.jpg")) + list(img_dir.glob("*.png")) + list(img_dir.glob("*.bmp")))
|
| 138 |
+
|
| 139 |
+
# 采样逻辑 (Train: 随机采; Val/Test: 取前N张)
|
| 140 |
+
if self.max_images_per_case and len(image_files) > self.max_images_per_case:
|
| 141 |
+
if self.split_type == 'train':
|
| 142 |
+
# 训练时随机采样,增加多样性
|
| 143 |
+
image_files = np.random.choice(image_files, self.max_images_per_case, replace=False)
|
| 144 |
+
else:
|
| 145 |
+
image_files = image_files[:self.max_images_per_case]
|
| 146 |
+
|
| 147 |
+
images = []
|
| 148 |
+
for img_path in image_files:
|
| 149 |
+
try:
|
| 150 |
+
img = Image.open(img_path).convert('RGB')
|
| 151 |
+
if self.transform:
|
| 152 |
+
img = self.transform(img)
|
| 153 |
+
images.append(img)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
if len(images) == 0:
|
| 158 |
+
# 异常处理:生成全黑图
|
| 159 |
+
images = [torch.zeros(3, self.img_size, self.img_size)]
|
| 160 |
+
|
| 161 |
+
images_stack = torch.stack(images) # [N, 3, H, W]
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
'images': images_stack,
|
| 165 |
+
'labels': label_tensor,
|
| 166 |
+
'num_images': len(images),
|
| 167 |
+
'case_id': case_rel_path
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
def get_sampler_weights(self):
|
| 171 |
+
"""
|
| 172 |
+
计算采样权重:包含稀有类别的样本权重 = 10,其他 = 1
|
| 173 |
+
"""
|
| 174 |
+
weights = []
|
| 175 |
+
for case_rel_path in self.case_list:
|
| 176 |
+
label_vec = self.df_labels.loc[case_rel_path, TARGET_CLASSES].values
|
| 177 |
+
|
| 178 |
+
# 检查是否有稀有类别
|
| 179 |
+
is_rare = False
|
| 180 |
+
for idx in RARE_CLASS_INDICES:
|
| 181 |
+
if label_vec[idx] == 1:
|
| 182 |
+
is_rare = True
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
if is_rare:
|
| 186 |
+
weights.append(10.0) # 稀有样本采样概率翻10倍
|
| 187 |
+
else:
|
| 188 |
+
weights.append(1.0)
|
| 189 |
+
|
| 190 |
+
return torch.tensor(weights, dtype=torch.double)
|
| 191 |
+
|
| 192 |
+
def collate_fn(batch):
|
| 193 |
+
images_list = []
|
| 194 |
+
labels_list = []
|
| 195 |
+
num_instances_list = []
|
| 196 |
+
case_ids = []
|
| 197 |
+
|
| 198 |
+
for item in batch:
|
| 199 |
+
images_list.append(item['images'])
|
| 200 |
+
labels_list.append(item['labels'])
|
| 201 |
+
num_instances_list.append(item['num_images'])
|
| 202 |
+
case_ids.append(item['case_id'])
|
| 203 |
+
|
| 204 |
+
all_images = torch.cat(images_list, dim=0)
|
| 205 |
+
labels = torch.stack(labels_list)
|
| 206 |
+
num_instances_per_case = torch.tensor(num_instances_list, dtype=torch.long)
|
| 207 |
+
|
| 208 |
+
return {
|
| 209 |
+
'images': all_images,
|
| 210 |
+
'labels': labels,
|
| 211 |
+
'num_instances_per_case': num_instances_per_case,
|
| 212 |
+
'case_ids': case_ids
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
def create_dataloaders(config):
|
| 216 |
+
data_root = config['data']['data_root']
|
| 217 |
+
csv_path = config['data']['annotation_csv']
|
| 218 |
+
val_json = config['data']['val_json']
|
| 219 |
+
test_json = config['data']['test_json']
|
| 220 |
+
|
| 221 |
+
# Train Dataset
|
| 222 |
+
train_dataset = ThyroidMultiLabelDataset(
|
| 223 |
+
data_root=data_root,
|
| 224 |
+
annotation_csv=csv_path,
|
| 225 |
+
split_type='train',
|
| 226 |
+
val_json_path=val_json,
|
| 227 |
+
test_json_path=test_json,
|
| 228 |
+
img_size=config['data']['img_size'],
|
| 229 |
+
max_images_per_case=config['data']['max_images_per_case']
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# 计算采样权重并创建 Sampler
|
| 233 |
+
print("Calculating sampler weights for class balance...")
|
| 234 |
+
train_weights = train_dataset.get_sampler_weights()
|
| 235 |
+
sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
| 236 |
+
|
| 237 |
+
train_loader = DataLoader(
|
| 238 |
+
train_dataset,
|
| 239 |
+
batch_size=config['training']['batch_size'],
|
| 240 |
+
sampler=sampler, # 使用 sampler 时不要 shuffle=True
|
| 241 |
+
num_workers=config['data']['num_workers'],
|
| 242 |
+
collate_fn=collate_fn,
|
| 243 |
+
pin_memory=True,
|
| 244 |
+
drop_last=True
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Val Dataset
|
| 248 |
+
val_dataset = ThyroidMultiLabelDataset(
|
| 249 |
+
data_root=data_root,
|
| 250 |
+
annotation_csv=csv_path,
|
| 251 |
+
split_type='val',
|
| 252 |
+
split_json=val_json,
|
| 253 |
+
img_size=config['data']['img_size'],
|
| 254 |
+
max_images_per_case=config['data']['max_images_per_case']
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
val_loader = DataLoader(
|
| 258 |
+
val_dataset,
|
| 259 |
+
batch_size=config['training']['batch_size'],
|
| 260 |
+
shuffle=False,
|
| 261 |
+
num_workers=config['data']['num_workers'],
|
| 262 |
+
collate_fn=collate_fn,
|
| 263 |
+
pin_memory=True
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Test Dataset
|
| 267 |
+
test_dataset = ThyroidMultiLabelDataset(
|
| 268 |
+
data_root=data_root,
|
| 269 |
+
annotation_csv=csv_path,
|
| 270 |
+
split_type='test',
|
| 271 |
+
split_json=test_json,
|
| 272 |
+
img_size=config['data']['img_size'],
|
| 273 |
+
max_images_per_case=None # 测试时尽可能用所有图
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
test_loader = DataLoader(
|
| 277 |
+
test_dataset,
|
| 278 |
+
batch_size=config['training']['batch_size'],
|
| 279 |
+
shuffle=False,
|
| 280 |
+
num_workers=config['data']['num_workers'],
|
| 281 |
+
collate_fn=collate_fn,
|
| 282 |
+
pin_memory=True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return train_loader, val_loader, test_loader
|
thyroid_multilabel_annotations.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_hybrid.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Script for TransMIL + Query2Label Hybrid Model
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- End-to-end training with ResNet-50 backbone
|
| 6 |
+
- Mixed precision training (AMP) for memory efficiency
|
| 7 |
+
- Gradient accumulation for larger effective batch size
|
| 8 |
+
- Gradient checkpointing for ResNet-50
|
| 9 |
+
- AsymmetricLoss for multi-label imbalance
|
| 10 |
+
- Multi-label evaluation metrics (mAP, per-class AP, F1)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
#sys.path.append('query2labels/lib/models')
|
| 15 |
+
#sys.path.append('/XYFS01/HDD_POOL/sysu_gbli2/sysu_gbli2xy_1/chenshiyu/ThyroidAgent/ThyroidRegion/HintsVer3/query2labels/lib/')
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import argparse
|
| 19 |
+
import yaml
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
import json
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.optim as optim
|
| 27 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 28 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 29 |
+
import numpy as np
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
from sklearn.metrics import average_precision_score, f1_score
|
| 32 |
+
|
| 33 |
+
# Import model and dataset
|
| 34 |
+
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
| 35 |
+
from thyroid_dataset import create_dataloaders
|
| 36 |
+
|
| 37 |
+
# Import AsymmetricLoss
|
| 38 |
+
try:
|
| 39 |
+
from models.aslloss import AsymmetricLossOptimized
|
| 40 |
+
except ImportError:
|
| 41 |
+
print("Warning: Could not import AsymmetricLoss.")
|
| 42 |
+
AsymmetricLossOptimized = None
|
| 43 |
+
'''
|
| 44 |
+
try:
|
| 45 |
+
#from aslloss import AsymmetricLossOptimized
|
| 46 |
+
from models.aslloss import AsymmetricLossOptimized
|
| 47 |
+
except ImportError:
|
| 48 |
+
print("Warning: Could not import AsymmetricLoss from query2labels.")
|
| 49 |
+
print("Make sure query2labels/lib/models/aslloss.py is in Python path.")
|
| 50 |
+
AsymmetricLossOptimized = None
|
| 51 |
+
|
| 52 |
+
'''
|
| 53 |
+
# ============================================================================
|
| 54 |
+
# Metrics
|
| 55 |
+
# ============================================================================
|
| 56 |
+
|
| 57 |
+
def compute_multilabel_metrics(preds, targets, threshold=0.5):
|
| 58 |
+
"""
|
| 59 |
+
Compute multi-label classification metrics.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
preds: [N, num_class] numpy array of probabilities
|
| 63 |
+
targets: [N, num_class] numpy array of binary labels
|
| 64 |
+
threshold: Classification threshold for F1 score
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
dict with mAP, per-class AP, F1 scores
|
| 68 |
+
"""
|
| 69 |
+
metrics = {}
|
| 70 |
+
|
| 71 |
+
# Mean Average Precision (mAP)
|
| 72 |
+
aps = []
|
| 73 |
+
for i in range(targets.shape[1]):
|
| 74 |
+
if targets[:, i].sum() > 0: # Skip classes with no positive samples
|
| 75 |
+
ap = average_precision_score(targets[:, i], preds[:, i])
|
| 76 |
+
aps.append(ap)
|
| 77 |
+
else:
|
| 78 |
+
aps.append(np.nan)
|
| 79 |
+
|
| 80 |
+
metrics['mAP'] = np.nanmean(aps)
|
| 81 |
+
metrics['per_class_AP'] = aps
|
| 82 |
+
|
| 83 |
+
# F1 Score at threshold
|
| 84 |
+
preds_binary = (preds >= threshold).astype(int)
|
| 85 |
+
f1_micro = f1_score(targets, preds_binary, average='micro', zero_division=0)
|
| 86 |
+
f1_macro = f1_score(targets, preds_binary, average='macro', zero_division=0)
|
| 87 |
+
f1_samples = f1_score(targets, preds_binary, average='samples', zero_division=0)
|
| 88 |
+
|
| 89 |
+
metrics['F1_micro'] = f1_micro
|
| 90 |
+
metrics['F1_macro'] = f1_macro
|
| 91 |
+
metrics['F1_samples'] = f1_samples
|
| 92 |
+
|
| 93 |
+
return metrics
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ============================================================================
|
| 97 |
+
# Training Functions
|
| 98 |
+
# ============================================================================
|
| 99 |
+
|
| 100 |
+
def train_epoch(model, dataloader, criterion, optimizer, scaler, device, config, epoch):
|
| 101 |
+
"""
|
| 102 |
+
Train for one epoch with gradient accumulation and mixed precision.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model: TransMIL_Query2Label_E2E model
|
| 106 |
+
dataloader: Training dataloader
|
| 107 |
+
criterion: AsymmetricLoss
|
| 108 |
+
optimizer: AdamW optimizer
|
| 109 |
+
scaler: GradScaler for AMP
|
| 110 |
+
device: torch.device
|
| 111 |
+
config: Config dict
|
| 112 |
+
epoch: Current epoch number
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Average loss for epoch
|
| 116 |
+
"""
|
| 117 |
+
model.train()
|
| 118 |
+
|
| 119 |
+
total_loss = 0.0
|
| 120 |
+
accumulation_steps = config['training']['gradient_accumulation_steps']
|
| 121 |
+
use_amp = config['training']['use_amp']
|
| 122 |
+
|
| 123 |
+
# Progress bar
|
| 124 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
|
| 125 |
+
|
| 126 |
+
optimizer.zero_grad()
|
| 127 |
+
|
| 128 |
+
for i, batch in enumerate(pbar):
|
| 129 |
+
images = batch['images'].to(device) # [B*N_total, 3, H, W]
|
| 130 |
+
labels = batch['labels'].to(device) # [B, num_class]
|
| 131 |
+
num_instances_per_case = batch['num_instances_per_case'] # [B]
|
| 132 |
+
|
| 133 |
+
# Mixed precision forward pass
|
| 134 |
+
if use_amp:
|
| 135 |
+
with autocast():
|
| 136 |
+
logits = model(images, num_instances_per_case)
|
| 137 |
+
loss = criterion(logits, labels)
|
| 138 |
+
loss = loss / accumulation_steps # Scale loss for accumulation
|
| 139 |
+
else:
|
| 140 |
+
logits = model(images, num_instances_per_case)
|
| 141 |
+
loss = criterion(logits, labels)
|
| 142 |
+
loss = loss / accumulation_steps
|
| 143 |
+
|
| 144 |
+
# Backward pass
|
| 145 |
+
if use_amp:
|
| 146 |
+
scaler.scale(loss).backward()
|
| 147 |
+
else:
|
| 148 |
+
loss.backward()
|
| 149 |
+
|
| 150 |
+
# Optimizer step every accumulation_steps
|
| 151 |
+
if (i + 1) % accumulation_steps == 0:
|
| 152 |
+
if use_amp:
|
| 153 |
+
scaler.step(optimizer)
|
| 154 |
+
scaler.update()
|
| 155 |
+
else:
|
| 156 |
+
optimizer.step()
|
| 157 |
+
optimizer.zero_grad()
|
| 158 |
+
|
| 159 |
+
# Track loss
|
| 160 |
+
total_loss += loss.item() * accumulation_steps
|
| 161 |
+
pbar.set_postfix({'loss': loss.item() * accumulation_steps})
|
| 162 |
+
|
| 163 |
+
return total_loss / len(dataloader)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@torch.no_grad()
|
| 167 |
+
def validate(model, dataloader, criterion, device, config):
|
| 168 |
+
"""
|
| 169 |
+
Validate model with multi-label metrics.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
model: TransMIL_Query2Label_E2E model
|
| 173 |
+
dataloader: Validation dataloader
|
| 174 |
+
criterion: AsymmetricLoss
|
| 175 |
+
device: torch.device
|
| 176 |
+
config: Config dict
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
dict with loss and metrics (mAP, F1, etc.)
|
| 180 |
+
"""
|
| 181 |
+
model.eval()
|
| 182 |
+
|
| 183 |
+
total_loss = 0.0
|
| 184 |
+
all_preds = []
|
| 185 |
+
all_targets = []
|
| 186 |
+
|
| 187 |
+
for batch in tqdm(dataloader, desc="Validating"):
|
| 188 |
+
images = batch['images'].to(device)
|
| 189 |
+
labels = batch['labels'].to(device)
|
| 190 |
+
num_instances_per_case = batch['num_instances_per_case']
|
| 191 |
+
|
| 192 |
+
# Forward pass
|
| 193 |
+
logits = model(images, num_instances_per_case)
|
| 194 |
+
loss = criterion(logits, labels)
|
| 195 |
+
|
| 196 |
+
# Sigmoid for multi-label probabilities
|
| 197 |
+
preds = torch.sigmoid(logits)
|
| 198 |
+
|
| 199 |
+
# Store predictions and targets
|
| 200 |
+
all_preds.append(preds.cpu().numpy())
|
| 201 |
+
all_targets.append(labels.cpu().numpy())
|
| 202 |
+
|
| 203 |
+
total_loss += loss.item()
|
| 204 |
+
|
| 205 |
+
# Concatenate all batches
|
| 206 |
+
all_preds = np.concatenate(all_preds, axis=0)
|
| 207 |
+
all_targets = np.concatenate(all_targets, axis=0)
|
| 208 |
+
|
| 209 |
+
# Compute metrics
|
| 210 |
+
metrics = compute_multilabel_metrics(all_preds, all_targets)
|
| 211 |
+
metrics['loss'] = total_loss / len(dataloader)
|
| 212 |
+
|
| 213 |
+
return metrics
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ============================================================================
|
| 217 |
+
# Main Training Loop
|
| 218 |
+
# ============================================================================
|
| 219 |
+
|
| 220 |
+
def train(config, resume_from=None):
|
| 221 |
+
"""
|
| 222 |
+
Main training function.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
config: Config dictionary from YAML
|
| 226 |
+
resume_from: Optional checkpoint path to resume training
|
| 227 |
+
"""
|
| 228 |
+
# Setup device
|
| 229 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 230 |
+
print(f"\nUsing device: {device}")
|
| 231 |
+
if torch.cuda.is_available():
|
| 232 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 233 |
+
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 234 |
+
|
| 235 |
+
# Create save directory
|
| 236 |
+
save_dir = Path(config['training']['save_dir'])
|
| 237 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 238 |
+
|
| 239 |
+
# Create tensorboard writer
|
| 240 |
+
log_dir = save_dir / 'logs' / datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 241 |
+
writer = SummaryWriter(log_dir)
|
| 242 |
+
|
| 243 |
+
# Save config
|
| 244 |
+
with open(save_dir / 'config.yaml', 'w') as f:
|
| 245 |
+
yaml.dump(config, f)
|
| 246 |
+
|
| 247 |
+
# Create dataloaders
|
| 248 |
+
print("\nCreating dataloaders...")
|
| 249 |
+
train_loader, val_loader, test_loader = create_dataloaders(config)
|
| 250 |
+
|
| 251 |
+
# Create model
|
| 252 |
+
print("\nCreating model...")
|
| 253 |
+
model = TransMIL_Query2Label_E2E(
|
| 254 |
+
num_class=config['model']['num_class'],
|
| 255 |
+
hidden_dim=config['model']['hidden_dim'],
|
| 256 |
+
nheads=config['model']['nheads'],
|
| 257 |
+
num_decoder_layers=config['model']['num_decoder_layers'],
|
| 258 |
+
pretrained_resnet=config['model']['pretrained_resnet'],
|
| 259 |
+
use_checkpointing=config['training']['gradient_checkpointing'],
|
| 260 |
+
use_ppeg=config['model'].get('use_ppeg', False)
|
| 261 |
+
)
|
| 262 |
+
model = model.to(device)
|
| 263 |
+
|
| 264 |
+
# Print model stats
|
| 265 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 266 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 267 |
+
print(f"Total parameters: {total_params:,}")
|
| 268 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 269 |
+
|
| 270 |
+
# Create optimizer
|
| 271 |
+
optimizer = optim.AdamW(
|
| 272 |
+
model.parameters(),
|
| 273 |
+
lr=config['training']['lr'],
|
| 274 |
+
weight_decay=config['training']['weight_decay']
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Create scheduler
|
| 278 |
+
scheduler_type = config['training'].get('scheduler', 'cosine')
|
| 279 |
+
if scheduler_type == 'cosine':
|
| 280 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 281 |
+
optimizer,
|
| 282 |
+
T_max=config['training']['epochs'],
|
| 283 |
+
eta_min=1e-6
|
| 284 |
+
)
|
| 285 |
+
elif scheduler_type == 'onecycle':
|
| 286 |
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
| 287 |
+
optimizer,
|
| 288 |
+
max_lr=config['training']['lr'],
|
| 289 |
+
epochs=config['training']['epochs'],
|
| 290 |
+
steps_per_epoch=len(train_loader)
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
scheduler = None
|
| 294 |
+
|
| 295 |
+
# Create loss function
|
| 296 |
+
if AsymmetricLossOptimized is not None:
|
| 297 |
+
criterion = AsymmetricLossOptimized(
|
| 298 |
+
gamma_neg=config['training']['gamma_neg'],
|
| 299 |
+
gamma_pos=config['training']['gamma_pos'],
|
| 300 |
+
clip=config['training']['clip'],
|
| 301 |
+
eps=1e-5
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
# Fallback to BCEWithLogitsLoss
|
| 305 |
+
print("Warning: Using BCEWithLogitsLoss instead of AsymmetricLoss")
|
| 306 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 307 |
+
|
| 308 |
+
# Mixed precision scaler
|
| 309 |
+
scaler = GradScaler() if config['training']['use_amp'] else None
|
| 310 |
+
|
| 311 |
+
# Resume from checkpoint if specified
|
| 312 |
+
start_epoch = 0
|
| 313 |
+
best_map = 0.0
|
| 314 |
+
|
| 315 |
+
if resume_from is not None and Path(resume_from).exists():
|
| 316 |
+
print(f"\nResuming from {resume_from}")
|
| 317 |
+
checkpoint = torch.load(resume_from, map_location=device)
|
| 318 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 319 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 320 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 321 |
+
best_map = checkpoint.get('best_map', 0.0)
|
| 322 |
+
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
|
| 323 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 324 |
+
print(f"Resumed from epoch {start_epoch}, best mAP: {best_map:.4f}")
|
| 325 |
+
|
| 326 |
+
# Training loop
|
| 327 |
+
print(f"\nStarting training for {config['training']['epochs']} epochs...")
|
| 328 |
+
print("="*80)
|
| 329 |
+
|
| 330 |
+
for epoch in range(start_epoch, config['training']['epochs']):
|
| 331 |
+
# Train
|
| 332 |
+
train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, config, epoch)
|
| 333 |
+
|
| 334 |
+
# Validate
|
| 335 |
+
val_metrics = validate(model, val_loader, criterion, device, config)
|
| 336 |
+
|
| 337 |
+
# Update scheduler
|
| 338 |
+
if scheduler is not None:
|
| 339 |
+
if scheduler_type == 'onecycle':
|
| 340 |
+
pass # OneCycleLR updates per step, not per epoch
|
| 341 |
+
else:
|
| 342 |
+
scheduler.step()
|
| 343 |
+
|
| 344 |
+
# Log metrics
|
| 345 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 346 |
+
writer.add_scalar('Loss/train', train_loss, epoch)
|
| 347 |
+
writer.add_scalar('Loss/val', val_metrics['loss'], epoch)
|
| 348 |
+
writer.add_scalar('Metrics/mAP', val_metrics['mAP'], epoch)
|
| 349 |
+
writer.add_scalar('Metrics/F1_micro', val_metrics['F1_micro'], epoch)
|
| 350 |
+
writer.add_scalar('Metrics/F1_macro', val_metrics['F1_macro'], epoch)
|
| 351 |
+
writer.add_scalar('LR', current_lr, epoch)
|
| 352 |
+
|
| 353 |
+
# Print epoch summary
|
| 354 |
+
print(f"\nEpoch {epoch}/{config['training']['epochs']}")
|
| 355 |
+
print(f" Train Loss: {train_loss:.4f}")
|
| 356 |
+
print(f" Val Loss: {val_metrics['loss']:.4f}")
|
| 357 |
+
print(f" mAP: {val_metrics['mAP']:.4f}")
|
| 358 |
+
print(f" F1 (micro): {val_metrics['F1_micro']:.4f}")
|
| 359 |
+
print(f" F1 (macro): {val_metrics['F1_macro']:.4f}")
|
| 360 |
+
print(f" LR: {current_lr:.6f}")
|
| 361 |
+
|
| 362 |
+
# Save checkpoint
|
| 363 |
+
is_best = val_metrics['mAP'] > best_map
|
| 364 |
+
if is_best:
|
| 365 |
+
best_map = val_metrics['mAP']
|
| 366 |
+
|
| 367 |
+
if (epoch + 1) % config['training']['save_freq'] == 0 or is_best:
|
| 368 |
+
checkpoint = {
|
| 369 |
+
'epoch': epoch,
|
| 370 |
+
'model_state_dict': model.state_dict(),
|
| 371 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 372 |
+
'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
|
| 373 |
+
'train_loss': train_loss,
|
| 374 |
+
'val_metrics': val_metrics,
|
| 375 |
+
'best_map': best_map,
|
| 376 |
+
'config': config
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
# Save latest checkpoint
|
| 380 |
+
torch.save(checkpoint, save_dir / 'checkpoint_latest.pth')
|
| 381 |
+
|
| 382 |
+
# Save best checkpoint
|
| 383 |
+
if is_best:
|
| 384 |
+
torch.save(checkpoint, save_dir / 'checkpoint_best.pth')
|
| 385 |
+
print(f" ✓ Saved best model (mAP: {best_map:.4f})")
|
| 386 |
+
|
| 387 |
+
# Save periodic checkpoint
|
| 388 |
+
if (epoch + 1) % config['training']['save_freq'] == 0:
|
| 389 |
+
torch.save(checkpoint, save_dir / f'checkpoint_epoch_{epoch}.pth')
|
| 390 |
+
|
| 391 |
+
print("\n" + "="*80)
|
| 392 |
+
print(f"Training completed! Best mAP: {best_map:.4f}")
|
| 393 |
+
print(f"Checkpoints saved to: {save_dir}")
|
| 394 |
+
|
| 395 |
+
writer.close()
|
| 396 |
+
|
| 397 |
+
# Final test evaluation
|
| 398 |
+
print("\nEvaluating on test set...")
|
| 399 |
+
test_metrics = validate(model, test_loader, criterion, device, config)
|
| 400 |
+
print(f"\nTest Results:")
|
| 401 |
+
print(f" mAP: {test_metrics['mAP']:.4f}")
|
| 402 |
+
print(f" F1 (micro): {test_metrics['F1_micro']:.4f}")
|
| 403 |
+
print(f" F1 (macro): {test_metrics['F1_macro']:.4f}")
|
| 404 |
+
|
| 405 |
+
# Save test results
|
| 406 |
+
with open(save_dir / 'test_results.json', 'w') as f:
|
| 407 |
+
json.dump({k: float(v) if not isinstance(v, list) else v
|
| 408 |
+
for k, v in test_metrics.items()}, f, indent=2)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# ============================================================================
|
| 412 |
+
# Main
|
| 413 |
+
# ============================================================================
|
| 414 |
+
|
| 415 |
+
def main():
|
| 416 |
+
parser = argparse.ArgumentParser(description='Train TransMIL + Query2Label Hybrid Model')
|
| 417 |
+
parser.add_argument('--config', type=str, default='hybrid_model/config.yaml',
|
| 418 |
+
help='Path to config file')
|
| 419 |
+
parser.add_argument('--resume', type=str, default=None,
|
| 420 |
+
help='Path to checkpoint to resume from')
|
| 421 |
+
args = parser.parse_args()
|
| 422 |
+
|
| 423 |
+
# Load config
|
| 424 |
+
with open(args.config, 'r') as f:
|
| 425 |
+
config = yaml.safe_load(f)
|
| 426 |
+
|
| 427 |
+
print("="*80)
|
| 428 |
+
print("TransMIL + Query2Label Hybrid Model Training")
|
| 429 |
+
print("="*80)
|
| 430 |
+
print(f"\nConfig: {args.config}")
|
| 431 |
+
if args.resume:
|
| 432 |
+
print(f"Resume from: {args.resume}")
|
| 433 |
+
|
| 434 |
+
# Train
|
| 435 |
+
train(config, resume_from=args.resume)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == "__main__":
|
| 439 |
+
main()
|