#!/usr/bin/env python3 """ Test script for RetNet Explicitness Classifier Usage: python test_model.py """ import torch import torch.nn.functional as F import json from transformers import AutoTokenizer from model import ProductionRetNet import time class RetNetExplicitnessClassifier: """Easy-to-use interface for RetNet explicitness classification""" def __init__(self, model_path=None, device='auto'): """Initialize the classifier Args: model_path: Path to the trained model file device: Device to run on ('auto', 'cpu', 'cuda', 'mps') """ # Load config with open('config.json', 'r') as f: self.config = json.load(f) # Auto-detect model path from config if not provided if model_path is None: model_path = self.config.get('model_file', 'model.safetensors') # Auto device selection if device == 'auto': if torch.cuda.is_available(): self.device = 'cuda' elif torch.backends.mps.is_available(): self.device = 'mps' else: self.device = 'cpu' else: self.device = device print(f"๐Ÿš€ Using device: {self.device}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained('gpt2') self.tokenizer.pad_token = self.tokenizer.eos_token # Load model self.model = self._load_model(model_path) self.labels = self.config['labels'] def _load_model(self, model_path): """Load the RetNet model""" model = ProductionRetNet( vocab_size=self.config['vocab_size'], dim=self.config['model_dim'], num_layers=self.config['num_layers'], num_heads=self.config['num_heads'], num_classes=self.config['num_classes'], max_length=self.config['max_length'] ) # Load trained weights from safetensors.torch import load_file state_dict = load_file(model_path, device=self.device) model.load_state_dict(state_dict) model.to(self.device) model.eval() return model def classify(self, text): """Classify a single text Args: text: Input text to classify Returns: dict: Classification results with label, confidence, and all probabilities """ # Tokenize inputs = self.tokenizer( text, truncation=True, padding=True, max_length=self.config['max_length'], return_tensors='pt' ) input_ids = inputs['input_ids'].to(self.device) attention_mask = inputs['attention_mask'].to(self.device) # Predict with torch.no_grad(): logits = self.model(input_ids, attention_mask) probabilities = F.softmax(logits, dim=-1) # Get results probs = probabilities[0].cpu().numpy() pred_id = int(probs.argmax()) confidence = float(probs[pred_id]) return { 'text': text, # Keep full text for fun-stats display 'predicted_class': self.labels[pred_id], 'confidence': confidence, 'probabilities': { label: float(probs[i]) for i, label in enumerate(self.labels) } } def classify_batch(self, texts): """Classify multiple texts efficiently Args: texts: List of input texts Returns: list: List of classification results """ results = [] batch_size = 32 for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] # Tokenize batch inputs = self.tokenizer( batch, truncation=True, padding=True, max_length=self.config['max_length'], return_tensors='pt' ) input_ids = inputs['input_ids'].to(self.device) attention_mask = inputs['attention_mask'].to(self.device) # Predict with torch.no_grad(): logits = self.model(input_ids, attention_mask) probabilities = F.softmax(logits, dim=-1) # Process results for j, text in enumerate(batch): probs = probabilities[j].cpu().numpy() pred_id = int(probs.argmax()) confidence = float(probs[pred_id]) results.append({ 'text': text, # Keep full text for fun-stats display 'predicted_class': self.labels[pred_id], 'confidence': confidence, 'probabilities': { label: float(probs[k]) for k, label in enumerate(self.labels) } }) return results def main(): """Test the RetNet classifier with example texts""" print("๐Ÿงช Testing RetNet Explicitness Classifier") print("=" * 60) # Initialize classifier classifier = RetNetExplicitnessClassifier() # Test examples covering different categories test_texts = [ # NON-EXPLICIT "The morning sun cast long shadows across the peaceful meadow as birds sang in the trees.", # SUGGESTIVE "She felt a spark of attraction as their eyes met across the crowded room.", # SEXUAL-REFERENCE "The romance novel described their passionate night together in tasteful detail.", # EXPLICIT-SEXUAL "His hands explored every inch of her naked body as she moaned with pleasure.", # EXPLICIT-VIOLENT "The killer slowly twisted the knife deeper into his victim's chest.", # EXPLICIT-OFFENSIVE "What the fuck is wrong with you, you goddamn idiot?", # EXPLICIT-DISCLAIMER "Warning: This content contains explicit sexual material and violence." ] print(f"๐Ÿ“Š Testing {len(test_texts)} example texts...\n") # Single text classification print("๐Ÿ” Single Text Classification:") print("-" * 40) for i, text in enumerate(test_texts): result = classifier.classify(text) print(f"\n{i+1}. Text: {result['text']}") print(f" Prediction: {result['predicted_class']}") print(f" Confidence: {result['confidence']:.3f}") # Batch classification with timing print(f"\nโšก Batch Classification Performance:") print("-" * 40) start_time = time.time() batch_results = classifier.classify_batch(test_texts) elapsed_time = time.time() - start_time texts_per_sec = len(test_texts) / elapsed_time print(f"๐Ÿ“ˆ Processed {len(test_texts)} texts in {elapsed_time:.3f}s") print(f"๐Ÿš€ Speed: {texts_per_sec:.1f} texts/second") # Show prediction distribution predictions = [r['predicted_class'] for r in batch_results] pred_counts = {} for pred in predictions: pred_counts[pred] = pred_counts.get(pred, 0) + 1 print(f"\n๐Ÿ“Š Prediction Distribution:") for label, count in sorted(pred_counts.items()): print(f" {label}: {count}") # Model info print(f"\n๐Ÿค– Model Information:") print(f" Parameters: {classifier.config['performance']['parameters']:,}") print(f" Holdout F1: {classifier.config['performance']['holdout_macro_f1']:.3f}") print(f" Holdout Accuracy: {classifier.config['performance']['holdout_accuracy']:.3f}") print(f" Training Time: {classifier.config['training']['training_time_hours']:.1f} hours") print(f"\nโœ… RetNet classifier test completed!") if __name__ == "__main__": main()