#!/usr/bin/env python3 """Downsample all-O (no entity) sentences to ~10% of training/valid data.""" import json import random from collections import Counter from pathlib import Path SEED = 42 TARGET_RATIO = 0.10 # 10% all-O def has_entities(example: dict) -> bool: spans = example.get("spans", {}) return any(len(v) > 0 for v in spans.values()) def process_file(input_path: Path, output_path: Path): print(f"\n{'='*60}") print(f"Processing: {input_path.name}") print(f"{'='*60}") with open(input_path) as f: data = [json.loads(line) for line in f] entity_examples = [ex for ex in data if has_entities(ex)] all_o_examples = [ex for ex in data if not has_entities(ex)] n_entity = len(entity_examples) n_all_o = len(all_o_examples) orig_ratio = n_all_o / len(data) print(f"Original: {len(data)} total, {n_entity} entity, {n_all_o} all-O ({orig_ratio:.1%})") # Source distribution of all-O source_counts = Counter(ex.get("info", {}).get("source", "unknown") for ex in all_o_examples) print(f"\nAll-O by source:") for src, cnt in source_counts.most_common(): print(f" {src}: {cnt}") # Target: N_entity / 9 all-O examples → 10% of total target_all_o = round(n_entity / 9) print(f"\nTarget all-O count: {target_all_o}") if target_all_o >= n_all_o: print("Already at or below target — no downsampling needed.") kept = all_o_examples else: # Proportional downsampling by source rng = random.Random(SEED) kept = [] for src, cnt in source_counts.items(): proportion = cnt / n_all_o n_keep = max(1, round(target_all_o * proportion)) src_examples = [ex for ex in all_o_examples if ex.get("info", {}).get("source", "unknown") == src] rng.shuffle(src_examples) kept.extend(src_examples[:n_keep]) # Trim or pad to exact target if len(kept) > target_all_o: rng.shuffle(kept) kept = kept[:target_all_o] # Combine and preserve original order by re-scanning kept_ids = {ex.get("info", {}).get("id") for ex in kept} result = [ex for ex in data if has_entities(ex) or ex.get("info", {}).get("id") in kept_ids] final_all_o = sum(1 for ex in result if not has_entities(ex)) final_ratio = final_all_o / len(result) print(f"\nResult: {len(result)} total, {len(result) - final_all_o} entity, {final_all_o} all-O ({final_ratio:.1%})") print(f"Removed: {n_all_o - final_all_o} all-O examples") with open(output_path, "w") as f: for ex in result: f.write(json.dumps(ex, ensure_ascii=False) + "\n") print(f"Written: {output_path}") if __name__ == "__main__": base = Path("/home/ubuntu/alkyline/data/processed") process_file( base / "enriched_5class_train_cleaned.jsonl", base / "enriched_5class_train_cleaned_trimmed.jsonl", ) process_file( base / "enriched_5class_valid_cleaned.jsonl", base / "enriched_5class_valid_cleaned_trimmed.jsonl", )