hamzabouajila's picture
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
raw
history blame
3.45 kB
import torch
from datasets import load_dataset
# from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score
# import argparse
import warnings
warnings.filterwarnings("ignore")
def load_and_prepare_data():
"""Load MADAR-TUN and prepare normalization & transliteration pairs."""
print("Loading MADAR-TUN dataset...")
ds = load_dataset("tunis-ai/MADAR-TUN", split="train")
valid_examples = [
ex for ex in ds
if ex["arabish"] != "<eos>"
and ex["words"] != "<eos>"
and ex["lem"] != "<eos>"
and ex["arabish"] is not None
and ex["arabish"].strip()
and ex["words"] is not None
and ex["words"].strip()
and ex["lem"] is not None
and ex["lem"].strip()
]
print(f"Loaded {len(valid_examples)} valid token entries.")
# Build unique pairs (deduplicate)
norm_pairs = {} # arabish -> canonical lemma
trans_pairs = {} # arabish <-> arabic
for ex in valid_examples:
arabizi = ex["arabish"]
arabic = ex["words"]
lemma = ex["lem"]
# For normalization: use lemma as canonical form
if arabizi not in norm_pairs:
norm_pairs[arabizi] = lemma
if arabizi not in trans_pairs:
trans_pairs[arabizi] = arabic
print(f"Normalization pairs: {len(norm_pairs)}")
print(f"Transliteration pairs: {len(trans_pairs)}")
return norm_pairs, trans_pairs
def evaluate_word_classification(model, tokenizer, word_pairs, device, task_name):
"""
Evaluate word-level classification (normalization or transliteration).
Treats it as closed-vocabulary classification via embedding similarity.
"""
words = list(word_pairs.keys())
targets = list(word_pairs.values())
# Build target vocabulary
unique_targets = sorted(set(targets))
target_to_id = {t: i for i, t in enumerate(unique_targets)}
_target_ids = [target_to_id[t] for t in targets]
print(f"\n[{task_name}] Vocabulary size: {len(unique_targets)}")
print(f"[{task_name}] Evaluation samples: {len(words)}")
# Get embeddings for all target forms
print(f"[{task_name}] Encoding target vocabulary...")
target_encodings = tokenizer(
unique_targets,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt"
).to(device)
with torch.no_grad():
target_embeds = model(**target_encodings).last_hidden_state[:, 0] # [V, H]
# Predict for each input word
predictions = []
batch_size = 32
print(f"[{task_name}] Predicting...")
for i in range(0, len(words), batch_size):
batch_words = words[i:i+batch_size]
inputs = tokenizer(
batch_words,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt"
).to(device)
with torch.no_grad():
word_embeds = model(**inputs).last_hidden_state[:, 0] # [B, H]
logits = torch.matmul(word_embeds, target_embeds.T) # [B, V]
preds = logits.argmax(dim=1).cpu().tolist()
predictions.extend(preds)
# Map back to target IDs
true_labels = [target_to_id[t] for t in targets]
acc = accuracy_score(true_labels, predictions)
print(f"[{task_name}] Accuracy: {acc:.4f}")
return acc