Spaces:
Runtime error
Runtime error
File size: 3,452 Bytes
bde1c71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
from datasets import load_dataset
# from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score
# import argparse
import warnings
warnings.filterwarnings("ignore")
def load_and_prepare_data():
"""Load MADAR-TUN and prepare normalization & transliteration pairs."""
print("Loading MADAR-TUN dataset...")
ds = load_dataset("tunis-ai/MADAR-TUN", split="train")
valid_examples = [
ex for ex in ds
if ex["arabish"] != "<eos>"
and ex["words"] != "<eos>"
and ex["lem"] != "<eos>"
and ex["arabish"] is not None
and ex["arabish"].strip()
and ex["words"] is not None
and ex["words"].strip()
and ex["lem"] is not None
and ex["lem"].strip()
]
print(f"Loaded {len(valid_examples)} valid token entries.")
# Build unique pairs (deduplicate)
norm_pairs = {} # arabish -> canonical lemma
trans_pairs = {} # arabish <-> arabic
for ex in valid_examples:
arabizi = ex["arabish"]
arabic = ex["words"]
lemma = ex["lem"]
# For normalization: use lemma as canonical form
if arabizi not in norm_pairs:
norm_pairs[arabizi] = lemma
if arabizi not in trans_pairs:
trans_pairs[arabizi] = arabic
print(f"Normalization pairs: {len(norm_pairs)}")
print(f"Transliteration pairs: {len(trans_pairs)}")
return norm_pairs, trans_pairs
def evaluate_word_classification(model, tokenizer, word_pairs, device, task_name):
"""
Evaluate word-level classification (normalization or transliteration).
Treats it as closed-vocabulary classification via embedding similarity.
"""
words = list(word_pairs.keys())
targets = list(word_pairs.values())
# Build target vocabulary
unique_targets = sorted(set(targets))
target_to_id = {t: i for i, t in enumerate(unique_targets)}
_target_ids = [target_to_id[t] for t in targets]
print(f"\n[{task_name}] Vocabulary size: {len(unique_targets)}")
print(f"[{task_name}] Evaluation samples: {len(words)}")
# Get embeddings for all target forms
print(f"[{task_name}] Encoding target vocabulary...")
target_encodings = tokenizer(
unique_targets,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt"
).to(device)
with torch.no_grad():
target_embeds = model(**target_encodings).last_hidden_state[:, 0] # [V, H]
# Predict for each input word
predictions = []
batch_size = 32
print(f"[{task_name}] Predicting...")
for i in range(0, len(words), batch_size):
batch_words = words[i:i+batch_size]
inputs = tokenizer(
batch_words,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt"
).to(device)
with torch.no_grad():
word_embeds = model(**inputs).last_hidden_state[:, 0] # [B, H]
logits = torch.matmul(word_embeds, target_embeds.T) # [B, V]
preds = logits.argmax(dim=1).cpu().tolist()
predictions.extend(preds)
# Map back to target IDs
true_labels = [target_to_id[t] for t in targets]
acc = accuracy_score(true_labels, predictions)
print(f"[{task_name}] Accuracy: {acc:.4f}")
return acc
|