arcspan / scripts /build_r9_dataset.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Build R9 merged training dataset from R8 + CyberNER harmonized + DNRTI deleaked.
R9 keeps R8 validation fixed, so train sources must be filtered against
validation and benchmark tests before train-only deduplication.
"""
import json, shutil
from collections import Counter
SOURCES = {
"r8_train": "/home/ubuntu/alkyline/data/processed/r8_5class_train.jsonl",
"cyberner": "/home/ubuntu/alkyline/data/processed/cyberner_harmonized_5class_deleaked.jsonl",
"dnrti": "/home/ubuntu/alkyline/data/processed/dnrti_5class_deleaked.jsonl",
}
VALID_IN = "/home/ubuntu/alkyline/data/processed/r8_5class_valid.jsonl"
TRAIN_OUT = "/home/ubuntu/alkyline/data/processed/r9_5class_train.jsonl"
VALID_OUT = "/home/ubuntu/alkyline/data/processed/r9_5class_valid.jsonl"
EXCLUDE_FILES = {
"r8_valid": VALID_IN,
"enriched_test": "/home/ubuntu/alkyline/data/processed/enriched_5class_test.jsonl",
"cyner_test": "/home/ubuntu/alkyline/data/processed/cyner_test.jsonl",
"securebert2_test": "/home/ubuntu/alkyline/data/processed/securebert2_5class_test.jsonl",
"aptner_test": "/home/ubuntu/alkyline/data/processed/aptner_5class_test_clean.jsonl",
}
def load_lines(path):
with open(path) as f:
return [line for line in f if line.strip()]
def load_text_sets(paths):
exact = set()
prefix = set()
counts = {}
for name, path in paths.items():
lines = load_lines(path)
counts[name] = len(lines)
for line in lines:
text = json.loads(line)["text"]
exact.add(text)
prefix.add(text[:80])
return exact, prefix, counts
def sanitize_spans(record):
"""Drop malformed/zero-length spans from one OPF-format record."""
spans = record.get("spans") or {}
if not isinstance(spans, dict):
return record, 0
text_len = len(record.get("text", ""))
cleaned = {}
removed = 0
for key, offsets in spans.items():
kept_offsets = []
for offset in offsets:
if not isinstance(offset, (list, tuple)) or len(offset) != 2:
removed += 1
continue
start, end = int(offset[0]), int(offset[1])
if 0 <= start < end <= text_len:
kept_offsets.append([start, end])
else:
removed += 1
if kept_offsets:
cleaned[key] = kept_offsets
if removed:
record = dict(record)
record["spans"] = cleaned
return record, removed
# Load all sources
all_examples = []
source_counts = {}
for name, path in SOURCES.items():
lines = load_lines(path)
source_counts[name] = len(lines)
for line in lines:
all_examples.append((name, line))
print("=== Per-source counts (before dedup) ===")
for name, count in source_counts.items():
print(f" {name}: {count}")
print(f" TOTAL: {len(all_examples)}")
# Exclude validation/test overlaps before train-only dedup.
exclude_exact, exclude_prefix, exclude_counts = load_text_sets(EXCLUDE_FILES)
print("\n=== Exclusion sets ===")
for name, count in exclude_counts.items():
print(f" {name}: {count}")
print(f" exact text keys: {len(exclude_exact)}")
print(f" prefix-80 keys: {len(exclude_prefix)}")
filtered_examples = []
excluded_exact = Counter()
excluded_prefix = Counter()
for name, line in all_examples:
text = json.loads(line)["text"]
if text in exclude_exact:
excluded_exact[name] += 1
elif text[:80] in exclude_prefix:
excluded_prefix[name] += 1
else:
filtered_examples.append((name, line))
print("\n=== Excluded against validation/test ===")
for name in SOURCES:
print(
f" {name}: exact={excluded_exact[name]} "
f"prefix={excluded_prefix[name]} kept_for_dedup={sum(1 for n, _ in filtered_examples if n == name)}"
)
print(f" Total excluded: {sum(excluded_exact.values()) + sum(excluded_prefix.values())}")
# Dedup remaining train records by prefix-80.
seen = set()
kept = []
dedup_removed = Counter()
for name, line in filtered_examples:
text = json.loads(line)["text"]
key = text[:80]
if key in seen:
dedup_removed[name] += 1
else:
seen.add(key)
kept.append((name, line))
print(f"\n=== Dedup stats ===")
for name in SOURCES:
print(f" {name}: removed {dedup_removed.get(name, 0)}")
print(f" Total removed: {sum(dedup_removed.values())}")
print(f" Total kept: {len(kept)}")
# Write train
entity_counts = Counter()
source_kept = Counter()
invalid_spans_removed = Counter()
with open(TRAIN_OUT, "w") as f:
for name, line in kept:
obj = json.loads(line)
info = obj.get("info")
if not isinstance(info, dict):
info = {}
info.setdefault("source", name)
obj["info"] = info
obj, removed = sanitize_spans(obj)
invalid_spans_removed[name] += removed
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
source_kept[name] += 1
spans = obj.get("spans", {})
if isinstance(spans, dict):
# Format: {"Label: text": [[start, end], ...], ...}
for key in spans:
label = key.split(":")[0].strip()
entity_counts[label] += len(spans[key])
elif isinstance(spans, list):
# Format: [{"start": int, "end": int, "label": str}, ...]
for span in spans:
entity_counts[span["label"]] += 1
print(f"\n=== Per-source counts (after dedup) ===")
for name in SOURCES:
print(f" {name}: {source_kept[name]}")
print(f"\n=== Invalid spans removed from train output ===")
for name in SOURCES:
print(f" {name}: {invalid_spans_removed[name]}")
print(f"\n=== R9 Train: {len(kept)} examples ===")
print(f"Entity counts per class:")
for label, count in sorted(entity_counts.items()):
print(f" {label}: {count}")
print(f" TOTAL entities: {sum(entity_counts.values())}")
# Copy validation
shutil.copy2(VALID_IN, VALID_OUT)
with open(VALID_IN) as f:
valid_count = sum(1 for _ in f)
print(f"\n=== R9 Valid: {valid_count} examples (copied from R8) ===")
print(f"\nOutput: {TRAIN_OUT}")
print(f" {VALID_OUT}")