| |
| """Downsample all-O (no entity) sentences to ~10% of training/valid data.""" |
|
|
| import json |
| import random |
| from collections import Counter |
| from pathlib import Path |
|
|
| SEED = 42 |
| TARGET_RATIO = 0.10 |
|
|
|
|
| def has_entities(example: dict) -> bool: |
| spans = example.get("spans", {}) |
| return any(len(v) > 0 for v in spans.values()) |
|
|
|
|
| def process_file(input_path: Path, output_path: Path): |
| print(f"\n{'='*60}") |
| print(f"Processing: {input_path.name}") |
| print(f"{'='*60}") |
|
|
| with open(input_path) as f: |
| data = [json.loads(line) for line in f] |
|
|
| entity_examples = [ex for ex in data if has_entities(ex)] |
| all_o_examples = [ex for ex in data if not has_entities(ex)] |
|
|
| n_entity = len(entity_examples) |
| n_all_o = len(all_o_examples) |
| orig_ratio = n_all_o / len(data) |
|
|
| print(f"Original: {len(data)} total, {n_entity} entity, {n_all_o} all-O ({orig_ratio:.1%})") |
|
|
| |
| source_counts = Counter(ex.get("info", {}).get("source", "unknown") for ex in all_o_examples) |
| print(f"\nAll-O by source:") |
| for src, cnt in source_counts.most_common(): |
| print(f" {src}: {cnt}") |
|
|
| |
| target_all_o = round(n_entity / 9) |
| print(f"\nTarget all-O count: {target_all_o}") |
|
|
| if target_all_o >= n_all_o: |
| print("Already at or below target — no downsampling needed.") |
| kept = all_o_examples |
| else: |
| |
| rng = random.Random(SEED) |
| kept = [] |
| for src, cnt in source_counts.items(): |
| proportion = cnt / n_all_o |
| n_keep = max(1, round(target_all_o * proportion)) |
| src_examples = [ex for ex in all_o_examples if ex.get("info", {}).get("source", "unknown") == src] |
| rng.shuffle(src_examples) |
| kept.extend(src_examples[:n_keep]) |
|
|
| |
| if len(kept) > target_all_o: |
| rng.shuffle(kept) |
| kept = kept[:target_all_o] |
|
|
| |
| kept_ids = {ex.get("info", {}).get("id") for ex in kept} |
| result = [ex for ex in data if has_entities(ex) or ex.get("info", {}).get("id") in kept_ids] |
|
|
| final_all_o = sum(1 for ex in result if not has_entities(ex)) |
| final_ratio = final_all_o / len(result) |
|
|
| print(f"\nResult: {len(result)} total, {len(result) - final_all_o} entity, {final_all_o} all-O ({final_ratio:.1%})") |
| print(f"Removed: {n_all_o - final_all_o} all-O examples") |
|
|
| with open(output_path, "w") as f: |
| for ex in result: |
| f.write(json.dumps(ex, ensure_ascii=False) + "\n") |
|
|
| print(f"Written: {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| base = Path("/home/ubuntu/alkyline/data/processed") |
|
|
| process_file( |
| base / "enriched_5class_train_cleaned.jsonl", |
| base / "enriched_5class_train_cleaned_trimmed.jsonl", |
| ) |
| process_file( |
| base / "enriched_5class_valid_cleaned.jsonl", |
| base / "enriched_5class_valid_cleaned_trimmed.jsonl", |
| ) |
|
|