LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
import torch
import os
import json
import logging
from typing import List, Dict
# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.
Args:
tensor (`torch.Tensor`):
Input tensor of shape `(N,)`.
Returns:
`torch.Tensor`:
Standard deviation of the tensor, ignoring NaNs.
"""
variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) # Compute variance ignoring NaNs
count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values
variance *= count / (count - 1) # Bessel's correction
return torch.sqrt(variance)
def nanmax(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
Args:
tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
Returns:
`torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
"""
if torch.isnan(tensor).all():
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
return torch.max(tensor[~torch.isnan(tensor)])
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
Args:
tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
Returns:
`torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
"""
if torch.isnan(tensor).all():
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
return torch.min(tensor[~torch.isnan(tensor)])
def init_grpo_log_files(output_dir: str) -> tuple[str, str]:
"""
Initialize GRPO log files (human-readable txt and machine-readable jsonl).
Returns the tuple of (txt_log_path, jsonl_log_path).
"""
grpo_log_file = os.path.join(output_dir, "../logs/grpo_logs.txt")
grpo_jsonl_file = os.path.join(output_dir, "../logs/grpo_samples.jsonl")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.dirname(grpo_log_file), exist_ok=True)
# Create/clear the log file
with open(grpo_log_file, "w", encoding="utf-8") as f:
f.write("=" * 80 + "\n")
f.write("GRPO Training Logs - WeaverGRPOTrainer\n")
f.write("=" * 80 + "\n\n")
# Create/clear the JSONL file
with open(grpo_jsonl_file, "w", encoding="utf-8"):
pass
return grpo_log_file, grpo_jsonl_file
def log_prompt_truncation(
prompts_before: torch.Tensor,
prompts_after: torch.Tensor,
prompt_mask_before: torch.Tensor,
prompt_mask_after: torch.Tensor,
processing_class,
max_prompt_length: int,
sample_idx: int = 0
) -> None:
"""
Log prompt before and after truncation in token format.
Also checks if image/vision tokens were truncated.
Args:
prompts_before: Prompt token IDs before truncation [batch_size, seq_len_before]
prompts_after: Prompt token IDs after truncation [batch_size, seq_len_after]
prompt_mask_before: Attention mask before truncation
prompt_mask_after: Attention mask after truncation
processing_class: Tokenizer or processor for decoding
max_prompt_length: Maximum prompt length configured
sample_idx: Index of sample to log (default: 0, first sample in batch)
"""
# Get tokenizer
_tok = getattr(processing_class, "tokenizer", processing_class)
# Check for vision/image tokens - use known IDs directly
# Qwen2.5-VL vision token IDs:
# 151652: <|vision_start|>
# 151653: <|vision_end|>
# 151654: <|video_pad|>
# 151655: <|image_pad|>
vision_token_ids = [151652, 151653, 151654, 151655]
# Also try to get them from tokenizer
vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"]
for token_name in vision_token_names:
try:
token_id = _tok.encode(token_name, add_special_tokens=False)
if isinstance(token_id, list) and len(token_id) > 0:
if token_id[0] not in vision_token_ids:
vision_token_ids.append(token_id[0])
except Exception:
pass
# Extract single sample
prompt_before = prompts_before[sample_idx]
prompt_after = prompts_after[sample_idx]
mask_before = prompt_mask_before[sample_idx]
mask_after = prompt_mask_after[sample_idx]
# Filter out padding tokens (where mask == 0)
valid_tokens_before = prompt_before[mask_before.bool()].tolist()
valid_tokens_after = prompt_after[mask_after.bool()].tolist()
# Check if vision tokens were truncated
vision_tokens_before = set(valid_tokens_before) & set(vision_token_ids)
vision_tokens_after = set(valid_tokens_after) & set(vision_token_ids)
vision_tokens_lost = vision_tokens_before - vision_tokens_after
has_vision_loss = len(vision_tokens_lost) > 0
# Convert token IDs to readable format with special tokens
def tokens_to_readable(token_ids):
"""Convert token IDs to readable string with special tokens visible."""
# ANSI escape codes for colors
GREEN = "\033[92m"
RESET = "\033[0m"
tokens = []
prev_tid = None
consecutive_count = 0
for tid in token_ids:
try:
# Decode single token
token_str = _tok.decode([tid], skip_special_tokens=False)
# Check if this is image_pad (151655) or other vision pad tokens
is_image_pad = tid == 151655 or (tid in vision_token_ids and 'pad' in token_str.lower())
# If consecutive image_pad tokens, just count them
if is_image_pad and prev_tid == tid:
consecutive_count += 1
continue
else:
# Output the previous consecutive tokens if any
if consecutive_count > 0 and prev_tid is not None:
prev_str = _tok.decode([prev_tid], skip_special_tokens=False)
tokens.append(f"{GREEN}[IMG]{prev_str.strip()}[/IMG]{RESET}×{consecutive_count + 1}")
consecutive_count = 0
# Highlight vision tokens
if tid in vision_token_ids:
if is_image_pad:
prev_tid = tid
consecutive_count = 0
continue # Will be added in next iteration or at the end
else:
tokens.append(f"{GREEN}[IMG]{token_str.strip()}[/IMG]{RESET}")
# Show special tokens
elif tid == _tok.pad_token_id:
tokens.append(f"<|pad|>")
elif tid == _tok.eos_token_id:
tokens.append(f"<|eos|>")
elif tid == _tok.bos_token_id:
tokens.append(f"<|bos|>")
elif token_str.strip() in ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]:
tokens.append(token_str.strip())
else:
tokens.append(f"[{tid}:{repr(token_str)}]")
prev_tid = tid
except Exception:
tokens.append(f"[{tid}:?]")
prev_tid = tid
# Handle any remaining consecutive tokens at the end
if consecutive_count > 0 and prev_tid is not None:
try:
prev_str = _tok.decode([prev_tid], skip_special_tokens=False)
tokens.append(f"{GREEN}[IMG]{prev_str.strip()}[/IMG]{RESET}×{consecutive_count + 1}")
except Exception:
pass
return " ".join(tokens)
# Log information
logging.info("=" * 80)
logging.info(f"[PROMPT TRUNCATION] Sample {sample_idx}")
logging.info(f"Length before truncation: {len(valid_tokens_before)}")
logging.info(f"Length after truncation: {len(valid_tokens_after)}")
logging.info(f"Max prompt length: {max_prompt_length}")
logging.info(f"Tokens truncated: {len(valid_tokens_before) - len(valid_tokens_after)}")
# Warn if vision tokens were lost
if has_vision_loss:
logging.warning("⚠️ WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!")
logging.warning(f"⚠️ Lost vision token IDs: {vision_tokens_lost}")
logging.warning(f"⚠️ Vision tokens before: {vision_tokens_before}")
logging.warning(f"⚠️ Vision tokens after: {vision_tokens_after}")
logging.warning("⚠️ The model will NOT see the image information!")
elif len(vision_tokens_before) > 0:
logging.info(f"✓ Vision tokens preserved: {vision_tokens_before}")
logging.info("-" * 80)
# Log tokens before truncation
logging.info("[BEFORE TRUNCATION]")
tokens_before_str = tokens_to_readable(valid_tokens_before)
logging.info(f"Tokens: {tokens_before_str}")
# logging.info(f"Decoded text: {_tok.decode(valid_tokens_before, skip_special_tokens=False)}")
logging.info("-" * 80)
# Log tokens after truncation
logging.info("[AFTER TRUNCATION]")
tokens_after_str = tokens_to_readable(valid_tokens_after)
logging.info(f"Tokens: {tokens_after_str}")
# logging.info(f"Decoded text: {_tok.decode(valid_tokens_after, skip_special_tokens=False)}")
logging.info("=" * 80)
def log_rollout_input(
prompts: torch.Tensor,
prompt_mask: torch.Tensor,
processing_class,
sample_idx: int = 0
) -> None:
"""
Log the input tokens before model generation (rollout).
Args:
prompts: Prompt token IDs [batch_size, seq_len]
prompt_mask: Attention mask [batch_size, seq_len]
processing_class: Tokenizer or processor for decoding
sample_idx: Index of sample to log (default: 0, first sample in batch)
"""
# Get tokenizer
_tok = getattr(processing_class, "tokenizer", processing_class)
# Check for vision/image tokens
vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"]
vision_token_ids = []
for token_name in vision_token_names:
try:
token_id = _tok.encode(token_name, add_special_tokens=False)
if isinstance(token_id, list) and len(token_id) > 0:
vision_token_ids.append(token_id[0])
except Exception:
pass
# Extract single sample
prompt = prompts[sample_idx]
mask = prompt_mask[sample_idx]
# Filter out padding tokens
valid_tokens = prompt[mask.bool()].tolist()
# Check for vision tokens
vision_tokens_present = set(valid_tokens) & set(vision_token_ids)
has_vision = len(vision_tokens_present) > 0
# Convert token IDs to readable format
def tokens_to_readable(token_ids):
"""Convert token IDs to readable string with special tokens visible."""
# ANSI escape codes for colors
GREEN = "\033[92m"
RESET = "\033[0m"
tokens = []
for tid in token_ids:
try:
token_str = _tok.decode([tid], skip_special_tokens=False)
# Highlight vision tokens
if tid in vision_token_ids:
tokens.append(f"{GREEN}[IMG]{token_str.strip()}[/IMG]{RESET}")
# Show special tokens
elif tid == _tok.pad_token_id:
tokens.append(f"<|pad|>")
elif tid == _tok.eos_token_id:
tokens.append(f"<|eos|>")
elif tid == _tok.bos_token_id:
tokens.append(f"<|bos|>")
elif token_str.strip() in ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]:
tokens.append(token_str.strip())
else:
tokens.append(f"[{tid}:{repr(token_str)}]")
except Exception:
tokens.append(f"[{tid}:?]")
return " ".join(tokens)
# Log information
logging.info("=" * 80)
logging.info(f"[ROLLOUT INPUT] Sample {sample_idx}")
logging.info(f"Prompt length: {len(valid_tokens)} tokens")
logging.info(f"Batch shape: {prompts.shape}")
if has_vision:
logging.info(f"✓ Contains vision tokens: {vision_tokens_present}")
else:
logging.info("ℹ️ No vision tokens detected (text-only prompt)")
logging.info("-" * 80)
# Log tokens
logging.info("[INPUT TOKENS]")
tokens_str = tokens_to_readable(valid_tokens)
logging.info(f"Tokens: {tokens_str}")
logging.info(f"Decoded text: {_tok.decode(valid_tokens, skip_special_tokens=False)}")
logging.info("=" * 80)
def persist_grpo_logs(
log_file: str,
jsonl_file: str,
step: int,
mode: str,
prompt_texts: list[str],
completion_texts: list[str],
rewards: list[float],
rewards_by_func: dict[str, list[float]],
token_counts: list[int],
ground_truths: list[str] | None,
solutions_extracted: list[str] | None,
verifies: list[bool] | None,
reward_func_names: list[str],
stop_reasons: list[str] | None = None,
) -> None:
"""
Append per-sample human-readable and JSONL logs for GRPO.
"""
try:
# Flatten possibly nested lists (from distributed gather)
def _flatten(lst):
if isinstance(lst, list) and len(lst) > 0 and isinstance(lst[0], list):
return [item for sub in lst for item in sub]
return lst
prompt_texts = _flatten(prompt_texts)
completion_texts = _flatten(completion_texts)
rewards = _flatten(rewards)
token_counts = _flatten(token_counts)
rewards_by_func = {k: _flatten(v) for k, v in rewards_by_func.items()}
stop_reasons = _flatten(stop_reasons) if stop_reasons is not None else None
ground_truths = _flatten(ground_truths) if ground_truths is not None else None
solutions_extracted = _flatten(solutions_extracted) if solutions_extracted is not None else None
verifies = _flatten(verifies) if verifies is not None else None
# Guard against length mismatches
n = min(
len(prompt_texts),
len(completion_texts),
len(rewards),
len(token_counts),
*[len(rewards_by_func[name]) for name in reward_func_names],
*( [len(ground_truths)] if ground_truths is not None else [] ),
*( [len(solutions_extracted)] if solutions_extracted is not None else [] ),
*( [len(verifies)] if verifies is not None else [] ),
*( [len(stop_reasons)] if stop_reasons is not None else [] ),
)
if n == 0:
return
with open(log_file, "a", encoding="utf-8") as f_txt:
f_txt.write(f"\n{'='*80}\n")
f_txt.write(f"Step: {step} | Mode: {mode}\n")
f_txt.write(f"{'='*80}\n")
for idx in range(n):
p_txt = prompt_texts[idx]
c_txt = completion_texts[idx]
r_total = rewards[idx]
f_txt.write(f"\n[Sample {idx}]\n")
f_txt.write(f"Prompt: {p_txt}\n")
comp_str = ", ".join([f"{name}: {float(rewards_by_func[name][idx]):.6f}" for name in reward_func_names])
f_txt.write(f"Reward: {float(r_total):.6f} | Components: {comp_str}\n")
if ground_truths is not None:
f_txt.write(f"Ground truth: {ground_truths[idx]}\n")
if solutions_extracted is not None:
f_txt.write(f"Solution: {solutions_extracted[idx]}\n")
if verifies is not None:
f_txt.write(f"Verify: {bool(verifies[idx])}\n")
s_reason = (
stop_reasons[idx]
if stop_reasons is not None and idx < len(stop_reasons)
else "unknown"
)
f_txt.write(f"Stop reason: {s_reason}\n")
# Always place completion last in the per-sample block
f_txt.write(f"Completion: {c_txt}\n")
f_txt.write(f"{'-'*80}\n")
with open(jsonl_file, "a", encoding="utf-8") as f_jsonl:
for idx in range(n):
s_reason = (
stop_reasons[idx]
if stop_reasons is not None and idx < len(stop_reasons)
else "unknown"
)
record = {
"reward": float(rewards[idx]),
"token_count": int(token_counts[idx]),
# "step": int(step),
# "mode": mode,
# "sample_index": int(idx),
"stop_reason": s_reason,
}
if ground_truths is not None:
record["ground_truth"] = ground_truths[idx]
if solutions_extracted is not None:
record["solution"] = solutions_extracted[idx]
if verifies is not None:
record["verify"] = bool(verifies[idx])
# Ensure completion is always the last field
record["completion"] = completion_texts[idx]
f_jsonl.write(json.dumps(record, ensure_ascii=False) + "\n")
except Exception as e:
logging.warning(f"Failed to persist GRPO logs: {e}")