File size: 17,807 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import torch
import os
import json
import logging
from typing import List, Dict

# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`):
            Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`:
            Standard deviation of the tensor, ignoring NaNs.
    """
    variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2)  # Compute variance ignoring NaNs
    count = torch.sum(~torch.isnan(tensor))  # Count of non-NaN values
    variance *= count / (count - 1)  # Bessel's correction
    return torch.sqrt(variance)

def nanmax(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`): Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
    """
    if torch.isnan(tensor).all():
        return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
    return torch.max(tensor[~torch.isnan(tensor)])

def nanmin(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`): Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
    """
    if torch.isnan(tensor).all():
        return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
    return torch.min(tensor[~torch.isnan(tensor)])


def init_grpo_log_files(output_dir: str) -> tuple[str, str]:
    """
    Initialize GRPO log files (human-readable txt and machine-readable jsonl).

    Returns the tuple of (txt_log_path, jsonl_log_path).
    """
    grpo_log_file = os.path.join(output_dir, "../logs/grpo_logs.txt")
    grpo_jsonl_file = os.path.join(output_dir, "../logs/grpo_samples.jsonl")
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.dirname(grpo_log_file), exist_ok=True)

    # Create/clear the log file
    with open(grpo_log_file, "w", encoding="utf-8") as f:
        f.write("=" * 80 + "\n")
        f.write("GRPO Training Logs - WeaverGRPOTrainer\n")
        f.write("=" * 80 + "\n\n")

    # Create/clear the JSONL file
    with open(grpo_jsonl_file, "w", encoding="utf-8"):
        pass

    return grpo_log_file, grpo_jsonl_file


