model111 / IMAGE_TOKEN_TRUNCATION_WARNING.md
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified

⚠️ 图像Token截断问题说明

问题描述

在 GRPO 训练中,当 prompt 长度超过 max_prompt_length (默认512) 时,代码会从右侧截取 tokens:

prompts = prompts[:, -self.max_prompt_length :]

这意味着左侧的 tokens 会被截断

风险

对于包含图像的 prompt,图像 tokens 通常位于 prompt 的开头或中间位置:

<|im_start|>user
🖼️<|vision_start|><|image_pad|><|vision_end|>🖼️
这是一道数学题,请解答...
<|im_end|><|im_start|>assistant

如果 prompt 太长被截断,图像 tokens 会丢失,导致:

  • ❌ 模型无法看到图像信息
  • ❌ 只能基于文本部分回答
  • ❌ 对于视觉问答任务,性能会严重下降

检测机制

我已经在 log_prompt_truncation 函数中添加了自动检测:

1. 检测的 Vision Token 类型

  • <|vision_start|>
  • <|vision_end|>
  • <|image_pad|>
  • <|video_pad|>
  • <|vision_pad|>

2. 日志输出

如果图像 tokens 被截断,会显示警告:

⚠️  WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!
⚠️  Lost vision token IDs: {151652, 151653, 151654}
⚠️  Vision tokens before: {151652, 151653, 151654, 151655}
⚠️  Vision tokens after: {151655}
⚠️  The model will NOT see the image information!

如果图像 tokens 保留完整

✓ Vision tokens preserved: {151652, 151653, 151654, 151655}

3. Token 可视化

在 token 列表中,vision tokens 会用 🖼️ 标记:

[BEFORE TRUNCATION]
Tokens: <|im_start|> [1587:'user'] 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ 🖼️<|vision_end|>🖼️ ...

[AFTER TRUNCATION]
Tokens: [2768:' following'] [7033:' math'] [3575:' problem'] ...

解决方案

方案 1:增加 max_prompt_length(推荐)

修改配置文件 configs/latent_memory/mm_math.yaml

generation:
  max_start_length: 1024  # 从 512 增加到 1024 或更大
  max_prompt_length: 4096

方案 2:优化 Prompt 长度

  • 简化问题描述
  • 移除不必要的上下文
  • 确保图像相关的核心内容在前 512 tokens 内

方案 3:修改截断策略(需要代码修改)

当前是从右侧截取(保留后面的内容):

prompts = prompts[:, -self.max_prompt_length :]  # 保留右侧

可以改为从左侧截取(保留前面的内容,包括图像):

prompts = prompts[:, :self.max_prompt_length]  # 保留左侧

但这样会丢失问题的后半部分,需要权衡。

方案 4:智能截断(最佳但复杂)

检测 vision tokens 的位置,确保它们不被截断:

  1. 找到 vision tokens 的位置
  2. 如果需要截断,从 vision tokens 之后开始截取
  3. 保证图像信息始终保留

当前状态

已添加检测和警告机制

  • 自动检测 vision tokens 是否被截断
  • 在日志中高亮显示 vision tokens
  • 如果截断发生会显著警告

⚠️ 需要手动配置

  • 根据实际数据集调整 max_start_length
  • 监控日志中的截断警告
  • 如果频繁出现截断,增加长度限制

日志输出层次

现在有三个层次的日志输出:

1. PROMPT INFO - 基本信息

[PROMPT INFO] Original prompt shape: torch.Size([8, 650]), max_prompt_length: 512
[PROMPT INFO] After truncation shape: torch.Size([8, 512])
[PROMPT INFO] Truncation detected: 650 -> 512

2. PROMPT TRUNCATION - 截断详情(如果发生截断)

================================================================================
[PROMPT TRUNCATION] Sample 0
Length before truncation: 650
Length after truncation: 512
⚠️  WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!
...
================================================================================

3. ROLLOUT INPUT - 传给 trainer 的输入

================================================================================
[ROLLOUT INPUT] Sample 0
Prompt length: 512 tokens
✓ Contains vision tokens: {151652, 151653}
[INPUT TOKENS]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ ...
================================================================================

4. MODEL.GENERATE INPUT - 实际传给模型的输入

================================================================================
[MODEL.GENERATE INPUT] Sample 0
Input length: 512 tokens
✓ Contains vision tokens: {151652, 151653}
[TOKENS TO MODEL]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ ...
================================================================================

查看日志

运行训练后,检查是否有截断警告:

# 查看所有 prompt 信息
grep "\[PROMPT INFO\]" test_output/mm_math/logs/log.txt

# 查看截断警告
grep "WARNING.*VISION.*TRUNCATED" test_output/mm_math/logs/log.txt

# 查看详细的截断日志
grep -A 30 "\[PROMPT TRUNCATION\]" test_output/mm_math/logs/log.txt

# 查看 rollout 输入
grep -A 15 "\[ROLLOUT INPUT\]" test_output/mm_math/logs/log.txt

# 查看实际传给模型的输入
grep -A 15 "\[MODEL.GENERATE INPUT\]" test_output/mm_math/logs/log.txt

建议

  1. 训练前:先运行一个 epoch,检查日志中是否有 vision token 截断警告
  2. 如果有警告:立即增加 max_start_length,重新开始训练
  3. 监控:定期检查日志,确保图像信息没有丢失
  4. 数据统计:统计数据集中 prompt 长度分布,设置合适的 max_start_length

示例输出

正常情况(无截断)

[PROMPT INFO] Original prompt shape: torch.Size([8, 450]), max_prompt_length: 512
[PROMPT INFO] After truncation shape: torch.Size([8, 450])
[PROMPT INFO] No truncation needed: length 450 <= max 512

有截断但保留图像

[PROMPT INFO] Truncation detected: 650 -> 512
================================================================================
[PROMPT TRUNCATION] Sample 0
Length before truncation: 650
Length after truncation: 512
✓ Vision tokens preserved: {151652, 151653, 151654}
[BEFORE TRUNCATION]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ ...
[AFTER TRUNCATION]
Tokens: 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ ...
================================================================================

危险情况(图像被截断)⚠️

[PROMPT INFO] Truncation detected: 650 -> 512
================================================================================
⚠️  WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!
⚠️  Lost vision token IDs: {151652, 151653}
⚠️  The model will NOT see the image information!
[BEFORE TRUNCATION]
Tokens: <|im_start|> 🖼️<|vision_start|>🖼️ 🖼️<|image_pad|>🖼️ ...
[AFTER TRUNCATION]
Tokens: [2768:' following'] [7033:' math'] ... (无图像 tokens)
================================================================================

如果看到这种警告,必须立即调整配置!