Spaces:
Runtime error
Runtime error
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
| import torch | |
| from datasets import load_dataset | |
| import traceback | |
| import time | |
| def evaluate_tsac_sentiment(model, tokenizer, device): | |
| """Evaluate model on TSAC sentiment analysis task""" | |
| try: | |
| print("\n=== Starting TSAC sentiment evaluation ===") | |
| print(f"Current device: {device}") | |
| # Load and preprocess dataset | |
| print("\nLoading and preprocessing TSAC dataset...") | |
| dataset = load_dataset("fbougares/tsac", split="test", trust_remote_code=True) | |
| dataset = dataset.select(range(10)) # Only evaluate on 200 samples | |
| # print(f"Dataset size: {len(dataset)} examples") | |
| def preprocess(examples): | |
| return tokenizer( | |
| examples['sentence'], | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors=None | |
| ) | |
| print(dataset.column_names) | |
| dataset = dataset.map(preprocess, batched=True) | |
| dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'target']) | |
| # Check first example | |
| first_example = dataset[0] | |
| print("\nFirst example details:") | |
| print(f"Input IDs shape: {first_example['input_ids'].shape}") | |
| print(f"Attention mask shape: {first_example['attention_mask'].shape}") | |
| print(f"Target: {first_example['target']}") | |
| model.eval() | |
| print(f"\nModel class: {model.__class__.__name__}") | |
| print(f"Model device: {next(model.parameters()).device}") | |
| with torch.no_grad(): | |
| predictions = [] | |
| targets = [] | |
| # Create DataLoader with batch size 16 | |
| from torch.utils.data import DataLoader | |
| # Define a custom collate function | |
| def collate_fn(batch): | |
| input_ids = torch.stack([sample['input_ids'] for sample in batch]) | |
| attention_mask = torch.stack([sample['attention_mask'] for sample in batch]) | |
| targets = torch.stack([sample['target'] for sample in batch]) | |
| return { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'target': targets | |
| } | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=16, | |
| shuffle=False, | |
| collate_fn=collate_fn | |
| ) | |
| for i, batch in enumerate(dataloader): | |
| if i % 10 == 0 : | |
| print("\nProcessing first batch...") | |
| print(f"Batch keys: {list(batch.keys())}") | |
| print(f"Target shape: {batch['target'].shape}") | |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'target'} | |
| target = batch['target'].to(device) | |
| before = time.time() | |
| outputs = model(**inputs) | |
| # print(f"\nBatch {i} output type: {type(outputs)}") | |
| # Handle different model output formats | |
| if isinstance(outputs, dict): | |
| # print(f"Output keys: {list(outputs.keys())}") | |
| if 'logits' in outputs: | |
| logits = outputs['logits'] | |
| elif 'prediction_logits' in outputs: | |
| logits = outputs['prediction_logits'] | |
| else: | |
| raise ValueError(f"Unknown output format. Available keys: {list(outputs.keys())}") | |
| elif isinstance(outputs, tuple): | |
| print(f"Output tuple length: {len(outputs)}") | |
| logits = outputs[0] | |
| else: | |
| logits = outputs | |
| # print(f"Logits shape: {logits.shape}") | |
| # For sequence classification, we typically use the [CLS] token's prediction | |
| if len(logits.shape) == 3: # [batch_size, sequence_length, num_classes] | |
| logits = logits[:, 0, :] # Take the [CLS] token prediction | |
| # print(f"Final logits shape: {logits.shape}") | |
| batch_predictions = logits.argmax(dim=-1).cpu().tolist() | |
| batch_targets = target.cpu().tolist() | |
| predictions.extend(batch_predictions) | |
| targets.extend(batch_targets) | |
| if i % 10 == 0: | |
| print("\nFirst batch predictions:") | |
| print(f"Predictions: {batch_predictions[:5]}") | |
| print(f"Targets: {batch_targets[:5]}") | |
| print(f"\nTotal predictions: {len(predictions)}") | |
| print(f"Total targets: {len(targets)}") | |
| # Calculate accuracy | |
| correct = sum(p == t for p, t in zip(predictions, targets)) | |
| total = len(predictions) | |
| accuracy = correct / total if total > 0 else 0.0 | |
| print(f"\nEvaluation results:") | |
| print(f"Correct predictions: {correct}") | |
| print(f"Total predictions: {total}") | |
| print(f"Accuracy: {accuracy:.4f}") | |
| return {"fbougares/tsac": accuracy} | |
| except Exception as e: | |
| print(f"\n=== Error in TSAC evaluation: {str(e)} ===") | |
| print(f"Full traceback: {traceback.format_exc()}") | |
| raise e |