File size: 4,864 Bytes
9030cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
RAE Tokenizer Utilities
═══════════════════════════════════════════════════════════════
Phase-aware tokenization for RAE training data.

Handles the special structure of RAE responses where XML-style
phase tags delineate cognitive phases. Ensures proper tokenization
of phase boundaries and provides utilities for phase-level analysis.
═══════════════════════════════════════════════════════════════
"""

from typing import Optional
import re


PHASE_TAGS = {
    "saturation": ("<SATURATION>", "</SATURATION>"),
    "abstraction": ("<ABSTRACTION>", "</ABSTRACTION>"),
    "descent": ("<DESCENT>", "</DESCENT>"),
    "integration": ("<INTEGRATION>", "</INTEGRATION>"),
}

ALL_TAGS = []
for open_tag, close_tag in PHASE_TAGS.values():
    ALL_TAGS.extend([open_tag, close_tag])


def add_rae_tokens(tokenizer):
    """
    Add RAE phase tags as special tokens to the tokenizer.
    
    This ensures phase boundaries are tokenized as single tokens
    rather than being split across subwords, which makes phase
    detection much more reliable during loss computation.
    """
    special_tokens = {"additional_special_tokens": ALL_TAGS}
    num_added = tokenizer.add_special_tokens(special_tokens)
    
    if num_added > 0:
        print(f"  Added {num_added} RAE phase tokens to tokenizer")
    
    return tokenizer, num_added


def extract_phases(text: str) -> dict[str, str]:
    """Extract phase content from RAE-structured text."""
    phases = {}
    for phase_name, (open_tag, close_tag) in PHASE_TAGS.items():
        pattern = re.escape(open_tag) + r"(.*?)" + re.escape(close_tag)
        match = re.search(pattern, text, re.DOTALL)
        phases[phase_name] = match.group(1).strip() if match else ""
    return phases


def validate_rae_response(text: str) -> dict:
    """
    Validate that a response contains proper RAE structure.
    
    Returns a report with:
    - is_valid: bool
    - phases_found: list of phase names found
    - phases_missing: list of phase names missing
    - compression_ratio: abstraction_len / saturation_len
    - warnings: list of potential issues
    """
    phases = extract_phases(text)
    found = [name for name, content in phases.items() if content]
    missing = [name for name, content in phases.items() if not content]
    
    warnings = []
    
    # Check phase ordering
    if found:
        expected_order = ["saturation", "abstraction", "descent", "integration"]
        found_order = [p for p in expected_order if p in found]
        if found_order != [p for p in found if p in expected_order]:
            warnings.append("Phases appear out of order")
    
    # Check compression
    compression_ratio = None
    sat_len = len(phases.get("saturation", "").split())
    abs_len = len(phases.get("abstraction", "").split())
    if sat_len > 0:
        compression_ratio = abs_len / sat_len
        if compression_ratio > 1.0:
            warnings.append(f"Abstraction is LONGER than Saturation (ratio={compression_ratio:.2f})")
    
    # Check for degenerate phases
    for phase_name, content in phases.items():
        word_count = len(content.split())
        if content and word_count < 10:
            warnings.append(f"{phase_name} is very short ({word_count} words)")
        if content and word_count > 1000:
            warnings.append(f"{phase_name} is very long ({word_count} words)")
    
    return {
        "is_valid": len(found) == 4 and len(warnings) == 0,
        "phases_found": found,
        "phases_missing": missing,
        "phase_lengths": {name: len(content.split()) for name, content in phases.items()},
        "compression_ratio": compression_ratio,
        "warnings": warnings,
    }


def format_rae_chat(
    system_prompt: str,
    user_message: str,
    phases: dict[str, str],
    tokenizer=None,
) -> str:
    """
    Format RAE phases into a chat-template-ready message.
    
    If tokenizer is provided, applies the chat template.
    Otherwise returns raw message list.
    """
    assistant_content = ""
    for phase_name in ["saturation", "abstraction", "descent", "integration"]:
        open_tag, close_tag = PHASE_TAGS[phase_name]
        content = phases.get(phase_name, "")
        assistant_content += f"{open_tag}\n{content}\n{close_tag}\n\n"
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": assistant_content.strip()},
    ]
    
    if tokenizer:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    
    return messages