File size: 8,973 Bytes
3dac39e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | #!/usr/bin/env python3
"""Clean training data: fix unlabeled IOCs, downsample ExploitDB, remove noisy labels."""
import json
import random
import re
import sys
from pathlib import Path
random.seed(42)
# --- IOC Regexes ---
IOC_PATTERNS = [
# SHA256 (64 hex chars, word boundary)
re.compile(r'\b[a-fA-F0-9]{64}\b'),
# SHA1 (40 hex chars)
re.compile(r'\b[a-fA-F0-9]{40}\b'),
# MD5 (32 hex chars)
re.compile(r'\b[a-fA-F0-9]{32}\b'),
# IPv4
re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b'),
# URL (http/https/hxxp/hxxps)
re.compile(r'(?:https?|hxxps?)://[^\s<>"\')\]]+'),
# Defanged URL
re.compile(r'hxxps?://[^\s<>"\')\]]+'),
# Domain-like (at least one dot, TLD 2-10 chars, not all digits)
re.compile(r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,10}\b'),
]
# Common non-IOC domains/words to skip
DOMAIN_SKIPLIST = {
'e.g.', 'i.e.', 'et.al.', 'Fig.', 'fig.', 'vs.', 'etc.',
}
# Common TLDs for domain validation
VALID_TLDS = {
'com', 'net', 'org', 'io', 'ru', 'cn', 'de', 'uk', 'fr', 'jp', 'kr',
'info', 'biz', 'xyz', 'top', 'online', 'site', 'club', 'pro',
'gov', 'edu', 'mil', 'int', 'co', 'us', 'ca', 'au', 'in', 'br',
'it', 'es', 'nl', 'se', 'no', 'fi', 'dk', 'pl', 'cz', 'at', 'ch',
'be', 'ie', 'pt', 'gr', 'hu', 'ro', 'bg', 'hr', 'sk', 'si', 'lt',
'lv', 'ee', 'me', 'pw', 'tk', 'ml', 'ga', 'cf', 'gq',
'onion', 'bit', 'cc', 'tv', 'ws', 'la', 'ly', 'su', 'ua', 'kz',
'ddns', 'duckdns', 'no-ip',
}
# Words that look like domains but aren't
FAKE_DOMAINS = {
'Super Mario', 'e.g', 'i.e', 'et al', 'Fig', 'Remote Code',
}
# Bare file extensions that should NOT be Indicator
BARE_EXTENSIONS = re.compile(r'^\.[a-zA-Z]{2,5}$') # .dll, .exe, .pdf, etc.
def get_existing_spans(example):
"""Get list of (start, end) for all existing spans."""
intervals = []
for key, offsets in example.get('spans', {}).items():
for (start, end) in offsets:
intervals.append((start, end))
return intervals
def overlaps_any(start, end, intervals):
"""Check if [start, end) overlaps any existing [s, e) interval."""
for s, e in intervals:
if start < e and end > s:
return True
return False
def is_valid_domain(text, match_str):
"""Check if a regex domain match is actually a plausible domain/IOC."""
# Skip very short matches
if len(match_str) < 5:
return False
# Must have at least one dot
if '.' not in match_str:
return False
# Check TLD
parts = match_str.rstrip('.').split('.')
tld = parts[-1].lower()
if tld not in VALID_TLDS:
return False
# Skip if it looks like a sentence fragment
if match_str in DOMAIN_SKIPLIST:
return False
# Skip common English words that match domain pattern
if all(p.isalpha() and len(p) <= 3 for p in parts):
# e.g. "the.end" — skip short all-alpha
if len(parts) <= 2:
return False
return True
def is_valid_ip(match_str):
"""Validate IP address octets."""
parts = match_str.split('.')
return all(0 <= int(p) <= 255 for p in parts)
def is_valid_hash(match_str, expected_len):
"""Validate hash — must not be all same char, must be hex."""
if len(set(match_str.lower())) < 4:
return False
return True
def find_unlabeled_iocs(text, existing_intervals):
"""Find IOC matches in text that don't overlap existing spans."""
new_spans = []
seen = set() # avoid duplicate spans at same position
for i, pattern in enumerate(IOC_PATTERNS):
for m in pattern.finditer(text):
start, end = m.start(), m.end()
match_str = m.group()
# Skip if overlaps existing span
if overlaps_any(start, end, existing_intervals):
continue
# Skip if we already found a span here (longer patterns checked first)
if (start, end) in seen:
continue
# Validate by type
if i <= 2: # Hash patterns (SHA256, SHA1, MD5)
expected = [64, 40, 32][i]
if len(match_str) != expected:
continue
if not is_valid_hash(match_str, expected):
continue
elif i == 3: # IPv4
if not is_valid_ip(match_str):
continue
elif i in (4, 5): # URLs
pass # URLs are generally valid if matched
elif i == 6: # Domain
if not is_valid_domain(text, match_str):
continue
# Check this new span doesn't overlap another new span we already added
if overlaps_any(start, end, new_spans):
continue
new_spans.append((start, end))
seen.add((start, end))
return new_spans
def add_indicator_spans(example, new_offsets):
"""Add new Indicator spans to an example."""
key = "Indicator: " # spans use "Class: text" as key
for start, end in new_offsets:
span_text = example['text'][start:end]
span_key = f"Indicator: {span_text}"
if span_key not in example['spans']:
example['spans'][span_key] = []
example['spans'][span_key].append([start, end])
return example
def is_exploitdb(example):
"""Check if example is from exploitdb source."""
info = example.get('info', {})
source = info.get('source', '')
if 'exploitdb' in source.lower():
return True
return False
def entity_density(example):
"""Calculate fraction of text covered by entity spans."""
text_len = len(example.get('text', ''))
if text_len == 0:
return 0
covered = set()
for key, offsets in example.get('spans', {}).items():
for (start, end) in offsets:
for i in range(start, end):
covered.add(i)
return len(covered) / text_len
def remove_bare_extension_indicators(example):
"""Remove Indicator spans that are bare file extensions."""
removed = 0
keys_to_remove = []
for key in list(example['spans'].keys()):
if not key.startswith('Indicator:'):
continue
# Extract the span text from the key
span_text = key[len('Indicator:'):].strip()
if BARE_EXTENSIONS.match(span_text):
keys_to_remove.append(key)
removed += len(example['spans'][key])
for key in keys_to_remove:
del example['spans'][key]
return example, removed
def clean_file(input_path, output_path):
"""Clean a single JSONL file."""
stats = {
'total': 0,
'iocs_added': 0,
'examples_with_new_iocs': 0,
'exploitdb_removed': 0,
'exploitdb_kept': 0,
'extension_labels_removed': 0,
'output': 0,
}
# First pass: identify exploitdb examples for downsampling
exploitdb_indices = []
examples = []
with open(input_path) as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
ex = json.loads(line)
examples.append(ex)
if is_exploitdb(ex):
exploitdb_indices.append(i)
stats['total'] = len(examples)
# Downsample exploitdb to 500
keep_exploitdb = set(random.sample(exploitdb_indices, min(500, len(exploitdb_indices))))
stats['exploitdb_removed'] = len(exploitdb_indices) - len(keep_exploitdb)
stats['exploitdb_kept'] = len(keep_exploitdb)
with open(output_path, 'w') as out:
for i, ex in enumerate(examples):
# Skip most exploitdb
if i in exploitdb_indices and i not in keep_exploitdb:
continue
# Fix 1: Add unlabeled IOCs
existing = get_existing_spans(ex)
new_iocs = find_unlabeled_iocs(ex['text'], existing)
if new_iocs:
ex = add_indicator_spans(ex, new_iocs)
stats['iocs_added'] += len(new_iocs)
stats['examples_with_new_iocs'] += 1
# Fix 3: Remove bare extension indicators
ex, ext_removed = remove_bare_extension_indicators(ex)
stats['extension_labels_removed'] += ext_removed
out.write(json.dumps(ex, ensure_ascii=False) + '\n')
stats['output'] += 1
return stats
def main():
base = Path('/home/ubuntu/alkyline/data/processed')
for split in ['train', 'valid']:
inp = base / f'enriched_5class_{split}.jsonl'
outp = base / f'enriched_5class_{split}_cleaned.jsonl'
if not inp.exists():
print(f"SKIP: {inp} not found")
continue
print(f"\n{'='*60}")
print(f"Cleaning: {inp.name} -> {outp.name}")
print(f"{'='*60}")
stats = clean_file(inp, outp)
for k, v in stats.items():
print(f" {k}: {v}")
if __name__ == '__main__':
main()
|