""" RAE Tokenizer Utilities ═══════════════════════════════════════════════════════════════ Phase-aware tokenization for RAE training data. Handles the special structure of RAE responses where XML-style phase tags delineate cognitive phases. Ensures proper tokenization of phase boundaries and provides utilities for phase-level analysis. ═══════════════════════════════════════════════════════════════ """ from typing import Optional import re PHASE_TAGS = { "saturation": ("", ""), "abstraction": ("", ""), "descent": ("", ""), "integration": ("", ""), } ALL_TAGS = [] for open_tag, close_tag in PHASE_TAGS.values(): ALL_TAGS.extend([open_tag, close_tag]) def add_rae_tokens(tokenizer): """ Add RAE phase tags as special tokens to the tokenizer. This ensures phase boundaries are tokenized as single tokens rather than being split across subwords, which makes phase detection much more reliable during loss computation. """ special_tokens = {"additional_special_tokens": ALL_TAGS} num_added = tokenizer.add_special_tokens(special_tokens) if num_added > 0: print(f" Added {num_added} RAE phase tokens to tokenizer") return tokenizer, num_added def extract_phases(text: str) -> dict[str, str]: """Extract phase content from RAE-structured text.""" phases = {} for phase_name, (open_tag, close_tag) in PHASE_TAGS.items(): pattern = re.escape(open_tag) + r"(.*?)" + re.escape(close_tag) match = re.search(pattern, text, re.DOTALL) phases[phase_name] = match.group(1).strip() if match else "" return phases def validate_rae_response(text: str) -> dict: """ Validate that a response contains proper RAE structure. Returns a report with: - is_valid: bool - phases_found: list of phase names found - phases_missing: list of phase names missing - compression_ratio: abstraction_len / saturation_len - warnings: list of potential issues """ phases = extract_phases(text) found = [name for name, content in phases.items() if content] missing = [name for name, content in phases.items() if not content] warnings = [] # Check phase ordering if found: expected_order = ["saturation", "abstraction", "descent", "integration"] found_order = [p for p in expected_order if p in found] if found_order != [p for p in found if p in expected_order]: warnings.append("Phases appear out of order") # Check compression compression_ratio = None sat_len = len(phases.get("saturation", "").split()) abs_len = len(phases.get("abstraction", "").split()) if sat_len > 0: compression_ratio = abs_len / sat_len if compression_ratio > 1.0: warnings.append(f"Abstraction is LONGER than Saturation (ratio={compression_ratio:.2f})") # Check for degenerate phases for phase_name, content in phases.items(): word_count = len(content.split()) if content and word_count < 10: warnings.append(f"{phase_name} is very short ({word_count} words)") if content and word_count > 1000: warnings.append(f"{phase_name} is very long ({word_count} words)") return { "is_valid": len(found) == 4 and len(warnings) == 0, "phases_found": found, "phases_missing": missing, "phase_lengths": {name: len(content.split()) for name, content in phases.items()}, "compression_ratio": compression_ratio, "warnings": warnings, } def format_rae_chat( system_prompt: str, user_message: str, phases: dict[str, str], tokenizer=None, ) -> str: """ Format RAE phases into a chat-template-ready message. If tokenizer is provided, applies the chat template. Otherwise returns raw message list. """ assistant_content = "" for phase_name in ["saturation", "abstraction", "descent", "integration"]: open_tag, close_tag = PHASE_TAGS[phase_name] content = phases.get(phase_name, "") assistant_content += f"{open_tag}\n{content}\n{close_tag}\n\n" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_content.strip()}, ] if tokenizer: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) return messages