arcspan / scripts /clean_training_data.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Clean training data: fix unlabeled IOCs, downsample ExploitDB, remove noisy labels."""
import json
import random
import re
import sys
from pathlib import Path
random.seed(42)
# --- IOC Regexes ---
IOC_PATTERNS = [
# SHA256 (64 hex chars, word boundary)
re.compile(r'\b[a-fA-F0-9]{64}\b'),
# SHA1 (40 hex chars)
re.compile(r'\b[a-fA-F0-9]{40}\b'),
# MD5 (32 hex chars)
re.compile(r'\b[a-fA-F0-9]{32}\b'),
# IPv4
re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b'),
# URL (http/https/hxxp/hxxps)
re.compile(r'(?:https?|hxxps?)://[^\s<>"\')\]]+'),
# Defanged URL
re.compile(r'hxxps?://[^\s<>"\')\]]+'),
# Domain-like (at least one dot, TLD 2-10 chars, not all digits)
re.compile(r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,10}\b'),
]
# Common non-IOC domains/words to skip
DOMAIN_SKIPLIST = {
'e.g.', 'i.e.', 'et.al.', 'Fig.', 'fig.', 'vs.', 'etc.',
}
# Common TLDs for domain validation
VALID_TLDS = {
'com', 'net', 'org', 'io', 'ru', 'cn', 'de', 'uk', 'fr', 'jp', 'kr',
'info', 'biz', 'xyz', 'top', 'online', 'site', 'club', 'pro',
'gov', 'edu', 'mil', 'int', 'co', 'us', 'ca', 'au', 'in', 'br',
'it', 'es', 'nl', 'se', 'no', 'fi', 'dk', 'pl', 'cz', 'at', 'ch',
'be', 'ie', 'pt', 'gr', 'hu', 'ro', 'bg', 'hr', 'sk', 'si', 'lt',
'lv', 'ee', 'me', 'pw', 'tk', 'ml', 'ga', 'cf', 'gq',
'onion', 'bit', 'cc', 'tv', 'ws', 'la', 'ly', 'su', 'ua', 'kz',
'ddns', 'duckdns', 'no-ip',
}
# Words that look like domains but aren't
FAKE_DOMAINS = {
'Super Mario', 'e.g', 'i.e', 'et al', 'Fig', 'Remote Code',
}
# Bare file extensions that should NOT be Indicator
BARE_EXTENSIONS = re.compile(r'^\.[a-zA-Z]{2,5}$') # .dll, .exe, .pdf, etc.
def get_existing_spans(example):
"""Get list of (start, end) for all existing spans."""
intervals = []
for key, offsets in example.get('spans', {}).items():
for (start, end) in offsets:
intervals.append((start, end))
return intervals
def overlaps_any(start, end, intervals):
"""Check if [start, end) overlaps any existing [s, e) interval."""
for s, e in intervals:
if start < e and end > s:
return True
return False
def is_valid_domain(text, match_str):
"""Check if a regex domain match is actually a plausible domain/IOC."""
# Skip very short matches
if len(match_str) < 5:
return False
# Must have at least one dot
if '.' not in match_str:
return False
# Check TLD
parts = match_str.rstrip('.').split('.')
tld = parts[-1].lower()
if tld not in VALID_TLDS:
return False
# Skip if it looks like a sentence fragment
if match_str in DOMAIN_SKIPLIST:
return False
# Skip common English words that match domain pattern
if all(p.isalpha() and len(p) <= 3 for p in parts):
# e.g. "the.end" — skip short all-alpha
if len(parts) <= 2:
return False
return True
def is_valid_ip(match_str):
"""Validate IP address octets."""
parts = match_str.split('.')
return all(0 <= int(p) <= 255 for p in parts)
def is_valid_hash(match_str, expected_len):
"""Validate hash — must not be all same char, must be hex."""
if len(set(match_str.lower())) < 4:
return False
return True
def find_unlabeled_iocs(text, existing_intervals):
"""Find IOC matches in text that don't overlap existing spans."""
new_spans = []
seen = set() # avoid duplicate spans at same position
for i, pattern in enumerate(IOC_PATTERNS):
for m in pattern.finditer(text):
start, end = m.start(), m.end()
match_str = m.group()
# Skip if overlaps existing span
if overlaps_any(start, end, existing_intervals):
continue
# Skip if we already found a span here (longer patterns checked first)
if (start, end) in seen:
continue
# Validate by type
if i <= 2: # Hash patterns (SHA256, SHA1, MD5)
expected = [64, 40, 32][i]
if len(match_str) != expected:
continue
if not is_valid_hash(match_str, expected):
continue
elif i == 3: # IPv4
if not is_valid_ip(match_str):
continue
elif i in (4, 5): # URLs
pass # URLs are generally valid if matched
elif i == 6: # Domain
if not is_valid_domain(text, match_str):
continue
# Check this new span doesn't overlap another new span we already added
if overlaps_any(start, end, new_spans):
continue
new_spans.append((start, end))
seen.add((start, end))
return new_spans
def add_indicator_spans(example, new_offsets):
"""Add new Indicator spans to an example."""
key = "Indicator: " # spans use "Class: text" as key
for start, end in new_offsets:
span_text = example['text'][start:end]
span_key = f"Indicator: {span_text}"
if span_key not in example['spans']:
example['spans'][span_key] = []
example['spans'][span_key].append([start, end])
return example
def is_exploitdb(example):
"""Check if example is from exploitdb source."""
info = example.get('info', {})
source = info.get('source', '')
if 'exploitdb' in source.lower():
return True
return False
def entity_density(example):
"""Calculate fraction of text covered by entity spans."""
text_len = len(example.get('text', ''))
if text_len == 0:
return 0
covered = set()
for key, offsets in example.get('spans', {}).items():
for (start, end) in offsets:
for i in range(start, end):
covered.add(i)
return len(covered) / text_len
def remove_bare_extension_indicators(example):
"""Remove Indicator spans that are bare file extensions."""
removed = 0
keys_to_remove = []
for key in list(example['spans'].keys()):
if not key.startswith('Indicator:'):
continue
# Extract the span text from the key
span_text = key[len('Indicator:'):].strip()
if BARE_EXTENSIONS.match(span_text):
keys_to_remove.append(key)
removed += len(example['spans'][key])
for key in keys_to_remove:
del example['spans'][key]
return example, removed
def clean_file(input_path, output_path):
"""Clean a single JSONL file."""
stats = {
'total': 0,
'iocs_added': 0,
'examples_with_new_iocs': 0,
'exploitdb_removed': 0,
'exploitdb_kept': 0,
'extension_labels_removed': 0,
'output': 0,
}
# First pass: identify exploitdb examples for downsampling
exploitdb_indices = []
examples = []
with open(input_path) as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
ex = json.loads(line)
examples.append(ex)
if is_exploitdb(ex):
exploitdb_indices.append(i)
stats['total'] = len(examples)
# Downsample exploitdb to 500
keep_exploitdb = set(random.sample(exploitdb_indices, min(500, len(exploitdb_indices))))
stats['exploitdb_removed'] = len(exploitdb_indices) - len(keep_exploitdb)
stats['exploitdb_kept'] = len(keep_exploitdb)
with open(output_path, 'w') as out:
for i, ex in enumerate(examples):
# Skip most exploitdb
if i in exploitdb_indices and i not in keep_exploitdb:
continue
# Fix 1: Add unlabeled IOCs
existing = get_existing_spans(ex)
new_iocs = find_unlabeled_iocs(ex['text'], existing)
if new_iocs:
ex = add_indicator_spans(ex, new_iocs)
stats['iocs_added'] += len(new_iocs)
stats['examples_with_new_iocs'] += 1
# Fix 3: Remove bare extension indicators
ex, ext_removed = remove_bare_extension_indicators(ex)
stats['extension_labels_removed'] += ext_removed
out.write(json.dumps(ex, ensure_ascii=False) + '\n')
stats['output'] += 1
return stats
def main():
base = Path('/home/ubuntu/alkyline/data/processed')
for split in ['train', 'valid']:
inp = base / f'enriched_5class_{split}.jsonl'
outp = base / f'enriched_5class_{split}_cleaned.jsonl'
if not inp.exists():
print(f"SKIP: {inp} not found")
continue
print(f"\n{'='*60}")
print(f"Cleaning: {inp.name} -> {outp.name}")
print(f"{'='*60}")
stats = clean_file(inp, outp)
for k, v in stats.items():
print(f" {k}: {v}")
if __name__ == '__main__':
main()