|
|
|
|
|
""" |
|
|
Test script for RetNet Explicitness Classifier |
|
|
Usage: python test_model.py |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import json |
|
|
from transformers import AutoTokenizer |
|
|
from model import ProductionRetNet |
|
|
import time |
|
|
|
|
|
class RetNetExplicitnessClassifier: |
|
|
"""Easy-to-use interface for RetNet explicitness classification""" |
|
|
|
|
|
def __init__(self, model_path=None, device='auto'): |
|
|
"""Initialize the classifier |
|
|
|
|
|
Args: |
|
|
model_path: Path to the trained model file |
|
|
device: Device to run on ('auto', 'cpu', 'cuda', 'mps') |
|
|
""" |
|
|
|
|
|
with open('config.json', 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
if model_path is None: |
|
|
model_path = self.config.get('model_file', 'model.safetensors') |
|
|
|
|
|
|
|
|
if device == 'auto': |
|
|
if torch.cuda.is_available(): |
|
|
self.device = 'cuda' |
|
|
elif torch.backends.mps.is_available(): |
|
|
self.device = 'mps' |
|
|
else: |
|
|
self.device = 'cpu' |
|
|
else: |
|
|
self.device = device |
|
|
|
|
|
print(f"π Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained('gpt2') |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.model = self._load_model(model_path) |
|
|
self.labels = self.config['labels'] |
|
|
|
|
|
def _load_model(self, model_path): |
|
|
"""Load the RetNet model""" |
|
|
model = ProductionRetNet( |
|
|
vocab_size=self.config['vocab_size'], |
|
|
dim=self.config['model_dim'], |
|
|
num_layers=self.config['num_layers'], |
|
|
num_heads=self.config['num_heads'], |
|
|
num_classes=self.config['num_classes'], |
|
|
max_length=self.config['max_length'] |
|
|
) |
|
|
|
|
|
|
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(model_path, device=self.device) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.to(self.device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
def classify(self, text): |
|
|
"""Classify a single text |
|
|
|
|
|
Args: |
|
|
text: Input text to classify |
|
|
|
|
|
Returns: |
|
|
dict: Classification results with label, confidence, and all probabilities |
|
|
""" |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=self.config['max_length'], |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = inputs['input_ids'].to(self.device) |
|
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(input_ids, attention_mask) |
|
|
probabilities = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
probs = probabilities[0].cpu().numpy() |
|
|
pred_id = int(probs.argmax()) |
|
|
confidence = float(probs[pred_id]) |
|
|
|
|
|
return { |
|
|
'text': text, |
|
|
'predicted_class': self.labels[pred_id], |
|
|
'confidence': confidence, |
|
|
'probabilities': { |
|
|
label: float(probs[i]) for i, label in enumerate(self.labels) |
|
|
} |
|
|
} |
|
|
|
|
|
def classify_batch(self, texts): |
|
|
"""Classify multiple texts efficiently |
|
|
|
|
|
Args: |
|
|
texts: List of input texts |
|
|
|
|
|
Returns: |
|
|
list: List of classification results |
|
|
""" |
|
|
results = [] |
|
|
batch_size = 32 |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
batch, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=self.config['max_length'], |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = inputs['input_ids'].to(self.device) |
|
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(input_ids, attention_mask) |
|
|
probabilities = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
for j, text in enumerate(batch): |
|
|
probs = probabilities[j].cpu().numpy() |
|
|
pred_id = int(probs.argmax()) |
|
|
confidence = float(probs[pred_id]) |
|
|
|
|
|
results.append({ |
|
|
'text': text, |
|
|
'predicted_class': self.labels[pred_id], |
|
|
'confidence': confidence, |
|
|
'probabilities': { |
|
|
label: float(probs[k]) for k, label in enumerate(self.labels) |
|
|
} |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def main(): |
|
|
"""Test the RetNet classifier with example texts""" |
|
|
print("π§ͺ Testing RetNet Explicitness Classifier") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
classifier = RetNetExplicitnessClassifier() |
|
|
|
|
|
|
|
|
test_texts = [ |
|
|
|
|
|
"The morning sun cast long shadows across the peaceful meadow as birds sang in the trees.", |
|
|
|
|
|
|
|
|
"She felt a spark of attraction as their eyes met across the crowded room.", |
|
|
|
|
|
|
|
|
"The romance novel described their passionate night together in tasteful detail.", |
|
|
|
|
|
|
|
|
"His hands explored every inch of her naked body as she moaned with pleasure.", |
|
|
|
|
|
|
|
|
"The killer slowly twisted the knife deeper into his victim's chest.", |
|
|
|
|
|
|
|
|
"What the fuck is wrong with you, you goddamn idiot?", |
|
|
|
|
|
|
|
|
"Warning: This content contains explicit sexual material and violence." |
|
|
] |
|
|
|
|
|
print(f"π Testing {len(test_texts)} example texts...\n") |
|
|
|
|
|
|
|
|
print("π Single Text Classification:") |
|
|
print("-" * 40) |
|
|
|
|
|
for i, text in enumerate(test_texts): |
|
|
result = classifier.classify(text) |
|
|
print(f"\n{i+1}. Text: {result['text']}") |
|
|
print(f" Prediction: {result['predicted_class']}") |
|
|
print(f" Confidence: {result['confidence']:.3f}") |
|
|
|
|
|
|
|
|
print(f"\nβ‘ Batch Classification Performance:") |
|
|
print("-" * 40) |
|
|
|
|
|
start_time = time.time() |
|
|
batch_results = classifier.classify_batch(test_texts) |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
texts_per_sec = len(test_texts) / elapsed_time |
|
|
|
|
|
print(f"π Processed {len(test_texts)} texts in {elapsed_time:.3f}s") |
|
|
print(f"π Speed: {texts_per_sec:.1f} texts/second") |
|
|
|
|
|
|
|
|
predictions = [r['predicted_class'] for r in batch_results] |
|
|
pred_counts = {} |
|
|
for pred in predictions: |
|
|
pred_counts[pred] = pred_counts.get(pred, 0) + 1 |
|
|
|
|
|
print(f"\nπ Prediction Distribution:") |
|
|
for label, count in sorted(pred_counts.items()): |
|
|
print(f" {label}: {count}") |
|
|
|
|
|
|
|
|
print(f"\nπ€ Model Information:") |
|
|
print(f" Parameters: {classifier.config['performance']['parameters']:,}") |
|
|
print(f" Holdout F1: {classifier.config['performance']['holdout_macro_f1']:.3f}") |
|
|
print(f" Holdout Accuracy: {classifier.config['performance']['holdout_accuracy']:.3f}") |
|
|
print(f" Training Time: {classifier.config['training']['training_time_hours']:.1f} hours") |
|
|
|
|
|
print(f"\nβ
RetNet classifier test completed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |