Gmail / app.py
varun324242's picture
Update app.py
aa1d607 verified
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn as nn
import pickle
from sklearn.feature_extraction.text import HashingVectorizer
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
app = FastAPI()
# Define the RNN model class
class EmailRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, dropout=0.3):
super(EmailRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
self.fc = nn.Linear(hidden_size, 1)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.gru(x.unsqueeze(1), h0)
out = self.dropout(out[:, -1, :])
out = torch.sigmoid(self.fc(out))
return out
# Load the trained model and vectorizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = 10000
hidden_size = 64
num_layers = 2
model = EmailRNN(input_size, hidden_size, num_layers).to(device)
model.load_state_dict(torch.load('best_gmail_scam_detection_model.pth', map_location=device))
model.eval()
with open('gmail_hash_vectorizer.pkl', 'rb') as f:
vectorizer = pickle.load(f)
# Prediction functions
def preprocess_email(subject, body):
combined_text = f"Subject: {subject}\n\nBody: {body}"
system_message = f"System: This message needs to be classified as scam or ham. Message: {combined_text}"
email_vectorized = vectorizer.transform([system_message])
email_tensor = torch.tensor(email_vectorized.toarray(), dtype=torch.float32)
return email_tensor
def predict_email(email_tensor):
with torch.no_grad():
output = model(email_tensor.to(device))
prediction = output.item()
return prediction
class Email(BaseModel):
subject: str
body: str
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["POST"],
allow_headers=["*"],
)
@app.post('/predict')
async def predict(email: Email = Body(...)):
try:
email_tensor = preprocess_email(email.subject, email.body)
prediction = predict_email(email_tensor)
result = 'Scam' if prediction > 0.5 else 'Ham'
return {
'subject': email.subject,
'body': email.body,
'predicted_result': result,
'confidence': prediction
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8001)