Che237 commited on
Commit
df8ae40
·
verified ·
1 Parent(s): 7951cc7

Phase 3: Add transformer models (URL BERT + DGA + SecurityLLM via Inference API)

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py CHANGED
@@ -504,6 +504,182 @@ def extract_url_features(url: str) -> Dict:
504
  }
505
 
506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  # ============================================================================
508
  # NOTEBOOK EXECUTION (existing Gradio functionality)
509
  # ============================================================================
@@ -853,6 +1029,57 @@ async def api_list_models():
853
  return {"models": result, "total": len(ml_loader.MODEL_NAMES), "loaded": len(ml_loader.models)}
854
 
855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  @api.post("/api/analysis/network")
857
  async def api_analyze_network(request: Request):
858
  """Network traffic analysis – called by backend mlService.analyzeNetworkTraffic()"""
 
504
  }
505
 
506
 
507
+ # ============================================================================
508
+ # TRANSFORMER MODEL LOADER (Phase 3 — real HF transformer models)
509
+ # ============================================================================
510
+ #
511
+ # Loads pretrained BERT-based classifiers from the Hugging Face Hub.
512
+ # Models are loaded lazily on first request to keep cold-start fast.
513
+ # A 7B Security LLM is NOT loaded locally (too big for free tier) —
514
+ # it's accessed via the HF Inference API on demand.
515
+
516
+ class TransformerModelLoader:
517
+ """Loads pretrained Transformer classifiers from the HF Hub on demand."""
518
+
519
+ # Model registry — name → HF repo + task description
520
+ REGISTRY = {
521
+ "url_phishing_bert": {
522
+ "repo": "elftsdmr/malware-url-detect",
523
+ "task": "text-classification",
524
+ "labels": ["benign", "malicious"],
525
+ "desc": "BERT-based URL phishing/malware classifier",
526
+ },
527
+ "dga_detector": {
528
+ "repo": "YangYang-Research/dga-detection",
529
+ "task": "text-classification",
530
+ "labels": ["legit", "dga"],
531
+ "desc": "Domain Generation Algorithm detector (45-char domain input)",
532
+ },
533
+ }
534
+ SECURITY_LLM_REPO = "ZySec-AI/SecurityLLM" # Used via HF Inference API only
535
+
536
+ def __init__(self):
537
+ self.pipelines = {} # name → transformers.Pipeline (lazy-loaded)
538
+ self.load_errors = {} # name → last error message
539
+ self.transformers_available = False
540
+ try:
541
+ import transformers # noqa: F401
542
+ import torch # noqa: F401
543
+ self.transformers_available = True
544
+ logger.info("✅ transformers + torch available")
545
+ except ImportError as e:
546
+ logger.warning(f"⚠️ transformers/torch not installed: {e}")
547
+
548
+ def _ensure(self, name: str):
549
+ """Load a pipeline on first use. Returns the pipeline or None on failure."""
550
+ if name in self.pipelines:
551
+ return self.pipelines[name]
552
+ if not self.transformers_available:
553
+ self.load_errors[name] = "transformers/torch not installed"
554
+ return None
555
+ if name not in self.REGISTRY:
556
+ self.load_errors[name] = f"Unknown model: {name}"
557
+ return None
558
+ spec = self.REGISTRY[name]
559
+ try:
560
+ from transformers import pipeline
561
+ logger.info(f"⏳ Loading {name} from {spec['repo']}...")
562
+ pipe = pipeline(spec["task"], model=spec["repo"], device=-1, top_k=None)
563
+ self.pipelines[name] = pipe
564
+ logger.info(f"✅ Loaded {name}")
565
+ return pipe
566
+ except Exception as e:
567
+ err = f"{type(e).__name__}: {str(e)[:200]}"
568
+ self.load_errors[name] = err
569
+ logger.error(f"❌ Failed to load {name}: {err}")
570
+ return None
571
+
572
+ def predict_url_phishing(self, url: str) -> Dict:
573
+ """Classify a URL as benign or malicious using elftsdmr/malware-url-detect."""
574
+ pipe = self._ensure("url_phishing_bert")
575
+ if pipe is None:
576
+ return self._unavailable("url_phishing_bert")
577
+ try:
578
+ # Strip the protocol — model was trained on bare URLs
579
+ text = url.replace("https://", "").replace("http://", "")[:512]
580
+ result = pipe(text)
581
+ scores = result[0] if isinstance(result[0], list) else result
582
+ return self._format_classification(scores, "url_phishing_bert")
583
+ except Exception as e:
584
+ return self._error("url_phishing_bert", e)
585
+
586
+ def predict_dga(self, domain: str) -> Dict:
587
+ """Classify a domain as legitimate or DGA-generated."""
588
+ pipe = self._ensure("dga_detector")
589
+ if pipe is None:
590
+ return self._unavailable("dga_detector")
591
+ try:
592
+ # Model expects bare domain, optimized for ≤45 chars
593
+ d = domain.lower().strip()[:45]
594
+ result = pipe(d)
595
+ scores = result[0] if isinstance(result[0], list) else result
596
+ return self._format_classification(scores, "dga_detector")
597
+ except Exception as e:
598
+ return self._error("dga_detector", e)
599
+
600
+ def security_chat(self, query: str, max_tokens: int = 512) -> Dict:
601
+ """Cybersecurity Q&A via ZySec-AI/SecurityLLM hosted on HF Inference API.
602
+ Falls back to Gemini when the LLM is rate-limited or HF token is missing."""
603
+ if not HF_TOKEN:
604
+ return {
605
+ "model": "security-llm",
606
+ "response": None,
607
+ "error": "HF_TOKEN env var required to call ZySec-AI/SecurityLLM via Inference API",
608
+ "fallback_available": gemini_service.ready,
609
+ }
610
+ try:
611
+ import requests
612
+ url = f"https://api-inference.huggingface.co/models/{self.SECURITY_LLM_REPO}"
613
+ headers = {"Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json"}
614
+ payload = {
615
+ "inputs": query[:2000],
616
+ "parameters": {"max_new_tokens": max_tokens, "temperature": 0.3, "return_full_text": False},
617
+ "options": {"wait_for_model": True},
618
+ }
619
+ r = requests.post(url, headers=headers, json=payload, timeout=45)
620
+ if r.status_code == 200:
621
+ data = r.json()
622
+ text = data[0].get("generated_text") if isinstance(data, list) and data else (data.get("generated_text") if isinstance(data, dict) else str(data))
623
+ return {
624
+ "model": "security-llm",
625
+ "source": "huggingface_inference_api",
626
+ "response": text,
627
+ "model_id": self.SECURITY_LLM_REPO,
628
+ }
629
+ return {
630
+ "model": "security-llm",
631
+ "error": f"HF Inference API HTTP {r.status_code}: {r.text[:200]}",
632
+ "fallback_available": gemini_service.ready,
633
+ }
634
+ except Exception as e:
635
+ return {
636
+ "model": "security-llm",
637
+ "error": f"{type(e).__name__}: {str(e)[:200]}",
638
+ "fallback_available": gemini_service.ready,
639
+ }
640
+
641
+ def status(self) -> Dict:
642
+ return {
643
+ "transformers_available": self.transformers_available,
644
+ "loaded": list(self.pipelines.keys()),
645
+ "available": list(self.REGISTRY.keys()) + ["security_llm (via HF Inference API)"],
646
+ "load_errors": self.load_errors,
647
+ }
648
+
649
+ @staticmethod
650
+ def _format_classification(scores, model_name) -> Dict:
651
+ """Normalize HF text-classification output to the cyberforge schema."""
652
+ if not scores:
653
+ return {"model": model_name, "error": "Empty scores"}
654
+ # scores is a list of {label, score} dicts. Find the threat label.
655
+ threat_labels = {"malicious", "phishing", "malware", "dga", "label_1", "1"}
656
+ # Top prediction
657
+ top = max(scores, key=lambda s: s["score"]) if isinstance(scores, list) else scores
658
+ is_threat = str(top["label"]).lower() in threat_labels
659
+ # Threat score: probability of the threat class (not just the top class)
660
+ threat_score = top["score"] if is_threat else 1.0 - top["score"]
661
+ return {
662
+ "model": model_name,
663
+ "prediction": top["label"],
664
+ "is_threat": is_threat,
665
+ "confidence": round(top["score"] * 100, 2),
666
+ "threat_score": round(threat_score, 4),
667
+ "all_scores": scores,
668
+ "inference_source": "huggingface_transformer",
669
+ }
670
+
671
+ @staticmethod
672
+ def _unavailable(model_name) -> Dict:
673
+ return {"model": model_name, "error": "Model unavailable — see /api/v2/status"}
674
+
675
+ @staticmethod
676
+ def _error(model_name, e: Exception) -> Dict:
677
+ return {"model": model_name, "error": f"{type(e).__name__}: {str(e)[:200]}"}
678
+
679
+
680
+ transformer_loader = TransformerModelLoader()
681
+
682
+
683
  # ============================================================================
