File size: 7,425 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#!/usr/bin/env python3
"""Augment training data with defanged indicator variants."""

import json
import random
import re
import sys
from pathlib import Path

random.seed(42)

INPUT = Path("data/processed/enriched_5class_train_cleaned.jsonl")
OUTPUT = Path("data/processed/defanged_augmented.jsonl")
SAMPLE_RATE = 0.30


def classify_indicator(value: str) -> str:
    """Classify an indicator value as ip, domain, url, or other."""
    if "://" in value or value.startswith("hxxp"):
        return "url"
    # IP: all digits and dots, at least 3 dots
    if all(c in "0123456789.[]() " for c in value) and value.count(".") >= 3:
        return "ip"
    if "." in value and not all(c in "0123456789abcdef" for c in value):
        return "domain"
    return "other"


def defang_ip(text: str) -> str:
    """Defang an IP address with random style."""
    style = random.choice(["bracket", "bracket_space"])
    if style == "bracket":
        return text.replace(".", "[.]")
    else:
        return text.replace(".", " [ . ] ")


def defang_domain(text: str) -> str:
    """Defang a domain with random style."""
    style = random.choice(["bracket", "bracket_space"])
    if style == "bracket":
        return text.replace(".", "[.]")
    else:
        return text.replace(".", "[ . ]")


def defang_url(text: str) -> str:
    """Defang a URL: protocol + domain part."""
    result = text.replace("https://", "hxxps://").replace("http://", "hxxp://")
    # Also defang dots in the domain portion (before first /)
    proto_end = result.find("://")
    if proto_end >= 0:
        after_proto = proto_end + 3
        slash_pos = result.find("/", after_proto)
        if slash_pos == -1:
            slash_pos = len(result)
        domain_part = result[after_proto:slash_pos]
        style = random.choice(["bracket", "bracket_space"])
        if style == "bracket":
            domain_defanged = domain_part.replace(".", "[.]")
        else:
            domain_defanged = domain_part.replace(".", "[ . ]")
        result = result[:after_proto] + domain_defanged + result[slash_pos:]
    return result


def defang_span_text(text: str, indicator_value: str) -> str | None:
    """Defang the text of a span. Returns new text or None if not defangable."""
    itype = classify_indicator(indicator_value)
    if itype == "ip":
        return defang_ip(text)
    elif itype == "domain":
        return defang_domain(text)
    elif itype == "url":
        return defang_url(text)
    return None


def augment_example(example: dict) -> dict | None:
    """Create a defanged copy of an example. Returns None if nothing to defang."""
    text = example["text"]
    spans = example["spans"]

    # Collect all indicator spans with positions, sorted by start offset
    indicator_spans = []
    for label, positions in spans.items():
        if label.startswith("Indicator:"):
            indicator_value = label.split(": ", 1)[1]
            for start, end in positions:
                indicator_spans.append((start, end, label, indicator_value))

    if not indicator_spans:
        return None

    # Sort by start position
    indicator_spans.sort(key=lambda x: x[0])

    # Try to defang each indicator span, track replacements
    replacements = []  # (old_start, old_end, new_text)
    for start, end, label, indicator_value in indicator_spans:
        old_text = text[start:end]
        new_text = defang_span_text(old_text, indicator_value)
        if new_text and new_text != old_text:
            replacements.append((start, end, new_text))

    if not replacements:
        return None

    # Build new text and offset mapping
    # Process replacements from end to start to not mess up offsets
    # But we need a forward pass to compute cumulative offset shifts

    # Compute cumulative offset adjustments
    # For each position in original text, compute how much it shifts
    shifts = []  # (original_pos, delta) - at original_pos, cumulative delta changes
    cumulative = 0
    for old_start, old_end, new_text in replacements:
        old_len = old_end - old_start
        new_len = len(new_text)
        delta = new_len - old_len
        shifts.append((old_start, old_end, delta, cumulative))
        cumulative += delta

    # Build new text
    new_text_parts = []
    prev_end = 0
    for old_start, old_end, new_text in replacements:
        new_text_parts.append(text[prev_end:old_start])
        new_text_parts.append(new_text)
        prev_end = old_end
    new_text_parts.append(text[prev_end:])
    new_full_text = "".join(new_text_parts)

    # Adjust all span offsets
    def adjust_offset(pos: int) -> int:
        """Adjust an original offset to account for replacements."""
        cum = 0
        for old_start, old_end, new_text in replacements:
            old_len = old_end - old_start
            new_len = len(new_text)
            delta = new_len - old_len
            if pos <= old_start:
                break
            elif pos >= old_end:
                cum += delta
            else:
                # pos is inside a replacement - scale proportionally
                # This handles the span endpoints that ARE the replacement
                frac = (pos - old_start) / old_len
                cum += int(frac * delta)
                break
        return pos + cum

    new_spans = {}
    for label, positions in spans.items():
        new_positions = []
        for start, end in positions:
            new_start = adjust_offset(start)
            new_end = adjust_offset(end)
            new_positions.append([new_start, new_end])
        new_spans[label] = new_positions

    new_example = {
        "text": new_full_text,
        "spans": new_spans,
        "info": {**example.get("info", {}), "source": "defanged_augment"},
    }
    return new_example


def main():
    examples = []
    with open(INPUT) as f:
        for line in f:
            examples.append(json.loads(line))

    # Find examples with Indicator spans
    indicator_examples = [
        ex for ex in examples
        if any(k.startswith("Indicator:") for k in ex["spans"])
    ]

    print(f"Total examples: {len(examples)}")
    print(f"Examples with Indicator spans: {len(indicator_examples)}")

    # Sample 30%
    sampled = random.sample(indicator_examples, int(len(indicator_examples) * SAMPLE_RATE))
    print(f"Sampled for augmentation: {len(sampled)}")

    augmented = []
    defanged_count = 0
    for ex in sampled:
        result = augment_example(ex)
        if result:
            augmented.append(result)
            # Count defanged indicators
            for label in result["spans"]:
                if label.startswith("Indicator:"):
                    defanged_count += len(result["spans"][label])

    print(f"Successfully augmented: {len(augmented)}")
    print(f"Total indicator spans in augmented data: {defanged_count}")

    with open(OUTPUT, "w") as f:
        for ex in augmented:
            f.write(json.dumps(ex, ensure_ascii=False) + "\n")

    print(f"Written to {OUTPUT}")

    # Verify a few examples
    print("\n=== Sample verification ===")
    for ex in augmented[:3]:
        print(f"\nText: {ex['text'][:120]}...")
        for label, positions in ex["spans"].items():
            if label.startswith("Indicator:"):
                for s, e in positions:
                    print(f"  {label}: '{ex['text'][s:e]}' [{s}:{e}]")


if __name__ == "__main__":
    main()