File size: 606 Bytes
40d87dd | 1 2 3 4 5 6 7 8 9 10 11 12 | ### 1. `train_dflash_lora.py`
* 加了lora,原来是调用小模型,现在是hidden states+lora预测。
* `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
### 2. OOM优化
* 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
* `batch-size=1`,`accumulation-steps=8`。
* 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
* `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
### 运行
* bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2 |