File size: 2,220 Bytes
2667b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
"""
Test script for plot arc classifier
"""

import json
import torch
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification

def load_tests():
    """Load synthetic test cases"""
    with open('tests/synthetic_tests.json', 'r') as f:
        return json.load(f)

def run_tests():
    """Run all synthetic tests"""
    print("Loading model...")
    tokenizer = DebertaV2Tokenizer.from_pretrained('.')
    model = DebertaV2ForSequenceClassification.from_pretrained('.')
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    class_names = ['NONE', 'INTERNAL', 'EXTERNAL', 'BOTH']
    class_to_idx = {name: idx for idx, name in enumerate(class_names)}
    
    tests = load_tests()
    
    correct = 0
    total = len(tests)
    
    print(f"Running {total} synthetic tests...\n")
    
    for i, test in enumerate(tests, 1):
        text = test['description']
        expected = test['expected_class']
        expected_idx = class_to_idx[expected]
        
        # Predict
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.softmax(outputs.logits, dim=-1)
            predicted_idx = torch.argmax(probabilities, dim=-1).item()
            confidence = probabilities[0][predicted_idx].item()
        
        predicted = class_names[predicted_idx]
        is_correct = predicted == expected
        
        if is_correct:
            correct += 1
            status = "✅ PASS"
        else:
            status = "❌ FAIL"
        
        print(f"Test {i:2d}: {status}")
        print(f"  Text: {text[:100]}{'...' if len(text) > 100 else ''}")
        print(f"  Expected: {expected} | Predicted: {predicted} (conf: {confidence:.3f})")
        print(f"  Reasoning: {test['reasoning']}")
        print()
    
    accuracy = correct / total
    print(f"Results: {correct}/{total} correct ({accuracy:.1%})")
    
    return accuracy

if __name__ == "__main__":
    run_tests()