ataeff commited on
Commit
9d4dd3f
·
verified ·
1 Parent(s): d305f4a

Upload 5 files

Browse files
haze/tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # tests/__init__.py
haze/tests/test_async_modules.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for async haze modules: mathbrain, experts, trauma, subjectivity, cleanup.
3
+ """
4
+
5
+ import pytest
6
+ import asyncio
7
+ import numpy as np
8
+
9
+
10
+ # ============================================================
11
+ # MATHBRAIN TESTS
12
+ # ============================================================
13
+
14
+ class TestMathBrain:
15
+ """Tests for MathBrain field perception."""
16
+
17
+ def test_import(self):
18
+ """MathBrain can be imported."""
19
+ from haze.mathbrain import MathBrain, AsyncMathBrain, FieldPerception
20
+ assert MathBrain is not None
21
+ assert AsyncMathBrain is not None
22
+ assert FieldPerception is not None
23
+
24
+ def test_create_brain(self):
25
+ """Can create MathBrain instance."""
26
+ from haze.mathbrain import MathBrain
27
+ brain = MathBrain()
28
+ assert brain is not None
29
+ assert len(brain.layers) > 0
30
+
31
+ def test_forward_pass(self):
32
+ """Forward pass produces output."""
33
+ from haze.mathbrain import MathBrain
34
+ brain = MathBrain()
35
+ x = np.array([0.5, 0.3, 0.7, 0.1, 0.6])
36
+ output = brain._forward(x)
37
+ assert output.shape == (4,)
38
+ assert np.all(output >= 0) and np.all(output <= 1) # sigmoid output
39
+
40
+ def test_async_perceive(self):
41
+ """Async perception works."""
42
+ from haze.mathbrain import AsyncMathBrain
43
+
44
+ async def run_test():
45
+ brain = AsyncMathBrain()
46
+ perception = await brain.perceive(
47
+ arousal=0.5,
48
+ novelty=0.3,
49
+ entropy=0.7,
50
+ trauma=0.1,
51
+ coherence=0.6,
52
+ )
53
+ assert perception is not None
54
+ assert perception.mood in ["calm", "excited", "focused", "diffuse", "alert"]
55
+ assert 0.4 <= perception.recommended_temp <= 1.2
56
+ assert 0.0 <= perception.identity_weight <= 1.0
57
+
58
+ asyncio.run(run_test())
59
+
60
+ def test_perceive_smooth(self):
61
+ """EMA smoothing works."""
62
+ from haze.mathbrain import AsyncMathBrain
63
+
64
+ async def run_test():
65
+ brain = AsyncMathBrain()
66
+
67
+ # First perception
68
+ p1 = await brain.perceive_smooth(arousal=0.2)
69
+ # Second with different value
70
+ p2 = await brain.perceive_smooth(arousal=0.8)
71
+
72
+ # Smoothed arousal should be between 0.2 and 0.8
73
+ assert 0.2 <= p2.arousal <= 0.8
74
+
75
+ asyncio.run(run_test())
76
+
77
+ def test_hebbian_update(self):
78
+ """Hebbian update modifies weights."""
79
+ from haze.mathbrain import AsyncMathBrain
80
+
81
+ async def run_test():
82
+ brain = AsyncMathBrain()
83
+
84
+ # Perceive first
85
+ await brain.perceive(arousal=0.5)
86
+
87
+ # Get initial weights
88
+ initial_weights = brain.layers[0].weights.copy()
89
+
90
+ # Apply Hebbian update with positive reward
91
+ await brain.hebbian_update(reward=1.0)
92
+
93
+ # Weights should change
94
+ assert not np.allclose(brain.layers[0].weights, initial_weights)
95
+
96
+ asyncio.run(run_test())
97
+
98
+
99
+ # ============================================================
100
+ # EXPERTS TESTS
101
+ # ============================================================
102
+
103
+ class TestExperts:
104
+ """Tests for Resonant Experts (MOE-style routing)."""
105
+
106
+ def test_import(self):
107
+ """Experts can be imported."""
108
+ from haze.experts import route_to_mixture, pulse_to_signals, ExpertMixture
109
+ assert route_to_mixture is not None
110
+ assert pulse_to_signals is not None
111
+ assert ExpertMixture is not None
112
+
113
+ def test_pulse_to_signals(self):
114
+ """Pulse converts to field signals."""
115
+ from haze.experts import pulse_to_signals
116
+ signals = pulse_to_signals(novelty=0.5, arousal=0.3, entropy=0.7)
117
+ # FieldSignals is a dataclass, check attributes
118
+ assert hasattr(signals, 'novelty')
119
+ assert hasattr(signals, 'arousal')
120
+ assert hasattr(signals, 'entropy')
121
+
122
+ def test_route_to_mixture(self):
123
+ """Routing produces mixture of experts."""
124
+ from haze.experts import route_to_mixture, pulse_to_signals
125
+ signals = pulse_to_signals(novelty=0.5, arousal=0.3, entropy=0.7)
126
+ mixture = route_to_mixture(signals)
127
+
128
+ # Should have all 4 experts
129
+ assert 'structural' in mixture.weights
130
+ assert 'semantic' in mixture.weights
131
+ assert 'creative' in mixture.weights
132
+ assert 'precise' in mixture.weights
133
+
134
+ # Weights should sum to ~1
135
+ total = sum(mixture.weights.values())
136
+ assert 0.99 <= total <= 1.01
137
+
138
+ # Temperature should be in valid range
139
+ assert 0.3 <= mixture.temperature <= 1.5
140
+
141
+ def test_high_arousal_boosts_semantic(self):
142
+ """High arousal increases semantic expert weight."""
143
+ from haze.experts import route_to_mixture, pulse_to_signals
144
+
145
+ low_arousal = route_to_mixture(pulse_to_signals(arousal=0.1))
146
+ high_arousal = route_to_mixture(pulse_to_signals(arousal=0.9))
147
+
148
+ # High arousal should boost semantic
149
+ assert high_arousal.weights['semantic'] >= low_arousal.weights['semantic']
150
+
151
+ def test_high_novelty_boosts_creative(self):
152
+ """High novelty affects expert weights."""
153
+ from haze.experts import route_to_mixture, pulse_to_signals
154
+
155
+ low_novelty = route_to_mixture(pulse_to_signals(novelty=0.1))
156
+ high_novelty = route_to_mixture(pulse_to_signals(novelty=0.9))
157
+
158
+ # Both should have valid weights
159
+ assert high_novelty.weights['creative'] > 0
160
+ assert low_novelty.weights['creative'] > 0
161
+
162
+
163
+ # ============================================================
164
+ # TRAUMA TESTS
165
+ # ============================================================
166
+
167
+ class TestTrauma:
168
+ """Tests for Trauma module (identity return)."""
169
+
170
+ def test_import(self):
171
+ """Trauma can be imported."""
172
+ from haze.trauma import AsyncTrauma, TraumaState, get_identity_prefix
173
+ assert AsyncTrauma is not None
174
+ assert TraumaState is not None
175
+ assert get_identity_prefix is not None
176
+
177
+ def test_detect_trauma(self):
178
+ """Trauma detection works."""
179
+ from haze.trauma import AsyncTrauma
180
+
181
+ async def run_test():
182
+ trauma = AsyncTrauma()
183
+
184
+ # Text with bootstrap words should trigger trauma
185
+ state = await trauma.process("The haze resonates with the field pattern")
186
+
187
+ assert state is not None
188
+ assert state.level > 0 # Should detect some trauma
189
+ assert len(state.trigger_words) > 0
190
+
191
+ asyncio.run(run_test())
192
+
193
+ def test_no_trauma_on_neutral(self):
194
+ """Neutral text has low or no trauma."""
195
+ from haze.trauma import AsyncTrauma
196
+
197
+ async def run_test():
198
+ trauma = AsyncTrauma()
199
+
200
+ state = await trauma.process("Hello how are you today")
201
+
202
+ # May return None for neutral text, or low trauma
203
+ if state is not None:
204
+ assert state.level < 0.5
205
+
206
+ asyncio.run(run_test())
207
+
208
+ def test_identity_prefix(self):
209
+ """Identity prefix generation works."""
210
+ from haze.trauma import get_identity_prefix
211
+
212
+ # get_identity_prefix takes no arguments, returns random prefix
213
+ prefix = get_identity_prefix()
214
+
215
+ assert prefix is not None
216
+ # Should contain "haze" or "field"
217
+ assert "haze" in prefix.lower() or "field" in prefix.lower()
218
+
219
+
220
+ # ============================================================
221
+ # CLEANUP TESTS
222
+ # ============================================================
223
+
224
+ class TestCleanup:
225
+ """Tests for cleanup module."""
226
+
227
+ def test_import(self):
228
+ """Cleanup can be imported."""
229
+ from haze.cleanup import cleanup_output
230
+ assert cleanup_output is not None
231
+
232
+ def test_basic_cleanup(self):
233
+ """Basic cleanup works."""
234
+ from haze.cleanup import cleanup_output
235
+ result = cleanup_output(" hello world ")
236
+ assert "hello" in result.lower()
237
+
238
+ def test_contraction_preservation(self):
239
+ """Contractions are preserved."""
240
+ from haze.cleanup import cleanup_output
241
+
242
+ tests = ["I'm", "don't", "they're", "it's", "won't"]
243
+ for t in tests:
244
+ result = cleanup_output(f"{t} here")
245
+ # Should contain some form of apostrophe (ASCII or fancy)
246
+ has_apostrophe = "'" in result or chr(8217) in result
247
+ assert has_apostrophe, f"Failed for {t}: {result}"
248
+
249
+ def test_broken_contraction_fix(self):
250
+ """Broken contractions are fixed."""
251
+ from haze.cleanup import cleanup_output
252
+
253
+ # "I'" + space should become "I'm"
254
+ result = cleanup_output("I' trying")
255
+ has_im = "I'm" in result or "I'm" in result or ("I" in result and "m" in result)
256
+ assert has_im, f"Got: {result}"
257
+
258
+ # "don" + space + verb should become "don't"
259
+ result = cleanup_output("don believe")
260
+ has_dont = "don't" in result or "don't" in result or ("don" in result and "t" in result)
261
+ assert has_dont, f"Got: {result}"
262
+
263
+ def test_heuristic_contraction_fix(self):
264
+ """Heuristic patterns work (-ing, -ed, -en)."""
265
+ from haze.cleanup import cleanup_output
266
+
267
+ # -ing
268
+ result = cleanup_output("don trying")
269
+ assert "don't" in result or "don't" in result
270
+
271
+ # -ed
272
+ result = cleanup_output("don tired")
273
+ assert "don't" in result or "don't" in result
274
+
275
+ # -en
276
+ result = cleanup_output("don forgotten")
277
+ assert "don't" in result or "don't" in result
278
+
279
+ def test_they_re_fix(self):
280
+ """'they re' becomes 'they're'."""
281
+ from haze.cleanup import cleanup_output
282
+
283
+ result = cleanup_output("they re here")
284
+ assert "they're" in result or "they're" in result
285
+
286
+ def test_em_dash_removal(self):
287
+ """Em-dash at start is removed."""
288
+ from haze.cleanup import cleanup_output
289
+
290
+ result = cleanup_output("— Hello there")
291
+ assert not result.startswith("—")
292
+ assert not result.startswith("–")
293
+
294
+
295
+ # ============================================================
296
+ # SUBWORD FIELD TESTS
297
+ # ============================================================
298
+
299
+ class TestSubwordField:
300
+ """Tests for SubwordField (BPE tokenization)."""
301
+
302
+ def test_import(self):
303
+ """SubwordField can be imported."""
304
+ try:
305
+ from haze.subword_field import SubwordField, AsyncSubwordField
306
+ assert SubwordField is not None
307
+ except ImportError:
308
+ pytest.skip("sentencepiece not installed")
309
+
310
+ def test_build_from_corpus(self, tmp_path):
311
+ """Can build field from corpus."""
312
+ try:
313
+ import sentencepiece
314
+ except ImportError:
315
+ pytest.skip("sentencepiece not installed")
316
+
317
+ from haze.subword_field import SubwordField
318
+
319
+ # Create temp corpus with enough data
320
+ corpus = tmp_path / "corpus.txt"
321
+ corpus.write_text("Hello world. I love you. The living room. Don't worry.\n" * 100)
322
+
323
+ # Use smaller vocab size to avoid error
324
+ field = SubwordField.from_corpus(str(corpus), vocab_size=50)
325
+
326
+ assert field is not None
327
+ assert field.vocab is not None
328
+ assert len(field.bigram_counts) > 0
329
+
330
+ def test_generate(self, tmp_path):
331
+ """Generation produces text."""
332
+ try:
333
+ import sentencepiece
334
+ except ImportError:
335
+ pytest.skip("sentencepiece not installed")
336
+
337
+ from haze.subword_field import SubwordField
338
+
339
+ # Create temp corpus with enough data
340
+ corpus = tmp_path / "corpus.txt"
341
+ corpus.write_text("Hello world. I love you. The living room. Don't worry.\n" * 100)
342
+
343
+ # Use smaller vocab size
344
+ field = SubwordField.from_corpus(str(corpus), vocab_size=50)
345
+ result = field.generate("Hello", length=10, temperature=0.7)
346
+
347
+ assert result is not None
348
+ assert len(result) > 0
349
+
350
+
351
+ # ============================================================
352
+ # SUBJECTIVITY TESTS
353
+ # ============================================================
354
+
355
+ class TestSubjectivity:
356
+ """Tests for Subjectivity (no seed from prompt)."""
357
+
358
+ def test_import(self):
359
+ """Subjectivity can be imported."""
360
+ from haze.subjectivity import Subjectivity, AsyncSubjectivity, PulseSnapshot
361
+ assert Subjectivity is not None
362
+ assert AsyncSubjectivity is not None
363
+ assert PulseSnapshot is not None
364
+
365
+ def test_pulse_computation(self):
366
+ """Pulse is computed from input."""
367
+ from haze.subjectivity import Subjectivity
368
+ from haze.haze import Vocab
369
+
370
+ corpus = "Hello world. I love you. The living room."
371
+ vocab = Vocab.from_text(corpus)
372
+ subj = Subjectivity(corpus, vocab)
373
+
374
+ pulse = subj.compute_pulse("AMAZING!!! I LOVE THIS!!!")
375
+
376
+ assert pulse is not None
377
+ assert 0 <= pulse.arousal <= 1
378
+ assert 0 <= pulse.novelty <= 1
379
+ assert 0 <= pulse.entropy <= 1
380
+ # High arousal for exclamation
381
+ assert pulse.arousal > 0.3
382
+
383
+ def test_internal_seed_excludes_prompt(self):
384
+ """Internal seed does NOT contain prompt words."""
385
+ from haze.subjectivity import Subjectivity
386
+ from haze.haze import Vocab
387
+
388
+ corpus = "Hello world. I love you. The living room. Darling sweetheart."
389
+ vocab = Vocab.from_text(corpus)
390
+ subj = Subjectivity(corpus, vocab)
391
+
392
+ prompt = "I love"
393
+ tokens, pulse, seed_text = subj.get_internal_seed(prompt)
394
+
395
+ # Seed should NOT contain "I" or "love"
396
+ seed_words = set(seed_text.lower().split())
397
+ prompt_words = set(prompt.lower().split())
398
+
399
+ overlap = seed_words & prompt_words
400
+ assert len(overlap) == 0, f"Seed contains prompt words: {overlap}"
haze/tests/test_cleanup.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tests for enhanced cleanup.py functionality.
4
+ """
5
+
6
+ import unittest
7
+ from haze.cleanup import (
8
+ cleanup_output,
9
+ cleanup_with_resonance,
10
+ ensure_sentence_boundaries,
11
+ calculate_garbage_score,
12
+ _detect_poetic_repetition,
13
+ _calculate_local_entropy,
14
+ )
15
+
16
+
17
+ class TestBasicCleanup(unittest.TestCase):
18
+ """Test basic cleanup functionality."""
19
+
20
+ def test_punctuation_normalization(self):
21
+ """Test quote and apostrophe normalization."""
22
+ result = cleanup_output("don't say 'hello'")
23
+ # After normalization, we should have fancy apostrophes (U+2019)
24
+ # Check by Unicode codepoint
25
+ self.assertTrue(any(ord(c) == 0x2019 for c in result),
26
+ "Should contain right single quote (U+2019)")
27
+
28
+ def test_repeated_punctuation(self):
29
+ """Test collapsing repeated punctuation."""
30
+ result = cleanup_output("Wait..... really????")
31
+ self.assertEqual(result, "Wait... Really???")
32
+
33
+ def test_capitalize_first_letter(self):
34
+ """Test first letter capitalization."""
35
+ result = cleanup_output("hello world")
36
+ self.assertTrue(result[0].isupper())
37
+
38
+ def test_capitalize_i(self):
39
+ """Test standalone 'i' capitalization."""
40
+ result = cleanup_output("i am here")
41
+ self.assertIn("I am", result)
42
+
43
+
44
+ class TestRepetitionHandling(unittest.TestCase):
45
+ """Test word repetition detection and handling."""
46
+
47
+ def test_double_repetition_removed(self):
48
+ """Test double word repetition is removed."""
49
+ result = cleanup_output("the the house")
50
+ self.assertEqual(result.lower(), "the house.")
51
+
52
+ def test_triple_repetition_removed(self):
53
+ """Test triple word repetition is removed."""
54
+ result = cleanup_output("the the the house")
55
+ self.assertEqual(result.lower(), "the house.")
56
+
57
+ def test_quad_repetition_removed(self):
58
+ """Test 4+ word repetitions are removed."""
59
+ result = cleanup_output("word word word word here")
60
+ self.assertEqual(result.lower(), "word here.")
61
+
62
+ def test_poetic_repetition_preserved(self):
63
+ """Test comma-separated repetition is preserved."""
64
+ result = cleanup_output("Love, love, love in the morning")
65
+ self.assertIn("love, love, love", result.lower())
66
+
67
+ def test_emphatic_repetition_preserved(self):
68
+ """Test emphatic repetition with punctuation is preserved."""
69
+ result = cleanup_output("Never, never, never!")
70
+ # Should preserve the pattern (may capitalize)
71
+ self.assertGreaterEqual(result.lower().count("never"), 3)
72
+
73
+
74
+ class TestContractions(unittest.TestCase):
75
+ """Test contraction handling."""
76
+
77
+ def test_basic_contractions(self):
78
+ """Test basic contraction fixes."""
79
+ cases = [
80
+ ("dont go", "don't go"),
81
+ ("wont work", "won't work"),
82
+ ("cant see", "can't see"),
83
+ ("isnt it", "isn't it"),
84
+ ]
85
+ for input_text, expected_substr in cases:
86
+ result = cleanup_output(input_text)
87
+ self.assertIn(expected_substr, result.lower())
88
+
89
+ def test_advanced_contractions(self):
90
+ """Test advanced/compound contractions."""
91
+ cases = [
92
+ ("would have gone", "would've gone"),
93
+ ("could have been", "could've been"),
94
+ ("should have said", "should've said"),
95
+ ]
96
+ for input_text, expected_substr in cases:
97
+ result = cleanup_output(input_text)
98
+ self.assertIn(expected_substr, result.lower())
99
+
100
+ def test_possessive_vs_contraction(self):
101
+ """Test its vs it's disambiguation."""
102
+ # Should be "it's" (it is)
103
+ result1 = cleanup_output("its going to rain")
104
+ self.assertIn("it's", result1.lower())
105
+
106
+ # Should be "its" (possessive)
107
+ result2 = cleanup_output("its wings spread wide")
108
+ self.assertIn("its wings", result2.lower())
109
+ self.assertNotIn("it's wings", result2.lower())
110
+
111
+
112
+ class TestSentenceStructure(unittest.TestCase):
113
+ """Test sentence structure improvements."""
114
+
115
+ def test_sentence_ending_added(self):
116
+ """Test that missing sentence endings are added."""
117
+ result = cleanup_output("Hello world")
118
+ self.assertTrue(result.endswith('.') or result.endswith('!') or result.endswith('?'))
119
+
120
+ def test_ellipsis_cleanup(self):
121
+ """Test trailing ellipsis is cleaned up."""
122
+ # Trailing ellipsis should be converted to period
123
+ result = cleanup_output("I don't know...")
124
+ self.assertTrue(result.endswith('.'))
125
+ # Should not end with multiple dots
126
+ self.assertFalse(result.endswith('...'))
127
+
128
+ # Mid-sentence ellipsis should be preserved
129
+ result2 = cleanup_output("I don't know... but I think so")
130
+ # Should have proper ending
131
+ self.assertTrue(result2.endswith('.'))
132
+
133
+ def test_capitalize_after_period(self):
134
+ """Test capitalization after period."""
135
+ result = cleanup_output("Hello. world.")
136
+ self.assertIn("Hello. World.", result)
137
+
138
+ def test_run_on_sentences_moderate(self):
139
+ """Test run-on sentence detection in moderate mode."""
140
+ result = cleanup_output("I went there I saw things", mode="moderate")
141
+ # Should have at least 2 sentences now
142
+ self.assertGreaterEqual(result.count('.'), 1)
143
+
144
+
145
+ class TestArtifactCleanup(unittest.TestCase):
146
+ """Test cleanup of generation artifacts."""
147
+
148
+ def test_orphan_apostrophe_fragments(self):
149
+ """Test removal of orphan apostrophe fragments."""
150
+ result = cleanup_output("hello 't there 's world")
151
+ self.assertNotIn(" 't ", result)
152
+ self.assertNotIn(" 's ", result)
153
+
154
+ def test_short_word_validation(self):
155
+ """Test that valid short words are preserved."""
156
+ result = cleanup_output("I go to a place")
157
+ for word in ["I", "go", "to", "a"]:
158
+ self.assertIn(word, result)
159
+
160
+ def test_garbage_score_calculation(self):
161
+ """Test garbage score calculation."""
162
+ clean_text = "Hello, world."
163
+ messy_text = "H..,,ello,,.world..,,"
164
+
165
+ clean_score = calculate_garbage_score(clean_text)
166
+ messy_score = calculate_garbage_score(messy_text)
167
+
168
+ self.assertLess(clean_score, messy_score)
169
+
170
+
171
+ class TestEntropyAndResonance(unittest.TestCase):
172
+ """Test entropy and resonance-aware features."""
173
+
174
+ def test_entropy_calculation(self):
175
+ """Test local entropy calculation."""
176
+ # High entropy (diverse characters)
177
+ high_entropy_text = "abcdefghijklmnop"
178
+ # Low entropy (repetitive)
179
+ low_entropy_text = "aaaaaaaaaaaaaaaa"
180
+
181
+ high_entropy = _calculate_local_entropy(high_entropy_text)
182
+ low_entropy = _calculate_local_entropy(low_entropy_text)
183
+
184
+ self.assertGreater(high_entropy, low_entropy)
185
+
186
+ def test_poetic_repetition_detection(self):
187
+ """Test detection of poetic repetition patterns."""
188
+ text = "Love, love, love in the morning light"
189
+ regions = _detect_poetic_repetition(text)
190
+
191
+ # Should detect at least one comma repetition pattern
192
+ self.assertGreater(len(regions), 0)
193
+
194
+ def test_resonance_aware_cleanup(self):
195
+ """Test resonance-aware cleanup mode selection."""
196
+ text = "Hello the the world"
197
+
198
+ # High resonance, high entropy -> gentle mode
199
+ result1 = cleanup_with_resonance(text, resonance_score=0.8, entropy=3.0)
200
+
201
+ # Low resonance, low entropy -> moderate mode
202
+ result2 = cleanup_with_resonance(text, resonance_score=0.3, entropy=1.2)
203
+
204
+ # Both should fix the repetition, but we're testing mode selection works
205
+ self.assertNotEqual(result1, text)
206
+
207
+
208
+ class TestSentenceBoundaries(unittest.TestCase):
209
+ """Test sentence boundary detection and repair."""
210
+
211
+ def test_ensure_sentence_boundaries(self):
212
+ """Test sentence boundary enforcement."""
213
+ result = ensure_sentence_boundaries("hello world")
214
+ self.assertTrue(result.endswith('.'))
215
+ self.assertTrue(result[0].isupper())
216
+
217
+ def test_fragment_removal(self):
218
+ """Test removal of trailing fragments."""
219
+ result = ensure_sentence_boundaries("Hello world st")
220
+ # Should remove very short trailing fragment
221
+ self.assertNotIn(" st", result)
222
+
223
+ def test_multiple_sentences(self):
224
+ """Test multiple sentence handling."""
225
+ result = ensure_sentence_boundaries("hello. world. testing")
226
+ sentences = result.split('.')
227
+ # Each sentence should start with capital
228
+ for sentence in sentences[:-1]: # Exclude empty last element
229
+ if sentence.strip():
230
+ self.assertTrue(sentence.strip()[0].isupper())
231
+
232
+
233
+ class TestModeVariations(unittest.TestCase):
234
+ """Test different cleanup modes."""
235
+
236
+ def test_gentle_mode(self):
237
+ """Test gentle mode preserves more."""
238
+ text = "hello world the the test"
239
+ result = cleanup_output(text, mode="gentle")
240
+ # Should still fix basic issues
241
+ self.assertTrue(result[0].isupper())
242
+
243
+ def test_moderate_mode(self):
244
+ """Test moderate mode is more aggressive."""
245
+ text = "hello I went there I came back"
246
+ result = cleanup_output(text, mode="moderate")
247
+ # May add sentence breaks
248
+ self.assertGreaterEqual(len(result), len(text) - 5)
249
+
250
+ def test_strict_mode(self):
251
+ """Test strict mode is most aggressive."""
252
+ text = "hello world st"
253
+ result = cleanup_output(text, mode="strict")
254
+ # Should clean up trailing fragments
255
+ self.assertTrue(result.endswith('.'))
256
+
257
+
258
+ class TestEdgeCases(unittest.TestCase):
259
+ """Test edge cases and error handling."""
260
+
261
+ def test_empty_string(self):
262
+ """Test empty string handling."""
263
+ result = cleanup_output("")
264
+ self.assertEqual(result, "")
265
+
266
+ def test_none_input(self):
267
+ """Test None input handling."""
268
+ result = cleanup_output(None)
269
+ self.assertIsNone(result)
270
+
271
+ def test_very_short_text(self):
272
+ """Test very short text handling."""
273
+ result = cleanup_output("hi")
274
+ self.assertEqual(result, "Hi.")
275
+
276
+ def test_only_punctuation(self):
277
+ """Test text with only punctuation."""
278
+ result = cleanup_output("...")
279
+ # Should handle gracefully
280
+ self.assertIsInstance(result, str)
281
+
282
+
283
+ class TestRealWorldExamples(unittest.TestCase):
284
+ """Test with real-world generation examples."""
285
+
286
+ def test_gothic_dialogue(self):
287
+ """Test cleanup of gothic dialogue style."""
288
+ text = "I dont know... the haze the haze settles over everything"
289
+ result = cleanup_output(text)
290
+
291
+ # Should fix contraction
292
+ self.assertIn("don't", result)
293
+ # Should fix repetition
294
+ self.assertEqual(result.lower().count("the haze"), 1)
295
+ # Should have proper ending
296
+ self.assertTrue(result.endswith('.'))
297
+
298
+ def test_mixed_artifacts(self):
299
+ """Test cleanup of mixed artifacts."""
300
+ text = "hello 't its going st the the well"
301
+ result = cleanup_output(text)
302
+
303
+ # Should remove orphan apostrophe
304
+ self.assertNotIn(" 't ", result)
305
+ # Should fix its -> it's
306
+ self.assertIn("it's", result.lower())
307
+ # Should remove repetition
308
+ self.assertEqual(result.lower().count("the"), 1)
309
+
310
+ def test_preserves_style(self):
311
+ """Test that emergent style is preserved."""
312
+ text = "The darkness, the darkness, the darkness calls to me"
313
+ result = cleanup_output(text)
314
+
315
+ # Should preserve emphatic repetition with commas
316
+ self.assertEqual(result.lower().count("the darkness"), 3)
317
+
318
+
319
+ if __name__ == '__main__':
320
+ unittest.main()
haze/tests/test_haze.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # tests/test_haze.py — Tests for haze module
3
+
4
+ import unittest
5
+ import numpy as np
6
+ import sys
7
+ import os
8
+
9
+ # Add parent directory to path
10
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+
12
+ from haze import Vocab, PostGPT, ReweightHead, ContentHead, HybridHead, Block, load_corpus, build_model_from_text
13
+ import nn
14
+
15
+
16
+ class TestVocab(unittest.TestCase):
17
+ """Test Vocab class."""
18
+
19
+ def test_from_text(self):
20
+ """Test vocabulary creation from text."""
21
+ text = "hello world"
22
+ vocab = Vocab.from_text(text)
23
+ self.assertIsInstance(vocab, Vocab)
24
+ self.assertGreater(vocab.vocab_size, 0)
25
+
26
+ def test_encode_decode(self):
27
+ """Test encode/decode round-trip."""
28
+ text = "hello world"
29
+ vocab = Vocab.from_text(text)
30
+ encoded = vocab.encode("hello")
31
+ decoded = vocab.decode(encoded)
32
+ self.assertEqual(decoded, "hello")
33
+
34
+ def test_lowercase_conversion(self):
35
+ """Test that vocab converts to lowercase."""
36
+ text = "Hello World"
37
+ vocab = Vocab.from_text(text)
38
+ encoded = vocab.encode("HELLO")
39
+ decoded = vocab.decode(encoded)
40
+ self.assertEqual(decoded, "hello")
41
+
42
+ def test_unknown_chars(self):
43
+ """Test handling of unknown characters."""
44
+ text = "abc"
45
+ vocab = Vocab.from_text(text)
46
+ # 'x' is not in vocab
47
+ encoded = vocab.encode("x")
48
+ self.assertEqual(len(encoded), 0)
49
+
50
+ def test_vocab_size(self):
51
+ """Test vocab size calculation."""
52
+ text = "aabbcc"
53
+ vocab = Vocab.from_text(text)
54
+ self.assertEqual(vocab.vocab_size, 3) # a, b, c
55
+
56
+
57
+ class TestReweightHead(unittest.TestCase):
58
+ """Test ReweightHead attention."""
59
+
60
+ def setUp(self):
61
+ self.rng = nn.get_rng(42)
62
+ self.n_emb = 16
63
+ self.head_dim = 8
64
+ self.T = 10
65
+ self.head = ReweightHead(self.n_emb, self.head_dim, self.T, self.rng)
66
+
67
+ def test_forward_shape(self):
68
+ """Test forward pass returns correct shape."""
69
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
70
+ out = self.head.forward(x)
71
+ self.assertEqual(out.shape, (self.T, self.head_dim))
72
+
73
+ def test_forward_shorter_sequence(self):
74
+ """Test forward with sequence shorter than T."""
75
+ x = np.random.randn(5, self.n_emb).astype(np.float32)
76
+ out = self.head.forward(x)
77
+ self.assertEqual(out.shape, (5, self.head_dim))
78
+
79
+
80
+ class TestContentHead(unittest.TestCase):
81
+ """Test ContentHead attention."""
82
+
83
+ def setUp(self):
84
+ self.rng = nn.get_rng(42)
85
+ self.n_emb = 16
86
+ self.head_dim = 8
87
+ self.T = 10
88
+ self.head = ContentHead(self.n_emb, self.head_dim, self.T, self.rng)
89
+
90
+ def test_forward_shape(self):
91
+ """Test forward pass returns correct shape."""
92
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
93
+ out = self.head.forward(x)
94
+ self.assertEqual(out.shape, (self.T, self.head_dim))
95
+
96
+ def test_forward_shorter_sequence(self):
97
+ """Test forward with sequence shorter than T."""
98
+ x = np.random.randn(5, self.n_emb).astype(np.float32)
99
+ out = self.head.forward(x)
100
+ self.assertEqual(out.shape, (5, self.head_dim))
101
+
102
+
103
+ class TestHybridHead(unittest.TestCase):
104
+ """Test HybridHead attention."""
105
+
106
+ def setUp(self):
107
+ self.rng = nn.get_rng(42)
108
+ self.n_emb = 16
109
+ self.head_dim = 8
110
+ self.T = 10
111
+ self.head = HybridHead(self.n_emb, self.head_dim, self.T, self.rng, alpha=0.5)
112
+
113
+ def test_forward_shape(self):
114
+ """Test forward pass returns correct shape."""
115
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
116
+ out = self.head.forward(x)
117
+ self.assertEqual(out.shape, (self.T, self.head_dim))
118
+
119
+ def test_alpha_parameter(self):
120
+ """Test alpha parameter is stored."""
121
+ self.assertEqual(self.head.alpha, 0.5)
122
+
123
+
124
+ class TestBlock(unittest.TestCase):
125
+ """Test transformer Block."""
126
+
127
+ def setUp(self):
128
+ self.rng = nn.get_rng(42)
129
+ self.n_emb = 32
130
+ self.T = 10
131
+ self.nodes = 64
132
+
133
+ def test_block_hybrid_forward(self):
134
+ """Test hybrid block forward pass."""
135
+ block = Block(
136
+ self.n_emb, self.T, self.nodes, self.rng,
137
+ n_heads=4, head_type="hybrid"
138
+ )
139
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
140
+ out = block.forward(x)
141
+ self.assertEqual(out.shape, (self.T, self.n_emb))
142
+
143
+ def test_block_rrpram_forward(self):
144
+ """Test RRPRAM-only block forward pass."""
145
+ block = Block(
146
+ self.n_emb, self.T, self.nodes, self.rng,
147
+ n_heads=4, head_type="rrpram"
148
+ )
149
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
150
+ out = block.forward(x)
151
+ self.assertEqual(out.shape, (self.T, self.n_emb))
152
+
153
+ def test_block_reweight_backward_compat(self):
154
+ """Test reweight head_type still works (backwards compat)."""
155
+ block = Block(
156
+ self.n_emb, self.T, self.nodes, self.rng,
157
+ n_heads=4, head_type="reweight"
158
+ )
159
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
160
+ out = block.forward(x)
161
+ self.assertEqual(out.shape, (self.T, self.n_emb))
162
+
163
+ def test_block_content_forward(self):
164
+ """Test content-only block forward pass."""
165
+ block = Block(
166
+ self.n_emb, self.T, self.nodes, self.rng,
167
+ n_heads=4, head_type="content"
168
+ )
169
+ x = np.random.randn(self.T, self.n_emb).astype(np.float32)
170
+ out = block.forward(x)
171
+ self.assertEqual(out.shape, (self.T, self.n_emb))
172
+
173
+
174
+ class TestPostGPT(unittest.TestCase):
175
+ """Test PostGPT model."""
176
+
177
+ def setUp(self):
178
+ self.vocab_size = 20
179
+ self.T = 16
180
+ self.n_emb = 32
181
+ self.model = PostGPT(
182
+ vocab_size=self.vocab_size,
183
+ T=self.T,
184
+ n_emb=self.n_emb,
185
+ nodes=32,
186
+ n_blocks=2,
187
+ n_heads=4,
188
+ head_type="hybrid",
189
+ seed=42,
190
+ )
191
+
192
+ def test_model_initialization(self):
193
+ """Test model initializes correctly."""
194
+ self.assertEqual(self.model.vocab_size, self.vocab_size)
195
+ self.assertEqual(self.model.T, self.T)
196
+ self.assertEqual(self.model.n_emb, self.n_emb)
197
+
198
+ def test_logits_shape(self):
199
+ """Test logits output has correct shape."""
200
+ idx_seq = np.array([0, 1, 2, 3, 4], dtype=np.int32)
201
+ logits = self.model.logits(idx_seq)
202
+ self.assertEqual(logits.shape, (5, self.vocab_size))
203
+
204
+ def test_generate_simple(self):
205
+ """Test simple generation."""
206
+ seed_seq = [0, 1, 2]
207
+ tokens = self.model.generate_simple(seed_seq, length=10, temperature=1.0)
208
+ self.assertEqual(len(tokens), 10)
209
+ # Check all tokens are valid
210
+ for token in tokens:
211
+ self.assertGreaterEqual(token, 0)
212
+ self.assertLess(token, self.vocab_size)
213
+
214
+ def test_generate_with_stats(self):
215
+ """Test generation with statistics."""
216
+ seed_seq = [0, 1, 2]
217
+ tokens, stats = self.model.generate(
218
+ seed_seq,
219
+ length=10,
220
+ temperature=1.0,
221
+ sampling="basic"
222
+ )
223
+ self.assertEqual(len(tokens), 10)
224
+ self.assertIn("mean_entropy", stats)
225
+ self.assertIn("mean_confidence", stats)
226
+ self.assertIn("mean_temp", stats)
227
+
228
+ def test_generate_entropy_sampling(self):
229
+ """Test entropy-aware sampling."""
230
+ seed_seq = [0, 1, 2]
231
+ tokens, stats = self.model.generate(
232
+ seed_seq,
233
+ length=10,
234
+ sampling="entropy",
235
+ target_entropy=2.0
236
+ )
237
+ self.assertEqual(len(tokens), 10)
238
+ self.assertGreater(stats["mean_entropy"], 0)
239
+
240
+ def test_generate_top_k(self):
241
+ """Test top-k sampling."""
242
+ seed_seq = [0, 1, 2]
243
+ tokens, _ = self.model.generate(
244
+ seed_seq,
245
+ length=10,
246
+ sampling="top_k",
247
+ top_k=5,
248
+ temperature=1.0
249
+ )
250
+ self.assertEqual(len(tokens), 10)
251
+
252
+ def test_generate_top_p(self):
253
+ """Test top-p nucleus sampling."""
254
+ seed_seq = [0, 1, 2]
255
+ tokens, _ = self.model.generate(
256
+ seed_seq,
257
+ length=10,
258
+ sampling="top_p",
259
+ top_p=0.9,
260
+ temperature=1.0
261
+ )
262
+ self.assertEqual(len(tokens), 10)
263
+
264
+ def test_generate_mirostat(self):
265
+ """Test mirostat v1 sampling."""
266
+ seed_seq = [0, 1, 2]
267
+ tokens, stats = self.model.generate(
268
+ seed_seq,
269
+ length=10,
270
+ sampling="mirostat",
271
+ target_entropy=2.0,
272
+ mirostat_tau=0.1
273
+ )
274
+ self.assertEqual(len(tokens), 10)
275
+ self.assertIn("mean_entropy", stats)
276
+
277
+ def test_generate_mirostat_v2(self):
278
+ """Test mirostat v2 sampling."""
279
+ seed_seq = [0, 1, 2]
280
+ tokens, stats = self.model.generate(
281
+ seed_seq,
282
+ length=10,
283
+ sampling="mirostat_v2",
284
+ target_entropy=2.0,
285
+ mirostat_tau=0.1
286
+ )
287
+ self.assertEqual(len(tokens), 10)
288
+ self.assertIn("mean_entropy", stats)
289
+
290
+ def test_generate_resonance(self):
291
+ """Test resonance-based sampling."""
292
+ seed_seq = [0, 1, 2]
293
+ tokens, stats = self.model.generate(
294
+ seed_seq,
295
+ length=20,
296
+ sampling="resonance",
297
+ target_resonance=0.7
298
+ )
299
+ self.assertEqual(len(tokens), 20)
300
+ self.assertIn("mean_resonance", stats)
301
+ self.assertGreater(stats["mean_resonance"], 0)
302
+ self.assertLess(stats["mean_resonance"], 1.0)
303
+
304
+ def test_generate_empty_seed(self):
305
+ """Test generation with empty seed."""
306
+ tokens, _ = self.model.generate(
307
+ seed_seq=[],
308
+ length=10,
309
+ temperature=1.0
310
+ )
311
+ self.assertEqual(len(tokens), 10)
312
+
313
+ def test_save_and_load_theweightofhaze(self):
314
+ """Test saving and loading model weights."""
315
+ import tempfile
316
+ with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as f:
317
+ temp_path = f.name
318
+
319
+ try:
320
+ # save weights
321
+ self.model.save_theweightofhaze(temp_path)
322
+ self.assertTrue(os.path.exists(temp_path))
323
+
324
+ # load weights
325
+ loaded_model = PostGPT.theweightofhaze(
326
+ vocab_size=self.vocab_size,
327
+ path=temp_path
328
+ )
329
+
330
+ # verify structure
331
+ self.assertEqual(loaded_model.vocab_size, self.vocab_size)
332
+ self.assertEqual(loaded_model.T, self.T)
333
+ self.assertEqual(loaded_model.n_emb, self.n_emb)
334
+
335
+ # test that loaded model can generate
336
+ tokens, _ = loaded_model.generate(
337
+ seed_seq=[0, 1, 2],
338
+ length=5,
339
+ temperature=1.0
340
+ )
341
+ self.assertEqual(len(tokens), 5)
342
+ finally:
343
+ if os.path.exists(temp_path):
344
+ os.remove(temp_path)
345
+
346
+
347
+ class TestModelVariants(unittest.TestCase):
348
+ """Test different model configurations."""
349
+
350
+ def test_rrpram_only_model(self):
351
+ """Test model with only RRPRAM heads."""
352
+ model = PostGPT(
353
+ vocab_size=20,
354
+ T=16,
355
+ n_emb=32,
356
+ nodes=32,
357
+ n_blocks=2,
358
+ n_heads=4,
359
+ head_type="rrpram",
360
+ seed=42,
361
+ )
362
+ idx_seq = np.array([0, 1, 2], dtype=np.int32)
363
+ logits = model.logits(idx_seq)
364
+ self.assertEqual(logits.shape, (3, 20))
365
+
366
+ def test_reweight_backward_compat(self):
367
+ """Test model with reweight head_type (backwards compat)."""
368
+ model = PostGPT(
369
+ vocab_size=20,
370
+ T=16,
371
+ n_emb=32,
372
+ nodes=32,
373
+ n_blocks=2,
374
+ n_heads=4,
375
+ head_type="reweight",
376
+ seed=42,
377
+ )
378
+ idx_seq = np.array([0, 1, 2], dtype=np.int32)
379
+ logits = model.logits(idx_seq)
380
+ self.assertEqual(logits.shape, (3, 20))
381
+
382
+ def test_content_only_model(self):
383
+ """Test model with only content heads."""
384
+ model = PostGPT(
385
+ vocab_size=20,
386
+ T=16,
387
+ n_emb=32,
388
+ nodes=32,
389
+ n_blocks=2,
390
+ n_heads=4,
391
+ head_type="content",
392
+ seed=42,
393
+ )
394
+ idx_seq = np.array([0, 1, 2], dtype=np.int32)
395
+ logits = model.logits(idx_seq)
396
+ self.assertEqual(logits.shape, (3, 20))
397
+
398
+ def test_hybrid_model(self):
399
+ """Test model with hybrid heads."""
400
+ model = PostGPT(
401
+ vocab_size=20,
402
+ T=16,
403
+ n_emb=32,
404
+ nodes=32,
405
+ n_blocks=2,
406
+ n_heads=4,
407
+ head_type="hybrid",
408
+ alpha=0.7,
409
+ seed=42,
410
+ )
411
+ idx_seq = np.array([0, 1, 2], dtype=np.int32)
412
+ logits = model.logits(idx_seq)
413
+ self.assertEqual(logits.shape, (3, 20))
414
+
415
+
416
+ class TestHelpers(unittest.TestCase):
417
+ """Test helper functions."""
418
+
419
+ def test_load_corpus(self):
420
+ """Test corpus loading."""
421
+ # Create a temporary file
422
+ import tempfile
423
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
424
+ f.write("test corpus")
425
+ temp_path = f.name
426
+
427
+ try:
428
+ corpus = load_corpus(temp_path)
429
+ self.assertEqual(corpus, "test corpus")
430
+ finally:
431
+ os.remove(temp_path)
432
+
433
+ def test_build_model_from_text(self):
434
+ """Test building model from text file."""
435
+ import tempfile
436
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
437
+ f.write("hello world this is a test")
438
+ temp_path = f.name
439
+
440
+ try:
441
+ text, vocab, model = build_model_from_text(
442
+ temp_path,
443
+ T=16,
444
+ n_emb=32,
445
+ nodes=32,
446
+ n_blocks=2,
447
+ n_heads=4,
448
+ )
449
+ self.assertIsInstance(text, str)
450
+ self.assertIsInstance(vocab, Vocab)
451
+ self.assertIsInstance(model, PostGPT)
452
+ self.assertEqual(model.vocab_size, vocab.vocab_size)
453
+ finally:
454
+ os.remove(temp_path)
455
+
456
+
457
+ class TestEndToEnd(unittest.TestCase):
458
+ """End-to-end integration tests."""
459
+
460
+ def test_full_pipeline(self):
461
+ """Test complete text generation pipeline."""
462
+ # Create corpus
463
+ text = "the quick brown fox jumps over the lazy dog"
464
+ vocab = Vocab.from_text(text)
465
+
466
+ # Build model
467
+ model = PostGPT(
468
+ vocab_size=vocab.vocab_size,
469
+ T=16,
470
+ n_emb=32,
471
+ nodes=32,
472
+ n_blocks=2,
473
+ n_heads=4,
474
+ head_type="hybrid",
475
+ seed=42,
476
+ )
477
+
478
+ # Generate text
479
+ seed_text = "the"
480
+ seed_idx = vocab.encode(seed_text)
481
+ tokens, stats = model.generate(
482
+ seed_seq=seed_idx,
483
+ length=20,
484
+ sampling="entropy",
485
+ target_entropy=2.0
486
+ )
487
+
488
+ # Decode
489
+ generated = vocab.decode(tokens)
490
+
491
+ # Verify
492
+ self.assertEqual(len(tokens), 20)
493
+ self.assertIsInstance(generated, str)
494
+ self.assertGreater(len(generated), 0)
495
+ self.assertIn("mean_entropy", stats)
496
+
497
+
498
+ if __name__ == "__main__":
499
+ unittest.main()
haze/tests/test_nn.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # tests/test_nn.py — Tests for nn.py module
3
+
4
+ import unittest
5
+ import numpy as np
6
+ import sys
7
+ import os
8
+
9
+ # Add parent directory to path
10
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+
12
+ import nn
13
+
14
+
15
+ class TestRNG(unittest.TestCase):
16
+ """Test random number generator utilities."""
17
+
18
+ def test_get_rng_no_seed(self):
19
+ """Test RNG creation without seed."""
20
+ rng = nn.get_rng()
21
+ self.assertIsInstance(rng, np.random.Generator)
22
+
23
+ def test_get_rng_with_seed(self):
24
+ """Test RNG creation with seed produces reproducible results."""
25
+ rng1 = nn.get_rng(42)
26
+ rng2 = nn.get_rng(42)
27
+ val1 = rng1.random()
28
+ val2 = rng2.random()
29
+ self.assertEqual(val1, val2)
30
+
31
+
32
+ class TestWeightInit(unittest.TestCase):
33
+ """Test weight initialization functions."""
34
+
35
+ def setUp(self):
36
+ self.rng = nn.get_rng(42)
37
+
38
+ def test_init_weight_shape(self):
39
+ """Test that init_weight returns correct shape."""
40
+ shape = (10, 20)
41
+ w = nn.init_weight(shape, self.rng)
42
+ self.assertEqual(w.shape, shape)
43
+ self.assertEqual(w.dtype, np.float32)
44
+
45
+ def test_init_weight_scale(self):
46
+ """Test that init_weight respects scale parameter."""
47
+ shape = (100, 100)
48
+ scale = 0.01
49
+ w = nn.init_weight(shape, self.rng, scale=scale)
50
+ # Check that std is approximately equal to scale
51
+ self.assertLess(abs(w.std() - scale), 0.01)
52
+
53
+ def test_init_weight_orthogonal_shape(self):
54
+ """Test orthogonal initialization returns correct shape."""
55
+ shape = (10, 20)
56
+ w = nn.init_weight_orthogonal(shape, self.rng)
57
+ self.assertEqual(w.shape, shape)
58
+ self.assertEqual(w.dtype, np.float32)
59
+
60
+
61
+ class TestActivations(unittest.TestCase):
62
+ """Test activation functions."""
63
+
64
+ def test_relu_positive(self):
65
+ """Test ReLU on positive values."""
66
+ x = np.array([1.0, 2.0, 3.0])
67
+ y = nn.relu(x)
68
+ np.testing.assert_array_equal(y, x)
69
+
70
+ def test_relu_negative(self):
71
+ """Test ReLU on negative values."""
72
+ x = np.array([-1.0, -2.0, -3.0])
73
+ y = nn.relu(x)
74
+ np.testing.assert_array_equal(y, np.zeros_like(x))
75
+
76
+ def test_relu_mixed(self):
77
+ """Test ReLU on mixed values."""
78
+ x = np.array([-1.0, 0.0, 1.0])
79
+ y = nn.relu(x)
80
+ np.testing.assert_array_equal(y, np.array([0.0, 0.0, 1.0]))
81
+
82
+ def test_leaky_relu(self):
83
+ """Test leaky ReLU."""
84
+ x = np.array([-1.0, 0.0, 1.0])
85
+ y = nn.leaky_relu(x, alpha=0.01)
86
+ expected = np.array([-0.01, 0.0, 1.0])
87
+ np.testing.assert_array_almost_equal(y, expected)
88
+
89
+ def test_gelu_shape(self):
90
+ """Test GELU preserves shape."""
91
+ x = np.random.randn(10, 20)
92
+ y = nn.gelu(x)
93
+ self.assertEqual(y.shape, x.shape)
94
+
95
+ def test_swish_shape(self):
96
+ """Test Swish preserves shape."""
97
+ x = np.random.randn(10, 20)
98
+ y = nn.swish(x)
99
+ self.assertEqual(y.shape, x.shape)
100
+
101
+ def test_sigmoid_range(self):
102
+ """Test sigmoid output is in [0, 1]."""
103
+ x = np.random.randn(100)
104
+ y = nn.sigmoid(x)
105
+ self.assertTrue(np.all(y >= 0))
106
+ self.assertTrue(np.all(y <= 1))
107
+
108
+ def test_sigmoid_zero(self):
109
+ """Test sigmoid(0) = 0.5."""
110
+ y = nn.sigmoid(np.array([0.0]))
111
+ np.testing.assert_almost_equal(y[0], 0.5)
112
+
113
+ def test_softmax_sum_to_one(self):
114
+ """Test softmax outputs sum to 1."""
115
+ x = np.random.randn(10)
116
+ y = nn.softmax(x)
117
+ np.testing.assert_almost_equal(y.sum(), 1.0)
118
+
119
+ def test_softmax_positive(self):
120
+ """Test softmax outputs are positive."""
121
+ x = np.random.randn(10)
122
+ y = nn.softmax(x)
123
+ self.assertTrue(np.all(y > 0))
124
+
125
+
126
+ class TestNormalization(unittest.TestCase):
127
+ """Test normalization functions."""
128
+
129
+ def test_layer_norm_shape(self):
130
+ """Test layer norm preserves shape."""
131
+ x = np.random.randn(5, 10).astype(np.float32)
132
+ gamma = np.ones(10, dtype=np.float32)
133
+ beta = np.zeros(10, dtype=np.float32)
134
+ y = nn.layer_norm(x, gamma, beta)
135
+ self.assertEqual(y.shape, x.shape)
136
+
137
+ def test_layer_norm_mean_zero(self):
138
+ """Test layer norm produces zero mean."""
139
+ x = np.random.randn(5, 10).astype(np.float32)
140
+ gamma = np.ones(10, dtype=np.float32)
141
+ beta = np.zeros(10, dtype=np.float32)
142
+ y = nn.layer_norm(x, gamma, beta)
143
+ means = y.mean(axis=-1)
144
+ np.testing.assert_array_almost_equal(means, np.zeros(5), decimal=5)
145
+
146
+ def test_rms_norm_shape(self):
147
+ """Test RMS norm preserves shape."""
148
+ x = np.random.randn(5, 10).astype(np.float32)
149
+ gamma = np.ones(10, dtype=np.float32)
150
+ y = nn.rms_norm(x, gamma)
151
+ self.assertEqual(y.shape, x.shape)
152
+
153
+
154
+ class TestSampling(unittest.TestCase):
155
+ """Test sampling functions."""
156
+
157
+ def setUp(self):
158
+ self.rng = nn.get_rng(42)
159
+ self.logits = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
160
+
161
+ def test_sample_basic_greedy(self):
162
+ """Test basic sampling with temperature=0 is greedy."""
163
+ token = nn.sample_basic(self.logits, temperature=0.0, rng=self.rng)
164
+ self.assertEqual(token, 4) # argmax of logits
165
+
166
+ def test_sample_basic_returns_valid_index(self):
167
+ """Test basic sampling returns valid token index."""
168
+ token = nn.sample_basic(self.logits, temperature=1.0, rng=self.rng)
169
+ self.assertIsInstance(token, int)
170
+ self.assertGreaterEqual(token, 0)
171
+ self.assertLess(token, len(self.logits))
172
+
173
+ def test_sample_top_k_greedy(self):
174
+ """Test top-k sampling with temperature=0 is greedy."""
175
+ token = nn.sample_top_k(self.logits, k=3, temperature=0.0, rng=self.rng)
176
+ self.assertEqual(token, 4) # argmax
177
+
178
+ def test_sample_top_k_valid(self):
179
+ """Test top-k sampling returns valid index."""
180
+ token = nn.sample_top_k(self.logits, k=3, temperature=1.0, rng=self.rng)
181
+ self.assertIsInstance(token, int)
182
+ self.assertGreaterEqual(token, 0)
183
+ self.assertLess(token, len(self.logits))
184
+
185
+ def test_sample_top_p_greedy(self):
186
+ """Test top-p sampling with temperature=0 is greedy."""
187
+ token = nn.sample_top_p(self.logits, p=0.9, temperature=0.0, rng=self.rng)
188
+ self.assertEqual(token, 4) # argmax
189
+
190
+ def test_sample_top_p_valid(self):
191
+ """Test top-p sampling returns valid index."""
192
+ token = nn.sample_top_p(self.logits, p=0.9, temperature=1.0, rng=self.rng)
193
+ self.assertIsInstance(token, int)
194
+ self.assertGreaterEqual(token, 0)
195
+ self.assertLess(token, len(self.logits))
196
+
197
+ def test_sample_mirostat_returns_tuple(self):
198
+ """Test mirostat returns (token, new_mu)."""
199
+ result = nn.sample_mirostat(
200
+ self.logits,
201
+ target_entropy=2.0,
202
+ tau=0.1,
203
+ mu=5.0,
204
+ rng=self.rng
205
+ )
206
+ self.assertIsInstance(result, tuple)
207
+ self.assertEqual(len(result), 2)
208
+ token, new_mu = result
209
+ self.assertIsInstance(token, int)
210
+ self.assertIsInstance(new_mu, float)
211
+
212
+ def test_sample_mirostat_v2_returns_tuple(self):
213
+ """Test mirostat v2 returns (token, new_mu)."""
214
+ result = nn.sample_mirostat_v2(
215
+ self.logits,
216
+ target_entropy=2.0,
217
+ tau=0.1,
218
+ mu=5.0,
219
+ rng=self.rng
220
+ )
221
+ self.assertIsInstance(result, tuple)
222
+ self.assertEqual(len(result), 2)
223
+ token, new_mu = result
224
+ self.assertIsInstance(token, int)
225
+ self.assertIsInstance(new_mu, float)
226
+
227
+ def test_sample_mirostat_v2_clips_mu(self):
228
+ """Test mirostat v2 clips mu to reasonable range."""
229
+ target_entropy = 2.0
230
+ result = nn.sample_mirostat_v2(
231
+ self.logits,
232
+ target_entropy=target_entropy,
233
+ tau=0.5,
234
+ mu=100.0, # very high mu
235
+ rng=self.rng
236
+ )
237
+ _, new_mu = result
238
+ # mu should be clipped to reasonable range
239
+ self.assertLessEqual(new_mu, target_entropy * 3.0)
240
+ self.assertGreaterEqual(new_mu, target_entropy * 0.5)
241
+
242
+
243
+ class TestEntropyMetrics(unittest.TestCase):
244
+ """Test entropy and information metrics."""
245
+
246
+ def test_entropy_uniform(self):
247
+ """Test entropy of uniform distribution."""
248
+ probs = np.array([0.25, 0.25, 0.25, 0.25])
249
+ h = nn.entropy(probs)
250
+ # Uniform distribution should have maximum entropy
251
+ self.assertGreater(h, 0)
252
+
253
+ def test_entropy_bits_uniform(self):
254
+ """Test entropy in bits for uniform distribution."""
255
+ probs = np.array([0.25, 0.25, 0.25, 0.25])
256
+ h = nn.entropy_bits(probs)
257
+ # Should be log2(4) = 2 bits
258
+ np.testing.assert_almost_equal(h, 2.0, decimal=5)
259
+
260
+ def test_entropy_deterministic(self):
261
+ """Test entropy of deterministic distribution is near zero."""
262
+ probs = np.array([1.0, 0.0, 0.0, 0.0])
263
+ h = nn.entropy(probs)
264
+ self.assertLess(h, 0.01)
265
+
266
+ def test_perplexity_high_prob(self):
267
+ """Test perplexity for high probability target."""
268
+ logits = np.array([1.0, 5.0, 2.0])
269
+ ppl = nn.perplexity(logits, target_idx=1)
270
+ # High prob target should have low perplexity
271
+ self.assertLess(ppl, 2.0)
272
+
273
+ def test_cross_entropy_positive(self):
274
+ """Test cross entropy is always positive."""
275
+ logits = np.random.randn(10)
276
+ for target in range(len(logits)):
277
+ ce = nn.cross_entropy(logits, target)
278
+ self.assertGreater(ce, 0)
279
+
280
+ def test_kl_divergence_identical(self):
281
+ """Test KL divergence is zero for identical distributions."""
282
+ p = np.array([0.25, 0.25, 0.25, 0.25])
283
+ kl = nn.kl_divergence(p, p)
284
+ self.assertLess(kl, 0.01)
285
+
286
+
287
+ class TestAdaptiveTemperature(unittest.TestCase):
288
+ """Test entropy-aware temperature functions."""
289
+
290
+ def test_entropy_temperature_bounds(self):
291
+ """Test adaptive temperature respects bounds."""
292
+ logits = np.random.randn(10)
293
+ temp = nn.entropy_temperature(
294
+ logits,
295
+ target_entropy=2.0,
296
+ min_temp=0.5,
297
+ max_temp=1.5
298
+ )
299
+ self.assertGreaterEqual(temp, 0.5)
300
+ self.assertLessEqual(temp, 1.5)
301
+
302
+ def test_confidence_score_range(self):
303
+ """Test confidence score is in [0, 1]."""
304
+ logits = np.random.randn(10)
305
+ conf = nn.confidence_score(logits)
306
+ self.assertGreaterEqual(conf, 0.0)
307
+ self.assertLessEqual(conf, 1.0)
308
+
309
+ def test_margin_score_positive(self):
310
+ """Test margin score is positive."""
311
+ logits = np.array([1.0, 5.0, 2.0])
312
+ margin = nn.margin_score(logits)
313
+ self.assertGreater(margin, 0)
314
+
315
+ def test_resonance_temperature_bounds(self):
316
+ """Test resonance temperature respects bounds."""
317
+ logits = np.random.randn(10)
318
+ history = [np.random.randn(10) for _ in range(5)]
319
+ temp = nn.resonance_temperature(
320
+ logits,
321
+ history,
322
+ target_resonance=0.7,
323
+ min_temp=0.5,
324
+ max_temp=1.5
325
+ )
326
+ self.assertGreaterEqual(temp, 0.5)
327
+ self.assertLessEqual(temp, 1.5)
328
+
329
+ def test_resonance_temperature_no_history(self):
330
+ """Test resonance temperature with empty history."""
331
+ logits = np.random.randn(10)
332
+ temp = nn.resonance_temperature(
333
+ logits,
334
+ [],
335
+ target_resonance=0.7,
336
+ min_temp=0.5,
337
+ max_temp=1.5
338
+ )
339
+ # should return mid-point when no history
340
+ self.assertGreater(temp, 0.5)
341
+ self.assertLess(temp, 1.5)
342
+
343
+
344
+ class TestResonanceMetrics(unittest.TestCase):
345
+ """Test resonance metrics."""
346
+
347
+ def test_resonance_score_identical(self):
348
+ """Test resonance is 1 for identical distributions."""
349
+ logits = np.random.randn(10)
350
+ score = nn.resonance_score(logits, logits)
351
+ np.testing.assert_almost_equal(score, 1.0, decimal=5)
352
+
353
+ def test_resonance_score_range(self):
354
+ """Test resonance score is in valid range."""
355
+ logits1 = np.random.randn(10)
356
+ logits2 = np.random.randn(10)
357
+ score = nn.resonance_score(logits1, logits2)
358
+ self.assertGreaterEqual(score, 0.0)
359
+ self.assertLessEqual(score, 1.0)
360
+
361
+ def test_harmonic_mean_positive(self):
362
+ """Test harmonic mean of positive values."""
363
+ values = np.array([1.0, 2.0, 3.0, 4.0])
364
+ hm = nn.harmonic_mean(values)
365
+ self.assertGreater(hm, 0)
366
+ # Harmonic mean should be less than arithmetic mean
367
+ self.assertLess(hm, values.mean())
368
+
369
+
370
+ if __name__ == "__main__":
371
+ unittest.main()