| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import torch |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | |
| | MODEL_NAME = "ealvaradob/bert-finetuned-phishing" |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
| |
|
| | |
| | class TextInput(BaseModel): |
| | text: str |
| |
|
| | @app.post("/predict") |
| | def predict_spam(input_data: TextInput): |
| | |
| | inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| |
|
| | |
| | prediction = torch.argmax(outputs.logits, dim=1).item() |
| |
|
| | |
| | return { |
| | "text": input_data.text, |
| | "prediction": "Phishing Email" if prediction == 1 else "Not Phishing Email" |
| | } |
| |
|
| | |
| | @app.get("/") |
| | def home(): |
| | return {"message": "Welcome to the Spam Classifier API!"} |
| |
|