arcspan / scripts /audit_class_balance.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""
Class balance audit for Arcspan NER datasets.
Analyzes entity distribution across fixed/deleaked training sets.
"""
import json
from collections import defaultdict
from pathlib import Path
CLASSES = ["Malware", "Indicator", "Organization", "System", "Vulnerability"]
FILES = [
"/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned_trimmed.jsonl",
"/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned_deleaked.jsonl",
"/home/ubuntu/alkyline/data/processed/aptner_5class_train_deleaked.jsonl",
"/home/ubuntu/alkyline/data/processed/securebert2_5class_train_deleaked.jsonl",
"/home/ubuntu/alkyline/data/processed/defanged_augmented.jsonl",
]
def analyze_file(filepath):
"""Analyze a single JSONL file."""
if not Path(filepath).exists():
return None
stats = {
"total_examples": 0,
"class_counts": defaultdict(int),
"all_o_examples": 0,
"total_spans": 0,
}
with open(filepath, 'r') as f:
for line in f:
try:
record = json.loads(line.strip())
stats["total_examples"] += 1
spans = record.get("spans", {})
# Check if this is an all-O example (no entities)
if not spans:
stats["all_o_examples"] += 1
else:
# Count entities by class
for label_key in spans:
# Parse label key format: "Label: entity_text" or just class name
if ": " in label_key:
class_name = label_key.split(": ")[0]
else:
class_name = label_key
# Count occurrences of this class in this example
offsets = spans[label_key]
if isinstance(offsets, list) and len(offsets) > 0:
count = len(offsets) if isinstance(offsets[0], list) else 1
stats["class_counts"][class_name] += count
stats["total_spans"] += count
except json.JSONDecodeError as e:
print(f" ⚠ JSON error in {Path(filepath).name}: {e}")
continue
return stats
def format_report(filename, stats):
"""Format stats for a single file."""
if stats is None:
return f" ✗ {filename}: NOT FOUND\n"
total = stats["total_examples"]
all_o_pct = 100.0 * stats["all_o_examples"] / total if total > 0 else 0
# Get min/max for imbalance ratio
class_counts = {c: stats["class_counts"].get(c, 0) for c in CLASSES}
nonzero_counts = [c for c in class_counts.values() if c > 0]
if len(nonzero_counts) < 2:
imbalance_ratio = 1.0
else:
imbalance_ratio = max(nonzero_counts) / min(nonzero_counts)
lines = [f" {filename}"]
lines.append(f" Examples: {total:,} | All-O: {stats['all_o_examples']:,} ({all_o_pct:.1f}%)")
lines.append(f" Total spans: {stats['total_spans']:,} | Imbalance ratio: {imbalance_ratio:.2f}x")
for cls in CLASSES:
count = class_counts[cls]
lines.append(f" {cls}: {count:,}")
return "\n".join(lines) + "\n"
# Main
print("=" * 80)
print("ARCSPAN NER CLASS BALANCE AUDIT")
print(f"Classes: {', '.join(CLASSES)}")
print("=" * 80)
print()
all_stats = {}
combined = {
"total_examples": 0,
"class_counts": defaultdict(int),
"all_o_examples": 0,
"total_spans": 0,
}
for filepath in FILES:
filename = Path(filepath).name
stats = analyze_file(filepath)
all_stats[filename] = stats
if stats:
print(format_report(filename, stats))
combined["total_examples"] += stats["total_examples"]
combined["all_o_examples"] += stats["all_o_examples"]
combined["total_spans"] += stats["total_spans"]
for cls in CLASSES:
combined["class_counts"][cls] += stats["class_counts"][cls]
else:
print(f" ✗ {filename}: NOT FOUND\n")
print("\n" + "=" * 80)
print("COMBINED TOTAL (all files)")
print("=" * 80)
print(format_report("COMBINED", combined))
# Class imbalance for combined
combined_class_counts = {c: combined["class_counts"][c] for c in CLASSES}
nonzero = [c for c in combined_class_counts.values() if c > 0]
if len(nonzero) >= 2:
combined_imbalance = max(nonzero) / min(nonzero)
print(f" Overall imbalance ratio: {combined_imbalance:.2f}x")
print(f" Most common: {max(nonzero):,} | Least common: {min(nonzero):,}")