HintsPredictionModel / config.yaml
Doul0414's picture
Initial upload: HintsPrediction
343e05c verified
# 超声提示多标签分类模型配置文件
# TransMIL + Query2Label Hybrid Model
data:
# 【需要修改】数据根目录(包含 Report_XXX 文件夹的目录)
data_root: "/path/to/your/ReportData_ROI/"
# 【需要修改】多标签注释 CSV 文件路径
annotation_csv: "/path/to/your/ReportData_ROI/thyroid_multilabel_annotations.csv"
# 【需要修改】验证集 JSON 文件路径
val_json: "/path/to/your/ReportData_ROI/classification_val_set_single.json"
# 【需要修改】测试集 JSON 文件路径
test_json: "/path/to/your/ReportData_ROI/classification_test_set_single.json"
img_size: 224
max_images_per_case: 20
num_workers: 8
model:
num_class: 17 # 17类标签(已删除"切除术后")
hidden_dim: 512
nheads: 8
num_decoder_layers: 2
pretrained_resnet: True
use_ppeg: False
training:
batch_size: 4
epochs: 50
lr: 0.0001
weight_decay: 0.0001
optimizer: "AdamW"
# Asymmetric Loss 参数(处理多标签不平衡)
gamma_neg: 4
gamma_pos: 1
clip: 0.05
# 内存优化策略
use_amp: true # 混合精度训练
gradient_accumulation_steps: 4 # 有效 batch_size = 4 * 4 = 16
gradient_checkpointing: true
# 学习率调度器
scheduler: "cosine"
warmup_epochs: 5
# 模型保存
save_dir: "checkpoints/"
save_freq: 5