Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # --- import your architecture --- | |
| # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) | |
| # and update the import path accordingly. | |
| from model import DeBERTaLSTMClassifier # <-- your class | |
| # --------- Config ---------- | |
| REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint | |
| CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name | |
| MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone | |
| LABELS = ["benign", "phishing"] # adjust to your classes | |
| # If your checkpoint contains hyperparams, you can fetch them like: | |
| # checkpoint.get("config") or checkpoint.get("model_args") | |
| # and pass into DeBERTaLSTMClassifier(**model_args) | |
| # --------- Load model/tokenizer once (global) ---------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| # If you saved hyperparams in the checkpoint, use them: | |
| model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} | |
| model = DeBERTaLSTMClassifier(**model_args) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device).eval() | |
| # --------- Inference function ---------- | |
| def predict_fn(text: str): | |
| if not text or not text.strip(): | |
| return {"error": "Please enter a URL or text."} | |
| # Tokenize | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, # single example -> becomes [1, seq_len] | |
| max_length=256 # adjust as used during training | |
| ) | |
| # DeBERTa typically doesn't use token_type_ids | |
| inputs.pop("token_type_ids", None) | |
| # Move to device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs) # your model.forward should accept (input_ids, attention_mask) | |
| probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
| # Build label->prob mapping for Gradio Label output | |
| # If LABELS length doesn't match logits dim, just return raw list | |
| if len(LABELS) == len(probs): | |
| return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
| else: | |
| return {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
| # --------- Gradio UI ---------- | |
| demo = gr.Interface( | |
| fn=predict_fn, | |
| inputs=gr.Textbox(label="URL or text", placeholder="e.g., http://suspicious-site.example"), | |
| outputs=gr.Label(label="Prediction"), | |
| title="Phishing Detector (DeBERTa + LSTM)", | |
| description="Enter a URL/text. The model outputs class probabilities.", | |
| examples=[ | |
| ["http://rendmoiunserviceeee.com"], | |
| ["https://www.google.com"], | |
| ["https://mail-secure-login-verify.example/path?token=..."] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |