| |
| """Augment training data with defanged indicator variants.""" |
|
|
| import json |
| import random |
| import re |
| import sys |
| from pathlib import Path |
|
|
| random.seed(42) |
|
|
| INPUT = Path("data/processed/enriched_5class_train_cleaned.jsonl") |
| OUTPUT = Path("data/processed/defanged_augmented.jsonl") |
| SAMPLE_RATE = 0.30 |
|
|
|
|
| def classify_indicator(value: str) -> str: |
| """Classify an indicator value as ip, domain, url, or other.""" |
| if "://" in value or value.startswith("hxxp"): |
| return "url" |
| |
| if all(c in "0123456789.[]() " for c in value) and value.count(".") >= 3: |
| return "ip" |
| if "." in value and not all(c in "0123456789abcdef" for c in value): |
| return "domain" |
| return "other" |
|
|
|
|
| def defang_ip(text: str) -> str: |
| """Defang an IP address with random style.""" |
| style = random.choice(["bracket", "bracket_space"]) |
| if style == "bracket": |
| return text.replace(".", "[.]") |
| else: |
| return text.replace(".", " [ . ] ") |
|
|
|
|
| def defang_domain(text: str) -> str: |
| """Defang a domain with random style.""" |
| style = random.choice(["bracket", "bracket_space"]) |
| if style == "bracket": |
| return text.replace(".", "[.]") |
| else: |
| return text.replace(".", "[ . ]") |
|
|
|
|
| def defang_url(text: str) -> str: |
| """Defang a URL: protocol + domain part.""" |
| result = text.replace("https://", "hxxps://").replace("http://", "hxxp://") |
| |
| proto_end = result.find("://") |
| if proto_end >= 0: |
| after_proto = proto_end + 3 |
| slash_pos = result.find("/", after_proto) |
| if slash_pos == -1: |
| slash_pos = len(result) |
| domain_part = result[after_proto:slash_pos] |
| style = random.choice(["bracket", "bracket_space"]) |
| if style == "bracket": |
| domain_defanged = domain_part.replace(".", "[.]") |
| else: |
| domain_defanged = domain_part.replace(".", "[ . ]") |
| result = result[:after_proto] + domain_defanged + result[slash_pos:] |
| return result |
|
|
|
|
| def defang_span_text(text: str, indicator_value: str) -> str | None: |
| """Defang the text of a span. Returns new text or None if not defangable.""" |
| itype = classify_indicator(indicator_value) |
| if itype == "ip": |
| return defang_ip(text) |
| elif itype == "domain": |
| return defang_domain(text) |
| elif itype == "url": |
| return defang_url(text) |
| return None |
|
|
|
|
| def augment_example(example: dict) -> dict | None: |
| """Create a defanged copy of an example. Returns None if nothing to defang.""" |
| text = example["text"] |
| spans = example["spans"] |
|
|
| |
| indicator_spans = [] |
| for label, positions in spans.items(): |
| if label.startswith("Indicator:"): |
| indicator_value = label.split(": ", 1)[1] |
| for start, end in positions: |
| indicator_spans.append((start, end, label, indicator_value)) |
|
|
| if not indicator_spans: |
| return None |
|
|
| |
| indicator_spans.sort(key=lambda x: x[0]) |
|
|
| |
| replacements = [] |
| for start, end, label, indicator_value in indicator_spans: |
| old_text = text[start:end] |
| new_text = defang_span_text(old_text, indicator_value) |
| if new_text and new_text != old_text: |
| replacements.append((start, end, new_text)) |
|
|
| if not replacements: |
| return None |
|
|
| |
| |
| |
|
|
| |
| |
| shifts = [] |
| cumulative = 0 |
| for old_start, old_end, new_text in replacements: |
| old_len = old_end - old_start |
| new_len = len(new_text) |
| delta = new_len - old_len |
| shifts.append((old_start, old_end, delta, cumulative)) |
| cumulative += delta |
|
|
| |
| new_text_parts = [] |
| prev_end = 0 |
| for old_start, old_end, new_text in replacements: |
| new_text_parts.append(text[prev_end:old_start]) |
| new_text_parts.append(new_text) |
| prev_end = old_end |
| new_text_parts.append(text[prev_end:]) |
| new_full_text = "".join(new_text_parts) |
|
|
| |
| def adjust_offset(pos: int) -> int: |
| """Adjust an original offset to account for replacements.""" |
| cum = 0 |
| for old_start, old_end, new_text in replacements: |
| old_len = old_end - old_start |
| new_len = len(new_text) |
| delta = new_len - old_len |
| if pos <= old_start: |
| break |
| elif pos >= old_end: |
| cum += delta |
| else: |
| |
| |
| frac = (pos - old_start) / old_len |
| cum += int(frac * delta) |
| break |
| return pos + cum |
|
|
| new_spans = {} |
| for label, positions in spans.items(): |
| new_positions = [] |
| for start, end in positions: |
| new_start = adjust_offset(start) |
| new_end = adjust_offset(end) |
| new_positions.append([new_start, new_end]) |
| new_spans[label] = new_positions |
|
|
| new_example = { |
| "text": new_full_text, |
| "spans": new_spans, |
| "info": {**example.get("info", {}), "source": "defanged_augment"}, |
| } |
| return new_example |
|
|
|
|
| def main(): |
| examples = [] |
| with open(INPUT) as f: |
| for line in f: |
| examples.append(json.loads(line)) |
|
|
| |
| indicator_examples = [ |
| ex for ex in examples |
| if any(k.startswith("Indicator:") for k in ex["spans"]) |
| ] |
|
|
| print(f"Total examples: {len(examples)}") |
| print(f"Examples with Indicator spans: {len(indicator_examples)}") |
|
|
| |
| sampled = random.sample(indicator_examples, int(len(indicator_examples) * SAMPLE_RATE)) |
| print(f"Sampled for augmentation: {len(sampled)}") |
|
|
| augmented = [] |
| defanged_count = 0 |
| for ex in sampled: |
| result = augment_example(ex) |
| if result: |
| augmented.append(result) |
| |
| for label in result["spans"]: |
| if label.startswith("Indicator:"): |
| defanged_count += len(result["spans"][label]) |
|
|
| print(f"Successfully augmented: {len(augmented)}") |
| print(f"Total indicator spans in augmented data: {defanged_count}") |
|
|
| with open(OUTPUT, "w") as f: |
| for ex in augmented: |
| f.write(json.dumps(ex, ensure_ascii=False) + "\n") |
|
|
| print(f"Written to {OUTPUT}") |
|
|
| |
| print("\n=== Sample verification ===") |
| for ex in augmented[:3]: |
| print(f"\nText: {ex['text'][:120]}...") |
| for label, positions in ex["spans"].items(): |
| if label.startswith("Indicator:"): |
| for s, e in positions: |
| print(f" {label}: '{ex['text'][s:e]}' [{s}:{e}]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|