684
  # NOTEBOOK EXECUTION (existing Gradio functionality)
685
  # ============================================================================
 
1029
  return {"models": result, "total": len(ml_loader.MODEL_NAMES), "loaded": len(ml_loader.models)}
1030
 
1031
 
1032
+ # ============================================================================
1033
+ # PHASE-3 ENDPOINTS — Real HF Transformer models
1034
+ # ============================================================================
1035
+
1036
+ @api.post("/api/v2/url-classify")
1037
+ async def api_v2_url_classify(request: Request):
1038
+ """URL phishing/malware classification using elftsdmr/malware-url-detect (BERT).
1039
+ Body: { "url": "https://..." }"""
1040
+ body = await request.json()
1041
+ url = body.get("url", "").strip()
1042
+ if not url:
1043
+ return JSONResponse(status_code=400, content={"detail": "url required"})
1044
+ return transformer_loader.predict_url_phishing(url)
1045
+
1046
+
1047
+ @api.post("/api/v2/dga-detect")
1048
+ async def api_v2_dga_detect(request: Request):
1049
+ """DGA-generated domain detection using YangYang-Research/dga-detection.
1050
+ Body: { "domain": "abc123xyz.com" }"""
1051
+ body = await request.json()
1052
+ domain = body.get("domain", "").strip()
1053
+ if not domain:
1054
+ return JSONResponse(status_code=400, content={"detail": "domain required"})
1055
+ return transformer_loader.predict_dga(domain)
1056
+
1057
+
1058
+ @api.post("/api/v2/security-chat")
1059
+ async def api_v2_security_chat(request: Request):
1060
+ """Cybersecurity Q&A via ZySec-AI/SecurityLLM (HF Inference API).
1061
+ Body: { "query": "...", "max_tokens": 512 }"""
1062
+ body = await request.json()
1063
+ query = body.get("query", "").strip()
1064
+ if not query:
1065
+ return JSONResponse(status_code=400, content={"detail": "query required"})
1066
+ max_tokens = int(body.get("max_tokens", 512))
1067
+ result = transformer_loader.security_chat(query, max_tokens=max_tokens)
1068
+ # Auto-fallback to Gemini when LLM unavailable
1069
+ if result.get("error") and gemini_service.ready:
1070
+ gemini_result = gemini_service.analyze(query)
1071
+ gemini_result["source"] = "gemini-fallback"
1072
+ gemini_result["llm_error"] = result["error"]
1073
+ return gemini_result
1074
+ return result
1075
+
1076
+
1077
+ @api.get("/api/v2/status")
1078
+ async def api_v2_status():
1079
+ """Status of phase-3 transformer models — what's loaded, what failed, what's available."""
1080
+ return transformer_loader.status()
1081
+
1082
+
1083
  @api.post("/api/analysis/network")
1084
  async def api_analyze_network(request: Request):
1085
  """Network traffic analysis – called by backend mlService.analyzeNetworkTraffic()"""