⚠️ 图像Token截断问题说明
问题描述
在 GRPO 训练中,当 prompt 长度超过 max_prompt_length (默认512) 时,代码会从右侧截取 tokens:
prompts = prompts[:, -self.max_prompt_length :]
这意味着左侧的 tokens 会被截断。
风险
对于包含图像的 prompt,图像 tokens 通常位于 prompt 的开头或中间位置:
<|im_start|>user
🖼️<|vision_start|><|image_pad|><|vision_end|>🖼️
这是一道数学题,请解答...
<|im_end|><|im_start|>assistant
如果 prompt 太长被截断,图像 tokens 会丢失,导致:
- ❌ 模型无法看到图像信息
- ❌ 只能基于文本部分回答
- ❌ 对于视觉问答任务,性能会严重下降
检测机制
我已经在 log_prompt_truncation 函数中添加了自动检测:
1. 检测的 Vision Token 类型
<|vision_start|><|vision_end|><|image_pad|><|video_pad|><|vision_pad|>
2. 日志输出
如果图像 tokens 被截断,会显示警告:
⚠️ WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!
⚠️ Lost vision token IDs: {151652, 151653, 151654}
⚠️ Vision tokens before: {151652, 151653, 151654, 151655}
⚠️ Vision tokens after: {151655}
⚠️ The model will NOT see the image information!
如果图像 tokens 保留完整:
✓ Vision tokens preserved: {151652, 151653, 151654, 151655}
3. Token 可视化
在 token 列表中,vision tokens 会用 🖼️ 标记:
[BEFORE TRUNCATION]
Tokens: <|im_start|> [1587:'user'] 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ 🖼️<|vision_end|>🖼️ ...
[AFTER TRUNCATION]
Tokens: [2768:' following'] [7033:' math'] [3575:' problem'] ...
解决方案
方案 1:增加 max_prompt_length(推荐)
修改配置文件 configs/latent_memory/mm_math.yaml:
generation:
max_start_length: 1024 # 从 512 增加到 1024 或更大
max_prompt_length: 4096
方案 2:优化 Prompt 长度
- 简化问题描述
- 移除不必要的上下文
- 确保图像相关的核心内容在前 512 tokens 内
方案 3:修改截断策略(需要代码修改)
当前是从右侧截取(保留后面的内容):
prompts = prompts[:, -self.max_prompt_length :] # 保留右侧
可以改为从左侧截取(保留前面的内容,包括图像):
prompts = prompts[:, :self.max_prompt_length] # 保留左侧
但这样会丢失问题的后半部分,需要权衡。
方案 4:智能截断(最佳但复杂)
检测 vision tokens 的位置,确保它们不被截断:
- 找到 vision tokens 的位置
- 如果需要截断,从 vision tokens 之后开始截取
- 保证图像信息始终保留
当前状态
✅ 已添加检测和警告机制
- 自动检测 vision tokens 是否被截断
- 在日志中高亮显示 vision tokens
- 如果截断发生会显著警告
⚠️ 需要手动配置
- 根据实际数据集调整
max_start_length - 监控日志中的截断警告
- 如果频繁出现截断,增加长度限制
日志输出层次
现在有三个层次的日志输出:
1. PROMPT INFO - 基本信息
[PROMPT INFO] Original prompt shape: torch.Size([8, 650]), max_prompt_length: 512
[PROMPT INFO] After truncation shape: torch.Size([8, 512])
[PROMPT INFO] Truncation detected: 650 -> 512
2. PROMPT TRUNCATION - 截断详情(如果发生截断)
================================================================================
[PROMPT TRUNCATION] Sample 0
Length before truncation: 650
Length after truncation: 512
⚠️ WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!
...
================================================================================
3. ROLLOUT INPUT - 传给 trainer 的输入
================================================================================
[ROLLOUT INPUT] Sample 0
Prompt length: 512 tokens
✓ Contains vision tokens: {151652, 151653}
[INPUT TOKENS]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ ...
================================================================================
4. MODEL.GENERATE INPUT - 实际传给模型的输入
================================================================================
[MODEL.GENERATE INPUT] Sample 0
Input length: 512 tokens
✓ Contains vision tokens: {151652, 151653}
[TOKENS TO MODEL]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ ...
================================================================================
查看日志
运行训练后,检查是否有截断警告:
# 查看所有 prompt 信息
grep "\[PROMPT INFO\]" test_output/mm_math/logs/log.txt
# 查看截断警告
grep "WARNING.*VISION.*TRUNCATED" test_output/mm_math/logs/log.txt
# 查看详细的截断日志
grep -A 30 "\[PROMPT TRUNCATION\]" test_output/mm_math/logs/log.txt
# 查看 rollout 输入
grep -A 15 "\[ROLLOUT INPUT\]" test_output/mm_math/logs/log.txt
# 查看实际传给模型的输入
grep -A 15 "\[MODEL.GENERATE INPUT\]" test_output/mm_math/logs/log.txt
建议
- 训练前:先运行一个 epoch,检查日志中是否有 vision token 截断警告
- 如果有警告:立即增加
max_start_length,重新开始训练 - 监控:定期检查日志,确保图像信息没有丢失
- 数据统计:统计数据集中 prompt 长度分布,设置合适的
max_start_length
示例输出
正常情况(无截断)
[PROMPT INFO] Original prompt shape: torch.Size([8, 450]), max_prompt_length: 512
[PROMPT INFO] After truncation shape: torch.Size([8, 450])
[PROMPT INFO] No truncation needed: length 450 <= max 512
有截断但保留图像
[PROMPT INFO] Truncation detected: 650 -> 512
================================================================================
[PROMPT TRUNCATION] Sample 0
Length before truncation: 650
Length after truncation: 512
✓ Vision tokens preserved: {151652, 151653, 151654}
[BEFORE TRUNCATION]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ ...
[AFTER TRUNCATION]
Tokens: 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ ...
================================================================================
危险情况(图像被截断)⚠️
[PROMPT INFO] Truncation detected: 650 -> 512
================================================================================
⚠️ WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!
⚠️ Lost vision token IDs: {151652, 151653}
⚠️ The model will NOT see the image information!
[BEFORE TRUNCATION]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ ...
[AFTER TRUNCATION]
Tokens: [2768:' following'] [7033:' math'] ... (无图像 tokens)
================================================================================
如果看到这种警告,必须立即调整配置!