|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
import os |
|
|
import subprocess |
|
|
import re |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from typing import List, Tuple, Optional |
|
|
from .ml import get_model_prediction, check_model_status |
|
|
|
|
|
app = FastAPI(title="AI-Powered Phishing Email Detection System") |
|
|
|
|
|
|
|
|
HF_SPACE_DIRECT_URL = "https://lleratodev-multinomial-nb-phishing-email-detection-api.hf.space" |
|
|
HF_SPACE_OLD_URL = "https://huggingface.co/spaces/lleratodev/multinomial-nb-phishing-email-detection-api" |
|
|
|
|
|
|
|
|
origins = [ |
|
|
"https://ai-powered-phishing-email-detection-system.vercel.app", |
|
|
"https://ai-powered-phishing-email-det-git-1342ca-lerato1ofones-projects.vercel.app", |
|
|
"https://ai-powered-phishing-email-detection-syst-lerato1ofones-projects.vercel.app", |
|
|
"http://localhost:3000", |
|
|
HF_SPACE_DIRECT_URL, |
|
|
HF_SPACE_OLD_URL, |
|
|
|
|
|
|
|
|
re.compile(r"https://.+\.hf\.space"), |
|
|
] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class EmailInput(BaseModel): |
|
|
subject: Optional[str] = "" |
|
|
sender: Optional[str] = "" |
|
|
body: str |
|
|
model_choice: Optional[str] = "nb" |
|
|
|
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
prediction: str |
|
|
label: int |
|
|
confidence: float |
|
|
explanation: List[Tuple[str, float]] |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "AI-Powered Phishing Email Detection API. POST to /predict with 'subject', 'sender', 'body'."} |
|
|
|
|
|
@app.get("/debug-info") |
|
|
async def get_debug_info(): |
|
|
try: |
|
|
cwd = os.getcwd() |
|
|
ls_output = subprocess.check_output(["ls", "-la", cwd], text=True) |
|
|
env_vars = dict(os.environ) |
|
|
return { |
|
|
"cwd": cwd, |
|
|
"ls_output": ls_output, |
|
|
"environment_variables": env_vars |
|
|
} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
@app.get("/status") |
|
|
async def model_status(): |
|
|
return check_model_status() |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
|
async def predict_email(email_input: EmailInput): |
|
|
|
|
|
if email_input.model_choice not in ["nb", "bert-mini"]: |
|
|
return PredictionResponse(prediction="Error", label=-1, confidence=0.0, explanation=[], |
|
|
error="Invalid model_choice. Please use 'nb' or 'bert-mini'.") |
|
|
try: |
|
|
result = get_model_prediction( |
|
|
subject=email_input.subject or "", |
|
|
sender=email_input.sender or "", |
|
|
body=email_input.body, |
|
|
model_choice=email_input.model_choice |
|
|
) |
|
|
return PredictionResponse(**result) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return PredictionResponse(prediction="Error", label=-1, confidence=0.0, explanation=[], |
|
|
error=f"Critical API endpoint error: {str(e)}") |