ShawtyIsBad-bgem3 / test_bg /test_script.py
aloobun's picture
Create test_script.py
551eebe verified
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any
from transformers import AutoTokenizer
class HFTokenizerTestSuite:
def __init__(self, model_name: str, test_data_paths: List[str]):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.languages = ['hindi', 'english']
self.edge_cases = {
'hindi': {
'script_test': 'नमस्ते, मैं भारत से हूँ। दिल्ली बहुत बड़ा शहर है।',
'unicode_test': 'हिन्दी १२३४५६७८९ vowels: अ आ इ ई उ ऊ',
'special_chars': 'हिन्दी! @ # $ % ^ & * ( ) _ + = [ ] { }',
},
'english': {
'script_test': 'Hello, I am from the United States. New York is a beautiful city.',
'unicode_test': 'English 0123456789 vowels: a e i o u',
'special_chars': 'English! @ # $ % ^ & * ( ) _ + = [ ] { }',
}
}
self.test_data = self._load_test_data(test_data_paths)
self.results = {
'coverage': {},
'complexity': {},
'language_analysis': {},
'edge_cases': {}
}
def _load_test_data(self, data_paths: List[str]) -> Dict[str, List[str]]:
test_data = {lang: [] for lang in self.languages}
for path in data_paths:
try:
with open(path, 'r', encoding='utf-8') as f:
texts = f.readlines()
for i, text in enumerate(texts):
lang = self.languages[i % len(self.languages)]
test_data[lang].append(text.strip())
except Exception as e:
print(f"Error loading {path}: {e}")
return test_data
def unicode_coverage_analysis(self) -> Dict[str, Any]:
unicode_results = {}
for lang, edge_cases in self.edge_cases.items():
unicode_test = edge_cases['unicode_test']
tokenizer_output = self.tokenizer(unicode_test, return_tensors="pt")
tokens = self.tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
unicode_results[lang] = {
'original_text': unicode_test,
'tokens': tokens,
'token_count': len(tokens),
'unique_tokens': len(set(tokens)),
'coverage_ratio': len(set(tokens)) / len(tokens)
}
self.results['unicode_coverage'] = unicode_results
return unicode_results
def language_specific_edge_cases(self) -> Dict[str, Any]:
edge_case_results = {}
for lang, cases in self.edge_cases.items():
language_results = {}
for case_name, text in cases.items():
try:
tokenizer_output = self.tokenizer(text, return_tensors="pt")
tokens = self.tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
language_results[case_name] = {
'tokens': tokens,
'token_count': len(tokens),
'unique_tokens': len(set(tokens))
}
except Exception as e:
language_results[case_name] = {
'error': str(e)
}
edge_case_results[lang] = language_results
self.results['edge_cases'] = edge_case_results
return edge_case_results
def script_complexity_analysis(self) -> Dict[str, Any]:
complexity_results = {}
for lang in self.languages:
text = self.edge_cases[lang]['script_test']
tokenizer_output = self.tokenizer(text, return_tensors="pt")
tokens = self.tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
# Filter out special tokens for accurate length calculation
filtered_tokens = [token for token in tokens if not token.startswith('[') or not token.endswith(']')]
complexity_results[lang] = {
'original_text_length': len(text),
'tokens': tokens,
'token_count': len(tokens),
'avg_token_length': np.mean([len(token) for token in filtered_tokens]) if filtered_tokens else 0,
'token_diversity': len(set(tokens)) / len(tokens)
}
self.results['script_complexity'] = complexity_results
return complexity_results
def generate_token_histograms(self):
plt.figure(figsize=(15, 10))
for i, lang in enumerate(self.languages):
text = self.test_data[lang][0] if self.test_data[lang] else self.edge_cases[lang]['script_test']
tokenizer_output = self.tokenizer(text, return_tensors="pt")
tokens = self.tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])
# Filter out special tokens
filtered_tokens = [token for token in tokens if not token.startswith('[') or not token.endswith(']')]
token_lengths = [len(token) for token in filtered_tokens]
plt.subplot(len(self.languages), 1, i+1)
plt.hist(token_lengths, bins=range(1, max(token_lengths) + 2), alpha=0.7)
plt.title(f'Token Length Distribution for {lang.capitalize()}')
plt.xlabel('Token Length')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('token_length_histograms.png')
plt.close()
def generate_unicode_visualization(self):
plt.figure(figsize=(15, 10))
unicode_results = self.results.get('unicode_coverage', {})
plt.subplot(2, 1, 1)
token_counts = [result['token_count'] for result in unicode_results.values()]
plt.bar(unicode_results.keys(), token_counts)
plt.title('Token Count in Unicode Test Texts')
plt.xlabel('Language')
plt.ylabel('Number of Tokens')
plt.xticks(rotation=45)
plt.subplot(2, 1, 2)
coverage_ratios = [result['coverage_ratio'] for result in unicode_results.values()]
plt.bar(unicode_results.keys(), coverage_ratios)
plt.title('Token Diversity Ratio')
plt.xlabel('Language')
plt.ylabel('Unique Tokens / Total Tokens')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('unicode_token_analysis.png')
plt.close()
def run_all_tests(self):
print("Running Tokenizer Test Suite for Hindi and English...")
print("1. Unicode Coverage Analysis...")
self.unicode_coverage_analysis()
print("2. Language-Specific Edge Cases...")
self.language_specific_edge_cases()
print("3. Script Complexity Analysis...")
self.script_complexity_analysis()
print("4. Generating Token Histograms...")
self.generate_token_histograms()
print("5. Generating Unicode Visualizations...")
self.generate_unicode_visualization()
print("Test Suite Complete!")
return self.results
if __name__ == "__main__":
MODEL_NAME = "tinycompany/ShawtyIsBad-bgem3"
TEST_DATA_PATHS = [
'./test2.txt'
]
test_suite = HFTokenizerTestSuite(MODEL_NAME, TEST_DATA_PATHS)
results = test_suite.run_all_tests()
with open('result1.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=4)