def log_prompt_truncation(
    prompts_before: torch.Tensor,
    prompts_after: torch.Tensor,
    prompt_mask_before: torch.Tensor,
    prompt_mask_after: torch.Tensor,
    processing_class,
    max_prompt_length: int,
    sample_idx: int = 0
) -> None:
    """
    Log prompt before and after truncation in token format.
    Also checks if image/vision tokens were truncated.
    
    Args:
        prompts_before: Prompt token IDs before truncation [batch_size, seq_len_before]
        prompts_after: Prompt token IDs after truncation [batch_size, seq_len_after]
        prompt_mask_before: Attention mask before truncation
        prompt_mask_after: Attention mask after truncation
        processing_class: Tokenizer or processor for decoding
        max_prompt_length: Maximum prompt length configured
        sample_idx: Index of sample to log (default: 0, first sample in batch)
    """
    # Get tokenizer
    _tok = getattr(processing_class, "tokenizer", processing_class)
    
    # Check for vision/image tokens - use known IDs directly
    # Qwen2.5-VL vision token IDs:
    # 151652: <|vision_start|>
    # 151653: <|vision_end|>  
    # 151654: <|video_pad|>
    # 151655: <|image_pad|>
    vision_token_ids = [151652, 151653, 151654, 151655]
    
    # Also try to get them from tokenizer
    vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"]
    for token_name in vision_token_names:
        try:
            token_id = _tok.encode(token_name, add_special_tokens=False)
            if isinstance(token_id, list) and len(token_id) > 0:
                if token_id[0] not in vision_token_ids:
                    vision_token_ids.append(token_id[0])
        except Exception:
            pass
    
    # Extract single sample
    prompt_before = prompts_before[sample_idx]
    prompt_after = prompts_after[sample_idx]
    mask_before = prompt_mask_before[sample_idx]
    mask_after = prompt_mask_after[sample_idx]
    
    # Filter out padding tokens (where mask == 0)
    valid_tokens_before = prompt_before[mask_before.bool()].tolist()
    valid_tokens_after = prompt_after[mask_after.bool()].tolist()
    
    # Check if vision tokens were truncated
    vision_tokens_before = set(valid_tokens_before) & set(vision_token_ids)
    vision_tokens_after = set(valid_tokens_after) & set(vision_token_ids)
    vision_tokens_lost = vision_tokens_before - vision_tokens_after
    has_vision_loss = len(vision_tokens_lost) > 0
    
    # Convert token IDs to readable format with special tokens
    def tokens_to_readable(token_ids):
        """Convert token IDs to readable string with special tokens visible."""
        # ANSI escape codes for colors
        GREEN = "\033[92m"
        RESET = "\033[0m"
        
        tokens = []
        prev_tid = None
        consecutive_count = 0
        
        for tid in token_ids:
            try:
                # Decode single token
                token_str = _tok.decode([tid], skip_special_tokens=False)
                
                # Check if this is image_pad (151655) or other vision pad tokens
                is_image_pad = tid == 151655 or (tid in vision_token_ids and 'pad' in token_str.lower())
                
                # If consecutive image_pad tokens, just count them
                if is_image_pad and prev_tid == tid:
                    consecutive_count += 1
                    continue
                else:
                    # Output the previous consecutive tokens if any
                    if consecutive_count > 0 and prev_tid is not None:
                        prev_str = _tok.decode([prev_tid], skip_special_tokens=False)
                        tokens.append(f"{GREEN}[IMG]{prev_str.strip()}[/IMG]{RESET}×{consecutive_count + 1}")
                        consecutive_count = 0
                    
                    # Highlight vision tokens
                    if tid in vision_token_ids:
                        if is_image_pad:
                            prev_tid = tid
                            consecutive_count = 0
                            continue  # Will be added in next iteration or at the end
                        else:
                            tokens.append(f"{GREEN}[IMG]{token_str.strip()}[/IMG]{RESET}")
                    # Show special tokens
                    elif tid == _tok.pad_token_id:
                        tokens.append(f"<|pad|>")
                    elif tid == _tok.eos_token_id:
                        tokens.append(f"<|eos|>")
                    elif tid == _tok.bos_token_id:
                        tokens.append(f"<|bos|>")
                    elif token_str.strip() in ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]:
                        tokens.append(token_str.strip())
                    else:
                        tokens.append(f"[{tid}:{repr(token_str)}]")
                    
                    prev_tid = tid
            except Exception:
                tokens.append(f"[{tid}:?]")
                prev_tid = tid
        
        # Handle any remaining consecutive tokens at the end
        if consecutive_count > 0 and prev_tid is not None:
            try:
                prev_str = _tok.decode([prev_tid], skip_special_tokens=False)
                tokens.append(f"{GREEN}[IMG]{prev_str.strip()}[/IMG]{RESET}×{consecutive_count + 1}")
            except Exception:
                pass
        
        return " ".join(tokens)
    
    # Log information
    logging.info("=" * 80)
    logging.info(f"[PROMPT TRUNCATION] Sample {sample_idx}")
    logging.info(f"Length before truncation: {len(valid_tokens_before)}")
    logging.info(f"Length after truncation: {len(valid_tokens_after)}")
    logging.info(f"Max prompt length: {max_prompt_length}")
    logging.info(f"Tokens truncated: {len(valid_tokens_before) - len(valid_tokens_after)}")
    
    # Warn if vision tokens were lost
    if has_vision_loss:
        logging.warning("⚠️  WARNING: IMAGE/VISION TOKENS WERE TRUNCATED!")
        logging.warning(f"⚠️  Lost vision token IDs: {vision_tokens_lost}")
        logging.warning(f"⚠️  Vision tokens before: {vision_tokens_before}")
        logging.warning(f"⚠️  Vision tokens after: {vision_tokens_after}")
        logging.warning("⚠️  The model will NOT see the image information!")
    elif len(vision_tokens_before) > 0:
        logging.info(f"✓ Vision tokens preserved: {vision_tokens_before}")
    
    logging.info("-" * 80)
    
    # Log tokens before truncation
    logging.info("[BEFORE TRUNCATION]")
    tokens_before_str = tokens_to_readable(valid_tokens_before)
    logging.info(f"Tokens: {tokens_before_str}")
    # logging.info(f"Decoded text: {_tok.decode(valid_tokens_before, skip_special_tokens=False)}")
    logging.info("-" * 80)
    
    # Log tokens after truncation
    logging.info("[AFTER TRUNCATION]")
    tokens_after_str = tokens_to_readable(valid_tokens_after)
    logging.info(f"Tokens: {tokens_after_str}")
    # logging.info(f"Decoded text: {_tok.decode(valid_tokens_after, skip_special_tokens=False)}")
    logging.info("=" * 80)


