Spaces:
Runtime error
Runtime error
| """ | |
| Parallel HTML Feature Extraction Pipeline | |
| Processes ~80k HTML files using multiprocessing for CPU-bound parsing. | |
| Integrates quality filtering INTO the same parse pass (no double-parsing). | |
| Includes checkpointing, progress tracking, and balanced output. | |
| Usage: | |
| python scripts/feature_extraction/html/extract_features.py | |
| python scripts/feature_extraction/html/extract_features.py --no-filter | |
| python scripts/feature_extraction/html/extract_features.py --workers 8 | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| import time | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| from pathlib import Path | |
| import pandas as pd | |
| from tqdm import tqdm | |
| # --------------------------------------------------------------------------- | |
| # Resolve project root so imports work regardless of cwd | |
| # --------------------------------------------------------------------------- | |
| PROJECT_ROOT = Path(__file__).resolve().parents[3] # src/ | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from scripts.feature_extraction.html.html_feature_extractor import HTMLFeatureExtractor | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%H:%M:%S', | |
| ) | |
| logger = logging.getLogger('extract_features') | |
| # --------------------------------------------------------------------------- | |
| # Quality filter constants | |
| # --------------------------------------------------------------------------- | |
| MIN_FILE_SIZE = 800 # bytes | |
| MIN_TAGS = 8 | |
| MIN_WORDS = 30 | |
| ERROR_PATTERNS = [ | |
| 'page not found', '404 not found', '403 forbidden', | |
| 'access denied', 'server error', 'not available', | |
| 'domain for sale', 'website expired', 'coming soon', | |
| 'under construction', 'parked domain', 'buy this domain', | |
| 'domain has expired', 'this site can', | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Worker function – runs in a subprocess | |
| # --------------------------------------------------------------------------- | |
| def _process_file(args: tuple) -> dict | None: | |
| """ | |
| Process a single HTML file: read → (optionally filter) → extract. | |
| Designed to run inside ProcessPoolExecutor – must be picklable | |
| (top-level function, not a method). | |
| Args: | |
| args: (file_path_str, label, apply_filter) | |
| Returns: | |
| Feature dict with 'filename' and 'label' added, or None on skip/error. | |
| """ | |
| file_path_str, label, apply_filter = args | |
| try: | |
| path = Path(file_path_str) | |
| # --- Read file --- | |
| raw = path.read_text(encoding='utf-8', errors='ignore') | |
| # --- Quick pre-filter (before expensive parse) --- | |
| if apply_filter and len(raw) < MIN_FILE_SIZE: | |
| return None | |
| # --- Parse once with lxml --- | |
| from bs4 import BeautifulSoup | |
| try: | |
| soup = BeautifulSoup(raw, 'lxml') | |
| except Exception: | |
| soup = BeautifulSoup(raw, 'html.parser') | |
| # --- Quality filter (uses the already-parsed soup) --- | |
| if apply_filter: | |
| if not soup.find('body'): | |
| return None | |
| all_tags = soup.find_all() | |
| if len(all_tags) < MIN_TAGS: | |
| return None | |
| text = soup.get_text(separator=' ', strip=True).lower() | |
| words = text.split() | |
| if len(words) < MIN_WORDS: | |
| return None | |
| # Check error-page patterns (first 2000 chars only) | |
| text_head = text[:2000] | |
| for pat in ERROR_PATTERNS: | |
| if pat in text_head: | |
| return None | |
| # Must have at least some content elements | |
| has_content = ( | |
| len(soup.find_all('a')) > 0 or | |
| len(soup.find_all('form')) > 0 or | |
| len(soup.find_all('input')) > 0 or | |
| len(soup.find_all('img')) > 0 or | |
| len(soup.find_all('div')) > 3 | |
| ) | |
| if not has_content: | |
| return None | |
| # --- Extract features (re-parses internally with cache) --- | |
| extractor = HTMLFeatureExtractor() | |
| features = extractor.extract_features(raw) | |
| features['filename'] = path.name | |
| features['label'] = label | |
| return features | |
| except Exception: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Directory processor | |
| # --------------------------------------------------------------------------- | |
| def extract_from_directory( | |
| html_dir: Path, | |
| label: int, | |
| apply_filter: bool = True, | |
| max_workers: int = 6, | |
| limit: int | None = None, | |
| ) -> list[dict]: | |
| """ | |
| Extract features from all .html files in a directory using multiprocessing. | |
| Args: | |
| html_dir: Directory with .html files | |
| label: 0 = legitimate, 1 = phishing | |
| apply_filter: Apply quality filter | |
| max_workers: Number of parallel workers | |
| limit: Max files to return (None = all) | |
| Returns: | |
| List of feature dictionaries | |
| """ | |
| html_files = sorted(html_dir.glob('*.html')) | |
| total = len(html_files) | |
| label_name = 'Phishing' if label == 1 else 'Legitimate' | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"Processing {label_name}: {total:,} files") | |
| logger.info(f" Directory: {html_dir}") | |
| logger.info(f" Quality filter: {'ON' if apply_filter else 'OFF'}") | |
| logger.info(f" Workers: {max_workers}") | |
| logger.info(f"{'='*60}") | |
| # Build task list | |
| tasks = [(str(f), label, apply_filter) for f in html_files] | |
| results = [] | |
| n_filtered = 0 | |
| t0 = time.perf_counter() | |
| with ProcessPoolExecutor(max_workers=max_workers) as pool: | |
| futures = {pool.submit(_process_file, t): t for t in tasks} | |
| with tqdm(total=total, desc=f'{label_name}', unit='file') as pbar: | |
| for future in as_completed(futures): | |
| pbar.update(1) | |
| result = future.result() | |
| if result is None: | |
| n_filtered += 1 | |
| else: | |
| results.append(result) | |
| if limit and len(results) >= limit: | |
| # Cancel remaining futures | |
| for f in futures: | |
| f.cancel() | |
| break | |
| elapsed = time.perf_counter() - t0 | |
| speed = total / elapsed if elapsed > 0 else 0 | |
| logger.info(f" Extracted: {len(results):,} quality samples") | |
| logger.info(f" Filtered out: {n_filtered:,} ({n_filtered/max(total,1)*100:.1f}%)") | |
| logger.info(f" Time: {elapsed:.1f}s ({speed:.0f} files/sec)") | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Extract HTML features for phishing detection (parallel)') | |
| parser.add_argument('--phishing-dir', type=str, nargs='+', | |
| default=['data/html/phishing', 'data/html/phishing_v1'], | |
| help='Directories with phishing HTML files') | |
| parser.add_argument('--legit-dir', type=str, nargs='+', | |
| default=['data/html/legitimate', 'data/html/legitimate_v1'], | |
| help='Directories with legitimate HTML files') | |
| parser.add_argument('--output', type=str, default='data/features/html_features.csv', | |
| help='Output CSV path') | |
| parser.add_argument('--workers', type=int, default=6, | |
| help='Number of parallel workers (default: 6)') | |
| parser.add_argument('--no-filter', action='store_true', | |
| help='Disable quality filtering') | |
| parser.add_argument('--limit', type=int, default=None, | |
| help='Limit samples per class (for testing)') | |
| parser.add_argument('--no-balance', action='store_true', | |
| help='Do not balance classes') | |
| args = parser.parse_args() | |
| apply_filter = not args.no_filter | |
| # Resolve paths relative to project root | |
| phishing_dirs = [(PROJECT_ROOT / d).resolve() for d in args.phishing_dir] | |
| legit_dirs = [(PROJECT_ROOT / d).resolve() for d in args.legit_dir] | |
| output_path = (PROJECT_ROOT / args.output).resolve() | |
| logger.info("=" * 70) | |
| logger.info("HTML FEATURE EXTRACTION PIPELINE") | |
| logger.info("=" * 70) | |
| for d in phishing_dirs: | |
| logger.info(f" Phishing dir: {d}") | |
| for d in legit_dirs: | |
| logger.info(f" Legitimate dir: {d}") | |
| logger.info(f" Output: {output_path}") | |
| logger.info(f" Workers: {args.workers}") | |
| logger.info(f" Quality filter: {'ON' if apply_filter else 'OFF'}") | |
| # Validate directories | |
| for d in phishing_dirs: | |
| if not d.exists(): | |
| logger.warning(f"Phishing directory not found (skipping): {d}") | |
| for d in legit_dirs: | |
| if not d.exists(): | |
| logger.warning(f"Legitimate directory not found (skipping): {d}") | |
| # ---- Extract features ---- | |
| t_start = time.perf_counter() | |
| phishing_features = [] | |
| for d in phishing_dirs: | |
| if d.exists(): | |
| phishing_features.extend(extract_from_directory( | |
| d, label=1, apply_filter=apply_filter, | |
| max_workers=args.workers, limit=args.limit)) | |
| legit_features = [] | |
| for d in legit_dirs: | |
| if d.exists(): | |
| legit_features.extend(extract_from_directory( | |
| d, label=0, apply_filter=apply_filter, | |
| max_workers=args.workers, limit=args.limit)) | |
| # ---- Balance ---- | |
| if not args.no_balance: | |
| min_count = min(len(phishing_features), len(legit_features)) | |
| logger.info(f"\nBalancing to {min_count:,} per class") | |
| # Shuffle before truncating to get random sample | |
| import random | |
| random.seed(42) | |
| random.shuffle(phishing_features) | |
| random.shuffle(legit_features) | |
| phishing_features = phishing_features[:min_count] | |
| legit_features = legit_features[:min_count] | |
| # ---- Build DataFrame ---- | |
| all_features = phishing_features + legit_features | |
| if not all_features: | |
| logger.error("No features extracted!") | |
| sys.exit(1) | |
| df = pd.DataFrame(all_features) | |
| # Reorder columns: metadata first, then sorted features | |
| meta_cols = ['filename', 'label'] | |
| feature_cols = sorted([c for c in df.columns if c not in meta_cols]) | |
| df = df[meta_cols + feature_cols] | |
| # Shuffle rows | |
| df = df.sample(frac=1, random_state=42).reset_index(drop=True) | |
| # ---- Save ---- | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| df.to_csv(output_path, index=False) | |
| elapsed = time.perf_counter() - t_start | |
| # ---- Summary ---- | |
| logger.info("\n" + "=" * 70) | |
| logger.info("EXTRACTION COMPLETE") | |
| logger.info("=" * 70) | |
| logger.info(f" Total samples: {len(df):,}") | |
| logger.info(f" Phishing: {(df['label']==1).sum():,}") | |
| logger.info(f" Legitimate: {(df['label']==0).sum():,}") | |
| logger.info(f" Features: {len(feature_cols)}") | |
| logger.info(f" Total time: {elapsed:.1f}s") | |
| logger.info(f" Output: {output_path}") | |
| logger.info("=" * 70) | |
| # Quick stats | |
| numeric = df[feature_cols].describe().T[['mean', 'std', 'min', 'max']] | |
| logger.info(f"\nFeature statistics (sample):") | |
| logger.info(numeric.head(15).to_string()) | |
| if __name__ == '__main__': | |
| main() | |