Spaces:
Running
Running
Phase 3: Add transformer models (URL BERT + DGA + SecurityLLM via Inference API)
Browse files
app.py
CHANGED
|
@@ -504,6 +504,182 @@ def extract_url_features(url: str) -> Dict:
|
|
| 504 |
}
|
| 505 |
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
# ============================================================================
|
| 508 |
# NOTEBOOK EXECUTION (existing Gradio functionality)
|
| 509 |
# ============================================================================
|
|
@@ -853,6 +1029,57 @@ async def api_list_models():
|
|
| 853 |
return {"models": result, "total": len(ml_loader.MODEL_NAMES), "loaded": len(ml_loader.models)}
|
| 854 |
|
| 855 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
@api.post("/api/analysis/network")
|
| 857 |
async def api_analyze_network(request: Request):
|
| 858 |
"""Network traffic analysis – called by backend mlService.analyzeNetworkTraffic()"""
|
|
|
|
| 504 |
}
|
| 505 |
|
| 506 |
|
| 507 |
+
# ============================================================================
|
| 508 |
+
# TRANSFORMER MODEL LOADER (Phase 3 — real HF transformer models)
|
| 509 |
+
# ============================================================================
|
| 510 |
+
#
|
| 511 |
+
# Loads pretrained BERT-based classifiers from the Hugging Face Hub.
|
| 512 |
+
# Models are loaded lazily on first request to keep cold-start fast.
|
| 513 |
+
# A 7B Security LLM is NOT loaded locally (too big for free tier) —
|
| 514 |
+
# it's accessed via the HF Inference API on demand.
|
| 515 |
+
|
| 516 |
+
class TransformerModelLoader:
|
| 517 |
+
"""Loads pretrained Transformer classifiers from the HF Hub on demand."""
|
| 518 |
+
|
| 519 |
+
# Model registry — name → HF repo + task description
|
| 520 |
+
REGISTRY = {
|
| 521 |
+
"url_phishing_bert": {
|
| 522 |
+
"repo": "elftsdmr/malware-url-detect",
|
| 523 |
+
"task": "text-classification",
|
| 524 |
+
"labels": ["benign", "malicious"],
|
| 525 |
+
"desc": "BERT-based URL phishing/malware classifier",
|
| 526 |
+
},
|
| 527 |
+
"dga_detector": {
|
| 528 |
+
"repo": "YangYang-Research/dga-detection",
|
| 529 |
+
"task": "text-classification",
|
| 530 |
+
"labels": ["legit", "dga"],
|
| 531 |
+
"desc": "Domain Generation Algorithm detector (45-char domain input)",
|
| 532 |
+
},
|
| 533 |
+
}
|
| 534 |
+
SECURITY_LLM_REPO = "ZySec-AI/SecurityLLM" # Used via HF Inference API only
|
| 535 |
+
|
| 536 |
+
def __init__(self):
|
| 537 |
+
self.pipelines = {} # name → transformers.Pipeline (lazy-loaded)
|
| 538 |
+
self.load_errors = {} # name → last error message
|
| 539 |
+
self.transformers_available = False
|
| 540 |
+
try:
|
| 541 |
+
import transformers # noqa: F401
|
| 542 |
+
import torch # noqa: F401
|
| 543 |
+
self.transformers_available = True
|
| 544 |
+
logger.info("✅ transformers + torch available")
|
| 545 |
+
except ImportError as e:
|
| 546 |
+
logger.warning(f"⚠️ transformers/torch not installed: {e}")
|
| 547 |
+
|
| 548 |
+
def _ensure(self, name: str):
|
| 549 |
+
"""Load a pipeline on first use. Returns the pipeline or None on failure."""
|
| 550 |
+
if name in self.pipelines:
|
| 551 |
+
return self.pipelines[name]
|
| 552 |
+
if not self.transformers_available:
|
| 553 |
+
self.load_errors[name] = "transformers/torch not installed"
|
| 554 |
+
return None
|
| 555 |
+
if name not in self.REGISTRY:
|
| 556 |
+
self.load_errors[name] = f"Unknown model: {name}"
|
| 557 |
+
return None
|
| 558 |
+
spec = self.REGISTRY[name]
|
| 559 |
+
try:
|
| 560 |
+
from transformers import pipeline
|
| 561 |
+
logger.info(f"⏳ Loading {name} from {spec['repo']}...")
|
| 562 |
+
pipe = pipeline(spec["task"], model=spec["repo"], device=-1, top_k=None)
|
| 563 |
+
self.pipelines[name] = pipe
|
| 564 |
+
logger.info(f"✅ Loaded {name}")
|
| 565 |
+
return pipe
|
| 566 |
+
except Exception as e:
|
| 567 |
+
err = f"{type(e).__name__}: {str(e)[:200]}"
|
| 568 |
+
self.load_errors[name] = err
|
| 569 |
+
logger.error(f"❌ Failed to load {name}: {err}")
|
| 570 |
+
return None
|
| 571 |
+
|
| 572 |
+
def predict_url_phishing(self, url: str) -> Dict:
|
| 573 |
+
"""Classify a URL as benign or malicious using elftsdmr/malware-url-detect."""
|
| 574 |
+
pipe = self._ensure("url_phishing_bert")
|
| 575 |
+
if pipe is None:
|
| 576 |
+
return self._unavailable("url_phishing_bert")
|
| 577 |
+
try:
|
| 578 |
+
# Strip the protocol — model was trained on bare URLs
|
| 579 |
+
text = url.replace("https://", "").replace("http://", "")[:512]
|
| 580 |
+
result = pipe(text)
|
| 581 |
+
scores = result[0] if isinstance(result[0], list) else result
|
| 582 |
+
return self._format_classification(scores, "url_phishing_bert")
|
| 583 |
+
except Exception as e:
|
| 584 |
+
return self._error("url_phishing_bert", e)
|
| 585 |
+
|
| 586 |
+
def predict_dga(self, domain: str) -> Dict:
|
| 587 |
+
"""Classify a domain as legitimate or DGA-generated."""
|
| 588 |
+
pipe = self._ensure("dga_detector")
|
| 589 |
+
if pipe is None:
|
| 590 |
+
return self._unavailable("dga_detector")
|
| 591 |
+
try:
|
| 592 |
+
# Model expects bare domain, optimized for ≤45 chars
|
| 593 |
+
d = domain.lower().strip()[:45]
|
| 594 |
+
result = pipe(d)
|
| 595 |
+
scores = result[0] if isinstance(result[0], list) else result
|
| 596 |
+
return self._format_classification(scores, "dga_detector")
|
| 597 |
+
except Exception as e:
|
| 598 |
+
return self._error("dga_detector", e)
|
| 599 |
+
|
| 600 |
+
def security_chat(self, query: str, max_tokens: int = 512) -> Dict:
|
| 601 |
+
"""Cybersecurity Q&A via ZySec-AI/SecurityLLM hosted on HF Inference API.
|
| 602 |
+
Falls back to Gemini when the LLM is rate-limited or HF token is missing."""
|
| 603 |
+
if not HF_TOKEN:
|
| 604 |
+
return {
|
| 605 |
+
"model": "security-llm",
|
| 606 |
+
"response": None,
|
| 607 |
+
"error": "HF_TOKEN env var required to call ZySec-AI/SecurityLLM via Inference API",
|
| 608 |
+
"fallback_available": gemini_service.ready,
|
| 609 |
+
}
|
| 610 |
+
try:
|
| 611 |
+
import requests
|
| 612 |
+
url = f"https://api-inference.huggingface.co/models/{self.SECURITY_LLM_REPO}"
|
| 613 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json"}
|
| 614 |
+
payload = {
|
| 615 |
+
"inputs": query[:2000],
|
| 616 |
+
"parameters": {"max_new_tokens": max_tokens, "temperature": 0.3, "return_full_text": False},
|
| 617 |
+
"options": {"wait_for_model": True},
|
| 618 |
+
}
|
| 619 |
+
r = requests.post(url, headers=headers, json=payload, timeout=45)
|
| 620 |
+
if r.status_code == 200:
|
| 621 |
+
data = r.json()
|
| 622 |
+
text = data[0].get("generated_text") if isinstance(data, list) and data else (data.get("generated_text") if isinstance(data, dict) else str(data))
|
| 623 |
+
return {
|
| 624 |
+
"model": "security-llm",
|
| 625 |
+
"source": "huggingface_inference_api",
|
| 626 |
+
"response": text,
|
| 627 |
+
"model_id": self.SECURITY_LLM_REPO,
|
| 628 |
+
}
|
| 629 |
+
return {
|
| 630 |
+
"model": "security-llm",
|
| 631 |
+
"error": f"HF Inference API HTTP {r.status_code}: {r.text[:200]}",
|
| 632 |
+
"fallback_available": gemini_service.ready,
|
| 633 |
+
}
|
| 634 |
+
except Exception as e:
|
| 635 |
+
return {
|
| 636 |
+
"model": "security-llm",
|
| 637 |
+
"error": f"{type(e).__name__}: {str(e)[:200]}",
|
| 638 |
+
"fallback_available": gemini_service.ready,
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
def status(self) -> Dict:
|
| 642 |
+
return {
|
| 643 |
+
"transformers_available": self.transformers_available,
|
| 644 |
+
"loaded": list(self.pipelines.keys()),
|
| 645 |
+
"available": list(self.REGISTRY.keys()) + ["security_llm (via HF Inference API)"],
|
| 646 |
+
"load_errors": self.load_errors,
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
@staticmethod
|
| 650 |
+
def _format_classification(scores, model_name) -> Dict:
|
| 651 |
+
"""Normalize HF text-classification output to the cyberforge schema."""
|
| 652 |
+
if not scores:
|
| 653 |
+
return {"model": model_name, "error": "Empty scores"}
|
| 654 |
+
# scores is a list of {label, score} dicts. Find the threat label.
|
| 655 |
+
threat_labels = {"malicious", "phishing", "malware", "dga", "label_1", "1"}
|
| 656 |
+
# Top prediction
|
| 657 |
+
top = max(scores, key=lambda s: s["score"]) if isinstance(scores, list) else scores
|
| 658 |
+
is_threat = str(top["label"]).lower() in threat_labels
|
| 659 |
+
# Threat score: probability of the threat class (not just the top class)
|
| 660 |
+
threat_score = top["score"] if is_threat else 1.0 - top["score"]
|
| 661 |
+
return {
|
| 662 |
+
"model": model_name,
|
| 663 |
+
"prediction": top["label"],
|
| 664 |
+
"is_threat": is_threat,
|
| 665 |
+
"confidence": round(top["score"] * 100, 2),
|
| 666 |
+
"threat_score": round(threat_score, 4),
|
| 667 |
+
"all_scores": scores,
|
| 668 |
+
"inference_source": "huggingface_transformer",
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
@staticmethod
|
| 672 |
+
def _unavailable(model_name) -> Dict:
|
| 673 |
+
return {"model": model_name, "error": "Model unavailable — see /api/v2/status"}
|
| 674 |
+
|
| 675 |
+
@staticmethod
|
| 676 |
+
def _error(model_name, e: Exception) -> Dict:
|
| 677 |
+
return {"model": model_name, "error": f"{type(e).__name__}: {str(e)[:200]}"}
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
transformer_loader = TransformerModelLoader()
|
| 681 |
+
|
| 682 |
+
|
| 683 |
# ============================================================================
|
| 684 |
# NOTEBOOK EXECUTION (existing Gradio functionality)
|
| 685 |
# ============================================================================
|
|
|
|
| 1029 |
return {"models": result, "total": len(ml_loader.MODEL_NAMES), "loaded": len(ml_loader.models)}
|
| 1030 |
|
| 1031 |
|
| 1032 |
+
# ============================================================================
|
| 1033 |
+
# PHASE-3 ENDPOINTS — Real HF Transformer models
|
| 1034 |
+
# ============================================================================
|
| 1035 |
+
|
| 1036 |
+
@api.post("/api/v2/url-classify")
|
| 1037 |
+
async def api_v2_url_classify(request: Request):
|
| 1038 |
+
"""URL phishing/malware classification using elftsdmr/malware-url-detect (BERT).
|
| 1039 |
+
Body: { "url": "https://..." }"""
|
| 1040 |
+
body = await request.json()
|
| 1041 |
+
url = body.get("url", "").strip()
|
| 1042 |
+
if not url:
|
| 1043 |
+
return JSONResponse(status_code=400, content={"detail": "url required"})
|
| 1044 |
+
return transformer_loader.predict_url_phishing(url)
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
@api.post("/api/v2/dga-detect")
|
| 1048 |
+
async def api_v2_dga_detect(request: Request):
|
| 1049 |
+
"""DGA-generated domain detection using YangYang-Research/dga-detection.
|
| 1050 |
+
Body: { "domain": "abc123xyz.com" }"""
|
| 1051 |
+
body = await request.json()
|
| 1052 |
+
domain = body.get("domain", "").strip()
|
| 1053 |
+
if not domain:
|
| 1054 |
+
return JSONResponse(status_code=400, content={"detail": "domain required"})
|
| 1055 |
+
return transformer_loader.predict_dga(domain)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
@api.post("/api/v2/security-chat")
|
| 1059 |
+
async def api_v2_security_chat(request: Request):
|
| 1060 |
+
"""Cybersecurity Q&A via ZySec-AI/SecurityLLM (HF Inference API).
|
| 1061 |
+
Body: { "query": "...", "max_tokens": 512 }"""
|
| 1062 |
+
body = await request.json()
|
| 1063 |
+
query = body.get("query", "").strip()
|
| 1064 |
+
if not query:
|
| 1065 |
+
return JSONResponse(status_code=400, content={"detail": "query required"})
|
| 1066 |
+
max_tokens = int(body.get("max_tokens", 512))
|
| 1067 |
+
result = transformer_loader.security_chat(query, max_tokens=max_tokens)
|
| 1068 |
+
# Auto-fallback to Gemini when LLM unavailable
|
| 1069 |
+
if result.get("error") and gemini_service.ready:
|
| 1070 |
+
gemini_result = gemini_service.analyze(query)
|
| 1071 |
+
gemini_result["source"] = "gemini-fallback"
|
| 1072 |
+
gemini_result["llm_error"] = result["error"]
|
| 1073 |
+
return gemini_result
|
| 1074 |
+
return result
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
@api.get("/api/v2/status")
|
| 1078 |
+
async def api_v2_status():
|
| 1079 |
+
"""Status of phase-3 transformer models — what's loaded, what failed, what's available."""
|
| 1080 |
+
return transformer_loader.status()
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
@api.post("/api/analysis/network")
|
| 1084 |
async def api_analyze_network(request: Request):
|
| 1085 |
"""Network traffic analysis – called by backend mlService.analyzeNetworkTraffic()"""
|