arcspan / scripts /augment_defanged.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Augment training data with defanged indicator variants."""
import json
import random
import re
import sys
from pathlib import Path
random.seed(42)
INPUT = Path("data/processed/enriched_5class_train_cleaned.jsonl")
OUTPUT = Path("data/processed/defanged_augmented.jsonl")
SAMPLE_RATE = 0.30
def classify_indicator(value: str) -> str:
"""Classify an indicator value as ip, domain, url, or other."""
if "://" in value or value.startswith("hxxp"):
return "url"
# IP: all digits and dots, at least 3 dots
if all(c in "0123456789.[]() " for c in value) and value.count(".") >= 3:
return "ip"
if "." in value and not all(c in "0123456789abcdef" for c in value):
return "domain"
return "other"
def defang_ip(text: str) -> str:
"""Defang an IP address with random style."""
style = random.choice(["bracket", "bracket_space"])
if style == "bracket":
return text.replace(".", "[.]")
else:
return text.replace(".", " [ . ] ")
def defang_domain(text: str) -> str:
"""Defang a domain with random style."""
style = random.choice(["bracket", "bracket_space"])
if style == "bracket":
return text.replace(".", "[.]")
else:
return text.replace(".", "[ . ]")
def defang_url(text: str) -> str:
"""Defang a URL: protocol + domain part."""
result = text.replace("https://", "hxxps://").replace("http://", "hxxp://")
# Also defang dots in the domain portion (before first /)
proto_end = result.find("://")
if proto_end >= 0:
after_proto = proto_end + 3
slash_pos = result.find("/", after_proto)
if slash_pos == -1:
slash_pos = len(result)
domain_part = result[after_proto:slash_pos]
style = random.choice(["bracket", "bracket_space"])
if style == "bracket":
domain_defanged = domain_part.replace(".", "[.]")
else:
domain_defanged = domain_part.replace(".", "[ . ]")
result = result[:after_proto] + domain_defanged + result[slash_pos:]
return result
def defang_span_text(text: str, indicator_value: str) -> str | None:
"""Defang the text of a span. Returns new text or None if not defangable."""
itype = classify_indicator(indicator_value)
if itype == "ip":
return defang_ip(text)
elif itype == "domain":
return defang_domain(text)
elif itype == "url":
return defang_url(text)
return None
def augment_example(example: dict) -> dict | None:
"""Create a defanged copy of an example. Returns None if nothing to defang."""
text = example["text"]
spans = example["spans"]
# Collect all indicator spans with positions, sorted by start offset
indicator_spans = []
for label, positions in spans.items():
if label.startswith("Indicator:"):
indicator_value = label.split(": ", 1)[1]
for start, end in positions:
indicator_spans.append((start, end, label, indicator_value))
if not indicator_spans:
return None
# Sort by start position
indicator_spans.sort(key=lambda x: x[0])
# Try to defang each indicator span, track replacements
replacements = [] # (old_start, old_end, new_text)
for start, end, label, indicator_value in indicator_spans:
old_text = text[start:end]
new_text = defang_span_text(old_text, indicator_value)
if new_text and new_text != old_text:
replacements.append((start, end, new_text))
if not replacements:
return None
# Build new text and offset mapping
# Process replacements from end to start to not mess up offsets
# But we need a forward pass to compute cumulative offset shifts
# Compute cumulative offset adjustments
# For each position in original text, compute how much it shifts
shifts = [] # (original_pos, delta) - at original_pos, cumulative delta changes
cumulative = 0
for old_start, old_end, new_text in replacements:
old_len = old_end - old_start
new_len = len(new_text)
delta = new_len - old_len
shifts.append((old_start, old_end, delta, cumulative))
cumulative += delta
# Build new text
new_text_parts = []
prev_end = 0
for old_start, old_end, new_text in replacements:
new_text_parts.append(text[prev_end:old_start])
new_text_parts.append(new_text)
prev_end = old_end
new_text_parts.append(text[prev_end:])
new_full_text = "".join(new_text_parts)
# Adjust all span offsets
def adjust_offset(pos: int) -> int:
"""Adjust an original offset to account for replacements."""
cum = 0
for old_start, old_end, new_text in replacements:
old_len = old_end - old_start
new_len = len(new_text)
delta = new_len - old_len
if pos <= old_start:
break
elif pos >= old_end:
cum += delta
else:
# pos is inside a replacement - scale proportionally
# This handles the span endpoints that ARE the replacement
frac = (pos - old_start) / old_len
cum += int(frac * delta)
break
return pos + cum
new_spans = {}
for label, positions in spans.items():
new_positions = []
for start, end in positions:
new_start = adjust_offset(start)
new_end = adjust_offset(end)
new_positions.append([new_start, new_end])
new_spans[label] = new_positions
new_example = {
"text": new_full_text,
"spans": new_spans,
"info": {**example.get("info", {}), "source": "defanged_augment"},
}
return new_example
def main():
examples = []
with open(INPUT) as f:
for line in f:
examples.append(json.loads(line))
# Find examples with Indicator spans
indicator_examples = [
ex for ex in examples
if any(k.startswith("Indicator:") for k in ex["spans"])
]
print(f"Total examples: {len(examples)}")
print(f"Examples with Indicator spans: {len(indicator_examples)}")
# Sample 30%
sampled = random.sample(indicator_examples, int(len(indicator_examples) * SAMPLE_RATE))
print(f"Sampled for augmentation: {len(sampled)}")
augmented = []
defanged_count = 0
for ex in sampled:
result = augment_example(ex)
if result:
augmented.append(result)
# Count defanged indicators
for label in result["spans"]:
if label.startswith("Indicator:"):
defanged_count += len(result["spans"][label])
print(f"Successfully augmented: {len(augmented)}")
print(f"Total indicator spans in augmented data: {defanged_count}")
with open(OUTPUT, "w") as f:
for ex in augmented:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"Written to {OUTPUT}")
# Verify a few examples
print("\n=== Sample verification ===")
for ex in augmented[:3]:
print(f"\nText: {ex['text'][:120]}...")
for label, positions in ex["spans"].items():
if label.startswith("Indicator:"):
for s, e in positions:
print(f" {label}: '{ex['text'][s:e]}' [{s}:{e}]")
if __name__ == "__main__":
main()