Spaces:
Runtime error
Runtime error
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| MODEL_ID = "rudycaz/qwen3-4b-phishing-detection" | |
| app = FastAPI( | |
| title="Phishing Email Detection API", | |
| description="Detects phishing emails using Qwen3-4B phishing detection model", | |
| version="1.0.0" | |
| ) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True | |
| ) | |
| print("Loading model...") | |
| # bnb_config = BitsAndBytesConfig( | |
| # load_in_8bit=True | |
| # ) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| trust_remote_code=True | |
| ) | |
| # model = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_ID, | |
| # device_map="auto", | |
| # torch_dtype=torch.float16, | |
| # trust_remote_code=True, | |
| # ignore_mismatched_sizes=True | |
| # ) | |
| print("Model loaded successfully!") | |
| class EmailRequest(BaseModel): | |
| email: str | |
| class EmailResponse(BaseModel): | |
| prediction: str | |
| model: str | |
| def classify_email(email_text: str): | |
| prompt = ( | |
| "You are a security assistant. Classify the following email as PHISHING or LEGIT.\n\n" | |
| f"EMAIL:\n{email_text}\n\n" | |
| "Answer with exactly one word: PHISHING or LEGIT." | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=4, | |
| temperature=0 | |
| ) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| result_upper = result.upper() | |
| if "PHISHING" in result_upper: | |
| return "PHISHING" | |
| elif "LEGIT" in result_upper: | |
| return "LEGIT" | |
| else: | |
| return "UNKNOWN" | |
| def root(): | |
| return {"message": "Phishing Detection API is running"} | |
| def detect_email(data: EmailRequest): | |
| prediction = classify_email(data.email) | |
| return { | |
| "prediction": prediction, | |
| "model": MODEL_ID | |
| } |