| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch |
|
|
| |
| app = FastAPI() |
|
|
| |
| MODEL_NAME = "ealvaradob/bert-finetuned-phishing" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
| |
| class TextInput(BaseModel): |
| text: str |
|
|
| @app.post("/predict") |
| def predict_spam(input_data: TextInput): |
| |
| inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| |
| prediction = torch.argmax(outputs.logits, dim=1).item() |
|
|
| |
| return { |
| "text": input_data.text, |
| "prediction": "Phishing Email" if prediction == 1 else "Not Phishing Email" |
| } |
|
|
| |
| @app.get("/") |
| def home(): |
| return {"message": "Welcome to the Spam Classifier API!"} |
|
|