Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - visual_analyzer.py | |
| # Takes a screenshot of a webpage using a headless browser | |
| # and analyzes it for visual phishing indicators. | |
| # | |
| # Screenshot parameters (from architecture doc 2.3): | |
| # Viewport: 1280Γ800 (standard desktop resolution) | |
| # Timeout: 10s (prevent hanging on slow/malicious pages) | |
| # Wait: domcontentloaded (faster than networkidle) | |
| # Blocked: fonts, media, video (60-70% faster load) | |
| # User-Agent: Chrome 120 string (avoid bot detection) | |
| # | |
| # Tier 4 is OPTIONAL β controlled by env var ENABLE_VISUAL_TIER. | |
| # Set ENABLE_VISUAL_TIER=1 to enable. | |
| # Unset / set 0 β tier 4 is skipped with "tier4_disabled". | |
| # | |
| # Render.com: If deploying with Playwright, your render.yaml | |
| # build command must install Chromium deps. See render.yaml | |
| # comments and the Dockerfile for required apt packages. | |
| # | |
| # Latency budget: < 200ms for screenshot capture | |
| # ============================================================ | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import time | |
| import hashlib | |
| import logging | |
| from urllib.parse import urlparse | |
| logger = logging.getLogger("phishguard.visual") | |
| # ββ Environment gate βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ENABLE_VISUAL_TIER = os.environ.get("ENABLE_VISUAL_TIER", "0").strip() in ("1", "true", "yes") | |
| if not ENABLE_VISUAL_TIER: | |
| print("[PhishGuard] Tier 4 visual analysis DISABLED (set ENABLE_VISUAL_TIER=1 to enable)") | |
| # ββ Playwright availability ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PLAYWRIGHT_AVAILABLE = False | |
| if ENABLE_VISUAL_TIER: | |
| try: | |
| from playwright.async_api import async_playwright | |
| PLAYWRIGHT_AVAILABLE = True | |
| print("[PhishGuard] Playwright available β screenshot capture enabled") | |
| except ImportError: | |
| print("[PhishGuard] Playwright not installed β visual analysis will use heuristic-only mode") | |
| # ββ PIL availability βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _pil_available = False | |
| try: | |
| from PIL import Image | |
| import io as _io | |
| _pil_available = True | |
| except ImportError: | |
| print("[PhishGuard] Pillow not available β color analysis disabled") | |
| # ββ Screenshot cache config ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "screenshots") | |
| _CACHE_TTL = 24 * 60 * 60 # 24 hours in seconds | |
| os.makedirs(_CACHE_DIR, exist_ok=True) | |
| # ββ Brand / financial keyword databases ββββββββββββββββββββββββββββββββββββββ | |
| BRAND_DATABASE = { | |
| # brand_keyword β list of legitimate domains | |
| "paypal": ["paypal.com"], | |
| "apple": ["apple.com", "icloud.com"], | |
| "google": ["google.com", "gmail.com", "accounts.google.com"], | |
| "amazon": ["amazon.com", "amazon.co.uk", "aws.amazon.com"], | |
| "microsoft": ["microsoft.com", "live.com", "outlook.com", "office.com"], | |
| "netflix": ["netflix.com"], | |
| "facebook": ["facebook.com", "fb.com"], | |
| "instagram": ["instagram.com"], | |
| "chase": ["chase.com"], | |
| "wellsfargo": ["wellsfargo.com"], | |
| "bankofamerica": ["bankofamerica.com"], | |
| "citibank": ["citibank.com", "citi.com"], | |
| "hsbc": ["hsbc.com"], | |
| "hdfc": ["hdfcbank.com"], | |
| "icici": ["icicibank.com"], | |
| "sbi": ["onlinesbi.com", "sbi.co.in"], | |
| } | |
| FINANCIAL_BRANDS = { | |
| "paypal", "chase", "wellsfargo", "bankofamerica", "citibank", | |
| "hsbc", "hdfc", "icici", "sbi", "bank", "banking", | |
| } | |
| def _domain_hash(url: str) -> str: | |
| """Generate a stable hash for screenshot caching based on the domain.""" | |
| try: | |
| parsed = urlparse(url if url.startswith("http") else "http://" + url) | |
| host = parsed.hostname or url | |
| return hashlib.sha256(host.encode()).hexdigest()[:16] | |
| except Exception: | |
| return hashlib.sha256(url.encode()).hexdigest()[:16] | |
| def _get_root_domain(url: str) -> str: | |
| """Extract root domain from URL. E.g. https://login.paypal.com β paypal.com""" | |
| try: | |
| parsed = urlparse(url if url.startswith("http") else "http://" + url) | |
| host = (parsed.hostname or "").lower().replace("www.", "") | |
| parts = host.split(".") | |
| return ".".join(parts[-2:]) if len(parts) >= 2 else host | |
| except Exception: | |
| return "" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SCREENSHOT CAPTURE (with cache) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_cached_screenshot(url: str) -> bytes | None: | |
| """ | |
| Check if a cached screenshot exists for this domain and is < 24 hours old. | |
| Returns the screenshot bytes or None. | |
| """ | |
| dhash = _domain_hash(url) | |
| cache_path = os.path.join(_CACHE_DIR, f"{dhash}.png") | |
| if not os.path.exists(cache_path): | |
| return None | |
| # Check age | |
| age = time.time() - os.path.getmtime(cache_path) | |
| if age >= _CACHE_TTL: | |
| # Expired β delete stale cache | |
| try: | |
| os.remove(cache_path) | |
| except OSError: | |
| pass | |
| return None | |
| try: | |
| with open(cache_path, "rb") as f: | |
| data = f.read() | |
| logger.info(f"Screenshot cache HIT | url={url} | age={age:.0f}s") | |
| return data | |
| except Exception: | |
| return None | |
| def _save_screenshot_cache(url: str, data: bytes): | |
| """Save screenshot bytes to cache as screenshots/<domain_hash>.png.""" | |
| try: | |
| dhash = _domain_hash(url) | |
| cache_path = os.path.join(_CACHE_DIR, f"{dhash}.png") | |
| with open(cache_path, "wb") as f: | |
| f.write(data) | |
| logger.info(f"Screenshot cached | url={url} | path={cache_path}") | |
| except Exception as e: | |
| logger.warning(f"Screenshot cache write failed | error={e}") | |
| async def take_screenshot(url: str) -> bytes | None: | |
| """ | |
| Open the URL in a hidden (headless) browser and take a screenshot. | |
| The user never sees this browser window. | |
| Uses a 24-hour cache: if screenshots/<domain_hash>.png exists and is | |
| fresh, returns cached bytes without launching a browser. | |
| Returns: screenshot as bytes, or None if it fails. | |
| """ | |
| # Gate: tier 4 disabled | |
| if not ENABLE_VISUAL_TIER: | |
| return None | |
| # Check cache first | |
| cached = _get_cached_screenshot(url) | |
| if cached is not None: | |
| return cached | |
| # Playwright not available β can't take a fresh screenshot | |
| if not PLAYWRIGHT_AVAILABLE: | |
| logger.warning(f"Screenshot skipped (no Playwright) | url={url}") | |
| return None | |
| try: | |
| async with async_playwright() as p: | |
| browser = await p.chromium.launch(headless=True) | |
| context = await browser.new_context( | |
| viewport={"width": 1280, "height": 800}, | |
| ignore_https_errors=True, | |
| user_agent=( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) " | |
| "Chrome/120.0.0.0 Safari/537.36" | |
| ) | |
| ) | |
| page = await context.new_page() | |
| # Block fonts and media to speed up loading (60-70% faster) | |
| await page.route( | |
| "**/*.{woff,woff2,ttf,mp4,mp3,wav}", | |
| lambda route: route.abort() | |
| ) | |
| await page.goto(url, timeout=10000, wait_until="domcontentloaded") | |
| # ββ Extract page metadata for heuristic analysis ββββββββββ | |
| page_title = await page.title() or "" | |
| has_password_field = await page.locator("input[type='password']").count() > 0 | |
| screenshot = await page.screenshot(full_page=False) | |
| await browser.close() | |
| # Cache the screenshot for 24 hours | |
| if screenshot: | |
| _save_screenshot_cache(url, screenshot) | |
| return screenshot | |
| except Exception as e: | |
| logger.error(f"Screenshot failed | url={url} | error={e}") | |
| return None | |
| async def take_screenshot_with_metadata(url: str) -> dict: | |
| """ | |
| Enhanced screenshot capture that also extracts page metadata | |
| (title, login forms) for heuristic visual scoring. | |
| Returns: { | |
| "screenshot": bytes|None, | |
| "page_title": str, | |
| "has_password_field": bool, | |
| "uses_https": bool, | |
| "error": str|None | |
| } | |
| """ | |
| result = { | |
| "screenshot": None, | |
| "page_title": "", | |
| "has_password_field": False, | |
| "uses_https": url.lower().startswith("https"), | |
| "error": None, | |
| } | |
| # Gate: tier 4 disabled | |
| if not ENABLE_VISUAL_TIER: | |
| result["error"] = "tier4_disabled" | |
| return result | |
| # Check screenshot cache (metadata won't be cached, just the image) | |
| cached = _get_cached_screenshot(url) | |
| if cached is not None: | |
| result["screenshot"] = cached | |
| # We can't get page metadata from cache, but we have the image | |
| return result | |
| if not PLAYWRIGHT_AVAILABLE: | |
| result["error"] = "playwright_not_available" | |
| return result | |
| try: | |
| async with async_playwright() as p: | |
| browser = await p.chromium.launch(headless=True) | |
| context = await browser.new_context( | |
| viewport={"width": 1280, "height": 800}, | |
| ignore_https_errors=True, | |
| user_agent=( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) " | |
| "Chrome/120.0.0.0 Safari/537.36" | |
| ) | |
| ) | |
| page = await context.new_page() | |
| await page.route( | |
| "**/*.{woff,woff2,ttf,mp4,mp3,wav}", | |
| lambda route: route.abort() | |
| ) | |
| await page.goto(url, timeout=10000, wait_until="domcontentloaded") | |
| # Extract metadata | |
| result["page_title"] = await page.title() or "" | |
| result["has_password_field"] = await page.locator("input[type='password']").count() > 0 | |
| screenshot = await page.screenshot(full_page=False) | |
| await browser.close() | |
| result["screenshot"] = screenshot | |
| # Cache the screenshot | |
| if screenshot: | |
| _save_screenshot_cache(url, screenshot) | |
| except Exception as e: | |
| result["error"] = str(e) | |
| logger.error(f"Screenshot+metadata failed | url={url} | error={e}") | |
| return result | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # VISUAL PHISHING HEURISTICS (no CNN needed) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze_visual_heuristic(url: str, page_title: str = "", | |
| has_password_field: bool = False) -> dict: | |
| """ | |
| Heuristic visual phishing scoring WITHOUT needing a trained CNN. | |
| Returns heuristic_visual_score from 0.0 to 1.0 based on: | |
| Signal 1: Page title contains brand names but domain doesn't match | |
| Signal 2: Page has a login form (input[type=password]) | |
| Signal 3: SSL cert missing for pages mentioning financial brands | |
| Signal 4: Brand keyword in URL path but not in domain (path spoofing) | |
| Returns: { | |
| heuristic_visual_score: float 0..1, | |
| flags: list[str], | |
| brand_mismatch: bool, | |
| has_login_form: bool, | |
| ssl_missing_financial: bool | |
| } | |
| """ | |
| score = 0.0 | |
| flags = [] | |
| brand_mismatch = False | |
| ssl_missing_financial = False | |
| root_domain = _get_root_domain(url) | |
| url_lower = url.lower() | |
| title_lower = (page_title or "").lower() | |
| uses_https = url_lower.startswith("https") | |
| # ββ Signal 1: Brand name in page title but domain doesn't match βββββββ | |
| for brand, legit_domains in BRAND_DATABASE.items(): | |
| if brand in title_lower: | |
| if not any(d in root_domain for d in legit_domains): | |
| score += 0.30 | |
| flags.append(f"title_brand_mismatch:{brand}") | |
| brand_mismatch = True | |
| break # One brand mismatch is enough | |
| # ββ Signal 2: Login form detected (input[type=password]) ββββββββββββββ | |
| if has_password_field: | |
| score += 0.15 | |
| flags.append("has_password_field") | |
| # Extra risk if combined with brand mismatch | |
| if brand_mismatch: | |
| score += 0.15 | |
| flags.append("login_form_with_brand_mismatch") | |
| # ββ Signal 3: No SSL for financial brand content ββββββββββββββββββββββ | |
| mentions_financial = any( | |
| fb in title_lower or fb in url_lower | |
| for fb in FINANCIAL_BRANDS | |
| ) | |
| if mentions_financial and not uses_https: | |
| score += 0.25 | |
| flags.append("no_ssl_financial_content") | |
| ssl_missing_financial = True | |
| # ββ Signal 4: Brand keyword in URL path but not in domain βββββββββββββ | |
| try: | |
| parsed = urlparse(url) | |
| path = (parsed.path or "").lower() | |
| for brand, legit_domains in BRAND_DATABASE.items(): | |
| if brand in path and not any(d in root_domain for d in legit_domains): | |
| score += 0.15 | |
| flags.append(f"brand_in_path_not_domain:{brand}") | |
| break | |
| except Exception: | |
| pass | |
| return { | |
| "heuristic_visual_score": round(min(score, 1.0), 4), | |
| "flags": flags, | |
| "brand_mismatch": brand_mismatch, | |
| "has_login_form": has_password_field, | |
| "ssl_missing_financial": ssl_missing_financial, | |
| } | |
| def analyze_visual_basic(screenshot_bytes: bytes, url: str) -> dict: | |
| """ | |
| Basic visual analysis using color histograms. | |
| Detects if a page uses colors associated with known brands | |
| but the URL doesn't match that brand. | |
| Note: For full CNN analysis, see cnn/cnn_model.py | |
| """ | |
| if not screenshot_bytes: | |
| return {"visual_risk": 0.1, "note": "screenshot_failed"} | |
| if not _pil_available: | |
| return {"visual_risk": 0.1, "note": "pil_not_available"} | |
| try: | |
| img = Image.open(_io.BytesIO(screenshot_bytes)).convert("RGB") | |
| img_small = img.resize((224, 224)) | |
| # Get average color channels | |
| r_vals = list(img_small.split()[0].getdata()) | |
| g_vals = list(img_small.split()[1].getdata()) | |
| b_vals = list(img_small.split()[2].getdata()) | |
| r_avg = sum(r_vals) / len(r_vals) | |
| g_avg = sum(g_vals) / len(g_vals) | |
| b_avg = sum(b_vals) / len(b_vals) | |
| risk = 0.2 # baseline | |
| url_lower = url.lower() | |
| # PayPal brand colors: deep blue | |
| if b_avg > r_avg * 1.4 and b_avg > g_avg * 1.3: | |
| if "paypal" not in url_lower: | |
| risk += 0.25 | |
| # Microsoft brand colors: orange/blue | |
| if r_avg > 180 and b_avg < 100: | |
| if "microsoft" not in url_lower and "office" not in url_lower: | |
| risk += 0.20 | |
| # Apple brand: mostly white/grey | |
| if r_avg > 220 and g_avg > 220 and b_avg > 220: | |
| if "apple" not in url_lower: | |
| risk += 0.10 | |
| return { | |
| "visual_risk": round(min(risk, 1.0), 4), | |
| "dominant_rgb": [round(r_avg), round(g_avg), round(b_avg)], | |
| "note": "basic_color_analysis" | |
| } | |
| except Exception as e: | |
| return {"visual_risk": 0.1, "note": "analysis_error"} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FULL TIER 4 ANALYSIS (combines CNN + heuristics + color) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_tier4_analysis(url: str, page_title: str = "", | |
| page_snippet: str = "") -> dict: | |
| """ | |
| Complete Tier 4 visual analysis pipeline. | |
| Called by main.py for borderline cases (0.40 β€ Pβ < 0.85). | |
| Graceful fallback chain: | |
| 1. If ENABLE_VISUAL_TIER is off β tier4_disabled | |
| 2. If screenshot fails β screenshot_failed (with heuristic fallback) | |
| 3. If CNN fails β uses heuristic_visual_score only | |
| Returns: { | |
| tier4_score: float|None, | |
| tier4_status: str ("ok"|"screenshot_failed"|"tier4_disabled"|...), | |
| tier4_reason: str, | |
| visual_heuristic: dict, | |
| color_analysis: dict, | |
| screenshot_cached: bool | |
| } | |
| """ | |
| # ββ Gate: completely skip if not enabled βββββββββββββββββββββββββββββββ | |
| if not ENABLE_VISUAL_TIER: | |
| return { | |
| "tier4_score": None, | |
| "tier4_status": "tier4_disabled", | |
| "tier4_reason": "ENABLE_VISUAL_TIER env var not set", | |
| } | |
| # ββ Attempt screenshot with metadata extraction βββββββββββββββββββββββ | |
| meta = await take_screenshot_with_metadata(url) | |
| screenshot = meta["screenshot"] | |
| extracted_title = meta["page_title"] or page_title | |
| has_password = meta["has_password_field"] | |
| screenshot_error = meta["error"] | |
| # ββ Always run visual heuristics (no screenshot needed) βββββββββββββββ | |
| heuristic = analyze_visual_heuristic( | |
| url, | |
| page_title=extracted_title, | |
| has_password_field=has_password, | |
| ) | |
| # ββ Screenshot failed β return heuristic-only result ββββββββββββββββββ | |
| if screenshot is None: | |
| reason = screenshot_error or "unknown_screenshot_error" | |
| return { | |
| "tier4_score": heuristic["heuristic_visual_score"], | |
| "tier4_status": "screenshot_failed", | |
| "tier4_reason": reason, | |
| "visual_heuristic": heuristic, | |
| "color_analysis": None, | |
| "screenshot_cached": False, | |
| } | |
| # ββ Color-based analysis (works without trained CNN) ββββββββββββββββββ | |
| color = analyze_visual_basic(screenshot, url) | |
| # ββ Combine heuristic + color into a single tier4 score βββββββββββββββ | |
| # Weight: 60% heuristic, 40% color (since CNN isn't trained) | |
| combined = (heuristic["heuristic_visual_score"] * 0.60) + (color["visual_risk"] * 0.40) | |
| return { | |
| "tier4_score": round(min(combined, 1.0), 4), | |
| "tier4_status": "ok", | |
| "tier4_reason": "heuristic_and_color_analysis", | |
| "visual_heuristic": heuristic, | |
| "color_analysis": color, | |
| "screenshot_cached": _get_cached_screenshot(url) is not None, | |
| } | |