File size: 8,973 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#!/usr/bin/env python3
"""Clean training data: fix unlabeled IOCs, downsample ExploitDB, remove noisy labels."""

import json
import random
import re
import sys
from pathlib import Path

random.seed(42)

# --- IOC Regexes ---
IOC_PATTERNS = [
    # SHA256 (64 hex chars, word boundary)
    re.compile(r'\b[a-fA-F0-9]{64}\b'),
    # SHA1 (40 hex chars)
    re.compile(r'\b[a-fA-F0-9]{40}\b'),
    # MD5 (32 hex chars)
    re.compile(r'\b[a-fA-F0-9]{32}\b'),
    # IPv4
    re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b'),
    # URL (http/https/hxxp/hxxps)
    re.compile(r'(?:https?|hxxps?)://[^\s<>"\')\]]+'),
    # Defanged URL
    re.compile(r'hxxps?://[^\s<>"\')\]]+'),
    # Domain-like (at least one dot, TLD 2-10 chars, not all digits)
    re.compile(r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,10}\b'),
]

# Common non-IOC domains/words to skip
DOMAIN_SKIPLIST = {
    'e.g.', 'i.e.', 'et.al.', 'Fig.', 'fig.', 'vs.', 'etc.',
}

# Common TLDs for domain validation
VALID_TLDS = {
    'com', 'net', 'org', 'io', 'ru', 'cn', 'de', 'uk', 'fr', 'jp', 'kr',
    'info', 'biz', 'xyz', 'top', 'online', 'site', 'club', 'pro',
    'gov', 'edu', 'mil', 'int', 'co', 'us', 'ca', 'au', 'in', 'br',
    'it', 'es', 'nl', 'se', 'no', 'fi', 'dk', 'pl', 'cz', 'at', 'ch',
    'be', 'ie', 'pt', 'gr', 'hu', 'ro', 'bg', 'hr', 'sk', 'si', 'lt',
    'lv', 'ee', 'me', 'pw', 'tk', 'ml', 'ga', 'cf', 'gq',
    'onion', 'bit', 'cc', 'tv', 'ws', 'la', 'ly', 'su', 'ua', 'kz',
    'ddns', 'duckdns', 'no-ip',
}

# Words that look like domains but aren't
FAKE_DOMAINS = {
    'Super Mario', 'e.g', 'i.e', 'et al', 'Fig', 'Remote Code',
}

# Bare file extensions that should NOT be Indicator
BARE_EXTENSIONS = re.compile(r'^\.[a-zA-Z]{2,5}$')  # .dll, .exe, .pdf, etc.


def get_existing_spans(example):
    """Get list of (start, end) for all existing spans."""
    intervals = []
    for key, offsets in example.get('spans', {}).items():
        for (start, end) in offsets:
            intervals.append((start, end))
    return intervals


def overlaps_any(start, end, intervals):
    """Check if [start, end) overlaps any existing [s, e) interval."""
    for s, e in intervals:
        if start < e and end > s:
            return True
    return False


def is_valid_domain(text, match_str):
    """Check if a regex domain match is actually a plausible domain/IOC."""
    # Skip very short matches
    if len(match_str) < 5:
        return False
    # Must have at least one dot
    if '.' not in match_str:
        return False
    # Check TLD
    parts = match_str.rstrip('.').split('.')
    tld = parts[-1].lower()
    if tld not in VALID_TLDS:
        return False
    # Skip if it looks like a sentence fragment
    if match_str in DOMAIN_SKIPLIST:
        return False
    # Skip common English words that match domain pattern
    if all(p.isalpha() and len(p) <= 3 for p in parts):
        # e.g. "the.end" — skip short all-alpha
        if len(parts) <= 2:
            return False
    return True


def is_valid_ip(match_str):
    """Validate IP address octets."""
    parts = match_str.split('.')
    return all(0 <= int(p) <= 255 for p in parts)


def is_valid_hash(match_str, expected_len):
    """Validate hash — must not be all same char, must be hex."""
    if len(set(match_str.lower())) < 4:
        return False
    return True


