TrueV1sion123 commited on
Commit
9030cc5
Β·
verified Β·
1 Parent(s): 3143539

Upload src/rae_tokenizer_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/rae_tokenizer_utils.py +133 -0
src/rae_tokenizer_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAE Tokenizer Utilities
3
+ ═══════════════════════════════════════════════════════════════
4
+ Phase-aware tokenization for RAE training data.
5
+
6
+ Handles the special structure of RAE responses where XML-style
7
+ phase tags delineate cognitive phases. Ensures proper tokenization
8
+ of phase boundaries and provides utilities for phase-level analysis.
9
+ ═══════════════════════════════════════════════════════════════
10
+ """
11
+
12
+ from typing import Optional
13
+ import re
14
+
15
+
16
+ PHASE_TAGS = {
17
+ "saturation": ("<SATURATION>", "</SATURATION>"),
18
+ "abstraction": ("<ABSTRACTION>", "</ABSTRACTION>"),
19
+ "descent": ("<DESCENT>", "</DESCENT>"),
20
+ "integration": ("<INTEGRATION>", "</INTEGRATION>"),
21
+ }
22
+
23
+ ALL_TAGS = []
24
+ for open_tag, close_tag in PHASE_TAGS.values():
25
+ ALL_TAGS.extend([open_tag, close_tag])
26
+
27
+
28
+ def add_rae_tokens(tokenizer):
29
+ """
30
+ Add RAE phase tags as special tokens to the tokenizer.
31
+
32
+ This ensures phase boundaries are tokenized as single tokens
33
+ rather than being split across subwords, which makes phase
34
+ detection much more reliable during loss computation.
35
+ """
36
+ special_tokens = {"additional_special_tokens": ALL_TAGS}
37
+ num_added = tokenizer.add_special_tokens(special_tokens)
38
+
39
+ if num_added > 0:
40
+ print(f" Added {num_added} RAE phase tokens to tokenizer")
41
+
42
+ return tokenizer, num_added
43
+
44
+
45
+ def extract_phases(text: str) -> dict[str, str]:
46
+ """Extract phase content from RAE-structured text."""
47
+ phases = {}
48
+ for phase_name, (open_tag, close_tag) in PHASE_TAGS.items():
49
+ pattern = re.escape(open_tag) + r"(.*?)" + re.escape(close_tag)
50
+ match = re.search(pattern, text, re.DOTALL)
51
+ phases[phase_name] = match.group(1).strip() if match else ""
52
+ return phases
53
+
54
+
55
+ def validate_rae_response(text: str) -> dict:
56
+ """
57
+ Validate that a response contains proper RAE structure.
58
+
59
+ Returns a report with:
60
+ - is_valid: bool
61
+ - phases_found: list of phase names found
62
+ - phases_missing: list of phase names missing
63
+ - compression_ratio: abstraction_len / saturation_len
64
+ - warnings: list of potential issues
65
+ """
66
+ phases = extract_phases(text)
67
+ found = [name for name, content in phases.items() if content]
68
+ missing = [name for name, content in phases.items() if not content]
69
+
70
+ warnings = []
71
+
72
+ # Check phase ordering
73
+ if found:
74
+ expected_order = ["saturation", "abstraction", "descent", "integration"]
75
+ found_order = [p for p in expected_order if p in found]
76
+ if found_order != [p for p in found if p in expected_order]:
77
+ warnings.append("Phases appear out of order")
78
+
79
+ # Check compression
80
+ compression_ratio = None
81
+ sat_len = len(phases.get("saturation", "").split())
82
+ abs_len = len(phases.get("abstraction", "").split())
83
+ if sat_len > 0:
84
+ compression_ratio = abs_len / sat_len
85
+ if compression_ratio > 1.0:
86
+ warnings.append(f"Abstraction is LONGER than Saturation (ratio={compression_ratio:.2f})")
87
+
88
+ # Check for degenerate phases
89
+ for phase_name, content in phases.items():
90
+ word_count = len(content.split())
91
+ if content and word_count < 10:
92
+ warnings.append(f"{phase_name} is very short ({word_count} words)")
93
+ if content and word_count > 1000:
94
+ warnings.append(f"{phase_name} is very long ({word_count} words)")
95
+
96
+ return {
97
+ "is_valid": len(found) == 4 and len(warnings) == 0,
98
+ "phases_found": found,
99
+ "phases_missing": missing,
100
+ "phase_lengths": {name: len(content.split()) for name, content in phases.items()},
101
+ "compression_ratio": compression_ratio,
102
+ "warnings": warnings,
103
+ }
104
+
105
+
106
+ def format_rae_chat(
107
+ system_prompt: str,
108
+ user_message: str,
109
+ phases: dict[str, str],
110
+ tokenizer=None,
111
+ ) -> str:
112
+ """
113
+ Format RAE phases into a chat-template-ready message.
114
+
115
+ If tokenizer is provided, applies the chat template.
116
+ Otherwise returns raw message list.
117
+ """
118
+ assistant_content = ""
119
+ for phase_name in ["saturation", "abstraction", "descent", "integration"]:
120
+ open_tag, close_tag = PHASE_TAGS[phase_name]
121
+ content = phases.get(phase_name, "")
122
+ assistant_content += f"{open_tag}\n{content}\n{close_tag}\n\n"
123
+
124
+ messages = [
125
+ {"role": "system", "content": system_prompt},
126
+ {"role": "user", "content": user_message},
127
+ {"role": "assistant", "content": assistant_content.strip()},
128
+ ]
129
+
130
+ if tokenizer:
131
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
132
+
133
+ return messages