def log_rollout_input(
    prompts: torch.Tensor,
    prompt_mask: torch.Tensor,
    processing_class,
    sample_idx: int = 0
) -> None:
    """
    Log the input tokens before model generation (rollout).
    
    Args:
        prompts: Prompt token IDs [batch_size, seq_len]
        prompt_mask: Attention mask [batch_size, seq_len]
        processing_class: Tokenizer or processor for decoding
        sample_idx: Index of sample to log (default: 0, first sample in batch)
    """
    # Get tokenizer
    _tok = getattr(processing_class, "tokenizer", processing_class)
    
    # Check for vision/image tokens
    vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"]
    vision_token_ids = []
    for token_name in vision_token_names:
        try:
            token_id = _tok.encode(token_name, add_special_tokens=False)
            if isinstance(token_id, list) and len(token_id) > 0:
                vision_token_ids.append(token_id[0])
        except Exception:
            pass
    
    # Extract single sample
    prompt = prompts[sample_idx]
    mask = prompt_mask[sample_idx]
    
    # Filter out padding tokens
    valid_tokens = prompt[mask.bool()].tolist()
    
    # Check for vision tokens
    vision_tokens_present = set(valid_tokens) & set(vision_token_ids)
    has_vision = len(vision_tokens_present) > 0
    
    # Convert token IDs to readable format
    def tokens_to_readable(token_ids):
        """Convert token IDs to readable string with special tokens visible."""
        # ANSI escape codes for colors
        GREEN = "\033[92m"
        RESET = "\033[0m"
        
        tokens = []
        for tid in token_ids:
            try:
                token_str = _tok.decode([tid], skip_special_tokens=False)
                # Highlight vision tokens
                if tid in vision_token_ids:
                    tokens.append(f"{GREEN}[IMG]{token_str.strip()}[/IMG]{RESET}")
                # Show special tokens
                elif tid == _tok.pad_token_id:
                    tokens.append(f"<|pad|>")
                elif tid == _tok.eos_token_id:
                    tokens.append(f"<|eos|>")
                elif tid == _tok.bos_token_id:
                    tokens.append(f"<|bos|>")
                elif token_str.strip() in ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]:
                    tokens.append(token_str.strip())
                else:
                    tokens.append(f"[{tid}:{repr(token_str)}]")
            except Exception:
                tokens.append(f"[{tid}:?]")
        return " ".join(tokens)
    
    # Log information
    logging.info("=" * 80)
    logging.info(f"[ROLLOUT INPUT] Sample {sample_idx}")
    logging.info(f"Prompt length: {len(valid_tokens)} tokens")
    logging.info(f"Batch shape: {prompts.shape}")
    
    if has_vision:
        logging.info(f"✓ Contains vision tokens: {vision_tokens_present}")
    else:
        logging.info("ℹ️  No vision tokens detected (text-only prompt)")
    
    logging.info("-" * 80)
    
    # Log tokens
    logging.info("[INPUT TOKENS]")
    tokens_str = tokens_to_readable(valid_tokens)
    logging.info(f"Tokens: {tokens_str}")
    logging.info(f"Decoded text: {_tok.decode(valid_tokens, skip_special_tokens=False)}")
    logging.info("=" * 80)


