File size: 3,575 Bytes
b5c79a0
 
3ac594e
 
eccd941
47e416f
b5c79a0
c551752
b5c79a0
 
47e416f
30cda43
eccd941
 
 
47e416f
eccd941
 
1b555ea
 
eccd941
 
 
 
 
 
beecd3e
47e416f
 
 
1b555ea
47e416f
 
 
 
b5c79a0
 
 
 
 
 
c551752
b5c79a0
 
 
 
 
 
 
 
 
 
 
 
 
 
3ac594e
 
 
 
 
 
 
 
 
 
 
 
 
 
c551752
 
 
 
 
b5c79a0
 
c551752
 
 
 
b5c79a0
c551752
 
 
 
 
b5c79a0
 
c551752
b5c79a0
c551752
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from fastapi import FastAPI
from pydantic import BaseModel
import os
import subprocess
import re # Import the 're' module for regular expressions
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Tuple, Optional 
from .ml import get_model_prediction, check_model_status

app = FastAPI(title="AI-Powered Phishing Email Detection System")

# Hugging face URLs
HF_SPACE_DIRECT_URL = "https://lleratodev-multinomial-nb-phishing-email-detection-api.hf.space"
HF_SPACE_OLD_URL = "https://huggingface.co/spaces/lleratodev/multinomial-nb-phishing-email-detection-api"

# Define allowed origins for CORS
origins = [
    "https://ai-powered-phishing-email-detection-system.vercel.app",  # Our Next.js frontend hosted on Vercel 
    "https://ai-powered-phishing-email-det-git-1342ca-lerato1ofones-projects.vercel.app",  # Our Next.js frontend hosted on Vercel 
    "https://ai-powered-phishing-email-detection-syst-lerato1ofones-projects.vercel.app",  # Our Next.js frontend hosted on Vercel 
    "http://localhost:3000",  # Local development
    HF_SPACE_DIRECT_URL,      # Specific direct URL for Hugging Face Space
    HF_SPACE_OLD_URL,         # Older Hugging Face Space URL
    # Regex to match any other *.hf.space subdomains if needed.
    # This matches https://<any-subdomain(s)>.hf.space
    re.compile(r"https://.+\.hf\.space"),
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,  # Use the defined list of origins
    allow_credentials=True, # Allows cookies to be included in requests
    allow_methods=["*"],  # Allows all methods (GET, POST, etc.)
    allow_headers=["*"],  # Allows all headers
)

# Input data model
class EmailInput(BaseModel):
    subject: Optional[str] = ""
    sender: Optional[str] = ""
    body: str
    model_choice: Optional[str] = "nb" # Default to Naive Bayes

# Define output data model
class PredictionResponse(BaseModel):
    prediction: str
    label: int
    confidence: float
    explanation: List[Tuple[str, float]] 
    error: Optional[str] = None


@app.get("/")
async def root():
    return {"message": "AI-Powered Phishing Email Detection API. POST to /predict with 'subject', 'sender', 'body'."}

@app.get("/debug-info")
async def get_debug_info():
    try:
        cwd = os.getcwd()
        ls_output = subprocess.check_output(["ls", "-la", cwd], text=True)
        env_vars = dict(os.environ)
        return {
            "cwd": cwd,
            "ls_output": ls_output,
            "environment_variables": env_vars
        }
    except Exception as e:
        return {"error": str(e)}

@app.get("/status")
async def model_status():
    return check_model_status() 


@app.post("/predict", response_model=PredictionResponse)
async def predict_email(email_input: EmailInput):
 
    if email_input.model_choice not in ["nb", "bert-mini"]:
        return PredictionResponse(prediction="Error", label=-1, confidence=0.0, explanation=[],
                                  error="Invalid model_choice. Please use 'nb' or 'bert-mini'.")
    try:
        result = get_model_prediction(
            subject=email_input.subject or "", 
            sender=email_input.sender or "",
            body=email_input.body,
            model_choice=email_input.model_choice
        )
        return PredictionResponse(**result)

    except Exception as e:
        # Fallback for truly unexpected errors in the endpoint itself
        return PredictionResponse(prediction="Error", label=-1, confidence=0.0, explanation=[],
                                  error=f"Critical API endpoint error: {str(e)}")