def find_unlabeled_iocs(text, existing_intervals):
    """Find IOC matches in text that don't overlap existing spans."""
    new_spans = []
    seen = set()  # avoid duplicate spans at same position

    for i, pattern in enumerate(IOC_PATTERNS):
        for m in pattern.finditer(text):
            start, end = m.start(), m.end()
            match_str = m.group()

            # Skip if overlaps existing span
            if overlaps_any(start, end, existing_intervals):
                continue

            # Skip if we already found a span here (longer patterns checked first)
            if (start, end) in seen:
                continue

            # Validate by type
            if i <= 2:  # Hash patterns (SHA256, SHA1, MD5)
                expected = [64, 40, 32][i]
                if len(match_str) != expected:
                    continue
                if not is_valid_hash(match_str, expected):
                    continue
            elif i == 3:  # IPv4
                if not is_valid_ip(match_str):
                    continue
            elif i in (4, 5):  # URLs
                pass  # URLs are generally valid if matched
            elif i == 6:  # Domain
                if not is_valid_domain(text, match_str):
                    continue

            # Check this new span doesn't overlap another new span we already added
            if overlaps_any(start, end, new_spans):
                continue

            new_spans.append((start, end))
            seen.add((start, end))

    return new_spans


def add_indicator_spans(example, new_offsets):
    """Add new Indicator spans to an example."""
    key = "Indicator: "  # spans use "Class: text" as key
    for start, end in new_offsets:
        span_text = example['text'][start:end]
        span_key = f"Indicator: {span_text}"
        if span_key not in example['spans']:
            example['spans'][span_key] = []
        example['spans'][span_key].append([start, end])
    return example


def is_exploitdb(example):
    """Check if example is from exploitdb source."""
    info = example.get('info', {})
    source = info.get('source', '')
    if 'exploitdb' in source.lower():
        return True
    return False


def entity_density(example):
    """Calculate fraction of text covered by entity spans."""
    text_len = len(example.get('text', ''))
    if text_len == 0:
        return 0
    covered = set()
    for key, offsets in example.get('spans', {}).items():
        for (start, end) in offsets:
            for i in range(start, end):
                covered.add(i)
    return len(covered) / text_len


def remove_bare_extension_indicators(example):
    """Remove Indicator spans that are bare file extensions."""
    removed = 0
    keys_to_remove = []
    for key in list(example['spans'].keys()):
        if not key.startswith('Indicator:'):
            continue
        # Extract the span text from the key
        span_text = key[len('Indicator:'):].strip()
        if BARE_EXTENSIONS.match(span_text):
            keys_to_remove.append(key)
            removed += len(example['spans'][key])

    for key in keys_to_remove:
        del example['spans'][key]
    return example, removed


def clean_file(input_path, output_path):
    """Clean a single JSONL file."""
    stats = {
        'total': 0,
        'iocs_added': 0,
        'examples_with_new_iocs': 0,
        'exploitdb_removed': 0,
        'exploitdb_kept': 0,
        'extension_labels_removed': 0,
        'output': 0,
    }

    # First pass: identify exploitdb examples for downsampling
    exploitdb_indices = []
    examples = []
    with open(input_path) as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            ex = json.loads(line)
            examples.append(ex)
            if is_exploitdb(ex):
                exploitdb_indices.append(i)

    stats['total'] = len(examples)

    # Downsample exploitdb to 500
    keep_exploitdb = set(random.sample(exploitdb_indices, min(500, len(exploitdb_indices))))
    stats['exploitdb_removed'] = len(exploitdb_indices) - len(keep_exploitdb)
    stats['exploitdb_kept'] = len(keep_exploitdb)

    with open(output_path, 'w') as out:
        for i, ex in enumerate(examples):
            # Skip most exploitdb
            if i in exploitdb_indices and i not in keep_exploitdb:
                continue

            # Fix 1: Add unlabeled IOCs
            existing = get_existing_spans(ex)
            new_iocs = find_unlabeled_iocs(ex['text'], existing)
            if new_iocs:
                ex = add_indicator_spans(ex, new_iocs)
                stats['iocs_added'] += len(new_iocs)
                stats['examples_with_new_iocs'] += 1

            # Fix 3: Remove bare extension indicators
            ex, ext_removed = remove_bare_extension_indicators(ex)
            stats['extension_labels_removed'] += ext_removed

            out.write(json.dumps(ex, ensure_ascii=False) + '\n')
            stats['output'] += 1

    return stats


def main():
    base = Path('/home/ubuntu/alkyline/data/processed')

    for split in ['train', 'valid']:
        inp = base / f'enriched_5class_{split}.jsonl'
        outp = base / f'enriched_5class_{split}_cleaned.jsonl'
        if not inp.exists():
            print(f"SKIP: {inp} not found")
            continue
        print(f"\n{'='*60}")
        print(f"Cleaning: {inp.name} -> {outp.name}")
        print(f"{'='*60}")
        stats = clean_file(inp, outp)
        for k, v in stats.items():
            print(f"  {k}: {v}")


if __name__ == '__main__':
    main()