hamzabouajila's picture
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
raw
history blame
5.51 kB
import torch
from datasets import load_dataset
import traceback
import time
def evaluate_tsac_sentiment(model, tokenizer, device):
"""Evaluate model on TSAC sentiment analysis task"""
try:
print("\n=== Starting TSAC sentiment evaluation ===")
print(f"Current device: {device}")
# Load and preprocess dataset
print("\nLoading and preprocessing TSAC dataset...")
dataset = load_dataset("fbougares/tsac", split="test", trust_remote_code=True)
dataset = dataset.select(range(10)) # Only evaluate on 200 samples
# print(f"Dataset size: {len(dataset)} examples")
def preprocess(examples):
return tokenizer(
examples['sentence'],
padding=True,
truncation=True,
max_length=512,
return_tensors=None
)
print(dataset.column_names)
dataset = dataset.map(preprocess, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'target'])
# Check first example
first_example = dataset[0]
print("\nFirst example details:")
print(f"Input IDs shape: {first_example['input_ids'].shape}")
print(f"Attention mask shape: {first_example['attention_mask'].shape}")
print(f"Target: {first_example['target']}")
model.eval()
print(f"\nModel class: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")
with torch.no_grad():
predictions = []
targets = []
# Create DataLoader with batch size 16
from torch.utils.data import DataLoader
# Define a custom collate function
def collate_fn(batch):
input_ids = torch.stack([sample['input_ids'] for sample in batch])
attention_mask = torch.stack([sample['attention_mask'] for sample in batch])
targets = torch.stack([sample['target'] for sample in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'target': targets
}
dataloader = DataLoader(
dataset,
batch_size=16,
shuffle=False,
collate_fn=collate_fn
)
for i, batch in enumerate(dataloader):
if i % 10 == 0 :
print("\nProcessing first batch...")
print(f"Batch keys: {list(batch.keys())}")
print(f"Target shape: {batch['target'].shape}")
inputs = {k: v.to(device) for k, v in batch.items() if k != 'target'}
target = batch['target'].to(device)
before = time.time()
outputs = model(**inputs)
# print(f"\nBatch {i} output type: {type(outputs)}")
# Handle different model output formats
if isinstance(outputs, dict):
# print(f"Output keys: {list(outputs.keys())}")
if 'logits' in outputs:
logits = outputs['logits']
elif 'prediction_logits' in outputs:
logits = outputs['prediction_logits']
else:
raise ValueError(f"Unknown output format. Available keys: {list(outputs.keys())}")
elif isinstance(outputs, tuple):
print(f"Output tuple length: {len(outputs)}")
logits = outputs[0]
else:
logits = outputs
# print(f"Logits shape: {logits.shape}")
# For sequence classification, we typically use the [CLS] token's prediction
if len(logits.shape) == 3: # [batch_size, sequence_length, num_classes]
logits = logits[:, 0, :] # Take the [CLS] token prediction
# print(f"Final logits shape: {logits.shape}")
batch_predictions = logits.argmax(dim=-1).cpu().tolist()
batch_targets = target.cpu().tolist()
predictions.extend(batch_predictions)
targets.extend(batch_targets)
if i % 10 == 0:
print("\nFirst batch predictions:")
print(f"Predictions: {batch_predictions[:5]}")
print(f"Targets: {batch_targets[:5]}")
print(f"\nTotal predictions: {len(predictions)}")
print(f"Total targets: {len(targets)}")
# Calculate accuracy
correct = sum(p == t for p, t in zip(predictions, targets))
total = len(predictions)
accuracy = correct / total if total > 0 else 0.0
print(f"\nEvaluation results:")
print(f"Correct predictions: {correct}")
print(f"Total predictions: {total}")
print(f"Accuracy: {accuracy:.4f}")
return {"fbougares/tsac": accuracy}
except Exception as e:
print(f"\n=== Error in TSAC evaluation: {str(e)} ===")
print(f"Full traceback: {traceback.format_exc()}")
raise e