File size: 5,514 Bytes
bde1c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
from datasets import load_dataset
import traceback
import time


def evaluate_tsac_sentiment(model, tokenizer, device):
    """Evaluate model on TSAC sentiment analysis task"""
    try:
        print("\n=== Starting TSAC sentiment evaluation ===")
        print(f"Current device: {device}")
        
        # Load and preprocess dataset
        print("\nLoading and preprocessing TSAC dataset...")
        dataset = load_dataset("fbougares/tsac", split="test", trust_remote_code=True)
        dataset = dataset.select(range(10))  # Only evaluate on 200 samples

        # print(f"Dataset size: {len(dataset)} examples")
        
        def preprocess(examples):
            return tokenizer(
                examples['sentence'], 
                padding=True, 
                truncation=True, 
                max_length=512,
                return_tensors=None
            )
        print(dataset.column_names)
        dataset = dataset.map(preprocess, batched=True)
        dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'target'])
        
        # Check first example
        first_example = dataset[0]
        print("\nFirst example details:")
        print(f"Input IDs shape: {first_example['input_ids'].shape}")
        print(f"Attention mask shape: {first_example['attention_mask'].shape}")
        print(f"Target: {first_example['target']}")
        
        model.eval()
        print(f"\nModel class: {model.__class__.__name__}")
        print(f"Model device: {next(model.parameters()).device}")
        
        with torch.no_grad():
            predictions = []
            targets = []
            
            # Create DataLoader with batch size 16
            from torch.utils.data import DataLoader
            
            # Define a custom collate function
            def collate_fn(batch):
                input_ids = torch.stack([sample['input_ids'] for sample in batch])
                attention_mask = torch.stack([sample['attention_mask'] for sample in batch])
                targets = torch.stack([sample['target'] for sample in batch])
                return {
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'target': targets
                }


            
            dataloader = DataLoader(
                dataset,
                batch_size=16,
                shuffle=False,
                collate_fn=collate_fn
            )
            
            for i, batch in enumerate(dataloader):
                if i % 10 == 0 :
                    print("\nProcessing first batch...")
                    print(f"Batch keys: {list(batch.keys())}")
                    print(f"Target shape: {batch['target'].shape}")
                
                inputs = {k: v.to(device) for k, v in batch.items() if k != 'target'}
                target = batch['target'].to(device)
                before = time.time()
                outputs = model(**inputs)
                # print(f"\nBatch {i} output type: {type(outputs)}")
                
                # Handle different model output formats
                if isinstance(outputs, dict):
                    # print(f"Output keys: {list(outputs.keys())}")
                    if 'logits' in outputs:
                        logits = outputs['logits']
                    elif 'prediction_logits' in outputs:
                        logits = outputs['prediction_logits']
                    else:
                        raise ValueError(f"Unknown output format. Available keys: {list(outputs.keys())}")
                elif isinstance(outputs, tuple):
                    print(f"Output tuple length: {len(outputs)}")
                    logits = outputs[0]
                else:
                    logits = outputs
                
                # print(f"Logits shape: {logits.shape}")
                
                # For sequence classification, we typically use the [CLS] token's prediction
                if len(logits.shape) == 3:  # [batch_size, sequence_length, num_classes]
                    logits = logits[:, 0, :]  # Take the [CLS] token prediction
                
                # print(f"Final logits shape: {logits.shape}")
                
                batch_predictions = logits.argmax(dim=-1).cpu().tolist()
                batch_targets = target.cpu().tolist()
                
                predictions.extend(batch_predictions)
                targets.extend(batch_targets)
                
                if i % 10 == 0:
                    print("\nFirst batch predictions:")
                    print(f"Predictions: {batch_predictions[:5]}")
                    print(f"Targets: {batch_targets[:5]}")
            
            print(f"\nTotal predictions: {len(predictions)}")
            print(f"Total targets: {len(targets)}")
            
            # Calculate accuracy
            correct = sum(p == t for p, t in zip(predictions, targets))
            total = len(predictions)
            accuracy = correct / total if total > 0 else 0.0
            
            print(f"\nEvaluation results:")
            print(f"Correct predictions: {correct}")
            print(f"Total predictions: {total}")
            print(f"Accuracy: {accuracy:.4f}")
            
            return {"fbougares/tsac": accuracy}
    except Exception as e:
        print(f"\n=== Error in TSAC evaluation: {str(e)} ===")
        print(f"Full traceback: {traceback.format_exc()}")
        raise e