def persist_grpo_logs(
    log_file: str,
    jsonl_file: str,
    step: int,
    mode: str,
    prompt_texts: list[str],
    completion_texts: list[str],
    rewards: list[float],
    rewards_by_func: dict[str, list[float]],
    token_counts: list[int],
    ground_truths: list[str] | None,
    solutions_extracted: list[str] | None,
    verifies: list[bool] | None,
    reward_func_names: list[str],
    stop_reasons: list[str] | None = None,
) -> None:
    """
    Append per-sample human-readable and JSONL logs for GRPO.
    """
    try:
        # Flatten possibly nested lists (from distributed gather)
        def _flatten(lst):
            if isinstance(lst, list) and len(lst) > 0 and isinstance(lst[0], list):
                return [item for sub in lst for item in sub]
            return lst

        prompt_texts = _flatten(prompt_texts)
        completion_texts = _flatten(completion_texts)
        rewards = _flatten(rewards)
        token_counts = _flatten(token_counts)
        rewards_by_func = {k: _flatten(v) for k, v in rewards_by_func.items()}
        stop_reasons = _flatten(stop_reasons) if stop_reasons is not None else None
        ground_truths = _flatten(ground_truths) if ground_truths is not None else None
        solutions_extracted = _flatten(solutions_extracted) if solutions_extracted is not None else None
        verifies = _flatten(verifies) if verifies is not None else None

        # Guard against length mismatches
        n = min(
            len(prompt_texts),
            len(completion_texts),
            len(rewards),
            len(token_counts),
            *[len(rewards_by_func[name]) for name in reward_func_names],
            *( [len(ground_truths)] if ground_truths is not None else [] ),
            *( [len(solutions_extracted)] if solutions_extracted is not None else [] ),
            *( [len(verifies)] if verifies is not None else [] ),
            *( [len(stop_reasons)] if stop_reasons is not None else [] ),
        )
        if n == 0:
            return

        with open(log_file, "a", encoding="utf-8") as f_txt:
            f_txt.write(f"\n{'='*80}\n")
            f_txt.write(f"Step: {step} | Mode: {mode}\n")
            f_txt.write(f"{'='*80}\n")
            for idx in range(n):
                p_txt = prompt_texts[idx]
                c_txt = completion_texts[idx]
                r_total = rewards[idx]
                f_txt.write(f"\n[Sample {idx}]\n")
                f_txt.write(f"Prompt: {p_txt}\n")
                comp_str = ", ".join([f"{name}: {float(rewards_by_func[name][idx]):.6f}" for name in reward_func_names])
                f_txt.write(f"Reward: {float(r_total):.6f} | Components: {comp_str}\n")
                if ground_truths is not None:
                    f_txt.write(f"Ground truth: {ground_truths[idx]}\n")
                if solutions_extracted is not None:
                    f_txt.write(f"Solution: {solutions_extracted[idx]}\n")
                if verifies is not None:
                    f_txt.write(f"Verify: {bool(verifies[idx])}\n")
                s_reason = (
                    stop_reasons[idx]
                    if stop_reasons is not None and idx < len(stop_reasons)
                    else "unknown"
                )
                f_txt.write(f"Stop reason: {s_reason}\n")
                # Always place completion last in the per-sample block
                f_txt.write(f"Completion: {c_txt}\n")
                f_txt.write(f"{'-'*80}\n")

        with open(jsonl_file, "a", encoding="utf-8") as f_jsonl:
            for idx in range(n):
                s_reason = (
                    stop_reasons[idx]
                    if stop_reasons is not None and idx < len(stop_reasons)
                    else "unknown"
                )
                record = {
                    "reward": float(rewards[idx]),
                    "token_count": int(token_counts[idx]),
                    # "step": int(step),
                    # "mode": mode,
                    # "sample_index": int(idx),
                    "stop_reason": s_reason,
                }
                if ground_truths is not None:
                    record["ground_truth"] = ground_truths[idx]
                if solutions_extracted is not None:
                    record["solution"] = solutions_extracted[idx]
                if verifies is not None:
                    record["verify"] = bool(verifies[idx])
                # Ensure completion is always the last field
                record["completion"] = completion_texts[idx]
                f_jsonl.write(json.dumps(record, ensure_ascii=False) + "\n")
    except Exception as e:
        logging.warning(f"Failed to persist GRPO logs: {e}")