#!/usr/bin/env python3 """Clean training data: fix unlabeled IOCs, downsample ExploitDB, remove noisy labels.""" import json import random import re import sys from pathlib import Path random.seed(42) # --- IOC Regexes --- IOC_PATTERNS = [ # SHA256 (64 hex chars, word boundary) re.compile(r'\b[a-fA-F0-9]{64}\b'), # SHA1 (40 hex chars) re.compile(r'\b[a-fA-F0-9]{40}\b'), # MD5 (32 hex chars) re.compile(r'\b[a-fA-F0-9]{32}\b'), # IPv4 re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b'), # URL (http/https/hxxp/hxxps) re.compile(r'(?:https?|hxxps?)://[^\s<>"\')\]]+'), # Defanged URL re.compile(r'hxxps?://[^\s<>"\')\]]+'), # Domain-like (at least one dot, TLD 2-10 chars, not all digits) re.compile(r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,10}\b'), ] # Common non-IOC domains/words to skip DOMAIN_SKIPLIST = { 'e.g.', 'i.e.', 'et.al.', 'Fig.', 'fig.', 'vs.', 'etc.', } # Common TLDs for domain validation VALID_TLDS = { 'com', 'net', 'org', 'io', 'ru', 'cn', 'de', 'uk', 'fr', 'jp', 'kr', 'info', 'biz', 'xyz', 'top', 'online', 'site', 'club', 'pro', 'gov', 'edu', 'mil', 'int', 'co', 'us', 'ca', 'au', 'in', 'br', 'it', 'es', 'nl', 'se', 'no', 'fi', 'dk', 'pl', 'cz', 'at', 'ch', 'be', 'ie', 'pt', 'gr', 'hu', 'ro', 'bg', 'hr', 'sk', 'si', 'lt', 'lv', 'ee', 'me', 'pw', 'tk', 'ml', 'ga', 'cf', 'gq', 'onion', 'bit', 'cc', 'tv', 'ws', 'la', 'ly', 'su', 'ua', 'kz', 'ddns', 'duckdns', 'no-ip', } # Words that look like domains but aren't FAKE_DOMAINS = { 'Super Mario', 'e.g', 'i.e', 'et al', 'Fig', 'Remote Code', } # Bare file extensions that should NOT be Indicator BARE_EXTENSIONS = re.compile(r'^\.[a-zA-Z]{2,5}$') # .dll, .exe, .pdf, etc. def get_existing_spans(example): """Get list of (start, end) for all existing spans.""" intervals = [] for key, offsets in example.get('spans', {}).items(): for (start, end) in offsets: intervals.append((start, end)) return intervals def overlaps_any(start, end, intervals): """Check if [start, end) overlaps any existing [s, e) interval.""" for s, e in intervals: if start < e and end > s: return True return False def is_valid_domain(text, match_str): """Check if a regex domain match is actually a plausible domain/IOC.""" # Skip very short matches if len(match_str) < 5: return False # Must have at least one dot if '.' not in match_str: return False # Check TLD parts = match_str.rstrip('.').split('.') tld = parts[-1].lower() if tld not in VALID_TLDS: return False # Skip if it looks like a sentence fragment if match_str in DOMAIN_SKIPLIST: return False # Skip common English words that match domain pattern if all(p.isalpha() and len(p) <= 3 for p in parts): # e.g. "the.end" — skip short all-alpha if len(parts) <= 2: return False return True def is_valid_ip(match_str): """Validate IP address octets.""" parts = match_str.split('.') return all(0 <= int(p) <= 255 for p in parts) def is_valid_hash(match_str, expected_len): """Validate hash — must not be all same char, must be hex.""" if len(set(match_str.lower())) < 4: return False return True def find_unlabeled_iocs(text, existing_intervals): """Find IOC matches in text that don't overlap existing spans.""" new_spans = [] seen = set() # avoid duplicate spans at same position for i, pattern in enumerate(IOC_PATTERNS): for m in pattern.finditer(text): start, end = m.start(), m.end() match_str = m.group() # Skip if overlaps existing span if overlaps_any(start, end, existing_intervals): continue # Skip if we already found a span here (longer patterns checked first) if (start, end) in seen: continue # Validate by type if i <= 2: # Hash patterns (SHA256, SHA1, MD5) expected = [64, 40, 32][i] if len(match_str) != expected: continue if not is_valid_hash(match_str, expected): continue elif i == 3: # IPv4 if not is_valid_ip(match_str): continue elif i in (4, 5): # URLs pass # URLs are generally valid if matched elif i == 6: # Domain if not is_valid_domain(text, match_str): continue # Check this new span doesn't overlap another new span we already added if overlaps_any(start, end, new_spans): continue new_spans.append((start, end)) seen.add((start, end)) return new_spans def add_indicator_spans(example, new_offsets): """Add new Indicator spans to an example.""" key = "Indicator: " # spans use "Class: text" as key for start, end in new_offsets: span_text = example['text'][start:end] span_key = f"Indicator: {span_text}" if span_key not in example['spans']: example['spans'][span_key] = [] example['spans'][span_key].append([start, end]) return example def is_exploitdb(example): """Check if example is from exploitdb source.""" info = example.get('info', {}) source = info.get('source', '') if 'exploitdb' in source.lower(): return True return False def entity_density(example): """Calculate fraction of text covered by entity spans.""" text_len = len(example.get('text', '')) if text_len == 0: return 0 covered = set() for key, offsets in example.get('spans', {}).items(): for (start, end) in offsets: for i in range(start, end): covered.add(i) return len(covered) / text_len def remove_bare_extension_indicators(example): """Remove Indicator spans that are bare file extensions.""" removed = 0 keys_to_remove = [] for key in list(example['spans'].keys()): if not key.startswith('Indicator:'): continue # Extract the span text from the key span_text = key[len('Indicator:'):].strip() if BARE_EXTENSIONS.match(span_text): keys_to_remove.append(key) removed += len(example['spans'][key]) for key in keys_to_remove: del example['spans'][key] return example, removed def clean_file(input_path, output_path): """Clean a single JSONL file.""" stats = { 'total': 0, 'iocs_added': 0, 'examples_with_new_iocs': 0, 'exploitdb_removed': 0, 'exploitdb_kept': 0, 'extension_labels_removed': 0, 'output': 0, } # First pass: identify exploitdb examples for downsampling exploitdb_indices = [] examples = [] with open(input_path) as f: for i, line in enumerate(f): line = line.strip() if not line: continue ex = json.loads(line) examples.append(ex) if is_exploitdb(ex): exploitdb_indices.append(i) stats['total'] = len(examples) # Downsample exploitdb to 500 keep_exploitdb = set(random.sample(exploitdb_indices, min(500, len(exploitdb_indices)))) stats['exploitdb_removed'] = len(exploitdb_indices) - len(keep_exploitdb) stats['exploitdb_kept'] = len(keep_exploitdb) with open(output_path, 'w') as out: for i, ex in enumerate(examples): # Skip most exploitdb if i in exploitdb_indices and i not in keep_exploitdb: continue # Fix 1: Add unlabeled IOCs existing = get_existing_spans(ex) new_iocs = find_unlabeled_iocs(ex['text'], existing) if new_iocs: ex = add_indicator_spans(ex, new_iocs) stats['iocs_added'] += len(new_iocs) stats['examples_with_new_iocs'] += 1 # Fix 3: Remove bare extension indicators ex, ext_removed = remove_bare_extension_indicators(ex) stats['extension_labels_removed'] += ext_removed out.write(json.dumps(ex, ensure_ascii=False) + '\n') stats['output'] += 1 return stats def main(): base = Path('/home/ubuntu/alkyline/data/processed') for split in ['train', 'valid']: inp = base / f'enriched_5class_{split}.jsonl' outp = base / f'enriched_5class_{split}_cleaned.jsonl' if not inp.exists(): print(f"SKIP: {inp} not found") continue print(f"\n{'='*60}") print(f"Cleaning: {inp.name} -> {outp.name}") print(f"{'='*60}") stats = clean_file(inp, outp) for k, v in stats.items(): print(f" {k}: {v}") if __name__ == '__main__': main()