#!/usr/bin/env python3 """ Book Classification Script for RetNet Explicitness Classifier Usage: # As CLI python classify_book.py book.txt --format json --batch-size 64 # As Python import from classify_book import BookClassifier classifier = BookClassifier() results = classifier.classify_book(paragraphs_list) """ import argparse import json import sys import time from pathlib import Path from typing import List, Dict, Union import torch from test_model import RetNetExplicitnessClassifier class BookClassifier: """Optimized book classification with batch processing""" def __init__(self, model_path=None, device='auto', batch_size=64, confidence_threshold=0.5): """Initialize book classifier Args: model_path: Path to model file (auto-detected from config if None) device: Device to use ('auto', 'cpu', 'cuda', 'mps') batch_size: Batch size for processing (default: 64) confidence_threshold: Minimum confidence for classification (default: 0.5) """ self.classifier = RetNetExplicitnessClassifier(model_path, device) self.batch_size = batch_size self.confidence_threshold = confidence_threshold def classify_book(self, paragraphs: List[str]) -> Dict: """Classify all paragraphs in a book with optimized batching Args: paragraphs: List of paragraph strings Returns: dict: Classification results with stats and paragraph results """ if not paragraphs: return {"error": "No paragraphs provided"} print(f"šŸ“– Classifying {len(paragraphs):,} paragraphs...") start_time = time.time() # Batch process for maximum efficiency results = self.classifier.classify_batch(paragraphs) # Apply confidence threshold for result in results: if result['confidence'] < self.confidence_threshold: result['original_prediction'] = result['predicted_class'] result['original_confidence'] = result['confidence'] result['predicted_class'] = 'INCONCLUSIVE' result['confidence'] = result['original_confidence'] # Keep original for analysis elapsed_time = time.time() - start_time paragraphs_per_sec = len(paragraphs) / elapsed_time # Calculate statistics stats = self._calculate_stats(results) # Count inconclusive predictions inconclusive_count = sum(1 for r in results if r['predicted_class'] == 'INCONCLUSIVE') # Calculate meta-class statistics meta_stats = self._calculate_meta_stats(results) return { "book_stats": { "total_paragraphs": len(paragraphs), "processing_time_seconds": round(elapsed_time, 3), "paragraphs_per_second": round(paragraphs_per_sec, 1), "batch_size_used": self.batch_size, "confidence_threshold": self.confidence_threshold, "inconclusive_count": inconclusive_count, "conclusive_count": len(paragraphs) - inconclusive_count }, "explicitness_distribution": stats, "meta_class_distribution": meta_stats, "paragraph_results": results } def classify_book_summary(self, paragraphs: List[str]) -> Dict: """Fast book classification returning only summary stats Args: paragraphs: List of paragraph strings Returns: dict: Summary statistics without individual paragraph results """ results = self.classify_book(paragraphs) # Return only summary, not individual results return { "book_stats": results["book_stats"], "explicitness_distribution": results["explicitness_distribution"] } def _calculate_stats(self, results: List[Dict]) -> Dict: """Calculate explicitness distribution statistics""" stats = {} # Count predictions for result in results: label = result['predicted_class'] stats[label] = stats.get(label, 0) + 1 total = len(results) # Convert to percentages and add counts distribution = {} for label, count in stats.items(): distribution[label] = { "count": count, "percentage": round(100 * count / total, 2) } # Sort by explicitness level label_order = [ "NON-EXPLICIT", "SUGGESTIVE", "SEXUAL-REFERENCE", "EXPLICIT-SEXUAL", "EXPLICIT-OFFENSIVE", "EXPLICIT-VIOLENT", "EXPLICIT-DISCLAIMER", "INCONCLUSIVE" ] ordered_dist = {} for label in label_order: if label in distribution: ordered_dist[label] = distribution[label] return ordered_dist def _calculate_meta_stats(self, results: List[Dict]) -> Dict: """Calculate meta-class groupings statistics""" # Define meta-class mappings meta_classes = { 'SAFE': ['NON-EXPLICIT'], 'SEXUAL': ['SUGGESTIVE', 'SEXUAL-REFERENCE', 'EXPLICIT-SEXUAL'], 'MATURE': ['EXPLICIT-SEXUAL', 'EXPLICIT-OFFENSIVE', 'EXPLICIT-VIOLENT'], 'EXPLICIT': ['EXPLICIT-SEXUAL', 'EXPLICIT-OFFENSIVE', 'EXPLICIT-VIOLENT'], 'WARNINGS': ['EXPLICIT-DISCLAIMER'] } total = len(results) meta_stats = {} for meta_label, class_list in meta_classes.items(): count = sum(1 for r in results if r['predicted_class'] in class_list) meta_stats[meta_label] = { "count": count, "percentage": round(100 * count / total, 2) if total > 0 else 0, "includes": class_list } # Add inconclusive as meta-class inconclusive_count = sum(1 for r in results if r['predicted_class'] == 'INCONCLUSIVE') meta_stats['INCONCLUSIVE'] = { "count": inconclusive_count, "percentage": round(100 * inconclusive_count / total, 2) if total > 0 else 0, "includes": ['INCONCLUSIVE'] } return meta_stats def calculate_fun_stats(self, results: List[Dict]) -> Dict: """Calculate fun statistics: strongest, borderline, and most confused examples""" fun_stats = { "strongest_examples": {}, # Highest confidence per class "borderline_examples": {}, # Lowest confidence per class "most_confused": None, # Overall lowest confidence "most_inconclusive": [] # Most inconclusive examples } # Group results by predicted class, excluding INCONCLUSIVE for most stats by_class = {} inconclusive_examples = [] for i, result in enumerate(results): label = result['predicted_class'] if label == 'INCONCLUSIVE': inconclusive_examples.append((i, result)) else: if label not in by_class: by_class[label] = [] by_class[label].append((i, result)) # Find strongest and borderline examples for each class for label, class_results in by_class.items(): # Sort by confidence sorted_results = sorted(class_results, key=lambda x: x[1]['confidence'], reverse=True) # Strongest (highest confidence) strongest_idx, strongest_result = sorted_results[0] fun_stats["strongest_examples"][label] = { "text": strongest_result['text'], "confidence": strongest_result['confidence'], "paragraph_number": strongest_idx + 1 } # Borderline (lowest confidence in this class) borderline_idx, borderline_result = sorted_results[-1] fun_stats["borderline_examples"][label] = { "text": borderline_result['text'], "confidence": borderline_result['confidence'], "paragraph_number": borderline_idx + 1 } # Most confused overall (lowest confidence excluding INCONCLUSIVE) non_inconclusive = [(i, r) for i, r in enumerate(results) if r['predicted_class'] != 'INCONCLUSIVE'] if non_inconclusive: most_confused = min(non_inconclusive, key=lambda x: x[1]['confidence']) most_confused_idx, most_confused_result = most_confused fun_stats["most_confused"] = { "text": most_confused_result['text'], "predicted_class": most_confused_result['predicted_class'], "confidence": most_confused_result['confidence'], "paragraph_number": most_confused_idx + 1, "all_probabilities": most_confused_result['probabilities'] } # Most inconclusive examples (lowest confidence among INCONCLUSIVE) if inconclusive_examples: inconclusive_sorted = sorted(inconclusive_examples, key=lambda x: x[1]['confidence']) fun_stats["most_inconclusive"] = [] for i, (para_idx, result) in enumerate(inconclusive_sorted[:3]): # Top 3 most inconclusive original_pred = result.get('original_prediction', 'UNKNOWN') fun_stats["most_inconclusive"].append({ "text": result['text'], "confidence": result['confidence'], "paragraph_number": para_idx + 1, "original_prediction": original_pred, "all_probabilities": result['probabilities'] }) return fun_stats def load_book_file(file_path: str) -> List[str]: """Load a book file and split into paragraphs Args: file_path: Path to text file Returns: List of paragraph strings """ try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() except UnicodeDecodeError: # Try with different encoding with open(file_path, 'r', encoding='latin-1') as f: content = f.read() # Split into paragraphs (double newlines or single newlines) paragraphs = [] # First try double newlines parts = content.split('\n\n') if len(parts) > 10: # Likely good paragraph separation paragraphs = [p.strip() for p in parts if p.strip()] else: # Fall back to single newlines parts = content.split('\n') paragraphs = [p.strip() for p in parts if p.strip() and len(p.strip()) > 20] return paragraphs def main(): """CLI interface for book classification""" parser = argparse.ArgumentParser( description="Classify explicitness levels in book text files", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python classify_book.py book.txt --summary python classify_book.py book.txt --format json --output results.json python classify_book.py book.txt --batch-size 32 --device cpu """ ) parser.add_argument('file', help='Path to book text file') parser.add_argument('--format', choices=['json', 'summary'], default='summary', help='Output format (default: summary)') parser.add_argument('--output', '-o', help='Output file (default: stdout)') parser.add_argument('--batch-size', type=int, default=64, help='Batch size for processing (default: 64)') parser.add_argument('--device', choices=['auto', 'cpu', 'cuda', 'mps'], default='auto', help='Device to use (default: auto)') parser.add_argument('--summary', action='store_true', help='Show only summary stats (faster)') parser.add_argument('--fun-stats', action='store_true', help='Show strongest, most borderline, and most confused examples') parser.add_argument('--confidence-threshold', type=float, default=0.5, help='Minimum confidence threshold (default: 0.5). Below this = INCONCLUSIVE') parser.add_argument('--show-meta-classes', action='store_true', help='Show meta-class groupings (SAFE, SEXUAL, MATURE, etc.)') parser.add_argument('--export-fun-stats', type=str, metavar='FILE', help='Export detailed fun-stats to JSON file (full text, no truncation)') args = parser.parse_args() # Validate file if not Path(args.file).exists(): print(f"āŒ Error: File '{args.file}' not found", file=sys.stderr) sys.exit(1) try: # Load book print(f"šŸ“š Loading book from '{args.file}'...") paragraphs = load_book_file(args.file) print(f"šŸ“„ Found {len(paragraphs):,} paragraphs") if len(paragraphs) == 0: print("āŒ Error: No paragraphs found in file", file=sys.stderr) sys.exit(1) # Initialize classifier classifier = BookClassifier( batch_size=args.batch_size, device=args.device, confidence_threshold=args.confidence_threshold ) # Classify if (args.summary or args.format == 'summary') and not args.fun_stats: # Only use summary mode if fun_stats not requested results = classifier.classify_book_summary(paragraphs) else: # Need full results for fun stats results = classifier.classify_book(paragraphs) # Add fun stats if requested if args.fun_stats and 'paragraph_results' in results: results['fun_stats'] = classifier.calculate_fun_stats(results['paragraph_results']) # Export fun stats to JSON if requested if args.export_fun_stats and 'paragraph_results' in results: if 'fun_stats' not in results: results['fun_stats'] = classifier.calculate_fun_stats(results['paragraph_results']) export_data = { 'book_stats': results['book_stats'], 'fun_stats': results['fun_stats'], 'export_info': { 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'confidence_threshold': args.confidence_threshold, 'note': 'Full text examples with no truncation' } } with open(args.export_fun_stats, 'w') as f: json.dump(export_data, f, indent=2) print(f"šŸ“ Fun stats exported to '{args.export_fun_stats}'") # Output results if args.format == 'json': output = json.dumps(results, indent=2) else: output = format_summary_output(results) if args.output: with open(args.output, 'w') as f: f.write(output) print(f"šŸ“ Results saved to '{args.output}'") else: print(output) except KeyboardInterrupt: print("\nāš ļø Classification interrupted by user") sys.exit(1) except Exception as e: print(f"āŒ Error: {e}", file=sys.stderr) sys.exit(1) def format_summary_output(results: Dict) -> str: """Format results as human-readable summary""" stats = results['book_stats'] dist = results['explicitness_distribution'] output = [] output.append("šŸ“Š Book Classification Results") output.append("=" * 50) output.append(f"šŸ“– Total paragraphs: {stats['total_paragraphs']:,}") output.append(f"⚔ Processing time: {stats['processing_time_seconds']}s") output.append(f"šŸš€ Speed: {stats['paragraphs_per_second']} paragraphs/sec") # Show confidence threshold info if 'confidence_threshold' in stats: threshold = stats['confidence_threshold'] inconclusive = stats.get('inconclusive_count', 0) conclusive = stats.get('conclusive_count', stats['total_paragraphs']) inconclusive_pct = 100 * inconclusive / stats['total_paragraphs'] output.append(f"šŸŽÆ Confidence threshold: {threshold:.1f}") output.append(f"āœ… Conclusive predictions: {conclusive:,} ({100-inconclusive_pct:.1f}%)") output.append(f"ā“ Inconclusive predictions: {inconclusive:,} ({inconclusive_pct:.1f}%)") output.append("") output.append("šŸ“ˆ Explicitness Distribution:") output.append("-" * 30) for label, data in dist.items(): bar_length = int(data['percentage'] / 2) # Scale for display bar = "ā–ˆ" * bar_length output.append(f"{label:18} {data['count']:5,} ({data['percentage']:5.1f}%) {bar}") # Show meta-classes if available and in results (always show them now) if 'meta_class_distribution' in results: meta_dist = results['meta_class_distribution'] output.append("") output.append("šŸ·ļø Meta-Class Distribution:") output.append("-" * 30) # Order meta-classes meaningfully meta_order = ['SAFE', 'SEXUAL', 'MATURE', 'EXPLICIT', 'WARNINGS', 'INCONCLUSIVE'] for meta_label in meta_order: if meta_label in meta_dist: data = meta_dist[meta_label] if data['count'] > 0: # Only show if there are examples bar_length = int(data['percentage'] / 2) bar = "ā–ˆ" * bar_length output.append(f"{meta_label:12} {data['count']:5,} ({data['percentage']:5.1f}%) {bar}") # Add fun stats if available if 'fun_stats' in results: output.append("") output.append("šŸŽÆ Fun Stats:") output.append("=" * 50) fun_stats = results['fun_stats'] # Strongest examples output.append("\nšŸ† Strongest Examples (Highest Confidence):") output.append("-" * 45) for label, example in fun_stats['strongest_examples'].items(): output.append(f"\n{label} ({example['confidence']:.3f} confidence)") output.append(f" Paragraph #{example['paragraph_number']}: \"{example['text'][:250]}...\"") # Borderline examples output.append("\nšŸ¤” Most Borderline Examples (Lowest Confidence per Class):") output.append("-" * 55) for label, example in fun_stats['borderline_examples'].items(): output.append(f"\n{label} ({example['confidence']:.3f} confidence)") output.append(f" Paragraph #{example['paragraph_number']}: \"{example['text'][:250]}...\"") # Most confused (among conclusive predictions) if fun_stats['most_confused']: confused = fun_stats['most_confused'] output.append(f"\n🤯 Most Confused Conclusive Paragraph ({confused['confidence']:.3f} confidence):") output.append("-" * 55) output.append(f"Paragraph #{confused['paragraph_number']}: \"{confused['text'][:250]}...\"") output.append(f"Predicted: {confused['predicted_class']}") # Show probability distribution for confused example output.append("All probabilities:") sorted_probs = sorted(confused['all_probabilities'].items(), key=lambda x: x[1], reverse=True) for label, prob in sorted_probs[:3]: # Top 3 output.append(f" {label}: {prob:.3f}") # Most inconclusive examples if fun_stats['most_inconclusive']: output.append(f"\nā“ Most Inconclusive Examples:") output.append("-" * 35) for i, inc in enumerate(fun_stats['most_inconclusive']): output.append(f"\n{i+1}. Paragraph #{inc['paragraph_number']} ({inc['confidence']:.3f} confidence)") output.append(f" \"{inc['text'][:250]}...\"") output.append(f" Original prediction: {inc['original_prediction']}") return "\n".join(output) if __name__ == "__main__": main()