|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from typing import List, Dict |
|
|
|
|
|
|
|
|
def nanstd(tensor: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors. |
|
|
|
|
|
Args: |
|
|
tensor (`torch.Tensor`): |
|
|
Input tensor of shape `(N,)`. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: |
|
|
Standard deviation of the tensor, ignoring NaNs. |
|
|
""" |
|
|
variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) |
|
|
count = torch.sum(~torch.isnan(tensor)) |
|
|
variance *= count / (count - 1) |
|
|
return torch.sqrt(variance) |
|
|
|
|
|
def nanmax(tensor: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors. |
|
|
|
|
|
Args: |
|
|
tensor (`torch.Tensor`): Input tensor of shape `(N,)`. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. |
|
|
""" |
|
|
if torch.isnan(tensor).all(): |
|
|
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) |
|
|
return torch.max(tensor[~torch.isnan(tensor)]) |
|
|
|
|
|
def nanmin(tensor: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. |
|
|
|
|
|
Args: |
|
|
tensor (`torch.Tensor`): Input tensor of shape `(N,)`. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. |
|
|
""" |
|
|
if torch.isnan(tensor).all(): |
|
|
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) |
|
|
return torch.min(tensor[~torch.isnan(tensor)]) |
|
|
|
|
|
|
|
|
def init_grpo_log_files(output_dir: str) -> tuple[str, str]: |
|
|
""" |
|
|
Initialize GRPO log files (human-readable txt and machine-readable jsonl). |
|
|
|
|
|
Returns the tuple of (txt_log_path, jsonl_log_path). |
|
|
""" |
|
|
grpo_log_file = os.path.join(output_dir, "../logs/grpo_logs.txt") |
|
|
grpo_jsonl_file = os.path.join(output_dir, "../logs/grpo_samples.jsonl") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
os.makedirs(os.path.dirname(grpo_log_file), exist_ok=True) |
|
|
|
|
|
|
|
|
with open(grpo_log_file, "w", encoding="utf-8") as f: |
|
|
f.write("=" * 80 + "\n") |
|
|
f.write("GRPO Training Logs - WeaverGRPOTrainer\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
|
|
|
with open(grpo_jsonl_file, "w", encoding="utf-8"): |
|
|
pass |
|
|
|
|
|
return grpo_log_file, grpo_jsonl_file |
|
|
|
|
|
|
|
|
def log_prompt_truncation( |
|
|
prompts_before: torch.Tensor, |
|
|
prompts_after: torch.Tensor, |
|
|
prompt_mask_before: torch.Tensor, |
|
|
prompt_mask_after: torch.Tensor, |
|
|
processing_class, |
|
|
max_prompt_length: int, |
|
|
sample_idx: int = 0 |
|
|
) -> None: |
|
|
""" |
|
|
Log prompt before and after truncation in token format. |
|
|
Also checks if image/vision tokens were truncated. |
|
|
|
|
|
Args: |
|
|
prompts_before: Prompt token IDs before truncation [batch_size, seq_len_before] |
|
|
prompts_after: Prompt token IDs after truncation [batch_size, seq_len_after] |
|
|
prompt_mask_before: Attention mask before truncation |
|
|
prompt_mask_after: Attention mask after truncation |
|
|
processing_class: Tokenizer or processor for decoding |
|
|
max_prompt_length: Maximum prompt length configured |
|
|
sample_idx: Index of sample to log (default: 0, first sample in batch) |
|
|
""" |
|
|
|
|
|
_tok = getattr(processing_class, "tokenizer", processing_class) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vision_token_ids = [151652, 151653, 151654, 151655] |
|
|
|
|
|
|
|
|
vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"] |
|
|
for token_name in vision_token_names: |
|
|
try: |
|
|
token_id = _tok.encode(token_name, add_special_tokens=False) |
|
|
if isinstance(token_id, list) and len(token_id) > 0: |
|
|
if token_id[0] not in vision_token_ids: |
|
|
vision_token_ids.append(token_id[0]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
prompt_before = prompts_before[sample_idx] |
|
|
prompt_after = prompts_after[sample_idx] |
|
|
mask_before = prompt_mask_before[sample_idx] |
|
|
mask_after = prompt_mask_after[sample_idx] |
|
|
|
|
|
|
|
|
valid_tokens_before = prompt_before[mask_before.bool()].tolist() |
|
|
valid_tokens_after = prompt_after[mask_after.bool()].tolist() |
|
|
|
|
|
|
|
|
vision_tokens_before = set(valid_tokens_before) & set(vision_token_ids) |
|
|
vision_tokens_after = set(valid_tokens_after) & set(vision_token_ids) |
|
|
vision_tokens_lost = vision_tokens_before - vision_tokens_after |
|
|
has_vision_loss = len(vision_tokens_lost) > 0 |
|
|
|
|
|
|
|
|
def tokens_to_readable(token_ids): |
|
|
"""Convert token IDs to readable string with special tokens visible.""" |
|
|
|
|
|
GREEN = "\033[92m" |
|
|
RESET = "\033[0m" |
|
|
|
|
|
tokens = [] |
|
|
prev_tid = None |
|
|
consecutive_count = 0 |
|
|
|
|
|
for tid in token_ids: |
|
|
try: |
|
|
|
|
|
token_str = _tok.decode([tid], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
is_image_pad = tid == 151655 or (tid in vision_token_ids and 'pad' in token_str.lower()) |
|
|
|
|
|
|
|
|
if is_image_pad and prev_tid == tid: |
|
|
consecutive_count += 1 |
|
|
continue |
|
|
else: |
|
|
|
|
|
if consecutive_count > 0 and prev_tid is not None: |
|
|
prev_str = _tok.decode([prev_tid], skip_special_tokens=False) |
|
|
tokens.append(f"{GREEN}[IMG]{prev_str.strip()}[/IMG]{RESET}×{consecutive_count + 1}") |
|
|
consecutive_count = 0 |
|
|
|
|
|
|
|
|
if tid in vision_token_ids: |
|
|
if is_image_pad: |
|
|
prev_tid = tid |
|
|
consecutive_count = 0 |
|
|
continue |
|
|
else: |
|
|
tokens.append(f"{GREEN}[IMG]{token_str.strip()}[/IMG]{RESET}") |
|
|
|
|
|
elif tid == _tok.pad_token_id: |
|
|
tokens.append(f"<|pad|>") |
|
|
elif tid == _tok.eos_token_id: |
|
|
tokens.append(f"<|eos|>") |
|
|
elif tid == _tok.bos_token_id: |
|
|
tokens.append(f"<|bos|>") |
|
|
elif token_str.strip() in ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]: |
|
|
tokens.append(token_str.strip()) |
|
|
else: |
|
|
tokens.append(f"[{tid}:{repr(token_str)}]") |
|
|
|
|
|
prev_tid = tid |
|
|
except Exception: |
|
|
tokens.append(f"[{tid}:?]") |
|
|
prev_tid = tid |
|
|
|
|
|
|
|
|
if consecutive_count > 0 and prev_tid is not None: |
|
|
try: |
|
|
prev_str = _tok.decode([prev_tid], skip_special_tokens=False) |
|
|
tokens.append(f"{GREEN}[IMG]{prev_str.strip()}[/IMG]{RESET}×{consecutive_count + 1}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return " ".join(tokens) |
|
|
|
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info(f"[PROMPT TRUNCATION] Sample {sample_idx}") |
|
|
logging.info(f"Length before truncation: {len(valid_tokens_before)}") |
|
|
logging.info(f"Length after truncation: {len(valid_tokens_after)}") |
|
|
logging.info(f"Max prompt length: {max_prompt_length}") |
|
|
logging.info(f"Tokens truncated: {len(valid_tokens_before) - len(valid_tokens_after)}") |
|
|
|
|
|
|
|
|
if has_vision_loss: |
|
|
logging.warning("⚠️ WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!") |
|
|
logging.warning(f"⚠️ Lost vision token IDs: {vision_tokens_lost}") |
|
|
logging.warning(f"⚠️ Vision tokens before: {vision_tokens_before}") |
|
|
logging.warning(f"⚠️ Vision tokens after: {vision_tokens_after}") |
|
|
logging.warning("⚠️ The model will NOT see the image information!") |
|
|
elif len(vision_tokens_before) > 0: |
|
|
logging.info(f"✓ Vision tokens preserved: {vision_tokens_before}") |
|
|
|
|
|
logging.info("-" * 80) |
|
|
|
|
|
|
|
|
logging.info("[BEFORE TRUNCATION]") |
|
|
tokens_before_str = tokens_to_readable(valid_tokens_before) |
|
|
logging.info(f"Tokens: {tokens_before_str}") |
|
|
|
|
|
logging.info("-" * 80) |
|
|
|
|
|
|
|
|
logging.info("[AFTER TRUNCATION]") |
|
|
tokens_after_str = tokens_to_readable(valid_tokens_after) |
|
|
logging.info(f"Tokens: {tokens_after_str}") |
|
|
|
|
|
logging.info("=" * 80) |
|
|
|
|
|
|
|
|
def log_rollout_input( |
|
|
prompts: torch.Tensor, |
|
|
prompt_mask: torch.Tensor, |
|
|
processing_class, |
|
|
sample_idx: int = 0 |
|
|
) -> None: |
|
|
""" |
|
|
Log the input tokens before model generation (rollout). |
|
|
|
|
|
Args: |
|
|
prompts: Prompt token IDs [batch_size, seq_len] |
|
|
prompt_mask: Attention mask [batch_size, seq_len] |
|
|
processing_class: Tokenizer or processor for decoding |
|
|
sample_idx: Index of sample to log (default: 0, first sample in batch) |
|
|
""" |
|
|
|
|
|
_tok = getattr(processing_class, "tokenizer", processing_class) |
|
|
|
|
|
|
|
|
vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"] |
|
|
vision_token_ids = [] |
|
|
for token_name in vision_token_names: |
|
|
try: |
|
|
token_id = _tok.encode(token_name, add_special_tokens=False) |
|
|
if isinstance(token_id, list) and len(token_id) > 0: |
|
|
vision_token_ids.append(token_id[0]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
prompt = prompts[sample_idx] |
|
|
mask = prompt_mask[sample_idx] |
|
|
|
|
|
|
|
|
valid_tokens = prompt[mask.bool()].tolist() |
|
|
|
|
|
|
|
|
vision_tokens_present = set(valid_tokens) & set(vision_token_ids) |
|
|
has_vision = len(vision_tokens_present) > 0 |
|
|
|
|
|
|
|
|
def tokens_to_readable(token_ids): |
|
|
"""Convert token IDs to readable string with special tokens visible.""" |
|
|
|
|
|
GREEN = "\033[92m" |
|
|
RESET = "\033[0m" |
|
|
|
|
|
tokens = [] |
|
|
for tid in token_ids: |
|
|
try: |
|
|
token_str = _tok.decode([tid], skip_special_tokens=False) |
|
|
|
|
|
if tid in vision_token_ids: |
|
|
tokens.append(f"{GREEN}[IMG]{token_str.strip()}[/IMG]{RESET}") |
|
|
|
|
|
elif tid == _tok.pad_token_id: |
|
|
tokens.append(f"<|pad|>") |
|
|
elif tid == _tok.eos_token_id: |
|
|
tokens.append(f"<|eos|>") |
|
|
elif tid == _tok.bos_token_id: |
|
|
tokens.append(f"<|bos|>") |
|
|
elif token_str.strip() in ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]: |
|
|
tokens.append(token_str.strip()) |
|
|
else: |
|
|
tokens.append(f"[{tid}:{repr(token_str)}]") |
|
|
except Exception: |
|
|
tokens.append(f"[{tid}:?]") |
|
|
return " ".join(tokens) |
|
|
|
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info(f"[ROLLOUT INPUT] Sample {sample_idx}") |
|
|
logging.info(f"Prompt length: {len(valid_tokens)} tokens") |
|
|
logging.info(f"Batch shape: {prompts.shape}") |
|
|
|
|
|
if has_vision: |
|
|
logging.info(f"✓ Contains vision tokens: {vision_tokens_present}") |
|
|
else: |
|
|
logging.info("ℹ️ No vision tokens detected (text-only prompt)") |
|
|
|
|
|
logging.info("-" * 80) |
|
|
|
|
|
|
|
|
logging.info("[INPUT TOKENS]") |
|
|
tokens_str = tokens_to_readable(valid_tokens) |
|
|
logging.info(f"Tokens: {tokens_str}") |
|
|
logging.info(f"Decoded text: {_tok.decode(valid_tokens, skip_special_tokens=False)}") |
|
|
logging.info("=" * 80) |
|
|
|
|
|
|
|
|
def persist_grpo_logs( |
|
|
log_file: str, |
|
|
jsonl_file: str, |
|
|
step: int, |
|
|
mode: str, |
|
|
prompt_texts: list[str], |
|
|
completion_texts: list[str], |
|
|
rewards: list[float], |
|
|
rewards_by_func: dict[str, list[float]], |
|
|
token_counts: list[int], |
|
|
ground_truths: list[str] | None, |
|
|
solutions_extracted: list[str] | None, |
|
|
verifies: list[bool] | None, |
|
|
reward_func_names: list[str], |
|
|
stop_reasons: list[str] | None = None, |
|
|
) -> None: |
|
|
""" |
|
|
Append per-sample human-readable and JSONL logs for GRPO. |
|
|
""" |
|
|
try: |
|
|
|
|
|
def _flatten(lst): |
|
|
if isinstance(lst, list) and len(lst) > 0 and isinstance(lst[0], list): |
|
|
return [item for sub in lst for item in sub] |
|
|
return lst |
|
|
|
|
|
prompt_texts = _flatten(prompt_texts) |
|
|
completion_texts = _flatten(completion_texts) |
|
|
rewards = _flatten(rewards) |
|
|
token_counts = _flatten(token_counts) |
|
|
rewards_by_func = {k: _flatten(v) for k, v in rewards_by_func.items()} |
|
|
stop_reasons = _flatten(stop_reasons) if stop_reasons is not None else None |
|
|
ground_truths = _flatten(ground_truths) if ground_truths is not None else None |
|
|
solutions_extracted = _flatten(solutions_extracted) if solutions_extracted is not None else None |
|
|
verifies = _flatten(verifies) if verifies is not None else None |
|
|
|
|
|
|
|
|
n = min( |
|
|
len(prompt_texts), |
|
|
len(completion_texts), |
|
|
len(rewards), |
|
|
len(token_counts), |
|
|
*[len(rewards_by_func[name]) for name in reward_func_names], |
|
|
*( [len(ground_truths)] if ground_truths is not None else [] ), |
|
|
*( [len(solutions_extracted)] if solutions_extracted is not None else [] ), |
|
|
*( [len(verifies)] if verifies is not None else [] ), |
|
|
*( [len(stop_reasons)] if stop_reasons is not None else [] ), |
|
|
) |
|
|
if n == 0: |
|
|
return |
|
|
|
|
|
with open(log_file, "a", encoding="utf-8") as f_txt: |
|
|
f_txt.write(f"\n{'='*80}\n") |
|
|
f_txt.write(f"Step: {step} | Mode: {mode}\n") |
|
|
f_txt.write(f"{'='*80}\n") |
|
|
for idx in range(n): |
|
|
p_txt = prompt_texts[idx] |
|
|
c_txt = completion_texts[idx] |
|
|
r_total = rewards[idx] |
|
|
f_txt.write(f"\n[Sample {idx}]\n") |
|
|
f_txt.write(f"Prompt: {p_txt}\n") |
|
|
comp_str = ", ".join([f"{name}: {float(rewards_by_func[name][idx]):.6f}" for name in reward_func_names]) |
|
|
f_txt.write(f"Reward: {float(r_total):.6f} | Components: {comp_str}\n") |
|
|
if ground_truths is not None: |
|
|
f_txt.write(f"Ground truth: {ground_truths[idx]}\n") |
|
|
if solutions_extracted is not None: |
|
|
f_txt.write(f"Solution: {solutions_extracted[idx]}\n") |
|
|
if verifies is not None: |
|
|
f_txt.write(f"Verify: {bool(verifies[idx])}\n") |
|
|
s_reason = ( |
|
|
stop_reasons[idx] |
|
|
if stop_reasons is not None and idx < len(stop_reasons) |
|
|
else "unknown" |
|
|
) |
|
|
f_txt.write(f"Stop reason: {s_reason}\n") |
|
|
|
|
|
f_txt.write(f"Completion: {c_txt}\n") |
|
|
f_txt.write(f"{'-'*80}\n") |
|
|
|
|
|
with open(jsonl_file, "a", encoding="utf-8") as f_jsonl: |
|
|
for idx in range(n): |
|
|
s_reason = ( |
|
|
stop_reasons[idx] |
|
|
if stop_reasons is not None and idx < len(stop_reasons) |
|
|
else "unknown" |
|
|
) |
|
|
record = { |
|
|
"reward": float(rewards[idx]), |
|
|
"token_count": int(token_counts[idx]), |
|
|
|
|
|
|
|
|
|
|
|
"stop_reason": s_reason, |
|
|
} |
|
|
if ground_truths is not None: |
|
|
record["ground_truth"] = ground_truths[idx] |
|
|
if solutions_extracted is not None: |
|
|
record["solution"] = solutions_extracted[idx] |
|
|
if verifies is not None: |
|
|
record["verify"] = bool(verifies[idx]) |
|
|
|
|
|
record["completion"] = completion_texts[idx] |
|
|
f_jsonl.write(json.dumps(record, ensure_ascii=False) + "\n") |
|
|
except Exception as e: |
|
|
logging.warning(f"Failed to persist GRPO logs: {e}") |