| """ |
| XAI API Routes — Explainability, User Profiling, and Feedback |
| |
| Endpoints: |
| POST /api/xai/explain — SHAP token attribution + fused risk score |
| POST /api/xai/explain-deepfake — Deepfake detection explainability |
| GET /api/xai/user-profile/{user_id} — User vulnerability score & history |
| POST /api/xai/feedback — Submit false-positive / false-negative feedback |
| POST /api/xai/deepfake-feedback — Submit deepfake detection feedback |
| GET /api/xai/feedback/stats — Global feedback statistics |
| """ |
|
|
| import logging |
| from fastapi import APIRouter, HTTPException |
| from pydantic import BaseModel |
| from typing import Optional, List |
|
|
| from ..services.xai_engine import xai_engine |
| from ..services.fusion_engine import fuse_scores |
| from ..services.user_profile import user_profile_manager |
| from ..services.feedback_store import feedback_store |
| from ..services.email_analyzer import get_email_analyzer |
| from ..services.phishing_analyzer import analyze_url as rule_based_analyze |
|
|
| logger = logging.getLogger(__name__) |
| router = APIRouter() |
|
|
|
|
| |
|
|
| class ExplainRequest(BaseModel): |
| text: str |
| url: Optional[str] = None |
| user_id: Optional[str] = None |
|
|
|
|
| class ExplainResponse(BaseModel): |
| tokens: List[str] |
| shap_values: List[float] |
| base_value: float |
| model_score: float |
| model_label: str |
| fused_score: int |
| severity: str |
| cta: str |
| component_scores: dict |
| fallback: bool |
| timestamp: str |
|
|
|
|
| class FeedbackRequest(BaseModel): |
| url: str |
| original_verdict: str |
| original_score: int |
| user_label: str |
| user_id: Optional[str] = None |
| raw_text: Optional[str] = None |
|
|
|
|
| class DeepfakeExplainRequest(BaseModel): |
| verdict: str |
| confidence: float |
| media_url: Optional[str] = None |
| details: Optional[dict] = None |
| user_id: Optional[str] = None |
|
|
|
|
| class DeepfakeFeature(BaseModel): |
| name: str |
| value: str |
| impact: str |
|
|
|
|
| class DeepfakeExplainResponse(BaseModel): |
| risk_score: int |
| severity: str |
| cta: str |
| features: List[DeepfakeFeature] |
|
|
|
|
| class DeepfakeFeedbackRequest(BaseModel): |
| media_url: str |
| original_verdict: str |
| original_confidence: float |
| user_label: str |
| user_id: Optional[str] = None |
|
|
|
|
| |
|
|
| @router.post("/xai/explain", response_model=ExplainResponse, tags=["XAI"]) |
| async def explain_text(request: ExplainRequest): |
| """ |
| Analyze text with SHAP token attribution and produce a fused risk score. |
| |
| - Runs DistilBERT + SHAP for token-level explainability |
| - Runs rule-based URL analysis if a URL is provided |
| - Fuses all signals via the fusion engine |
| - Incorporates user history boost if user_id is provided |
| """ |
| if not request.text or len(request.text.strip()) < 3: |
| raise HTTPException(status_code=400, detail="Text must be at least 3 characters") |
|
|
| try: |
| |
| if not xai_engine._initialized: |
| analyzer = get_email_analyzer() |
| if analyzer.phishing_model: |
| xai_engine.initialize(analyzer.phishing_model) |
|
|
| |
| xai_result = await xai_engine.explain(request.text) |
|
|
| |
| rule_score = 0.0 |
| if request.url: |
| rule_result = rule_based_analyze(request.url) |
| rule_score = rule_result.get("riskScore", 0) |
|
|
| |
| model_score_raw = xai_result.get("model_score", 0.0) |
| model_label = xai_result.get("model_label", "UNKNOWN") |
| |
| |
| if model_label == "LABEL_1": |
| model_score = model_score_raw * 100 |
| elif model_label == "LABEL_0": |
| model_score = (1 - model_score_raw) * 100 |
| else: |
| model_score = 0.0 |
|
|
| |
| user_boost = 0.0 |
| if request.user_id: |
| domain = "" |
| if request.url: |
| try: |
| from urllib.parse import urlparse |
| domain = urlparse(request.url).hostname or "" |
| except Exception: |
| pass |
| user_boost = user_profile_manager.get_history_anomaly_boost( |
| request.user_id, domain |
| ) |
|
|
| |
| fused = fuse_scores( |
| rule_based_score=rule_score, |
| model_score=model_score, |
| shap_values=xai_result.get("shap_values", []), |
| user_history_boost=user_boost |
| ) |
| |
| fused_score_val = int(round(fused["fused_score"])) |
|
|
| |
| if fused_score_val >= 80: |
| severity = "CRITICAL" |
| cta = "This URL is very likely malicious. Do not visit or share it." |
| elif fused_score_val >= 60: |
| severity = "HIGH" |
| cta = "This URL shows strong signs of phishing. Verify before proceeding." |
| elif fused_score_val >= 40: |
| severity = "MODERATE" |
| cta = "This URL has some suspicious characteristics. Use caution." |
| else: |
| severity = "LOW" |
| cta = "This URL appears safe. Standard caution is still recommended." |
|
|
| |
| if request.user_id: |
| domain = "" |
| if request.url: |
| try: |
| from urllib.parse import urlparse |
| domain = urlparse(request.url).hostname or "" |
| except Exception: |
| pass |
| user_profile_manager.record_scan( |
| user_id=request.user_id, |
| domain=domain, |
| risk_score=fused_score_val, |
| category=_score_to_category(fused_score_val) |
| ) |
|
|
| return ExplainResponse( |
| tokens=xai_result.get("tokens", []), |
| shap_values=xai_result.get("shap_values", []), |
| base_value=xai_result.get("base_value", 0.0), |
| model_score=round(model_score, 2), |
| model_label=model_label, |
| fused_score=fused_score_val, |
| severity=severity, |
| cta=cta, |
| component_scores=fused["component_scores"], |
| fallback=xai_result.get("fallback", True), |
| timestamp=xai_result.get("timestamp", "") |
| ) |
|
|
| except Exception as e: |
| logger.error(f"[XAI] Explain error: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"XAI analysis failed: {str(e)}") |
|
|
|
|
| @router.get("/xai/user-profile/{user_id}", tags=["XAI"]) |
| async def get_user_profile(user_id: str): |
| """ |
| Get user vulnerability score, scan history summary, and prediction trend. |
| """ |
| try: |
| return user_profile_manager.get_profile_summary(user_id) |
| except Exception as e: |
| logger.error(f"[XAI] Profile error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to load user profile") |
|
|
|
|
| @router.post("/xai/feedback", tags=["XAI"]) |
| async def submit_feedback(request: FeedbackRequest): |
| """ |
| Submit user feedback on a detection result (false positive / false negative). |
| """ |
| if request.user_label not in ("safe", "phishing"): |
| raise HTTPException( |
| status_code=400, |
| detail="user_label must be 'safe' or 'phishing'" |
| ) |
|
|
| try: |
| entry = feedback_store.add_feedback( |
| url=request.url, |
| original_verdict=request.original_verdict, |
| original_score=request.original_score, |
| user_label=request.user_label, |
| user_id=request.user_id, |
| raw_text=request.raw_text |
| ) |
| return {"status": "ok", "feedback_id": entry["id"]} |
| except Exception as e: |
| logger.error(f"[XAI] Feedback error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to store feedback") |
|
|
|
|
| @router.get("/xai/feedback/stats", tags=["XAI"]) |
| async def get_feedback_stats(): |
| """ |
| Get global feedback statistics: FP/FN rates, total feedback count, recent entries. |
| """ |
| try: |
| return feedback_store.get_stats() |
| except Exception as e: |
| logger.error(f"[XAI] Feedback stats error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to compute feedback stats") |
|
|
|
|
| @router.post("/xai/explain-deepfake", response_model=DeepfakeExplainResponse, tags=["XAI"]) |
| async def explain_deepfake(request: DeepfakeExplainRequest): |
| """ |
| Generate an explainability report for a deepfake detection result. |
| |
| Maps detection metadata into severity, CTA, and feature-level explanations. |
| """ |
| try: |
| verdict = request.verdict |
| confidence = request.confidence |
| details = request.details or {} |
|
|
| |
| if verdict == "DEEPFAKE": |
| risk_score = int(min(confidence * 100, 100)) |
| elif verdict == "AUTHENTIC": |
| risk_score = int(max((1 - confidence) * 100, 0)) |
| else: |
| risk_score = 50 |
|
|
| |
| if risk_score >= 70: |
| severity = "CRITICAL" |
| elif risk_score >= 50: |
| severity = "HIGH" |
| elif risk_score >= 30: |
| severity = "MODERATE" |
| else: |
| severity = "LOW" |
|
|
| |
| cta_map = { |
| "CRITICAL": "This media is very likely manipulated. Do not share or trust this content without independent verification.", |
| "HIGH": "This media shows strong signs of manipulation. Verify the source before sharing or acting on it.", |
| "MODERATE": "This media has some suspicious characteristics. Consider verifying with the original source.", |
| "LOW": "This media appears genuine. Standard caution is still recommended for any online content.", |
| } |
| cta = cta_map[severity] |
|
|
| |
| features = [] |
|
|
| model = details.get("model", details.get("method", "")) |
| if model: |
| features.append(DeepfakeFeature( |
| name="Detection Model", |
| value=model.split("/")[-1] if "/" in model else model, |
| impact="high" if "Deep-Fake" in model or "ViT" in model else "medium", |
| )) |
|
|
| if "average_score" in details: |
| avg = details["average_score"] |
| avg_pct = round(avg * 100) |
| features.append(DeepfakeFeature( |
| name="Average Deepfake Score", |
| value=f"{avg_pct}%", |
| impact="high" if avg > 0.6 else "medium" if avg > 0.3 else "low", |
| )) |
|
|
| if "score" in details: |
| score = details["score"] |
| score_pct = round(score * 100) |
| features.append(DeepfakeFeature( |
| name="Model Detection Score", |
| value=f"{score_pct}%", |
| impact="high" if score > 0.6 else "medium" if score > 0.3 else "low", |
| )) |
|
|
| if "faces_detected" in details: |
| n_faces = details["faces_detected"] |
| features.append(DeepfakeFeature( |
| name="Faces Detected", |
| value=str(n_faces), |
| impact="medium", |
| )) |
|
|
| if "max_score" in details: |
| mx = details["max_score"] |
| features.append(DeepfakeFeature( |
| name="Peak Frame Score", |
| value=f"{round(mx * 100)}%", |
| impact="high" if mx > 0.7 else "medium", |
| )) |
|
|
| if "frames_analyzed" in details: |
| features.append(DeepfakeFeature( |
| name="Frames Analyzed", |
| value=f"{details['frames_analyzed']}/{details.get('total_frames', '?')}", |
| impact="low", |
| )) |
|
|
| if "method" in details: |
| method = details["method"] |
| if method == "face_analysis": |
| features.append(DeepfakeFeature( |
| name="Analysis Method", |
| value="Heuristic face analysis (blur + frequency)", |
| impact="medium", |
| )) |
|
|
| |
| features.append(DeepfakeFeature( |
| name="Confidence Level", |
| value=f"{round(confidence * 100)}%", |
| impact="high" if confidence > 0.8 else "medium" if confidence > 0.5 else "low", |
| )) |
|
|
| return DeepfakeExplainResponse( |
| risk_score=risk_score, |
| severity=severity, |
| cta=cta, |
| features=features, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"[XAI] Deepfake explain error: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"Deepfake XAI analysis failed: {str(e)}") |
|
|
|
|
| @router.post("/xai/deepfake-feedback", tags=["XAI"]) |
| async def submit_deepfake_feedback(request: DeepfakeFeedbackRequest): |
| """ |
| Submit user feedback on a deepfake detection result. |
| """ |
| if request.user_label not in ("authentic", "deepfake"): |
| raise HTTPException( |
| status_code=400, |
| detail="user_label must be 'authentic' or 'deepfake'" |
| ) |
|
|
| try: |
| entry = feedback_store.add_feedback( |
| url=request.media_url, |
| original_verdict=request.original_verdict, |
| original_score=int(request.original_confidence * 100), |
| user_label=request.user_label, |
| user_id=request.user_id, |
| ) |
| return {"status": "ok", "feedback_id": entry["id"]} |
| except Exception as e: |
| logger.error(f"[XAI] Deepfake feedback error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to store deepfake feedback") |
|
|
|
|
| def _score_to_category(score: int) -> str: |
| """Map risk score to category string.""" |
| if score >= 70: |
| return "high_risk" |
| elif score >= 40: |
| return "medium_risk" |
| elif score >= 20: |
| return "low_risk" |
| else: |
| return "safe" |
|
|