Spaces:
Runtime error
Runtime error
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
| import torch | |
| from datasets import load_dataset | |
| # from transformers import AutoTokenizer, AutoModel | |
| from sklearn.metrics import accuracy_score | |
| # import argparse | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def load_and_prepare_data(): | |
| """Load MADAR-TUN and prepare normalization & transliteration pairs.""" | |
| print("Loading MADAR-TUN dataset...") | |
| ds = load_dataset("tunis-ai/MADAR-TUN", split="train") | |
| valid_examples = [ | |
| ex for ex in ds | |
| if ex["arabish"] != "<eos>" | |
| and ex["words"] != "<eos>" | |
| and ex["lem"] != "<eos>" | |
| and ex["arabish"] is not None | |
| and ex["arabish"].strip() | |
| and ex["words"] is not None | |
| and ex["words"].strip() | |
| and ex["lem"] is not None | |
| and ex["lem"].strip() | |
| ] | |
| print(f"Loaded {len(valid_examples)} valid token entries.") | |
| # Build unique pairs (deduplicate) | |
| norm_pairs = {} # arabish -> canonical lemma | |
| trans_pairs = {} # arabish <-> arabic | |
| for ex in valid_examples: | |
| arabizi = ex["arabish"] | |
| arabic = ex["words"] | |
| lemma = ex["lem"] | |
| # For normalization: use lemma as canonical form | |
| if arabizi not in norm_pairs: | |
| norm_pairs[arabizi] = lemma | |
| if arabizi not in trans_pairs: | |
| trans_pairs[arabizi] = arabic | |
| print(f"Normalization pairs: {len(norm_pairs)}") | |
| print(f"Transliteration pairs: {len(trans_pairs)}") | |
| return norm_pairs, trans_pairs | |
| def evaluate_word_classification(model, tokenizer, word_pairs, device, task_name): | |
| """ | |
| Evaluate word-level classification (normalization or transliteration). | |
| Treats it as closed-vocabulary classification via embedding similarity. | |
| """ | |
| words = list(word_pairs.keys()) | |
| targets = list(word_pairs.values()) | |
| # Build target vocabulary | |
| unique_targets = sorted(set(targets)) | |
| target_to_id = {t: i for i, t in enumerate(unique_targets)} | |
| _target_ids = [target_to_id[t] for t in targets] | |
| print(f"\n[{task_name}] Vocabulary size: {len(unique_targets)}") | |
| print(f"[{task_name}] Evaluation samples: {len(words)}") | |
| # Get embeddings for all target forms | |
| print(f"[{task_name}] Encoding target vocabulary...") | |
| target_encodings = tokenizer( | |
| unique_targets, | |
| padding=True, | |
| truncation=True, | |
| max_length=32, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| target_embeds = model(**target_encodings).last_hidden_state[:, 0] # [V, H] | |
| # Predict for each input word | |
| predictions = [] | |
| batch_size = 32 | |
| print(f"[{task_name}] Predicting...") | |
| for i in range(0, len(words), batch_size): | |
| batch_words = words[i:i+batch_size] | |
| inputs = tokenizer( | |
| batch_words, | |
| padding=True, | |
| truncation=True, | |
| max_length=32, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| word_embeds = model(**inputs).last_hidden_state[:, 0] # [B, H] | |
| logits = torch.matmul(word_embeds, target_embeds.T) # [B, V] | |
| preds = logits.argmax(dim=1).cpu().tolist() | |
| predictions.extend(preds) | |
| # Map back to target IDs | |
| true_labels = [target_to_id[t] for t in targets] | |
| acc = accuracy_score(true_labels, predictions) | |
| print(f"[{task_name}] Accuracy: {acc:.4f}") | |
| return acc | |