File size: 606 Bytes
40d87dd
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
### 1. `train_dflash_lora.py` 
* 加了lora,原来是调用小模型,现在是hidden states+lora预测。
* `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。

### 2. OOM优化
* 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`* `batch-size=1``accumulation-steps=8`* 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
* `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。

### 运行
* bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2