Mitchins's picture
Upload folder using huggingface_hub
54097f9 verified
#!/usr/bin/env python3
"""
Book Classification Script for RetNet Explicitness Classifier
Usage:
# As CLI
python classify_book.py book.txt --format json --batch-size 64
# As Python import
from classify_book import BookClassifier
classifier = BookClassifier()
results = classifier.classify_book(paragraphs_list)
"""
import argparse
import json
import sys
import time
from pathlib import Path
from typing import List, Dict, Union
import torch
from test_model import RetNetExplicitnessClassifier
class BookClassifier:
"""Optimized book classification with batch processing"""
def __init__(self, model_path=None, device='auto', batch_size=64, confidence_threshold=0.5):
"""Initialize book classifier
Args:
model_path: Path to model file (auto-detected from config if None)
device: Device to use ('auto', 'cpu', 'cuda', 'mps')
batch_size: Batch size for processing (default: 64)
confidence_threshold: Minimum confidence for classification (default: 0.5)
"""
self.classifier = RetNetExplicitnessClassifier(model_path, device)
self.batch_size = batch_size
self.confidence_threshold = confidence_threshold
def classify_book(self, paragraphs: List[str]) -> Dict:
"""Classify all paragraphs in a book with optimized batching
Args:
paragraphs: List of paragraph strings
Returns:
dict: Classification results with stats and paragraph results
"""
if not paragraphs:
return {"error": "No paragraphs provided"}
print(f"πŸ“– Classifying {len(paragraphs):,} paragraphs...")
start_time = time.time()
# Batch process for maximum efficiency
results = self.classifier.classify_batch(paragraphs)
# Apply confidence threshold
for result in results:
if result['confidence'] < self.confidence_threshold:
result['original_prediction'] = result['predicted_class']
result['original_confidence'] = result['confidence']
result['predicted_class'] = 'INCONCLUSIVE'
result['confidence'] = result['original_confidence'] # Keep original for analysis
elapsed_time = time.time() - start_time
paragraphs_per_sec = len(paragraphs) / elapsed_time
# Calculate statistics
stats = self._calculate_stats(results)
# Count inconclusive predictions
inconclusive_count = sum(1 for r in results if r['predicted_class'] == 'INCONCLUSIVE')
# Calculate meta-class statistics
meta_stats = self._calculate_meta_stats(results)
return {
"book_stats": {
"total_paragraphs": len(paragraphs),
"processing_time_seconds": round(elapsed_time, 3),
"paragraphs_per_second": round(paragraphs_per_sec, 1),
"batch_size_used": self.batch_size,
"confidence_threshold": self.confidence_threshold,
"inconclusive_count": inconclusive_count,
"conclusive_count": len(paragraphs) - inconclusive_count
},
"explicitness_distribution": stats,
"meta_class_distribution": meta_stats,
"paragraph_results": results
}
def classify_book_summary(self, paragraphs: List[str]) -> Dict:
"""Fast book classification returning only summary stats
Args:
paragraphs: List of paragraph strings
Returns:
dict: Summary statistics without individual paragraph results
"""
results = self.classify_book(paragraphs)
# Return only summary, not individual results
return {
"book_stats": results["book_stats"],
"explicitness_distribution": results["explicitness_distribution"]
}
def _calculate_stats(self, results: List[Dict]) -> Dict:
"""Calculate explicitness distribution statistics"""
stats = {}
# Count predictions
for result in results:
label = result['predicted_class']
stats[label] = stats.get(label, 0) + 1
total = len(results)
# Convert to percentages and add counts
distribution = {}
for label, count in stats.items():
distribution[label] = {
"count": count,
"percentage": round(100 * count / total, 2)
}
# Sort by explicitness level
label_order = [
"NON-EXPLICIT", "SUGGESTIVE", "SEXUAL-REFERENCE",
"EXPLICIT-SEXUAL", "EXPLICIT-OFFENSIVE", "EXPLICIT-VIOLENT",
"EXPLICIT-DISCLAIMER", "INCONCLUSIVE"
]
ordered_dist = {}
for label in label_order:
if label in distribution:
ordered_dist[label] = distribution[label]
return ordered_dist
def _calculate_meta_stats(self, results: List[Dict]) -> Dict:
"""Calculate meta-class groupings statistics"""
# Define meta-class mappings
meta_classes = {
'SAFE': ['NON-EXPLICIT'],
'SEXUAL': ['SUGGESTIVE', 'SEXUAL-REFERENCE', 'EXPLICIT-SEXUAL'],
'MATURE': ['EXPLICIT-SEXUAL', 'EXPLICIT-OFFENSIVE', 'EXPLICIT-VIOLENT'],
'EXPLICIT': ['EXPLICIT-SEXUAL', 'EXPLICIT-OFFENSIVE', 'EXPLICIT-VIOLENT'],
'WARNINGS': ['EXPLICIT-DISCLAIMER']
}
total = len(results)
meta_stats = {}
for meta_label, class_list in meta_classes.items():
count = sum(1 for r in results if r['predicted_class'] in class_list)
meta_stats[meta_label] = {
"count": count,
"percentage": round(100 * count / total, 2) if total > 0 else 0,
"includes": class_list
}
# Add inconclusive as meta-class
inconclusive_count = sum(1 for r in results if r['predicted_class'] == 'INCONCLUSIVE')
meta_stats['INCONCLUSIVE'] = {
"count": inconclusive_count,
"percentage": round(100 * inconclusive_count / total, 2) if total > 0 else 0,
"includes": ['INCONCLUSIVE']
}
return meta_stats
def calculate_fun_stats(self, results: List[Dict]) -> Dict:
"""Calculate fun statistics: strongest, borderline, and most confused examples"""
fun_stats = {
"strongest_examples": {}, # Highest confidence per class
"borderline_examples": {}, # Lowest confidence per class
"most_confused": None, # Overall lowest confidence
"most_inconclusive": [] # Most inconclusive examples
}
# Group results by predicted class, excluding INCONCLUSIVE for most stats
by_class = {}
inconclusive_examples = []
for i, result in enumerate(results):
label = result['predicted_class']
if label == 'INCONCLUSIVE':
inconclusive_examples.append((i, result))
else:
if label not in by_class:
by_class[label] = []
by_class[label].append((i, result))
# Find strongest and borderline examples for each class
for label, class_results in by_class.items():
# Sort by confidence
sorted_results = sorted(class_results, key=lambda x: x[1]['confidence'], reverse=True)
# Strongest (highest confidence)
strongest_idx, strongest_result = sorted_results[0]
fun_stats["strongest_examples"][label] = {
"text": strongest_result['text'],
"confidence": strongest_result['confidence'],
"paragraph_number": strongest_idx + 1
}
# Borderline (lowest confidence in this class)
borderline_idx, borderline_result = sorted_results[-1]
fun_stats["borderline_examples"][label] = {
"text": borderline_result['text'],
"confidence": borderline_result['confidence'],
"paragraph_number": borderline_idx + 1
}
# Most confused overall (lowest confidence excluding INCONCLUSIVE)
non_inconclusive = [(i, r) for i, r in enumerate(results) if r['predicted_class'] != 'INCONCLUSIVE']
if non_inconclusive:
most_confused = min(non_inconclusive, key=lambda x: x[1]['confidence'])
most_confused_idx, most_confused_result = most_confused
fun_stats["most_confused"] = {
"text": most_confused_result['text'],
"predicted_class": most_confused_result['predicted_class'],
"confidence": most_confused_result['confidence'],
"paragraph_number": most_confused_idx + 1,
"all_probabilities": most_confused_result['probabilities']
}
# Most inconclusive examples (lowest confidence among INCONCLUSIVE)
if inconclusive_examples:
inconclusive_sorted = sorted(inconclusive_examples, key=lambda x: x[1]['confidence'])
fun_stats["most_inconclusive"] = []
for i, (para_idx, result) in enumerate(inconclusive_sorted[:3]): # Top 3 most inconclusive
original_pred = result.get('original_prediction', 'UNKNOWN')
fun_stats["most_inconclusive"].append({
"text": result['text'],
"confidence": result['confidence'],
"paragraph_number": para_idx + 1,
"original_prediction": original_pred,
"all_probabilities": result['probabilities']
})
return fun_stats
def load_book_file(file_path: str) -> List[str]:
"""Load a book file and split into paragraphs
Args:
file_path: Path to text file
Returns:
List of paragraph strings
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
except UnicodeDecodeError:
# Try with different encoding
with open(file_path, 'r', encoding='latin-1') as f:
content = f.read()
# Split into paragraphs (double newlines or single newlines)
paragraphs = []
# First try double newlines
parts = content.split('\n\n')
if len(parts) > 10: # Likely good paragraph separation
paragraphs = [p.strip() for p in parts if p.strip()]
else:
# Fall back to single newlines
parts = content.split('\n')
paragraphs = [p.strip() for p in parts if p.strip() and len(p.strip()) > 20]
return paragraphs
def main():
"""CLI interface for book classification"""
parser = argparse.ArgumentParser(
description="Classify explicitness levels in book text files",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python classify_book.py book.txt --summary
python classify_book.py book.txt --format json --output results.json
python classify_book.py book.txt --batch-size 32 --device cpu
"""
)
parser.add_argument('file', help='Path to book text file')
parser.add_argument('--format', choices=['json', 'summary'], default='summary',
help='Output format (default: summary)')
parser.add_argument('--output', '-o', help='Output file (default: stdout)')
parser.add_argument('--batch-size', type=int, default=64,
help='Batch size for processing (default: 64)')
parser.add_argument('--device', choices=['auto', 'cpu', 'cuda', 'mps'],
default='auto', help='Device to use (default: auto)')
parser.add_argument('--summary', action='store_true',
help='Show only summary stats (faster)')
parser.add_argument('--fun-stats', action='store_true',
help='Show strongest, most borderline, and most confused examples')
parser.add_argument('--confidence-threshold', type=float, default=0.5,
help='Minimum confidence threshold (default: 0.5). Below this = INCONCLUSIVE')
parser.add_argument('--show-meta-classes', action='store_true',
help='Show meta-class groupings (SAFE, SEXUAL, MATURE, etc.)')
parser.add_argument('--export-fun-stats', type=str, metavar='FILE',
help='Export detailed fun-stats to JSON file (full text, no truncation)')
args = parser.parse_args()
# Validate file
if not Path(args.file).exists():
print(f"❌ Error: File '{args.file}' not found", file=sys.stderr)
sys.exit(1)
try:
# Load book
print(f"πŸ“š Loading book from '{args.file}'...")
paragraphs = load_book_file(args.file)
print(f"πŸ“„ Found {len(paragraphs):,} paragraphs")
if len(paragraphs) == 0:
print("❌ Error: No paragraphs found in file", file=sys.stderr)
sys.exit(1)
# Initialize classifier
classifier = BookClassifier(
batch_size=args.batch_size,
device=args.device,
confidence_threshold=args.confidence_threshold
)
# Classify
if (args.summary or args.format == 'summary') and not args.fun_stats:
# Only use summary mode if fun_stats not requested
results = classifier.classify_book_summary(paragraphs)
else:
# Need full results for fun stats
results = classifier.classify_book(paragraphs)
# Add fun stats if requested
if args.fun_stats and 'paragraph_results' in results:
results['fun_stats'] = classifier.calculate_fun_stats(results['paragraph_results'])
# Export fun stats to JSON if requested
if args.export_fun_stats and 'paragraph_results' in results:
if 'fun_stats' not in results:
results['fun_stats'] = classifier.calculate_fun_stats(results['paragraph_results'])
export_data = {
'book_stats': results['book_stats'],
'fun_stats': results['fun_stats'],
'export_info': {
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'confidence_threshold': args.confidence_threshold,
'note': 'Full text examples with no truncation'
}
}
with open(args.export_fun_stats, 'w') as f:
json.dump(export_data, f, indent=2)
print(f"πŸ“ Fun stats exported to '{args.export_fun_stats}'")
# Output results
if args.format == 'json':
output = json.dumps(results, indent=2)
else:
output = format_summary_output(results)
if args.output:
with open(args.output, 'w') as f:
f.write(output)
print(f"πŸ“ Results saved to '{args.output}'")
else:
print(output)
except KeyboardInterrupt:
print("\n⚠️ Classification interrupted by user")
sys.exit(1)
except Exception as e:
print(f"❌ Error: {e}", file=sys.stderr)
sys.exit(1)
def format_summary_output(results: Dict) -> str:
"""Format results as human-readable summary"""
stats = results['book_stats']
dist = results['explicitness_distribution']
output = []
output.append("πŸ“Š Book Classification Results")
output.append("=" * 50)
output.append(f"πŸ“– Total paragraphs: {stats['total_paragraphs']:,}")
output.append(f"⚑ Processing time: {stats['processing_time_seconds']}s")
output.append(f"πŸš€ Speed: {stats['paragraphs_per_second']} paragraphs/sec")
# Show confidence threshold info
if 'confidence_threshold' in stats:
threshold = stats['confidence_threshold']
inconclusive = stats.get('inconclusive_count', 0)
conclusive = stats.get('conclusive_count', stats['total_paragraphs'])
inconclusive_pct = 100 * inconclusive / stats['total_paragraphs']
output.append(f"🎯 Confidence threshold: {threshold:.1f}")
output.append(f"βœ… Conclusive predictions: {conclusive:,} ({100-inconclusive_pct:.1f}%)")
output.append(f"❓ Inconclusive predictions: {inconclusive:,} ({inconclusive_pct:.1f}%)")
output.append("")
output.append("πŸ“ˆ Explicitness Distribution:")
output.append("-" * 30)
for label, data in dist.items():
bar_length = int(data['percentage'] / 2) # Scale for display
bar = "β–ˆ" * bar_length
output.append(f"{label:18} {data['count']:5,} ({data['percentage']:5.1f}%) {bar}")
# Show meta-classes if available and in results (always show them now)
if 'meta_class_distribution' in results:
meta_dist = results['meta_class_distribution']
output.append("")
output.append("🏷️ Meta-Class Distribution:")
output.append("-" * 30)
# Order meta-classes meaningfully
meta_order = ['SAFE', 'SEXUAL', 'MATURE', 'EXPLICIT', 'WARNINGS', 'INCONCLUSIVE']
for meta_label in meta_order:
if meta_label in meta_dist:
data = meta_dist[meta_label]
if data['count'] > 0: # Only show if there are examples
bar_length = int(data['percentage'] / 2)
bar = "β–ˆ" * bar_length
output.append(f"{meta_label:12} {data['count']:5,} ({data['percentage']:5.1f}%) {bar}")
# Add fun stats if available
if 'fun_stats' in results:
output.append("")
output.append("🎯 Fun Stats:")
output.append("=" * 50)
fun_stats = results['fun_stats']
# Strongest examples
output.append("\nπŸ† Strongest Examples (Highest Confidence):")
output.append("-" * 45)
for label, example in fun_stats['strongest_examples'].items():
output.append(f"\n{label} ({example['confidence']:.3f} confidence)")
output.append(f" Paragraph #{example['paragraph_number']}: \"{example['text'][:250]}...\"")
# Borderline examples
output.append("\nπŸ€” Most Borderline Examples (Lowest Confidence per Class):")
output.append("-" * 55)
for label, example in fun_stats['borderline_examples'].items():
output.append(f"\n{label} ({example['confidence']:.3f} confidence)")
output.append(f" Paragraph #{example['paragraph_number']}: \"{example['text'][:250]}...\"")
# Most confused (among conclusive predictions)
if fun_stats['most_confused']:
confused = fun_stats['most_confused']
output.append(f"\n🀯 Most Confused Conclusive Paragraph ({confused['confidence']:.3f} confidence):")
output.append("-" * 55)
output.append(f"Paragraph #{confused['paragraph_number']}: \"{confused['text'][:250]}...\"")
output.append(f"Predicted: {confused['predicted_class']}")
# Show probability distribution for confused example
output.append("All probabilities:")
sorted_probs = sorted(confused['all_probabilities'].items(),
key=lambda x: x[1], reverse=True)
for label, prob in sorted_probs[:3]: # Top 3
output.append(f" {label}: {prob:.3f}")
# Most inconclusive examples
if fun_stats['most_inconclusive']:
output.append(f"\n❓ Most Inconclusive Examples:")
output.append("-" * 35)
for i, inc in enumerate(fun_stats['most_inconclusive']):
output.append(f"\n{i+1}. Paragraph #{inc['paragraph_number']} ({inc['confidence']:.3f} confidence)")
output.append(f" \"{inc['text'][:250]}...\"")
output.append(f" Original prediction: {inc['original_prediction']}")
return "\n".join(output)
if __name__ == "__main__":
main()