Trouter-Library commited on
Commit
04dc2a9
·
verified ·
1 Parent(s): 2c59201

Create inference/test_long_context.py

Browse files
Files changed (1) hide show
  1. inference/test_long_context.py +316 -0
inference/test_long_context.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script to verify 250K context length support
3
+ Tests RoPE scaling and long context handling
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
8
+ import logging
9
+ from typing import Optional
10
+ import time
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class LongContextTester:
17
+ """Test long context capabilities of Helion-OSC"""
18
+
19
+ def __init__(self, model_path: str = "./inference"):
20
+ """
21
+ Initialize tester
22
+
23
+ Args:
24
+ model_path: Path to model inference directory
25
+ """
26
+ self.model_path = model_path
27
+ logger.info("Loading model configuration...")
28
+
29
+ # Load config
30
+ self.config = AutoConfig.from_pretrained(model_path)
31
+
32
+ # Verify context length
33
+ max_pos = self.config.max_position_embeddings
34
+ logger.info(f"Model max position embeddings: {max_pos:,}")
35
+
36
+ if max_pos < 250000:
37
+ logger.warning(f"Context length ({max_pos:,}) is less than 250K!")
38
+ else:
39
+ logger.info(f"✓ Context length supports 250K+ tokens ({max_pos:,})")
40
+
41
+ # Check RoPE scaling
42
+ rope_scaling = getattr(self.config, 'rope_scaling', None)
43
+ rope_theta = getattr(self.config, 'rope_theta', None)
44
+
45
+ if rope_scaling:
46
+ logger.info(f"RoPE Scaling: {rope_scaling}")
47
+ if rope_theta:
48
+ logger.info(f"RoPE Theta: {rope_theta:,}")
49
+
50
+ def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"):
51
+ """Test that tokenizer supports long sequences"""
52
+ logger.info("\n" + "="*80)
53
+ logger.info("TEST 1: Tokenizer Capacity")
54
+ logger.info("="*80)
55
+
56
+ try:
57
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
58
+
59
+ max_length = tokenizer.model_max_length
60
+ logger.info(f"Tokenizer max length: {max_length:,}")
61
+
62
+ if max_length >= 250000:
63
+ logger.info("✓ Tokenizer supports 250K+ tokens")
64
+ else:
65
+ logger.warning(f"✗ Tokenizer max length only {max_length:,}")
66
+
67
+ # Test with a long sequence
68
+ test_tokens = 10000
69
+ test_text = "Hello world! " * (test_tokens // 2)
70
+
71
+ logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...")
72
+ encoded = tokenizer(test_text, return_tensors="pt", truncation=False)
73
+ actual_tokens = encoded['input_ids'].shape[1]
74
+
75
+ logger.info(f"Successfully tokenized {actual_tokens:,} tokens")
76
+ logger.info("✓ Tokenization test passed")
77
+
78
+ return True
79
+
80
+ except Exception as e:
81
+ logger.error(f"✗ Tokenization test failed: {e}")
82
+ return False
83
+
84
+ def test_position_embeddings(self):
85
+ """Test position embedding capacity"""
86
+ logger.info("\n" + "="*80)
87
+ logger.info("TEST 2: Position Embeddings")
88
+ logger.info("="*80)
89
+
90
+ max_pos = self.config.max_position_embeddings
91
+ hidden_size = self.config.hidden_size
92
+
93
+ logger.info(f"Max positions: {max_pos:,}")
94
+ logger.info(f"Hidden size: {hidden_size:,}")
95
+
96
+ # Calculate memory requirement for position embeddings
97
+ if hasattr(self.config, 'rope_theta'):
98
+ logger.info("Using RoPE (Rotary Position Embeddings)")
99
+ logger.info("✓ RoPE scales efficiently to long contexts")
100
+
101
+ # RoPE doesn't store position embeddings, it computes them
102
+ logger.info(f"RoPE Theta: {self.config.rope_theta:,}")
103
+
104
+ if hasattr(self.config, 'rope_scaling'):
105
+ scaling = self.config.rope_scaling
106
+ logger.info(f"RoPE Scaling Configuration:")
107
+ logger.info(f" Type: {scaling.get('type', 'N/A')}")
108
+ logger.info(f" Factor: {scaling.get('factor', 'N/A')}")
109
+
110
+ if scaling.get('factor', 0) >= 32:
111
+ logger.info("✓ RoPE scaling factor supports 250K+ context (32x from 8K base)")
112
+ else:
113
+ logger.warning("✗ RoPE scaling factor may be insufficient")
114
+
115
+ return True
116
+ else:
117
+ # Learned position embeddings
118
+ pos_emb_size = max_pos * hidden_size * 2 # bfloat16
119
+ pos_emb_gb = pos_emb_size / (1024**3)
120
+ logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB")
121
+
122
+ if max_pos >= 250000:
123
+ logger.info("✓ Sufficient position embeddings for 250K context")
124
+ return True
125
+ else:
126
+ logger.warning("✗ Insufficient position embeddings")
127
+ return False
128
+
129
+ def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]):
130
+ """Test attention computation at various lengths"""
131
+ logger.info("\n" + "="*80)
132
+ logger.info("TEST 3: Attention Computation Scaling")
133
+ logger.info("="*80)
134
+
135
+ hidden_size = self.config.hidden_size
136
+ num_heads = self.config.num_attention_heads
137
+ head_dim = hidden_size // num_heads
138
+
139
+ logger.info(f"Attention heads: {num_heads}")
140
+ logger.info(f"Head dimension: {head_dim}")
141
+
142
+ for seq_len in sequence_lengths:
143
+ # Calculate attention matrix size
144
+ # For self-attention: (batch, heads, seq_len, seq_len)
145
+ attn_size = 1 * num_heads * seq_len * seq_len * 2 # bfloat16
146
+ attn_gb = attn_size / (1024**3)
147
+
148
+ logger.info(f"\nSequence length: {seq_len:,} tokens")
149
+ logger.info(f" Attention matrix: {attn_gb:.2f} GB")
150
+
151
+ if seq_len <= 32768:
152
+ logger.info(f" ✓ Manageable size")
153
+ elif seq_len <= 131072:
154
+ logger.info(f" ⚠ Large - may need Flash Attention")
155
+ else:
156
+ logger.info(f" ⚠ Very large - requires optimizations")
157
+
158
+ # Check for Flash Attention support
159
+ use_flash = getattr(self.config, 'use_flash_attention_2', False)
160
+ if use_flash:
161
+ logger.info("\n✓ Flash Attention 2 enabled - efficient for long contexts")
162
+ else:
163
+ logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts")
164
+
165
+ return True
166
+
167
+ def test_memory_requirements(self):
168
+ """Calculate memory requirements for 250K context"""
169
+ logger.info("\n" + "="*80)
170
+ logger.info("TEST 4: Memory Requirements")
171
+ logger.info("="*80)
172
+
173
+ context_length = 250000
174
+ batch_size = 1
175
+ hidden_size = self.config.hidden_size
176
+ num_layers = self.config.num_hidden_layers
177
+
178
+ logger.info(f"Configuration:")
179
+ logger.info(f" Context: {context_length:,} tokens")
180
+ logger.info(f" Batch size: {batch_size}")
181
+ logger.info(f" Hidden size: {hidden_size:,}")
182
+ logger.info(f" Layers: {num_layers}")
183
+
184
+ # Calculate activation memory (rough estimate)
185
+ # Main components: hidden states, attention outputs
186
+ hidden_states_size = batch_size * context_length * hidden_size * 2 # bfloat16
187
+ hidden_states_gb = hidden_states_size / (1024**3)
188
+
189
+ # Per layer
190
+ layer_memory_gb = hidden_states_gb * 2 # rough estimate with attention
191
+ total_activation_gb = layer_memory_gb * num_layers
192
+
193
+ logger.info(f"\nMemory estimates:")
194
+ logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB")
195
+ logger.info(f" Total activation memory: {total_activation_gb:.2f} GB")
196
+ logger.info(f" Model weights: ~349 GB")
197
+ logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB")
198
+
199
+ logger.info(f"\nRecommendations:")
200
+ if total_activation_gb < 50:
201
+ logger.info(" ✓ Should fit on 8x A100 (80GB) GPUs")
202
+ elif total_activation_gb < 100:
203
+ logger.info(" ⚠ May need gradient checkpointing")
204
+ else:
205
+ logger.info(" ⚠ Will need aggressive optimizations (checkpointing, CPU offload)")
206
+
207
+ return True
208
+
209
+ def test_rope_frequencies(self):
210
+ """Test RoPE frequency calculations for long context"""
211
+ logger.info("\n" + "="*80)
212
+ logger.info("TEST 5: RoPE Frequency Analysis")
213
+ logger.info("="*80)
214
+
215
+ rope_theta = getattr(self.config, 'rope_theta', 10000)
216
+ hidden_size = self.config.hidden_size
217
+ num_heads = self.config.num_attention_heads
218
+ head_dim = hidden_size // num_heads
219
+
220
+ logger.info(f"RoPE theta: {rope_theta:,}")
221
+ logger.info(f"Head dimension: {head_dim}")
222
+
223
+ # Calculate frequency range
224
+ # freqs = theta^(-2i/d) for i in [0, d/2]
225
+ min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim)
226
+ max_freq = rope_theta ** 0
227
+
228
+ logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]")
229
+
230
+ # Calculate wavelengths at different frequencies
231
+ wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim))
232
+ for i in range(0, head_dim // 2, head_dim // 8)]
233
+
234
+ logger.info(f"\nWavelengths (in tokens):")
235
+ for i, wl in enumerate(wavelengths):
236
+ logger.info(f" Frequency {i}: {wl:,.0f} tokens")
237
+
238
+ max_wavelength = max(wavelengths)
239
+ if max_wavelength >= 250000:
240
+ logger.info(f"\n✓ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context")
241
+ else:
242
+ logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient")
243
+
244
+ return True
245
+
246
+ def run_all_tests(self):
247
+ """Run all context length tests"""
248
+ logger.info("\n" + "="*80)
249
+ logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE")
250
+ logger.info("="*80)
251
+
252
+ results = {
253
+ "tokenization": self.test_tokenization_capacity(),
254
+ "position_embeddings": self.test_position_embeddings(),
255
+ "attention_scaling": self.test_attention_computation(),
256
+ "memory_requirements": self.test_memory_requirements(),
257
+ "rope_frequencies": self.test_rope_frequencies()
258
+ }
259
+
260
+ # Summary
261
+ logger.info("\n" + "="*80)
262
+ logger.info("TEST SUMMARY")
263
+ logger.info("="*80)
264
+
265
+ for test_name, passed in results.items():
266
+ status = "✓ PASS" if passed else "✗ FAIL"
267
+ logger.info(f"{test_name}: {status}")
268
+
269
+ all_passed = all(results.values())
270
+
271
+ if all_passed:
272
+ logger.info("\n✓ All tests passed - Model supports 250K context length")
273
+ else:
274
+ logger.warning("\n⚠ Some tests failed - Check configuration")
275
+
276
+ return all_passed
277
+
278
+
279
+ def main():
280
+ """Main test script"""
281
+ import argparse
282
+
283
+ parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support")
284
+ parser.add_argument(
285
+ "--model-path",
286
+ type=str,
287
+ default="./inference",
288
+ help="Path to model inference directory"
289
+ )
290
+ parser.add_argument(
291
+ "--test",
292
+ choices=["all", "tokenization", "position", "attention", "memory", "rope"],
293
+ default="all",
294
+ help="Which test to run"
295
+ )
296
+
297
+ args = parser.parse_args()
298
+
299
+ tester = LongContextTester(args.model_path)
300
+
301
+ if args.test == "all":
302
+ tester.run_all_tests()
303
+ elif args.test == "tokenization":
304
+ tester.test_tokenization_capacity()
305
+ elif args.test == "position":
306
+ tester.test_position_embeddings()
307
+ elif args.test == "attention":
308
+ tester.test_attention_computation()
309
+ elif args.test == "memory":
310
+ tester.test_memory_requirements()
311
+ elif args.test == "rope":
312
+ tester.test_rope_frequencies()
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()