|
|
|
|
|
""" |
|
|
Test script for plot arc classifier |
|
|
""" |
|
|
|
|
|
import json |
|
|
import torch |
|
|
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification |
|
|
|
|
|
def load_tests(): |
|
|
"""Load synthetic test cases""" |
|
|
with open('tests/synthetic_tests.json', 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
def run_tests(): |
|
|
"""Run all synthetic tests""" |
|
|
print("Loading model...") |
|
|
tokenizer = DebertaV2Tokenizer.from_pretrained('.') |
|
|
model = DebertaV2ForSequenceClassification.from_pretrained('.') |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
class_names = ['NONE', 'INTERNAL', 'EXTERNAL', 'BOTH'] |
|
|
class_to_idx = {name: idx for idx, name in enumerate(class_names)} |
|
|
|
|
|
tests = load_tests() |
|
|
|
|
|
correct = 0 |
|
|
total = len(tests) |
|
|
|
|
|
print(f"Running {total} synthetic tests...\n") |
|
|
|
|
|
for i, test in enumerate(tests, 1): |
|
|
text = test['description'] |
|
|
expected = test['expected_class'] |
|
|
expected_idx = class_to_idx[expected] |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probabilities = torch.softmax(outputs.logits, dim=-1) |
|
|
predicted_idx = torch.argmax(probabilities, dim=-1).item() |
|
|
confidence = probabilities[0][predicted_idx].item() |
|
|
|
|
|
predicted = class_names[predicted_idx] |
|
|
is_correct = predicted == expected |
|
|
|
|
|
if is_correct: |
|
|
correct += 1 |
|
|
status = "✅ PASS" |
|
|
else: |
|
|
status = "❌ FAIL" |
|
|
|
|
|
print(f"Test {i:2d}: {status}") |
|
|
print(f" Text: {text[:100]}{'...' if len(text) > 100 else ''}") |
|
|
print(f" Expected: {expected} | Predicted: {predicted} (conf: {confidence:.3f})") |
|
|
print(f" Reasoning: {test['reasoning']}") |
|
|
print() |
|
|
|
|
|
accuracy = correct / total |
|
|
print(f"Results: {correct}/{total} correct ({accuracy:.1%})") |
|
|
|
|
|
return accuracy |
|
|
|
|
|
if __name__ == "__main__": |
|
|
run_tests() |
|
|
|