#!/usr/bin/env python3 """ Class balance audit for Arcspan NER datasets. Analyzes entity distribution across fixed/deleaked training sets. """ import json from collections import defaultdict from pathlib import Path CLASSES = ["Malware", "Indicator", "Organization", "System", "Vulnerability"] FILES = [ "/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned_trimmed.jsonl", "/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned_deleaked.jsonl", "/home/ubuntu/alkyline/data/processed/aptner_5class_train_deleaked.jsonl", "/home/ubuntu/alkyline/data/processed/securebert2_5class_train_deleaked.jsonl", "/home/ubuntu/alkyline/data/processed/defanged_augmented.jsonl", ] def analyze_file(filepath): """Analyze a single JSONL file.""" if not Path(filepath).exists(): return None stats = { "total_examples": 0, "class_counts": defaultdict(int), "all_o_examples": 0, "total_spans": 0, } with open(filepath, 'r') as f: for line in f: try: record = json.loads(line.strip()) stats["total_examples"] += 1 spans = record.get("spans", {}) # Check if this is an all-O example (no entities) if not spans: stats["all_o_examples"] += 1 else: # Count entities by class for label_key in spans: # Parse label key format: "Label: entity_text" or just class name if ": " in label_key: class_name = label_key.split(": ")[0] else: class_name = label_key # Count occurrences of this class in this example offsets = spans[label_key] if isinstance(offsets, list) and len(offsets) > 0: count = len(offsets) if isinstance(offsets[0], list) else 1 stats["class_counts"][class_name] += count stats["total_spans"] += count except json.JSONDecodeError as e: print(f" ⚠ JSON error in {Path(filepath).name}: {e}") continue return stats def format_report(filename, stats): """Format stats for a single file.""" if stats is None: return f" ✗ {filename}: NOT FOUND\n" total = stats["total_examples"] all_o_pct = 100.0 * stats["all_o_examples"] / total if total > 0 else 0 # Get min/max for imbalance ratio class_counts = {c: stats["class_counts"].get(c, 0) for c in CLASSES} nonzero_counts = [c for c in class_counts.values() if c > 0] if len(nonzero_counts) < 2: imbalance_ratio = 1.0 else: imbalance_ratio = max(nonzero_counts) / min(nonzero_counts) lines = [f" {filename}"] lines.append(f" Examples: {total:,} | All-O: {stats['all_o_examples']:,} ({all_o_pct:.1f}%)") lines.append(f" Total spans: {stats['total_spans']:,} | Imbalance ratio: {imbalance_ratio:.2f}x") for cls in CLASSES: count = class_counts[cls] lines.append(f" {cls}: {count:,}") return "\n".join(lines) + "\n" # Main print("=" * 80) print("ARCSPAN NER CLASS BALANCE AUDIT") print(f"Classes: {', '.join(CLASSES)}") print("=" * 80) print() all_stats = {} combined = { "total_examples": 0, "class_counts": defaultdict(int), "all_o_examples": 0, "total_spans": 0, } for filepath in FILES: filename = Path(filepath).name stats = analyze_file(filepath) all_stats[filename] = stats if stats: print(format_report(filename, stats)) combined["total_examples"] += stats["total_examples"] combined["all_o_examples"] += stats["all_o_examples"] combined["total_spans"] += stats["total_spans"] for cls in CLASSES: combined["class_counts"][cls] += stats["class_counts"][cls] else: print(f" ✗ {filename}: NOT FOUND\n") print("\n" + "=" * 80) print("COMBINED TOTAL (all files)") print("=" * 80) print(format_report("COMBINED", combined)) # Class imbalance for combined combined_class_counts = {c: combined["class_counts"][c] for c in CLASSES} nonzero = [c for c in combined_class_counts.values() if c > 0] if len(nonzero) >= 2: combined_imbalance = max(nonzero) / min(nonzero) print(f" Overall imbalance ratio: {combined_imbalance:.2f}x") print(f" Most common: {max(nonzero):,} | Least common: {min(nonzero):,}")