Spaces:
Sleeping
Sleeping
| import gdown | |
| import torch | |
| from fastapi import FastAPI | |
| from transformers import AutoModel, BertTokenizerFast | |
| from pydantic import BaseModel | |
| from model import BERT_Arch | |
| from preprocess_data import remove_html, remove_links | |
| import gradio as gr | |
| # Define input model structure | |
| class TextRequest(BaseModel): | |
| text: str | |
| # Download model from Google Drive | |
| model_url = "https://drive.google.com/uc?id=16ZWVa0d2V0T3s11Oq86rLOTA6bOR0DnR" | |
| model_path = "model.pth" | |
| gdown.download(model_url, model_path, quiet=False) | |
| # Load pre-trained BERT model | |
| bert = AutoModel.from_pretrained("bert-base-uncased") | |
| for param in bert.parameters(): | |
| param.requires_grad = False | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load your custom BERT-based model | |
| model = BERT_Arch(bert) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Load tokenizer | |
| tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| def home(): | |
| return {"message": "Phishing Detection API is running!"} | |
| def predict(request: TextRequest): | |
| return {"prediction": classify_text(request.text)} | |
| # Function to classify text | |
| def classify_text(text: str) -> str: | |
| try: | |
| text = text.strip() | |
| text = remove_html(text) | |
| text = remove_links(text) | |
| tokens = tokenizer( | |
| text, return_tensors="pt", truncation=True, | |
| padding="max_length", max_length=512 | |
| ) | |
| input_ids = tokens["input_ids"].to(device) | |
| attention_mask = tokens["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| output = model(input_ids, attention_mask) | |
| prediction = torch.argmax(output.cpu(), dim=1).item() | |
| return "Phishing" if prediction == 1 else "Not Phishing" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio UI | |
| gr.Interface( | |
| fn=classify_text, | |
| inputs=gr.Textbox(label="Enter website content or email text"), | |
| outputs=gr.Label(label="Prediction"), | |
| title="Phishing Text Detector", | |
| description="Website text to check if it's phishing." | |
| ).launch() | |