| |
| """ |
| Class balance audit for Arcspan NER datasets. |
| Analyzes entity distribution across fixed/deleaked training sets. |
| """ |
| import json |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| CLASSES = ["Malware", "Indicator", "Organization", "System", "Vulnerability"] |
| FILES = [ |
| "/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned_trimmed.jsonl", |
| "/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned_deleaked.jsonl", |
| "/home/ubuntu/alkyline/data/processed/aptner_5class_train_deleaked.jsonl", |
| "/home/ubuntu/alkyline/data/processed/securebert2_5class_train_deleaked.jsonl", |
| "/home/ubuntu/alkyline/data/processed/defanged_augmented.jsonl", |
| ] |
|
|
| def analyze_file(filepath): |
| """Analyze a single JSONL file.""" |
| if not Path(filepath).exists(): |
| return None |
|
|
| stats = { |
| "total_examples": 0, |
| "class_counts": defaultdict(int), |
| "all_o_examples": 0, |
| "total_spans": 0, |
| } |
|
|
| with open(filepath, 'r') as f: |
| for line in f: |
| try: |
| record = json.loads(line.strip()) |
| stats["total_examples"] += 1 |
|
|
| spans = record.get("spans", {}) |
|
|
| |
| if not spans: |
| stats["all_o_examples"] += 1 |
| else: |
| |
| for label_key in spans: |
| |
| if ": " in label_key: |
| class_name = label_key.split(": ")[0] |
| else: |
| class_name = label_key |
|
|
| |
| offsets = spans[label_key] |
| if isinstance(offsets, list) and len(offsets) > 0: |
| count = len(offsets) if isinstance(offsets[0], list) else 1 |
| stats["class_counts"][class_name] += count |
| stats["total_spans"] += count |
| except json.JSONDecodeError as e: |
| print(f" ⚠ JSON error in {Path(filepath).name}: {e}") |
| continue |
|
|
| return stats |
|
|
| def format_report(filename, stats): |
| """Format stats for a single file.""" |
| if stats is None: |
| return f" ✗ {filename}: NOT FOUND\n" |
|
|
| total = stats["total_examples"] |
| all_o_pct = 100.0 * stats["all_o_examples"] / total if total > 0 else 0 |
|
|
| |
| class_counts = {c: stats["class_counts"].get(c, 0) for c in CLASSES} |
| nonzero_counts = [c for c in class_counts.values() if c > 0] |
|
|
| if len(nonzero_counts) < 2: |
| imbalance_ratio = 1.0 |
| else: |
| imbalance_ratio = max(nonzero_counts) / min(nonzero_counts) |
|
|
| lines = [f" {filename}"] |
| lines.append(f" Examples: {total:,} | All-O: {stats['all_o_examples']:,} ({all_o_pct:.1f}%)") |
| lines.append(f" Total spans: {stats['total_spans']:,} | Imbalance ratio: {imbalance_ratio:.2f}x") |
| for cls in CLASSES: |
| count = class_counts[cls] |
| lines.append(f" {cls}: {count:,}") |
|
|
| return "\n".join(lines) + "\n" |
|
|
| |
| print("=" * 80) |
| print("ARCSPAN NER CLASS BALANCE AUDIT") |
| print(f"Classes: {', '.join(CLASSES)}") |
| print("=" * 80) |
| print() |
|
|
| all_stats = {} |
| combined = { |
| "total_examples": 0, |
| "class_counts": defaultdict(int), |
| "all_o_examples": 0, |
| "total_spans": 0, |
| } |
|
|
| for filepath in FILES: |
| filename = Path(filepath).name |
| stats = analyze_file(filepath) |
| all_stats[filename] = stats |
|
|
| if stats: |
| print(format_report(filename, stats)) |
| combined["total_examples"] += stats["total_examples"] |
| combined["all_o_examples"] += stats["all_o_examples"] |
| combined["total_spans"] += stats["total_spans"] |
| for cls in CLASSES: |
| combined["class_counts"][cls] += stats["class_counts"][cls] |
| else: |
| print(f" ✗ {filename}: NOT FOUND\n") |
|
|
| print("\n" + "=" * 80) |
| print("COMBINED TOTAL (all files)") |
| print("=" * 80) |
| print(format_report("COMBINED", combined)) |
|
|
| |
| combined_class_counts = {c: combined["class_counts"][c] for c in CLASSES} |
| nonzero = [c for c in combined_class_counts.values() if c > 0] |
| if len(nonzero) >= 2: |
| combined_imbalance = max(nonzero) / min(nonzero) |
| print(f" Overall imbalance ratio: {combined_imbalance:.2f}x") |
| print(f" Most common: {max(nonzero):,} | Least common: {min(nonzero):,}") |
|
|