Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| from sklearn.feature_extraction.text import HashingVectorizer | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| app = FastAPI() | |
| # Define the RNN model class | |
| class EmailRNN(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_layers, dropout=0.3): | |
| super(EmailRNN, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout) | |
| self.fc = nn.Linear(hidden_size, 1) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) | |
| out, _ = self.gru(x.unsqueeze(1), h0) | |
| out = self.dropout(out[:, -1, :]) | |
| out = torch.sigmoid(self.fc(out)) | |
| return out | |
| # Load the trained model and vectorizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| input_size = 10000 | |
| hidden_size = 64 | |
| num_layers = 2 | |
| model = EmailRNN(input_size, hidden_size, num_layers).to(device) | |
| model.load_state_dict(torch.load('best_gmail_scam_detection_model.pth', map_location=device)) | |
| model.eval() | |
| with open('gmail_hash_vectorizer.pkl', 'rb') as f: | |
| vectorizer = pickle.load(f) | |
| # Prediction functions | |
| def preprocess_email(subject, body): | |
| combined_text = f"Subject: {subject}\n\nBody: {body}" | |
| system_message = f"System: This message needs to be classified as scam or ham. Message: {combined_text}" | |
| email_vectorized = vectorizer.transform([system_message]) | |
| email_tensor = torch.tensor(email_vectorized.toarray(), dtype=torch.float32) | |
| return email_tensor | |
| def predict_email(email_tensor): | |
| with torch.no_grad(): | |
| output = model(email_tensor.to(device)) | |
| prediction = output.item() | |
| return prediction | |
| class Email(BaseModel): | |
| subject: str | |
| body: str | |
| # CORS setup | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["POST"], | |
| allow_headers=["*"], | |
| ) | |
| async def predict(email: Email = Body(...)): | |
| try: | |
| email_tensor = preprocess_email(email.subject, email.body) | |
| prediction = predict_email(email_tensor) | |
| result = 'Scam' if prediction > 0.5 else 'Ham' | |
| return { | |
| 'subject': email.subject, | |
| 'body': email.body, | |
| 'predicted_result': result, | |
| 'confidence': prediction | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run(app, host='0.0.0.0', port=8001) |