File size: 18,187 Bytes
2ded9d3
 
81cf36d
 
 
 
2ded9d3
 
81cf36d
 
 
 
 
2ded9d3
81cf36d
2ded9d3
81cf36d
 
 
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ded9d3
 
 
81cf36d
 
2ded9d3
81cf36d
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ded9d3
81cf36d
 
 
 
 
 
 
 
 
 
 
 
 
2ded9d3
 
 
81cf36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ded9d3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
#!/usr/bin/env python3
"""
Comprehensive test suite for Mon tokenizer Hugging Face integration.

This script provides extensive testing for the Mon language tokenizer,
including functionality tests, performance benchmarks, and compatibility checks.
"""

import logging
import time
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from transformers import AutoTokenizer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


class MonTokenizerTester:
    """Comprehensive testing suite for Mon tokenizer."""

    def __init__(self, tokenizer_path: str = "."):
        """
        Initialize the tester.

        Args:
            tokenizer_path: Path to the tokenizer files
        """
        self.tokenizer_path = tokenizer_path
        self.tokenizer = None
        self.test_results = {}

    def load_tokenizer(self) -> bool:
        """
        Load the tokenizer for testing.

        Returns:
            bool: True if tokenizer loaded successfully, False otherwise
        """
        try:
            logger.info(f"Loading tokenizer from: {self.tokenizer_path}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.tokenizer_path,
                local_files_only=True,
                trust_remote_code=False
            )
            
            logger.info(f"✓ Tokenizer loaded successfully")
            logger.info(f"  - Vocabulary size: {self.tokenizer.vocab_size:,}")
            logger.info(f"  - Model max length: {self.tokenizer.model_max_length:,}")
            logger.info(f"  - Tokenizer class: {self.tokenizer.__class__.__name__}")
            
            return True

        except Exception as e:
            logger.error(f"✗ Failed to load tokenizer: {e}")
            return False

    def test_basic_functionality(self) -> bool:
        """
        Test basic tokenizer functionality.

        Returns:
            bool: True if all basic tests pass, False otherwise
        """
        logger.info("=== Testing Basic Functionality ===")

        test_cases = [
            {
                "text": "ဘာသာမန်",
                "description": "Single Mon word",
                "expected_min_tokens": 1
            },
            {
                "text": "ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။",
                "description": "Complete Mon sentence",
                "expected_min_tokens": 3
            },
            {
                "text": "မန်တံဂှ် မံင်ပ္ဍဲ ရးမန် ကဵု ရးသေံ။",
                "description": "Mon geographical text",
                "expected_min_tokens": 3
            },
            {
                "text": "၁၂၃၄၅ ဂတာပ်ခ္ဍာ် ၂၀၂၄ သၞာံ",
                "description": "Mon numerals and dates",
                "expected_min_tokens": 2
            },
            {
                "text": "အရေဝ်ဘာသာမန် ပ္ဍဲလောကဏအ် ဂွံဆဵုကေတ် ပ္ဍဲဍုင်သေံ ကဵု ဍုင်ဗၟာ ရ။",
                "description": "Complex Mon linguistics text",
                "expected_min_tokens": 5
            }
        ]

        passed = 0
        total = len(test_cases)

        for i, test_case in enumerate(test_cases, 1):
            text = test_case["text"]
            description = test_case["description"]
            expected_min_tokens = test_case["expected_min_tokens"]

            try:
                # Test encoding
                start_time = time.time()
                tokens = self.tokenizer(text, return_tensors="pt")
                encoding_time = time.time() - start_time

                # Test decoding
                start_time = time.time()
                decoded = self.tokenizer.decode(
                    tokens["input_ids"][0], 
                    skip_special_tokens=True
                )
                decoding_time = time.time() - start_time

                # Validate results
                token_count = tokens["input_ids"].shape[1]
                round_trip_success = text.strip() == decoded.strip()

                if token_count >= expected_min_tokens and round_trip_success:
                    logger.info(f"✓ Test {i}: {description}")
                    logger.info(f"  Tokens: {token_count}, Encoding: {encoding_time*1000:.2f}ms, "
                               f"Decoding: {decoding_time*1000:.2f}ms")
                    passed += 1
                else:
                    logger.warning(f"⚠ Test {i}: {description}")
                    if token_count < expected_min_tokens:
                        logger.warning(f"  Token count too low: {token_count} < {expected_min_tokens}")
                    if not round_trip_success:
                        logger.warning(f"  Round-trip failed:")
                        logger.warning(f"    Input:  '{text}'")
                        logger.warning(f"    Output: '{decoded}'")

            except Exception as e:
                logger.error(f"✗ Test {i}: {description} - ERROR: {e}")

        success = passed == total
        self.test_results["basic_functionality"] = {
            "passed": passed,
            "total": total,
            "success": success
        }

        logger.info(f"Basic functionality: {passed}/{total} tests passed")
        return success

    def test_special_tokens(self) -> bool:
        """
        Test special token handling.

        Returns:
            bool: True if special token tests pass, False otherwise
        """
        logger.info("=== Testing Special Tokens ===")

        try:
            # Test special token IDs
            special_tokens = {
                "bos_token": self.tokenizer.bos_token,
                "eos_token": self.tokenizer.eos_token,
                "unk_token": self.tokenizer.unk_token,
                "pad_token": self.tokenizer.pad_token,
            }

            special_token_ids = {
                "bos_token_id": self.tokenizer.bos_token_id,
                "eos_token_id": self.tokenizer.eos_token_id,
                "unk_token_id": self.tokenizer.unk_token_id,
                "pad_token_id": self.tokenizer.pad_token_id,
            }

            logger.info("Special tokens:")
            for name, token in special_tokens.items():
                token_id = special_token_ids[f"{name}_id"]
                logger.info(f"  {name}: '{token}' (ID: {token_id})")

            # Test that special tokens are properly handled
            test_text = "ဘာသာမန်"
            tokens_with_special = self.tokenizer(
                test_text,
                add_special_tokens=True,
                return_tensors="pt"
            )
            tokens_without_special = self.tokenizer(
                test_text,
                add_special_tokens=False,
                return_tensors="pt"
            )

            with_special_count = tokens_with_special["input_ids"].shape[1]
            without_special_count = tokens_without_special["input_ids"].shape[1]

            if with_special_count > without_special_count:
                logger.info("✓ Special tokens are properly added")
                success = True
            else:
                logger.warning("⚠ Special tokens may not be properly added")
                success = False

            self.test_results["special_tokens"] = {"success": success}
            return success

        except Exception as e:
            logger.error(f"✗ Special token test failed: {e}")
            self.test_results["special_tokens"] = {"success": False}
            return False

    def test_edge_cases(self) -> bool:
        """
        Test edge cases and error handling.

        Returns:
            bool: True if edge case tests pass, False otherwise
        """
        logger.info("=== Testing Edge Cases ===")

        edge_cases = [
            ("", "Empty string"),
            ("   ", "Whitespace only"),
            ("a", "Single ASCII character"),
            ("123", "Numbers only"),
            ("!@#$%", "Special characters only"),
            ("ဘာသာမန်" * 100, "Very long text"),
            ("ဟ", "Single Mon character"),
            ("၀၁၂၃၄၅၆၇၈၉", "Mon numerals"),
        ]

        passed = 0
        total = len(edge_cases)

        for text, description in edge_cases:
            try:
                tokens = self.tokenizer(text, return_tensors="pt")
                decoded = self.tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=True)
                
                # For edge cases, we mainly check that no errors occur
                logger.info(f"✓ {description}: {tokens['input_ids'].shape[1]} tokens")
                passed += 1

            except Exception as e:
                logger.error(f"✗ {description}: {e}")

        success = passed == total
        self.test_results["edge_cases"] = {
            "passed": passed,
            "total": total,
            "success": success
        }

        logger.info(f"Edge cases: {passed}/{total} tests passed")
        return success

    def test_performance_benchmark(self) -> bool:
        """
        Run performance benchmarks.

        Returns:
            bool: True if performance is acceptable, False otherwise
        """
        logger.info("=== Performance Benchmark ===")

        # Test texts of varying lengths
        test_texts = [
            "ဘာသာမန်",
            "ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။",
            ("အရေဝ်ဘာသာမန် ပ္ဍဲလောကဏအ် ဂွံဆဵုကေတ် ပ္ဍဲဍုင်သေံ ကဵု ဍုင်ဗၟာ ရ။ " * 10),
            ("မန်တံဂှ် မံင်ပ္ဍဲ ရးမန် ကဵု ရးသေံ။ " * 50),
        ]

        benchmark_results = []

        for i, text in enumerate(test_texts, 1):
            char_count = len(text)
            
            # Benchmark encoding
            start_time = time.time()
            for _ in range(10):  # Run 10 times for average
                tokens = self.tokenizer(text, return_tensors="pt")
            encoding_time = (time.time() - start_time) / 10

            # Benchmark decoding
            start_time = time.time()
            for _ in range(10):  # Run 10 times for average
                decoded = self.tokenizer.decode(tokens["input_ids"][0])
            decoding_time = (time.time() - start_time) / 10

            token_count = tokens["input_ids"].shape[1]
            
            result = {
                "text_length": char_count,
                "token_count": token_count,
                "encoding_time": encoding_time,
                "decoding_time": decoding_time,
                "chars_per_second": char_count / encoding_time if encoding_time > 0 else 0,
                "tokens_per_second": token_count / decoding_time if decoding_time > 0 else 0
            }
            
            benchmark_results.append(result)
            
            logger.info(f"Text {i} ({char_count} chars, {token_count} tokens):")
            logger.info(f"  Encoding: {encoding_time*1000:.2f}ms ({result['chars_per_second']:.0f} chars/s)")
            logger.info(f"  Decoding: {decoding_time*1000:.2f}ms ({result['tokens_per_second']:.0f} tokens/s)")

        # Check if performance is acceptable (very lenient thresholds)
        avg_encoding_time = sum(r["encoding_time"] for r in benchmark_results) / len(benchmark_results)
        avg_decoding_time = sum(r["decoding_time"] for r in benchmark_results) / len(benchmark_results)

        success = avg_encoding_time < 1.0 and avg_decoding_time < 1.0  # Less than 1 second average

        self.test_results["performance"] = {
            "avg_encoding_time": avg_encoding_time,
            "avg_decoding_time": avg_decoding_time,
            "success": success,
            "details": benchmark_results
        }

        logger.info(f"Performance benchmark: {'PASSED' if success else 'FAILED'}")
        return success

    def test_compatibility(self) -> bool:
        """
        Test compatibility with transformers ecosystem.

        Returns:
            bool: True if compatibility tests pass, False otherwise
        """
        logger.info("=== Testing Compatibility ===")

        try:
            # Test tensor types
            text = "ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။"
            
            # Test different return types
            tokens_pt = self.tokenizer(text, return_tensors="pt")
            tokens_list = self.tokenizer(text, return_tensors=None)
            
            logger.info("✓ PyTorch tensor support")
            logger.info("✓ List output support")

            # Test padding and truncation
            texts = [
                "ဘာသာမန်",
                "ဘာသာမန် ပရူပရာတံဂှ် ကၠောန်ဗဒှ်လဝ်ရ။",
                "မန်တံဂှ် မံင်ပ္ဍဲ ရးမန် ကဵု ရးသေံ။"
            ]
            
            # Test batch processing
            batch_tokens = self.tokenizer(
                texts, 
                padding=True, 
                truncation=True, 
                return_tensors="pt"
            )
            
            logger.info(f"✓ Batch processing: {batch_tokens['input_ids'].shape}")

            # Test attention masks
            if "attention_mask" in batch_tokens:
                logger.info("✓ Attention mask generation")
            else:
                logger.warning("⚠ No attention mask generated")

            success = True

        except Exception as e:
            logger.error(f"✗ Compatibility test failed: {e}")
            success = False

        self.test_results["compatibility"] = {"success": success}
        return success

    def run_all_tests(self) -> bool:
        """
        Run all test suites.

        Returns:
            bool: True if all tests pass, False otherwise
        """
        logger.info("🚀 Starting Mon Tokenizer Test Suite")
        logger.info("=" * 50)

        # Load tokenizer
        if not self.load_tokenizer():
            return False

        # Run all test suites
        test_suites = [
            ("Basic Functionality", self.test_basic_functionality),
            ("Special Tokens", self.test_special_tokens),
            ("Edge Cases", self.test_edge_cases),
            ("Performance Benchmark", self.test_performance_benchmark),
            ("Compatibility", self.test_compatibility),
        ]

        results = []
        for suite_name, test_func in test_suites:
            logger.info(f"\n--- {suite_name} ---")
            success = test_func()
            results.append((suite_name, success))
            logger.info(f"{suite_name}: {'✅ PASSED' if success else '❌ FAILED'}")

        # Summary
        logger.info("\n" + "=" * 50)
        logger.info("📊 TEST SUMMARY")
        logger.info("=" * 50)

        passed_suites = sum(1 for _, success in results if success)
        total_suites = len(results)

        for suite_name, success in results:
            status = "✅ PASSED" if success else "❌ FAILED"
            logger.info(f"{suite_name}: {status}")

        overall_success = passed_suites == total_suites
        logger.info(f"\nOverall Result: {passed_suites}/{total_suites} test suites passed")
        
        if overall_success:
            logger.info("🎉 ALL TESTS PASSED! Tokenizer is ready for production.")
        else:
            logger.error("⚠️  Some tests failed. Please review the issues above.")

        return overall_success

    def generate_test_report(self) -> str:
        """
        Generate a detailed test report.

        Returns:
            str: Formatted test report
        """
        if not self.test_results:
            return "No test results available. Run tests first."

        report = ["# Mon Tokenizer Test Report", ""]
        
        for test_name, result in self.test_results.items():
            report.append(f"## {test_name.replace('_', ' ').title()}")
            
            if isinstance(result, dict) and "success" in result:
                status = "✅ PASSED" if result["success"] else "❌ FAILED"
                report.append(f"Status: {status}")
                
                if "passed" in result and "total" in result:
                    report.append(f"Tests: {result['passed']}/{result['total']}")
                
            report.append("")

        return "\n".join(report)


def main():
    """Main entry point for the test script."""
    import argparse

    parser = argparse.ArgumentParser(
        description="Test Mon tokenizer Hugging Face integration"
    )
    parser.add_argument(
        "--tokenizer-path",
        default=".",
        help="Path to tokenizer files (default: current directory)",
    )
    parser.add_argument(
        "--report",
        action="store_true",
        help="Generate detailed test report",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Enable verbose logging",
    )

    args = parser.parse_args()

    if args.verbose:
        logging.getLogger().setLevel(logging.DEBUG)

    # Create tester and run tests
    tester = MonTokenizerTester(tokenizer_path=args.tokenizer_path)
    success = tester.run_all_tests()

    # Generate report if requested
    if args.report:
        report = tester.generate_test_report()
        report_path = Path("test_report.md")
        with open(report_path, "w", encoding="utf-8") as f:
            f.write(report)
        logger.info(f"Test report saved to: {report_path}")

    exit(0 if success else 1)


if __name__ == "__main__":
    main()