|
|
|
|
|
""" |
|
|
Book Classification Script for RetNet Explicitness Classifier |
|
|
|
|
|
Usage: |
|
|
# As CLI |
|
|
python classify_book.py book.txt --format json --batch-size 64 |
|
|
|
|
|
# As Python import |
|
|
from classify_book import BookClassifier |
|
|
classifier = BookClassifier() |
|
|
results = classifier.classify_book(paragraphs_list) |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import sys |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Union |
|
|
|
|
|
import torch |
|
|
from test_model import RetNetExplicitnessClassifier |
|
|
|
|
|
|
|
|
class BookClassifier: |
|
|
"""Optimized book classification with batch processing""" |
|
|
|
|
|
def __init__(self, model_path=None, device='auto', batch_size=64, confidence_threshold=0.5): |
|
|
"""Initialize book classifier |
|
|
|
|
|
Args: |
|
|
model_path: Path to model file (auto-detected from config if None) |
|
|
device: Device to use ('auto', 'cpu', 'cuda', 'mps') |
|
|
batch_size: Batch size for processing (default: 64) |
|
|
confidence_threshold: Minimum confidence for classification (default: 0.5) |
|
|
""" |
|
|
self.classifier = RetNetExplicitnessClassifier(model_path, device) |
|
|
self.batch_size = batch_size |
|
|
self.confidence_threshold = confidence_threshold |
|
|
|
|
|
def classify_book(self, paragraphs: List[str]) -> Dict: |
|
|
"""Classify all paragraphs in a book with optimized batching |
|
|
|
|
|
Args: |
|
|
paragraphs: List of paragraph strings |
|
|
|
|
|
Returns: |
|
|
dict: Classification results with stats and paragraph results |
|
|
""" |
|
|
if not paragraphs: |
|
|
return {"error": "No paragraphs provided"} |
|
|
|
|
|
print(f"π Classifying {len(paragraphs):,} paragraphs...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
results = self.classifier.classify_batch(paragraphs) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
if result['confidence'] < self.confidence_threshold: |
|
|
result['original_prediction'] = result['predicted_class'] |
|
|
result['original_confidence'] = result['confidence'] |
|
|
result['predicted_class'] = 'INCONCLUSIVE' |
|
|
result['confidence'] = result['original_confidence'] |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
paragraphs_per_sec = len(paragraphs) / elapsed_time |
|
|
|
|
|
|
|
|
stats = self._calculate_stats(results) |
|
|
|
|
|
|
|
|
inconclusive_count = sum(1 for r in results if r['predicted_class'] == 'INCONCLUSIVE') |
|
|
|
|
|
|
|
|
meta_stats = self._calculate_meta_stats(results) |
|
|
|
|
|
return { |
|
|
"book_stats": { |
|
|
"total_paragraphs": len(paragraphs), |
|
|
"processing_time_seconds": round(elapsed_time, 3), |
|
|
"paragraphs_per_second": round(paragraphs_per_sec, 1), |
|
|
"batch_size_used": self.batch_size, |
|
|
"confidence_threshold": self.confidence_threshold, |
|
|
"inconclusive_count": inconclusive_count, |
|
|
"conclusive_count": len(paragraphs) - inconclusive_count |
|
|
}, |
|
|
"explicitness_distribution": stats, |
|
|
"meta_class_distribution": meta_stats, |
|
|
"paragraph_results": results |
|
|
} |
|
|
|
|
|
def classify_book_summary(self, paragraphs: List[str]) -> Dict: |
|
|
"""Fast book classification returning only summary stats |
|
|
|
|
|
Args: |
|
|
paragraphs: List of paragraph strings |
|
|
|
|
|
Returns: |
|
|
dict: Summary statistics without individual paragraph results |
|
|
""" |
|
|
results = self.classify_book(paragraphs) |
|
|
|
|
|
|
|
|
return { |
|
|
"book_stats": results["book_stats"], |
|
|
"explicitness_distribution": results["explicitness_distribution"] |
|
|
} |
|
|
|
|
|
def _calculate_stats(self, results: List[Dict]) -> Dict: |
|
|
"""Calculate explicitness distribution statistics""" |
|
|
stats = {} |
|
|
|
|
|
|
|
|
for result in results: |
|
|
label = result['predicted_class'] |
|
|
stats[label] = stats.get(label, 0) + 1 |
|
|
|
|
|
total = len(results) |
|
|
|
|
|
|
|
|
distribution = {} |
|
|
for label, count in stats.items(): |
|
|
distribution[label] = { |
|
|
"count": count, |
|
|
"percentage": round(100 * count / total, 2) |
|
|
} |
|
|
|
|
|
|
|
|
label_order = [ |
|
|
"NON-EXPLICIT", "SUGGESTIVE", "SEXUAL-REFERENCE", |
|
|
"EXPLICIT-SEXUAL", "EXPLICIT-OFFENSIVE", "EXPLICIT-VIOLENT", |
|
|
"EXPLICIT-DISCLAIMER", "INCONCLUSIVE" |
|
|
] |
|
|
|
|
|
ordered_dist = {} |
|
|
for label in label_order: |
|
|
if label in distribution: |
|
|
ordered_dist[label] = distribution[label] |
|
|
|
|
|
return ordered_dist |
|
|
|
|
|
def _calculate_meta_stats(self, results: List[Dict]) -> Dict: |
|
|
"""Calculate meta-class groupings statistics""" |
|
|
|
|
|
meta_classes = { |
|
|
'SAFE': ['NON-EXPLICIT'], |
|
|
'SEXUAL': ['SUGGESTIVE', 'SEXUAL-REFERENCE', 'EXPLICIT-SEXUAL'], |
|
|
'MATURE': ['EXPLICIT-SEXUAL', 'EXPLICIT-OFFENSIVE', 'EXPLICIT-VIOLENT'], |
|
|
'EXPLICIT': ['EXPLICIT-SEXUAL', 'EXPLICIT-OFFENSIVE', 'EXPLICIT-VIOLENT'], |
|
|
'WARNINGS': ['EXPLICIT-DISCLAIMER'] |
|
|
} |
|
|
|
|
|
total = len(results) |
|
|
meta_stats = {} |
|
|
|
|
|
for meta_label, class_list in meta_classes.items(): |
|
|
count = sum(1 for r in results if r['predicted_class'] in class_list) |
|
|
meta_stats[meta_label] = { |
|
|
"count": count, |
|
|
"percentage": round(100 * count / total, 2) if total > 0 else 0, |
|
|
"includes": class_list |
|
|
} |
|
|
|
|
|
|
|
|
inconclusive_count = sum(1 for r in results if r['predicted_class'] == 'INCONCLUSIVE') |
|
|
meta_stats['INCONCLUSIVE'] = { |
|
|
"count": inconclusive_count, |
|
|
"percentage": round(100 * inconclusive_count / total, 2) if total > 0 else 0, |
|
|
"includes": ['INCONCLUSIVE'] |
|
|
} |
|
|
|
|
|
return meta_stats |
|
|
|
|
|
def calculate_fun_stats(self, results: List[Dict]) -> Dict: |
|
|
"""Calculate fun statistics: strongest, borderline, and most confused examples""" |
|
|
fun_stats = { |
|
|
"strongest_examples": {}, |
|
|
"borderline_examples": {}, |
|
|
"most_confused": None, |
|
|
"most_inconclusive": [] |
|
|
} |
|
|
|
|
|
|
|
|
by_class = {} |
|
|
inconclusive_examples = [] |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
label = result['predicted_class'] |
|
|
if label == 'INCONCLUSIVE': |
|
|
inconclusive_examples.append((i, result)) |
|
|
else: |
|
|
if label not in by_class: |
|
|
by_class[label] = [] |
|
|
by_class[label].append((i, result)) |
|
|
|
|
|
|
|
|
for label, class_results in by_class.items(): |
|
|
|
|
|
sorted_results = sorted(class_results, key=lambda x: x[1]['confidence'], reverse=True) |
|
|
|
|
|
|
|
|
strongest_idx, strongest_result = sorted_results[0] |
|
|
fun_stats["strongest_examples"][label] = { |
|
|
"text": strongest_result['text'], |
|
|
"confidence": strongest_result['confidence'], |
|
|
"paragraph_number": strongest_idx + 1 |
|
|
} |
|
|
|
|
|
|
|
|
borderline_idx, borderline_result = sorted_results[-1] |
|
|
fun_stats["borderline_examples"][label] = { |
|
|
"text": borderline_result['text'], |
|
|
"confidence": borderline_result['confidence'], |
|
|
"paragraph_number": borderline_idx + 1 |
|
|
} |
|
|
|
|
|
|
|
|
non_inconclusive = [(i, r) for i, r in enumerate(results) if r['predicted_class'] != 'INCONCLUSIVE'] |
|
|
if non_inconclusive: |
|
|
most_confused = min(non_inconclusive, key=lambda x: x[1]['confidence']) |
|
|
most_confused_idx, most_confused_result = most_confused |
|
|
|
|
|
fun_stats["most_confused"] = { |
|
|
"text": most_confused_result['text'], |
|
|
"predicted_class": most_confused_result['predicted_class'], |
|
|
"confidence": most_confused_result['confidence'], |
|
|
"paragraph_number": most_confused_idx + 1, |
|
|
"all_probabilities": most_confused_result['probabilities'] |
|
|
} |
|
|
|
|
|
|
|
|
if inconclusive_examples: |
|
|
inconclusive_sorted = sorted(inconclusive_examples, key=lambda x: x[1]['confidence']) |
|
|
fun_stats["most_inconclusive"] = [] |
|
|
|
|
|
for i, (para_idx, result) in enumerate(inconclusive_sorted[:3]): |
|
|
original_pred = result.get('original_prediction', 'UNKNOWN') |
|
|
fun_stats["most_inconclusive"].append({ |
|
|
"text": result['text'], |
|
|
"confidence": result['confidence'], |
|
|
"paragraph_number": para_idx + 1, |
|
|
"original_prediction": original_pred, |
|
|
"all_probabilities": result['probabilities'] |
|
|
}) |
|
|
|
|
|
return fun_stats |
|
|
|
|
|
|
|
|
def load_book_file(file_path: str) -> List[str]: |
|
|
"""Load a book file and split into paragraphs |
|
|
|
|
|
Args: |
|
|
file_path: Path to text file |
|
|
|
|
|
Returns: |
|
|
List of paragraph strings |
|
|
""" |
|
|
try: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
except UnicodeDecodeError: |
|
|
|
|
|
with open(file_path, 'r', encoding='latin-1') as f: |
|
|
content = f.read() |
|
|
|
|
|
|
|
|
paragraphs = [] |
|
|
|
|
|
|
|
|
parts = content.split('\n\n') |
|
|
if len(parts) > 10: |
|
|
paragraphs = [p.strip() for p in parts if p.strip()] |
|
|
else: |
|
|
|
|
|
parts = content.split('\n') |
|
|
paragraphs = [p.strip() for p in parts if p.strip() and len(p.strip()) > 20] |
|
|
|
|
|
return paragraphs |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""CLI interface for book classification""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Classify explicitness levels in book text files", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
python classify_book.py book.txt --summary |
|
|
python classify_book.py book.txt --format json --output results.json |
|
|
python classify_book.py book.txt --batch-size 32 --device cpu |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument('file', help='Path to book text file') |
|
|
parser.add_argument('--format', choices=['json', 'summary'], default='summary', |
|
|
help='Output format (default: summary)') |
|
|
parser.add_argument('--output', '-o', help='Output file (default: stdout)') |
|
|
parser.add_argument('--batch-size', type=int, default=64, |
|
|
help='Batch size for processing (default: 64)') |
|
|
parser.add_argument('--device', choices=['auto', 'cpu', 'cuda', 'mps'], |
|
|
default='auto', help='Device to use (default: auto)') |
|
|
parser.add_argument('--summary', action='store_true', |
|
|
help='Show only summary stats (faster)') |
|
|
parser.add_argument('--fun-stats', action='store_true', |
|
|
help='Show strongest, most borderline, and most confused examples') |
|
|
parser.add_argument('--confidence-threshold', type=float, default=0.5, |
|
|
help='Minimum confidence threshold (default: 0.5). Below this = INCONCLUSIVE') |
|
|
parser.add_argument('--show-meta-classes', action='store_true', |
|
|
help='Show meta-class groupings (SAFE, SEXUAL, MATURE, etc.)') |
|
|
parser.add_argument('--export-fun-stats', type=str, metavar='FILE', |
|
|
help='Export detailed fun-stats to JSON file (full text, no truncation)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not Path(args.file).exists(): |
|
|
print(f"β Error: File '{args.file}' not found", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"π Loading book from '{args.file}'...") |
|
|
paragraphs = load_book_file(args.file) |
|
|
print(f"π Found {len(paragraphs):,} paragraphs") |
|
|
|
|
|
if len(paragraphs) == 0: |
|
|
print("β Error: No paragraphs found in file", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
classifier = BookClassifier( |
|
|
batch_size=args.batch_size, |
|
|
device=args.device, |
|
|
confidence_threshold=args.confidence_threshold |
|
|
) |
|
|
|
|
|
|
|
|
if (args.summary or args.format == 'summary') and not args.fun_stats: |
|
|
|
|
|
results = classifier.classify_book_summary(paragraphs) |
|
|
else: |
|
|
|
|
|
results = classifier.classify_book(paragraphs) |
|
|
|
|
|
|
|
|
if args.fun_stats and 'paragraph_results' in results: |
|
|
results['fun_stats'] = classifier.calculate_fun_stats(results['paragraph_results']) |
|
|
|
|
|
|
|
|
if args.export_fun_stats and 'paragraph_results' in results: |
|
|
if 'fun_stats' not in results: |
|
|
results['fun_stats'] = classifier.calculate_fun_stats(results['paragraph_results']) |
|
|
|
|
|
export_data = { |
|
|
'book_stats': results['book_stats'], |
|
|
'fun_stats': results['fun_stats'], |
|
|
'export_info': { |
|
|
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), |
|
|
'confidence_threshold': args.confidence_threshold, |
|
|
'note': 'Full text examples with no truncation' |
|
|
} |
|
|
} |
|
|
|
|
|
with open(args.export_fun_stats, 'w') as f: |
|
|
json.dump(export_data, f, indent=2) |
|
|
print(f"π Fun stats exported to '{args.export_fun_stats}'") |
|
|
|
|
|
|
|
|
if args.format == 'json': |
|
|
output = json.dumps(results, indent=2) |
|
|
else: |
|
|
output = format_summary_output(results) |
|
|
|
|
|
if args.output: |
|
|
with open(args.output, 'w') as f: |
|
|
f.write(output) |
|
|
print(f"π Results saved to '{args.output}'") |
|
|
else: |
|
|
print(output) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nβ οΈ Classification interrupted by user") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"β Error: {e}", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def format_summary_output(results: Dict) -> str: |
|
|
"""Format results as human-readable summary""" |
|
|
stats = results['book_stats'] |
|
|
dist = results['explicitness_distribution'] |
|
|
|
|
|
output = [] |
|
|
output.append("π Book Classification Results") |
|
|
output.append("=" * 50) |
|
|
output.append(f"π Total paragraphs: {stats['total_paragraphs']:,}") |
|
|
output.append(f"β‘ Processing time: {stats['processing_time_seconds']}s") |
|
|
output.append(f"π Speed: {stats['paragraphs_per_second']} paragraphs/sec") |
|
|
|
|
|
|
|
|
if 'confidence_threshold' in stats: |
|
|
threshold = stats['confidence_threshold'] |
|
|
inconclusive = stats.get('inconclusive_count', 0) |
|
|
conclusive = stats.get('conclusive_count', stats['total_paragraphs']) |
|
|
inconclusive_pct = 100 * inconclusive / stats['total_paragraphs'] |
|
|
|
|
|
output.append(f"π― Confidence threshold: {threshold:.1f}") |
|
|
output.append(f"β
Conclusive predictions: {conclusive:,} ({100-inconclusive_pct:.1f}%)") |
|
|
output.append(f"β Inconclusive predictions: {inconclusive:,} ({inconclusive_pct:.1f}%)") |
|
|
|
|
|
output.append("") |
|
|
|
|
|
output.append("π Explicitness Distribution:") |
|
|
output.append("-" * 30) |
|
|
|
|
|
for label, data in dist.items(): |
|
|
bar_length = int(data['percentage'] / 2) |
|
|
bar = "β" * bar_length |
|
|
output.append(f"{label:18} {data['count']:5,} ({data['percentage']:5.1f}%) {bar}") |
|
|
|
|
|
|
|
|
if 'meta_class_distribution' in results: |
|
|
meta_dist = results['meta_class_distribution'] |
|
|
output.append("") |
|
|
output.append("π·οΈ Meta-Class Distribution:") |
|
|
output.append("-" * 30) |
|
|
|
|
|
|
|
|
meta_order = ['SAFE', 'SEXUAL', 'MATURE', 'EXPLICIT', 'WARNINGS', 'INCONCLUSIVE'] |
|
|
|
|
|
for meta_label in meta_order: |
|
|
if meta_label in meta_dist: |
|
|
data = meta_dist[meta_label] |
|
|
if data['count'] > 0: |
|
|
bar_length = int(data['percentage'] / 2) |
|
|
bar = "β" * bar_length |
|
|
output.append(f"{meta_label:12} {data['count']:5,} ({data['percentage']:5.1f}%) {bar}") |
|
|
|
|
|
|
|
|
if 'fun_stats' in results: |
|
|
output.append("") |
|
|
output.append("π― Fun Stats:") |
|
|
output.append("=" * 50) |
|
|
|
|
|
fun_stats = results['fun_stats'] |
|
|
|
|
|
|
|
|
output.append("\nπ Strongest Examples (Highest Confidence):") |
|
|
output.append("-" * 45) |
|
|
for label, example in fun_stats['strongest_examples'].items(): |
|
|
output.append(f"\n{label} ({example['confidence']:.3f} confidence)") |
|
|
output.append(f" Paragraph #{example['paragraph_number']}: \"{example['text'][:250]}...\"") |
|
|
|
|
|
|
|
|
output.append("\nπ€ Most Borderline Examples (Lowest Confidence per Class):") |
|
|
output.append("-" * 55) |
|
|
for label, example in fun_stats['borderline_examples'].items(): |
|
|
output.append(f"\n{label} ({example['confidence']:.3f} confidence)") |
|
|
output.append(f" Paragraph #{example['paragraph_number']}: \"{example['text'][:250]}...\"") |
|
|
|
|
|
|
|
|
if fun_stats['most_confused']: |
|
|
confused = fun_stats['most_confused'] |
|
|
output.append(f"\nπ€― Most Confused Conclusive Paragraph ({confused['confidence']:.3f} confidence):") |
|
|
output.append("-" * 55) |
|
|
output.append(f"Paragraph #{confused['paragraph_number']}: \"{confused['text'][:250]}...\"") |
|
|
output.append(f"Predicted: {confused['predicted_class']}") |
|
|
|
|
|
|
|
|
output.append("All probabilities:") |
|
|
sorted_probs = sorted(confused['all_probabilities'].items(), |
|
|
key=lambda x: x[1], reverse=True) |
|
|
for label, prob in sorted_probs[:3]: |
|
|
output.append(f" {label}: {prob:.3f}") |
|
|
|
|
|
|
|
|
if fun_stats['most_inconclusive']: |
|
|
output.append(f"\nβ Most Inconclusive Examples:") |
|
|
output.append("-" * 35) |
|
|
for i, inc in enumerate(fun_stats['most_inconclusive']): |
|
|
output.append(f"\n{i+1}. Paragraph #{inc['paragraph_number']} ({inc['confidence']:.3f} confidence)") |
|
|
output.append(f" \"{inc['text'][:250]}...\"") |
|
|
output.append(f" Original prediction: {inc['original_prediction']}") |
|
|
|
|
|
return "\n".join(output) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |