Mitchins's picture
Upload folder using huggingface_hub
54097f9 verified
#!/usr/bin/env python3
"""
Test script for RetNet Explicitness Classifier
Usage: python test_model.py
"""
import torch
import torch.nn.functional as F
import json
from transformers import AutoTokenizer
from model import ProductionRetNet
import time
class RetNetExplicitnessClassifier:
"""Easy-to-use interface for RetNet explicitness classification"""
def __init__(self, model_path=None, device='auto'):
"""Initialize the classifier
Args:
model_path: Path to the trained model file
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
"""
# Load config
with open('config.json', 'r') as f:
self.config = json.load(f)
# Auto-detect model path from config if not provided
if model_path is None:
model_path = self.config.get('model_file', 'model.safetensors')
# Auto device selection
if device == 'auto':
if torch.cuda.is_available():
self.device = 'cuda'
elif torch.backends.mps.is_available():
self.device = 'mps'
else:
self.device = 'cpu'
else:
self.device = device
print(f"πŸš€ Using device: {self.device}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('gpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = self._load_model(model_path)
self.labels = self.config['labels']
def _load_model(self, model_path):
"""Load the RetNet model"""
model = ProductionRetNet(
vocab_size=self.config['vocab_size'],
dim=self.config['model_dim'],
num_layers=self.config['num_layers'],
num_heads=self.config['num_heads'],
num_classes=self.config['num_classes'],
max_length=self.config['max_length']
)
# Load trained weights
from safetensors.torch import load_file
state_dict = load_file(model_path, device=self.device)
model.load_state_dict(state_dict)
model.to(self.device)
model.eval()
return model
def classify(self, text):
"""Classify a single text
Args:
text: Input text to classify
Returns:
dict: Classification results with label, confidence, and all probabilities
"""
# Tokenize
inputs = self.tokenizer(
text,
truncation=True,
padding=True,
max_length=self.config['max_length'],
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
# Predict
with torch.no_grad():
logits = self.model(input_ids, attention_mask)
probabilities = F.softmax(logits, dim=-1)
# Get results
probs = probabilities[0].cpu().numpy()
pred_id = int(probs.argmax())
confidence = float(probs[pred_id])
return {
'text': text, # Keep full text for fun-stats display
'predicted_class': self.labels[pred_id],
'confidence': confidence,
'probabilities': {
label: float(probs[i]) for i, label in enumerate(self.labels)
}
}
def classify_batch(self, texts):
"""Classify multiple texts efficiently
Args:
texts: List of input texts
Returns:
list: List of classification results
"""
results = []
batch_size = 32
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
# Tokenize batch
inputs = self.tokenizer(
batch,
truncation=True,
padding=True,
max_length=self.config['max_length'],
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
# Predict
with torch.no_grad():
logits = self.model(input_ids, attention_mask)
probabilities = F.softmax(logits, dim=-1)
# Process results
for j, text in enumerate(batch):
probs = probabilities[j].cpu().numpy()
pred_id = int(probs.argmax())
confidence = float(probs[pred_id])
results.append({
'text': text, # Keep full text for fun-stats display
'predicted_class': self.labels[pred_id],
'confidence': confidence,
'probabilities': {
label: float(probs[k]) for k, label in enumerate(self.labels)
}
})
return results
def main():
"""Test the RetNet classifier with example texts"""
print("πŸ§ͺ Testing RetNet Explicitness Classifier")
print("=" * 60)
# Initialize classifier
classifier = RetNetExplicitnessClassifier()
# Test examples covering different categories
test_texts = [
# NON-EXPLICIT
"The morning sun cast long shadows across the peaceful meadow as birds sang in the trees.",
# SUGGESTIVE
"She felt a spark of attraction as their eyes met across the crowded room.",
# SEXUAL-REFERENCE
"The romance novel described their passionate night together in tasteful detail.",
# EXPLICIT-SEXUAL
"His hands explored every inch of her naked body as she moaned with pleasure.",
# EXPLICIT-VIOLENT
"The killer slowly twisted the knife deeper into his victim's chest.",
# EXPLICIT-OFFENSIVE
"What the fuck is wrong with you, you goddamn idiot?",
# EXPLICIT-DISCLAIMER
"Warning: This content contains explicit sexual material and violence."
]
print(f"πŸ“Š Testing {len(test_texts)} example texts...\n")
# Single text classification
print("πŸ” Single Text Classification:")
print("-" * 40)
for i, text in enumerate(test_texts):
result = classifier.classify(text)
print(f"\n{i+1}. Text: {result['text']}")
print(f" Prediction: {result['predicted_class']}")
print(f" Confidence: {result['confidence']:.3f}")
# Batch classification with timing
print(f"\n⚑ Batch Classification Performance:")
print("-" * 40)
start_time = time.time()
batch_results = classifier.classify_batch(test_texts)
elapsed_time = time.time() - start_time
texts_per_sec = len(test_texts) / elapsed_time
print(f"πŸ“ˆ Processed {len(test_texts)} texts in {elapsed_time:.3f}s")
print(f"πŸš€ Speed: {texts_per_sec:.1f} texts/second")
# Show prediction distribution
predictions = [r['predicted_class'] for r in batch_results]
pred_counts = {}
for pred in predictions:
pred_counts[pred] = pred_counts.get(pred, 0) + 1
print(f"\nπŸ“Š Prediction Distribution:")
for label, count in sorted(pred_counts.items()):
print(f" {label}: {count}")
# Model info
print(f"\nπŸ€– Model Information:")
print(f" Parameters: {classifier.config['performance']['parameters']:,}")
print(f" Holdout F1: {classifier.config['performance']['holdout_macro_f1']:.3f}")
print(f" Holdout Accuracy: {classifier.config['performance']['holdout_accuracy']:.3f}")
print(f" Training Time: {classifier.config['training']['training_time_hours']:.1f} hours")
print(f"\nβœ… RetNet classifier test completed!")
if __name__ == "__main__":
main()