| |
| """Clean training data: fix unlabeled IOCs, downsample ExploitDB, remove noisy labels.""" |
|
|
| import json |
| import random |
| import re |
| import sys |
| from pathlib import Path |
|
|
| random.seed(42) |
|
|
| |
| IOC_PATTERNS = [ |
| |
| re.compile(r'\b[a-fA-F0-9]{64}\b'), |
| |
| re.compile(r'\b[a-fA-F0-9]{40}\b'), |
| |
| re.compile(r'\b[a-fA-F0-9]{32}\b'), |
| |
| re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b'), |
| |
| re.compile(r'(?:https?|hxxps?)://[^\s<>"\')\]]+'), |
| |
| re.compile(r'hxxps?://[^\s<>"\')\]]+'), |
| |
| re.compile(r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,10}\b'), |
| ] |
|
|
| |
| DOMAIN_SKIPLIST = { |
| 'e.g.', 'i.e.', 'et.al.', 'Fig.', 'fig.', 'vs.', 'etc.', |
| } |
|
|
| |
| VALID_TLDS = { |
| 'com', 'net', 'org', 'io', 'ru', 'cn', 'de', 'uk', 'fr', 'jp', 'kr', |
| 'info', 'biz', 'xyz', 'top', 'online', 'site', 'club', 'pro', |
| 'gov', 'edu', 'mil', 'int', 'co', 'us', 'ca', 'au', 'in', 'br', |
| 'it', 'es', 'nl', 'se', 'no', 'fi', 'dk', 'pl', 'cz', 'at', 'ch', |
| 'be', 'ie', 'pt', 'gr', 'hu', 'ro', 'bg', 'hr', 'sk', 'si', 'lt', |
| 'lv', 'ee', 'me', 'pw', 'tk', 'ml', 'ga', 'cf', 'gq', |
| 'onion', 'bit', 'cc', 'tv', 'ws', 'la', 'ly', 'su', 'ua', 'kz', |
| 'ddns', 'duckdns', 'no-ip', |
| } |
|
|
| |
| FAKE_DOMAINS = { |
| 'Super Mario', 'e.g', 'i.e', 'et al', 'Fig', 'Remote Code', |
| } |
|
|
| |
| BARE_EXTENSIONS = re.compile(r'^\.[a-zA-Z]{2,5}$') |
|
|
|
|
| def get_existing_spans(example): |
| """Get list of (start, end) for all existing spans.""" |
| intervals = [] |
| for key, offsets in example.get('spans', {}).items(): |
| for (start, end) in offsets: |
| intervals.append((start, end)) |
| return intervals |
|
|
|
|
| def overlaps_any(start, end, intervals): |
| """Check if [start, end) overlaps any existing [s, e) interval.""" |
| for s, e in intervals: |
| if start < e and end > s: |
| return True |
| return False |
|
|
|
|
| def is_valid_domain(text, match_str): |
| """Check if a regex domain match is actually a plausible domain/IOC.""" |
| |
| if len(match_str) < 5: |
| return False |
| |
| if '.' not in match_str: |
| return False |
| |
| parts = match_str.rstrip('.').split('.') |
| tld = parts[-1].lower() |
| if tld not in VALID_TLDS: |
| return False |
| |
| if match_str in DOMAIN_SKIPLIST: |
| return False |
| |
| if all(p.isalpha() and len(p) <= 3 for p in parts): |
| |
| if len(parts) <= 2: |
| return False |
| return True |
|
|
|
|
| def is_valid_ip(match_str): |
| """Validate IP address octets.""" |
| parts = match_str.split('.') |
| return all(0 <= int(p) <= 255 for p in parts) |
|
|
|
|
| def is_valid_hash(match_str, expected_len): |
| """Validate hash — must not be all same char, must be hex.""" |
| if len(set(match_str.lower())) < 4: |
| return False |
| return True |
|
|
|
|
| def find_unlabeled_iocs(text, existing_intervals): |
| """Find IOC matches in text that don't overlap existing spans.""" |
| new_spans = [] |
| seen = set() |
|
|
| for i, pattern in enumerate(IOC_PATTERNS): |
| for m in pattern.finditer(text): |
| start, end = m.start(), m.end() |
| match_str = m.group() |
|
|
| |
| if overlaps_any(start, end, existing_intervals): |
| continue |
|
|
| |
| if (start, end) in seen: |
| continue |
|
|
| |
| if i <= 2: |
| expected = [64, 40, 32][i] |
| if len(match_str) != expected: |
| continue |
| if not is_valid_hash(match_str, expected): |
| continue |
| elif i == 3: |
| if not is_valid_ip(match_str): |
| continue |
| elif i in (4, 5): |
| pass |
| elif i == 6: |
| if not is_valid_domain(text, match_str): |
| continue |
|
|
| |
| if overlaps_any(start, end, new_spans): |
| continue |
|
|
| new_spans.append((start, end)) |
| seen.add((start, end)) |
|
|
| return new_spans |
|
|
|
|
| def add_indicator_spans(example, new_offsets): |
| """Add new Indicator spans to an example.""" |
| key = "Indicator: " |
| for start, end in new_offsets: |
| span_text = example['text'][start:end] |
| span_key = f"Indicator: {span_text}" |
| if span_key not in example['spans']: |
| example['spans'][span_key] = [] |
| example['spans'][span_key].append([start, end]) |
| return example |
|
|
|
|
| def is_exploitdb(example): |
| """Check if example is from exploitdb source.""" |
| info = example.get('info', {}) |
| source = info.get('source', '') |
| if 'exploitdb' in source.lower(): |
| return True |
| return False |
|
|
|
|
| def entity_density(example): |
| """Calculate fraction of text covered by entity spans.""" |
| text_len = len(example.get('text', '')) |
| if text_len == 0: |
| return 0 |
| covered = set() |
| for key, offsets in example.get('spans', {}).items(): |
| for (start, end) in offsets: |
| for i in range(start, end): |
| covered.add(i) |
| return len(covered) / text_len |
|
|
|
|
| def remove_bare_extension_indicators(example): |
| """Remove Indicator spans that are bare file extensions.""" |
| removed = 0 |
| keys_to_remove = [] |
| for key in list(example['spans'].keys()): |
| if not key.startswith('Indicator:'): |
| continue |
| |
| span_text = key[len('Indicator:'):].strip() |
| if BARE_EXTENSIONS.match(span_text): |
| keys_to_remove.append(key) |
| removed += len(example['spans'][key]) |
|
|
| for key in keys_to_remove: |
| del example['spans'][key] |
| return example, removed |
|
|
|
|
| def clean_file(input_path, output_path): |
| """Clean a single JSONL file.""" |
| stats = { |
| 'total': 0, |
| 'iocs_added': 0, |
| 'examples_with_new_iocs': 0, |
| 'exploitdb_removed': 0, |
| 'exploitdb_kept': 0, |
| 'extension_labels_removed': 0, |
| 'output': 0, |
| } |
|
|
| |
| exploitdb_indices = [] |
| examples = [] |
| with open(input_path) as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| if not line: |
| continue |
| ex = json.loads(line) |
| examples.append(ex) |
| if is_exploitdb(ex): |
| exploitdb_indices.append(i) |
|
|
| stats['total'] = len(examples) |
|
|
| |
| keep_exploitdb = set(random.sample(exploitdb_indices, min(500, len(exploitdb_indices)))) |
| stats['exploitdb_removed'] = len(exploitdb_indices) - len(keep_exploitdb) |
| stats['exploitdb_kept'] = len(keep_exploitdb) |
|
|
| with open(output_path, 'w') as out: |
| for i, ex in enumerate(examples): |
| |
| if i in exploitdb_indices and i not in keep_exploitdb: |
| continue |
|
|
| |
| existing = get_existing_spans(ex) |
| new_iocs = find_unlabeled_iocs(ex['text'], existing) |
| if new_iocs: |
| ex = add_indicator_spans(ex, new_iocs) |
| stats['iocs_added'] += len(new_iocs) |
| stats['examples_with_new_iocs'] += 1 |
|
|
| |
| ex, ext_removed = remove_bare_extension_indicators(ex) |
| stats['extension_labels_removed'] += ext_removed |
|
|
| out.write(json.dumps(ex, ensure_ascii=False) + '\n') |
| stats['output'] += 1 |
|
|
| return stats |
|
|
|
|
| def main(): |
| base = Path('/home/ubuntu/alkyline/data/processed') |
|
|
| for split in ['train', 'valid']: |
| inp = base / f'enriched_5class_{split}.jsonl' |
| outp = base / f'enriched_5class_{split}_cleaned.jsonl' |
| if not inp.exists(): |
| print(f"SKIP: {inp} not found") |
| continue |
| print(f"\n{'='*60}") |
| print(f"Cleaning: {inp.name} -> {outp.name}") |
| print(f"{'='*60}") |
| stats = clean_file(inp, outp) |
| for k, v in stats.items(): |
| print(f" {k}: {v}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|