phishguard-api / data_collector.py
prashanth135's picture
Upload 38 files
bebe233 verified
# ============================================================
# PhishGuard AI - data_collector.py
# Downloads all training data from public HTTP endpoints.
# No API keys required.
#
# Datasets:
# 1. PhishTank (bz2 JSON β†’ phishing URLs)
# 2. TRANCO Top-10K (zip CSV β†’ legitimate domains)
# 3. Kaggle GitHub mirror (CSV β†’ pre-extracted features)
# ============================================================
from __future__ import annotations
import bz2
import csv
import io
import json
import zipfile
import hashlib
import logging
from pathlib import Path
from typing import List, Tuple, Optional
import requests
import pandas as pd
from sklearn.model_selection import train_test_split
logger = logging.getLogger("phishguard.data_collector")
# ── Data directory ────────────────────────────────────────────────────
DATA_DIR = Path(__file__).parent / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)
# ── Public URLs (no API keys) ────────────────────────────────────────
PHISHTANK_URL = "http://data.phishtank.com/data/online-valid.json.bz2"
TRANCO_URL = "https://tranco-list.eu/top-1m.csv.zip"
KAGGLE_PRIMARY = "https://raw.githubusercontent.com/GregaVrbancic/Phishing-Dataset/master/dataset_full.csv"
KAGGLE_BACKUP = "https://raw.githubusercontent.com/datasets/phishing-websites/master/data.csv"
HEADERS = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
}
def download_phishtank(max_urls: int = 30000) -> List[str]:
"""
Download phishing URLs from PhishTank public feed.
Fetches bz2 β†’ decompresses β†’ parses JSON β†’ filters verified+online.
Returns list of verified phishing URLs (up to max_urls).
"""
logger.info("Downloading PhishTank data...")
phish_cache = DATA_DIR / "phishing_urls.txt"
# Use cache if recent
if phish_cache.exists() and phish_cache.stat().st_size > 1000:
urls = phish_cache.read_text().strip().splitlines()
if len(urls) >= 100:
logger.info(f"Using cached PhishTank data: {len(urls)} URLs")
return urls[:max_urls]
try:
resp = requests.get(PHISHTANK_URL, headers=HEADERS, timeout=120, stream=True)
resp.raise_for_status()
# Decompress bz2
raw_data = bz2.decompress(resp.content)
records = json.loads(raw_data)
# Filter: verified=True AND online (verification_time present)
urls: List[str] = []
for record in records:
if not isinstance(record, dict):
continue
url = record.get("url", "").strip()
verified = record.get("verified", "no")
online = record.get("online", "no")
is_verified = verified in (True, "yes", "true", "True", "1", 1)
is_online = online in (True, "yes", "true", "True", "1", 1)
if url and is_verified and is_online:
urls.append(url)
if len(urls) >= max_urls:
break
logger.info(f"PhishTank: {len(urls)} verified+online URLs extracted")
# Cache to disk
phish_cache.write_text("\n".join(urls))
return urls
except Exception as e:
logger.warning(f"PhishTank download failed: {e}")
# Fallback: try to use cached data
if phish_cache.exists():
urls = phish_cache.read_text().strip().splitlines()
logger.info(f"Using fallback cached data: {len(urls)} URLs")
return urls[:max_urls]
# Generate synthetic phishing-like URLs for training
logger.warning("Generating synthetic phishing URLs as fallback")
return _generate_synthetic_phishing(500)
def _generate_synthetic_phishing(count: int) -> List[str]:
"""Generate synthetic phishing URLs for training when real data unavailable."""
import random
brands = ["paypal", "google", "apple", "microsoft", "amazon", "netflix",
"facebook", "chase", "wellsfargo", "bankofamerica"]
tlds = [".xyz", ".tk", ".ml", ".ga", ".cf", ".gq", ".pw", ".top", ".click"]
keywords = ["login", "verify", "secure", "update", "account", "signin",
"reset", "confirm", "suspend", "banking", "alert", "password"]
urls: List[str] = []
for _ in range(count):
brand = random.choice(brands)
tld = random.choice(tlds)
kw = random.choice(keywords)
sep = random.choice(["-", ".", ""])
prefix = random.choice(["http://", "https://"])
sub = random.choice(["", "www.", "secure.", "login.", "m."])
urls.append(f"{prefix}{sub}{brand}{sep}{kw}{tld}/{kw}/index.html")
return urls
def download_tranco(n: int = 10000) -> List[str]:
"""
Download TRANCO Top-1M list, return top-N domains as https:// URLs.
Fetches zip β†’ extracts CSV β†’ takes column 2 (domain) β†’ top N rows.
"""
logger.info(f"Downloading TRANCO top-{n} domains...")
legit_cache = DATA_DIR / "legitimate_urls.txt"
# Use cache if present
if legit_cache.exists() and legit_cache.stat().st_size > 1000:
urls = legit_cache.read_text().strip().splitlines()
if len(urls) >= min(n, 100):
logger.info(f"Using cached TRANCO data: {len(urls)} domains")
return urls[:n]
try:
resp = requests.get(TRANCO_URL, headers=HEADERS, timeout=60)
resp.raise_for_status()
# Extract CSV from zip
with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
csv_name = zf.namelist()[0]
csv_data = zf.read(csv_name).decode("utf-8")
# Parse: format is "rank,domain" per line
urls: List[str] = []
for line in csv_data.strip().splitlines():
parts = line.split(",")
if len(parts) >= 2:
domain = parts[1].strip()
if domain:
urls.append(f"https://{domain}")
if len(urls) >= n:
break
logger.info(f"TRANCO: {len(urls)} legitimate domains extracted")
# Cache to disk
legit_cache.write_text("\n".join(urls))
return urls
except Exception as e:
logger.warning(f"TRANCO download failed: {e}")
# Fallback: use cached data or generate synthetic
if legit_cache.exists():
urls = legit_cache.read_text().strip().splitlines()
return urls[:n]
logger.warning("Generating synthetic legitimate URLs as fallback")
return _generate_synthetic_legitimate(n)
def _generate_synthetic_legitimate(count: int) -> List[str]:
"""Generate legitimate-looking URLs as fallback."""
top_domains = [
"google.com", "youtube.com", "facebook.com", "amazon.com",
"wikipedia.org", "twitter.com", "instagram.com", "linkedin.com",
"microsoft.com", "apple.com", "github.com", "stackoverflow.com",
"reddit.com", "netflix.com", "paypal.com", "yahoo.com", "bing.com",
"adobe.com", "dropbox.com", "zoom.us", "slack.com", "spotify.com",
"twitch.tv", "ebay.com", "walmart.com", "target.com", "cnn.com",
"bbc.com", "nytimes.com", "medium.com",
]
urls = [f"https://{d}" for d in top_domains]
# Pad with numbered subpages
while len(urls) < count:
d = top_domains[len(urls) % len(top_domains)]
urls.append(f"https://{d}/page/{len(urls)}")
return urls[:count]
def download_kaggle_mirror() -> pd.DataFrame:
"""
Download pre-extracted URL features from Kaggle GitHub mirror.
Falls back to backup URL if primary fails.
Returns DataFrame with features and CLASS_LABEL column.
"""
logger.info("Downloading Kaggle URL features dataset...")
kaggle_cache = DATA_DIR / "kaggle_features.csv"
if kaggle_cache.exists() and kaggle_cache.stat().st_size > 1000:
logger.info("Using cached Kaggle features")
return pd.read_csv(kaggle_cache)
for url in [KAGGLE_PRIMARY, KAGGLE_BACKUP]:
try:
resp = requests.get(url, headers=HEADERS, timeout=60)
resp.raise_for_status()
df = pd.read_csv(io.StringIO(resp.text))
# Standardize label column name
label_candidates = ["CLASS_LABEL", "class_label", "Result", "result", "label"]
for col in label_candidates:
if col in df.columns:
df = df.rename(columns={col: "CLASS_LABEL"})
break
if "CLASS_LABEL" not in df.columns:
# Try last column
df = df.rename(columns={df.columns[-1]: "CLASS_LABEL"})
# Normalize labels to 0/1
if df["CLASS_LABEL"].dtype == object:
df["CLASS_LABEL"] = df["CLASS_LABEL"].map(
{"legitimate": 0, "phishing": 1, "safe": 0}
).fillna(0).astype(int)
else:
# Handle -1 as legitimate (common in some datasets)
df["CLASS_LABEL"] = df["CLASS_LABEL"].apply(
lambda x: 0 if x <= 0 else 1
)
# Cache
df.to_csv(kaggle_cache, index=False)
logger.info(f"Kaggle features: {len(df)} rows, {len(df.columns)} columns")
return df
except Exception as e:
logger.warning(f"Kaggle mirror {url} failed: {e}")
continue
logger.error("All Kaggle mirrors failed")
return pd.DataFrame()
def merge_datasets(
phish_urls: List[str],
legit_urls: List[str],
test_size: float = 0.15,
val_size: float = 0.15,
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]], List[Tuple[str, int]]]:
"""
Merge phishing + legitimate URLs, return stratified 70/15/15 split.
Returns (train, val, test) where each is List[(url, label)].
Label: 1 = phishing, 0 = legitimate.
"""
# Deduplicate
phish_set = set(phish_urls)
legit_set = set(legit_urls) - phish_set # Ensure no URL in both sets
all_data = [(url, 1) for url in phish_set] + [(url, 0) for url in legit_set]
urls = [d[0] for d in all_data]
labels = [d[1] for d in all_data]
# First split: train+val vs test
train_val_urls, test_urls, train_val_labels, test_labels = train_test_split(
urls, labels,
test_size=test_size,
stratify=labels,
random_state=42,
)
# Second split: train vs val
relative_val = val_size / (1 - test_size)
train_urls, val_urls, train_labels, val_labels = train_test_split(
train_val_urls, train_val_labels,
test_size=relative_val,
stratify=train_val_labels,
random_state=42,
)
train = list(zip(train_urls, train_labels))
val = list(zip(val_urls, val_labels))
test = list(zip(test_urls, test_labels))
logger.info(f"Dataset split: train={len(train)}, val={len(val)}, test={len(test)}")
return train, val, test
def save_url_lists(
phish_urls: List[str],
legit_urls: List[str],
phish_path: Optional[Path] = None,
legit_path: Optional[Path] = None,
) -> None:
"""Save URL lists to text files."""
phish_path = phish_path or DATA_DIR / "phishing_urls.txt"
legit_path = legit_path or DATA_DIR / "legitimate_urls.txt"
phish_path.write_text("\n".join(phish_urls))
legit_path.write_text("\n".join(legit_urls))
logger.info(f"Saved {len(phish_urls)} phishing URLs to {phish_path}")
logger.info(f"Saved {len(legit_urls)} legitimate URLs to {legit_path}")
def url_hash(url: str) -> str:
"""SHA256 hash of a URL (for dedup and privacy)."""
return hashlib.sha256(url.encode("utf-8")).hexdigest()
# ── Entry point ──────────────────────────────────────────────────────
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-7s | %(message)s",
)
print("=" * 60)
print("PhishGuard AI β€” Data Collection")
print("=" * 60)
# 1. PhishTank
phish_urls = download_phishtank()
print(f"\nβœ… PhishTank: {len(phish_urls)} phishing URLs")
# 2. TRANCO
legit_urls = download_tranco(n=10000)
print(f"βœ… TRANCO: {len(legit_urls)} legitimate URLs")
# 3. Kaggle features
kaggle_df = download_kaggle_mirror()
if not kaggle_df.empty:
phish_count = (kaggle_df["CLASS_LABEL"] == 1).sum()
legit_count = (kaggle_df["CLASS_LABEL"] == 0).sum()
print(f"βœ… Kaggle: {len(kaggle_df)} rows ({phish_count} phishing, {legit_count} legit)")
else:
print("⚠️ Kaggle: download failed (will use PhishTank + TRANCO only)")
# 4. Save URL lists
save_url_lists(phish_urls, legit_urls)
# 5. Merge and split
train, val, test = merge_datasets(phish_urls, legit_urls)
print(f"\nπŸ“Š Dataset splits:")
print(f" Train: {len(train)} ({sum(1 for _,l in train if l==1)} phish / {sum(1 for _,l in train if l==0)} legit)")
print(f" Val: {len(val)} ({sum(1 for _,l in val if l==1)} phish / {sum(1 for _,l in val if l==0)} legit)")
print(f" Test: {len(test)} ({sum(1 for _,l in test if l==1)} phish / {sum(1 for _,l in test if l==0)} legit)")
print(f"\nβœ… All data saved to {DATA_DIR}")
print("=" * 60)
if __name__ == "__main__":
main()