arcspan / scripts /fix_all_o_ratio.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Downsample all-O (no entity) sentences to ~10% of training/valid data."""
import json
import random
from collections import Counter
from pathlib import Path
SEED = 42
TARGET_RATIO = 0.10 # 10% all-O
def has_entities(example: dict) -> bool:
spans = example.get("spans", {})
return any(len(v) > 0 for v in spans.values())
def process_file(input_path: Path, output_path: Path):
print(f"\n{'='*60}")
print(f"Processing: {input_path.name}")
print(f"{'='*60}")
with open(input_path) as f:
data = [json.loads(line) for line in f]
entity_examples = [ex for ex in data if has_entities(ex)]
all_o_examples = [ex for ex in data if not has_entities(ex)]
n_entity = len(entity_examples)
n_all_o = len(all_o_examples)
orig_ratio = n_all_o / len(data)
print(f"Original: {len(data)} total, {n_entity} entity, {n_all_o} all-O ({orig_ratio:.1%})")
# Source distribution of all-O
source_counts = Counter(ex.get("info", {}).get("source", "unknown") for ex in all_o_examples)
print(f"\nAll-O by source:")
for src, cnt in source_counts.most_common():
print(f" {src}: {cnt}")
# Target: N_entity / 9 all-O examples → 10% of total
target_all_o = round(n_entity / 9)
print(f"\nTarget all-O count: {target_all_o}")
if target_all_o >= n_all_o:
print("Already at or below target — no downsampling needed.")
kept = all_o_examples
else:
# Proportional downsampling by source
rng = random.Random(SEED)
kept = []
for src, cnt in source_counts.items():
proportion = cnt / n_all_o
n_keep = max(1, round(target_all_o * proportion))
src_examples = [ex for ex in all_o_examples if ex.get("info", {}).get("source", "unknown") == src]
rng.shuffle(src_examples)
kept.extend(src_examples[:n_keep])
# Trim or pad to exact target
if len(kept) > target_all_o:
rng.shuffle(kept)
kept = kept[:target_all_o]
# Combine and preserve original order by re-scanning
kept_ids = {ex.get("info", {}).get("id") for ex in kept}
result = [ex for ex in data if has_entities(ex) or ex.get("info", {}).get("id") in kept_ids]
final_all_o = sum(1 for ex in result if not has_entities(ex))
final_ratio = final_all_o / len(result)
print(f"\nResult: {len(result)} total, {len(result) - final_all_o} entity, {final_all_o} all-O ({final_ratio:.1%})")
print(f"Removed: {n_all_o - final_all_o} all-O examples")
with open(output_path, "w") as f:
for ex in result:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"Written: {output_path}")
if __name__ == "__main__":
base = Path("/home/ubuntu/alkyline/data/processed")
process_file(
base / "enriched_5class_train_cleaned.jsonl",
base / "enriched_5class_train_cleaned_trimmed.jsonl",
)
process_file(
base / "enriched_5class_valid_cleaned.jsonl",
base / "enriched_5class_valid_cleaned_trimmed.jsonl",
)