hamzabouajila's picture
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
import torch
from datasets import load_dataset
import traceback
import time
def evaluate_tsac_sentiment(model, tokenizer, device):
"""Evaluate model on TSAC sentiment analysis task"""
try:
print("\n=== Starting TSAC sentiment evaluation ===")
print(f"Current device: {device}")
# Load and preprocess dataset
print("\nLoading and preprocessing TSAC dataset...")
dataset = load_dataset("fbougares/tsac", split="test", trust_remote_code=True)
dataset = dataset.select(range(10)) # Only evaluate on 200 samples
# print(f"Dataset size: {len(dataset)} examples")
def preprocess(examples):
return tokenizer(
examples['sentence'],
padding=True,
truncation=True,
max_length=512,
return_tensors=None
)
print(dataset.column_names)
dataset = dataset.map(preprocess, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'target'])
# Check first example
first_example = dataset[0]
print("\nFirst example details:")
print(f"Input IDs shape: {first_example['input_ids'].shape}")
print(f"Attention mask shape: {first_example['attention_mask'].shape}")
print(f"Target: {first_example['target']}")
model.eval()
print(f"\nModel class: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")
with torch.no_grad():
predictions = []
targets = []
# Create DataLoader with batch size 16
from torch.utils.data import DataLoader
# Define a custom collate function
def collate_fn(batch):
input_ids = torch.stack([sample['input_ids'] for sample in batch])
attention_mask = torch.stack([sample['attention_mask'] for sample in batch])
targets = torch.stack([sample['target'] for sample in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'target': targets
}
dataloader = DataLoader(
dataset,
batch_size=16,
shuffle=False,
collate_fn=collate_fn
)
for i, batch in enumerate(dataloader):
if i % 10 == 0 :
print("\nProcessing first batch...")
print(f"Batch keys: {list(batch.keys())}")
print(f"Target shape: {batch['target'].shape}")
inputs = {k: v.to(device) for k, v in batch.items() if k != 'target'}
target = batch['target'].to(device)
before = time.time()
outputs = model(**inputs)
# print(f"\nBatch {i} output type: {type(outputs)}")
# Handle different model output formats
if isinstance(outputs, dict):
# print(f"Output keys: {list(outputs.keys())}")
if 'logits' in outputs:
logits = outputs['logits']
elif 'prediction_logits' in outputs:
logits = outputs['prediction_logits']
else:
raise ValueError(f"Unknown output format. Available keys: {list(outputs.keys())}")
elif isinstance(outputs, tuple):
print(f"Output tuple length: {len(outputs)}")
logits = outputs[0]
else:
logits = outputs
# print(f"Logits shape: {logits.shape}")
# For sequence classification, we typically use the [CLS] token's prediction
if len(logits.shape) == 3: # [batch_size, sequence_length, num_classes]
logits = logits[:, 0, :] # Take the [CLS] token prediction
# print(f"Final logits shape: {logits.shape}")
batch_predictions = logits.argmax(dim=-1).cpu().tolist()
batch_targets = target.cpu().tolist()
predictions.extend(batch_predictions)
targets.extend(batch_targets)
if i % 10 == 0:
print("\nFirst batch predictions:")
print(f"Predictions: {batch_predictions[:5]}")
print(f"Targets: {batch_targets[:5]}")
print(f"\nTotal predictions: {len(predictions)}")
print(f"Total targets: {len(targets)}")
# Calculate accuracy
correct = sum(p == t for p, t in zip(predictions, targets))
total = len(predictions)
accuracy = correct / total if total > 0 else 0.0
print(f"\nEvaluation results:")
print(f"Correct predictions: {correct}")
print(f"Total predictions: {total}")
print(f"Accuracy: {accuracy:.4f}")
return {"fbougares/tsac": accuracy}
except Exception as e:
print(f"\n=== Error in TSAC evaluation: {str(e)} ===")
print(f"Full traceback: {traceback.format_exc()}")
raise e