#!/usr/bin/env python3 """Augment training data with defanged indicator variants.""" import json import random import re import sys from pathlib import Path random.seed(42) INPUT = Path("data/processed/enriched_5class_train_cleaned.jsonl") OUTPUT = Path("data/processed/defanged_augmented.jsonl") SAMPLE_RATE = 0.30 def classify_indicator(value: str) -> str: """Classify an indicator value as ip, domain, url, or other.""" if "://" in value or value.startswith("hxxp"): return "url" # IP: all digits and dots, at least 3 dots if all(c in "0123456789.[]() " for c in value) and value.count(".") >= 3: return "ip" if "." in value and not all(c in "0123456789abcdef" for c in value): return "domain" return "other" def defang_ip(text: str) -> str: """Defang an IP address with random style.""" style = random.choice(["bracket", "bracket_space"]) if style == "bracket": return text.replace(".", "[.]") else: return text.replace(".", " [ . ] ") def defang_domain(text: str) -> str: """Defang a domain with random style.""" style = random.choice(["bracket", "bracket_space"]) if style == "bracket": return text.replace(".", "[.]") else: return text.replace(".", "[ . ]") def defang_url(text: str) -> str: """Defang a URL: protocol + domain part.""" result = text.replace("https://", "hxxps://").replace("http://", "hxxp://") # Also defang dots in the domain portion (before first /) proto_end = result.find("://") if proto_end >= 0: after_proto = proto_end + 3 slash_pos = result.find("/", after_proto) if slash_pos == -1: slash_pos = len(result) domain_part = result[after_proto:slash_pos] style = random.choice(["bracket", "bracket_space"]) if style == "bracket": domain_defanged = domain_part.replace(".", "[.]") else: domain_defanged = domain_part.replace(".", "[ . ]") result = result[:after_proto] + domain_defanged + result[slash_pos:] return result def defang_span_text(text: str, indicator_value: str) -> str | None: """Defang the text of a span. Returns new text or None if not defangable.""" itype = classify_indicator(indicator_value) if itype == "ip": return defang_ip(text) elif itype == "domain": return defang_domain(text) elif itype == "url": return defang_url(text) return None def augment_example(example: dict) -> dict | None: """Create a defanged copy of an example. Returns None if nothing to defang.""" text = example["text"] spans = example["spans"] # Collect all indicator spans with positions, sorted by start offset indicator_spans = [] for label, positions in spans.items(): if label.startswith("Indicator:"): indicator_value = label.split(": ", 1)[1] for start, end in positions: indicator_spans.append((start, end, label, indicator_value)) if not indicator_spans: return None # Sort by start position indicator_spans.sort(key=lambda x: x[0]) # Try to defang each indicator span, track replacements replacements = [] # (old_start, old_end, new_text) for start, end, label, indicator_value in indicator_spans: old_text = text[start:end] new_text = defang_span_text(old_text, indicator_value) if new_text and new_text != old_text: replacements.append((start, end, new_text)) if not replacements: return None # Build new text and offset mapping # Process replacements from end to start to not mess up offsets # But we need a forward pass to compute cumulative offset shifts # Compute cumulative offset adjustments # For each position in original text, compute how much it shifts shifts = [] # (original_pos, delta) - at original_pos, cumulative delta changes cumulative = 0 for old_start, old_end, new_text in replacements: old_len = old_end - old_start new_len = len(new_text) delta = new_len - old_len shifts.append((old_start, old_end, delta, cumulative)) cumulative += delta # Build new text new_text_parts = [] prev_end = 0 for old_start, old_end, new_text in replacements: new_text_parts.append(text[prev_end:old_start]) new_text_parts.append(new_text) prev_end = old_end new_text_parts.append(text[prev_end:]) new_full_text = "".join(new_text_parts) # Adjust all span offsets def adjust_offset(pos: int) -> int: """Adjust an original offset to account for replacements.""" cum = 0 for old_start, old_end, new_text in replacements: old_len = old_end - old_start new_len = len(new_text) delta = new_len - old_len if pos <= old_start: break elif pos >= old_end: cum += delta else: # pos is inside a replacement - scale proportionally # This handles the span endpoints that ARE the replacement frac = (pos - old_start) / old_len cum += int(frac * delta) break return pos + cum new_spans = {} for label, positions in spans.items(): new_positions = [] for start, end in positions: new_start = adjust_offset(start) new_end = adjust_offset(end) new_positions.append([new_start, new_end]) new_spans[label] = new_positions new_example = { "text": new_full_text, "spans": new_spans, "info": {**example.get("info", {}), "source": "defanged_augment"}, } return new_example def main(): examples = [] with open(INPUT) as f: for line in f: examples.append(json.loads(line)) # Find examples with Indicator spans indicator_examples = [ ex for ex in examples if any(k.startswith("Indicator:") for k in ex["spans"]) ] print(f"Total examples: {len(examples)}") print(f"Examples with Indicator spans: {len(indicator_examples)}") # Sample 30% sampled = random.sample(indicator_examples, int(len(indicator_examples) * SAMPLE_RATE)) print(f"Sampled for augmentation: {len(sampled)}") augmented = [] defanged_count = 0 for ex in sampled: result = augment_example(ex) if result: augmented.append(result) # Count defanged indicators for label in result["spans"]: if label.startswith("Indicator:"): defanged_count += len(result["spans"][label]) print(f"Successfully augmented: {len(augmented)}") print(f"Total indicator spans in augmented data: {defanged_count}") with open(OUTPUT, "w") as f: for ex in augmented: f.write(json.dumps(ex, ensure_ascii=False) + "\n") print(f"Written to {OUTPUT}") # Verify a few examples print("\n=== Sample verification ===") for ex in augmented[:3]: print(f"\nText: {ex['text'][:120]}...") for label, positions in ex["spans"].items(): if label.startswith("Indicator:"): for s, e in positions: print(f" {label}: '{ex['text'][s:e]}' [{s}:{e}]") if __name__ == "__main__": main()