#!/usr/bin/env python3 """Build R9 merged training dataset from R8 + CyberNER harmonized + DNRTI deleaked. R9 keeps R8 validation fixed, so train sources must be filtered against validation and benchmark tests before train-only deduplication. """ import json, shutil from collections import Counter SOURCES = { "r8_train": "/home/ubuntu/alkyline/data/processed/r8_5class_train.jsonl", "cyberner": "/home/ubuntu/alkyline/data/processed/cyberner_harmonized_5class_deleaked.jsonl", "dnrti": "/home/ubuntu/alkyline/data/processed/dnrti_5class_deleaked.jsonl", } VALID_IN = "/home/ubuntu/alkyline/data/processed/r8_5class_valid.jsonl" TRAIN_OUT = "/home/ubuntu/alkyline/data/processed/r9_5class_train.jsonl" VALID_OUT = "/home/ubuntu/alkyline/data/processed/r9_5class_valid.jsonl" EXCLUDE_FILES = { "r8_valid": VALID_IN, "enriched_test": "/home/ubuntu/alkyline/data/processed/enriched_5class_test.jsonl", "cyner_test": "/home/ubuntu/alkyline/data/processed/cyner_test.jsonl", "securebert2_test": "/home/ubuntu/alkyline/data/processed/securebert2_5class_test.jsonl", "aptner_test": "/home/ubuntu/alkyline/data/processed/aptner_5class_test_clean.jsonl", } def load_lines(path): with open(path) as f: return [line for line in f if line.strip()] def load_text_sets(paths): exact = set() prefix = set() counts = {} for name, path in paths.items(): lines = load_lines(path) counts[name] = len(lines) for line in lines: text = json.loads(line)["text"] exact.add(text) prefix.add(text[:80]) return exact, prefix, counts def sanitize_spans(record): """Drop malformed/zero-length spans from one OPF-format record.""" spans = record.get("spans") or {} if not isinstance(spans, dict): return record, 0 text_len = len(record.get("text", "")) cleaned = {} removed = 0 for key, offsets in spans.items(): kept_offsets = [] for offset in offsets: if not isinstance(offset, (list, tuple)) or len(offset) != 2: removed += 1 continue start, end = int(offset[0]), int(offset[1]) if 0 <= start < end <= text_len: kept_offsets.append([start, end]) else: removed += 1 if kept_offsets: cleaned[key] = kept_offsets if removed: record = dict(record) record["spans"] = cleaned return record, removed # Load all sources all_examples = [] source_counts = {} for name, path in SOURCES.items(): lines = load_lines(path) source_counts[name] = len(lines) for line in lines: all_examples.append((name, line)) print("=== Per-source counts (before dedup) ===") for name, count in source_counts.items(): print(f" {name}: {count}") print(f" TOTAL: {len(all_examples)}") # Exclude validation/test overlaps before train-only dedup. exclude_exact, exclude_prefix, exclude_counts = load_text_sets(EXCLUDE_FILES) print("\n=== Exclusion sets ===") for name, count in exclude_counts.items(): print(f" {name}: {count}") print(f" exact text keys: {len(exclude_exact)}") print(f" prefix-80 keys: {len(exclude_prefix)}") filtered_examples = [] excluded_exact = Counter() excluded_prefix = Counter() for name, line in all_examples: text = json.loads(line)["text"] if text in exclude_exact: excluded_exact[name] += 1 elif text[:80] in exclude_prefix: excluded_prefix[name] += 1 else: filtered_examples.append((name, line)) print("\n=== Excluded against validation/test ===") for name in SOURCES: print( f" {name}: exact={excluded_exact[name]} " f"prefix={excluded_prefix[name]} kept_for_dedup={sum(1 for n, _ in filtered_examples if n == name)}" ) print(f" Total excluded: {sum(excluded_exact.values()) + sum(excluded_prefix.values())}") # Dedup remaining train records by prefix-80. seen = set() kept = [] dedup_removed = Counter() for name, line in filtered_examples: text = json.loads(line)["text"] key = text[:80] if key in seen: dedup_removed[name] += 1 else: seen.add(key) kept.append((name, line)) print(f"\n=== Dedup stats ===") for name in SOURCES: print(f" {name}: removed {dedup_removed.get(name, 0)}") print(f" Total removed: {sum(dedup_removed.values())}") print(f" Total kept: {len(kept)}") # Write train entity_counts = Counter() source_kept = Counter() invalid_spans_removed = Counter() with open(TRAIN_OUT, "w") as f: for name, line in kept: obj = json.loads(line) info = obj.get("info") if not isinstance(info, dict): info = {} info.setdefault("source", name) obj["info"] = info obj, removed = sanitize_spans(obj) invalid_spans_removed[name] += removed f.write(json.dumps(obj, ensure_ascii=False) + "\n") source_kept[name] += 1 spans = obj.get("spans", {}) if isinstance(spans, dict): # Format: {"Label: text": [[start, end], ...], ...} for key in spans: label = key.split(":")[0].strip() entity_counts[label] += len(spans[key]) elif isinstance(spans, list): # Format: [{"start": int, "end": int, "label": str}, ...] for span in spans: entity_counts[span["label"]] += 1 print(f"\n=== Per-source counts (after dedup) ===") for name in SOURCES: print(f" {name}: {source_kept[name]}") print(f"\n=== Invalid spans removed from train output ===") for name in SOURCES: print(f" {name}: {invalid_spans_removed[name]}") print(f"\n=== R9 Train: {len(kept)} examples ===") print(f"Entity counts per class:") for label, count in sorted(entity_counts.items()): print(f" {label}: {count}") print(f" TOTAL entities: {sum(entity_counts.values())}") # Copy validation shutil.copy2(VALID_IN, VALID_OUT) with open(VALID_IN) as f: valid_count = sum(1 for _ in f) print(f"\n=== R9 Valid: {valid_count} examples (copied from R8) ===") print(f"\nOutput: {TRAIN_OUT}") print(f" {VALID_OUT}")