khoa-done commited on
Commit
9a6970e
·
1 Parent(s): 3f09186

Add all files for the model

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/Phishing-Detector-HF.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.13 (Chatbot)" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="9">
8
+ <item index="0" class="java.lang.String" itemvalue="chromadb" />
9
+ <item index="1" class="java.lang.String" itemvalue="protobuf" />
10
+ <item index="2" class="java.lang.String" itemvalue="langchain-community" />
11
+ <item index="3" class="java.lang.String" itemvalue="langchain" />
12
+ <item index="4" class="java.lang.String" itemvalue="streamlit" />
13
+ <item index="5" class="java.lang.String" itemvalue="langchain-huggingface" />
14
+ <item index="6" class="java.lang.String" itemvalue="python-dotenv" />
15
+ <item index="7" class="java.lang.String" itemvalue="pypdf" />
16
+ <item index="8" class="java.lang.String" itemvalue="pysqlite3-binary" />
17
+ </list>
18
+ </value>
19
+ </option>
20
+ </inspection_tool>
21
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
22
+ <option name="ignoredErrors">
23
+ <list>
24
+ <option value="N802" />
25
+ </list>
26
+ </option>
27
+ </inspection_tool>
28
+ </profile>
29
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.13 (Chatbot)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (Chatbot)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Phishing-Detector-HF.iml" filepath="$PROJECT_DIR$/.idea/Phishing-Detector-HF.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer
4
+ from huggingface_hub import hf_hub_download
5
+ import gradio as gr
6
+
7
+ # --- import your architecture ---
8
+ # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py)
9
+ # and update the import path accordingly.
10
+ from model import DeBERTaLSTMClassifier # <-- your class
11
+
12
+ # --------- Config ----------
13
+ REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint
14
+ CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name
15
+ MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
16
+ LABELS = ["benign", "phishing"] # adjust to your classes
17
+
18
+ # If your checkpoint contains hyperparams, you can fetch them like:
19
+ # checkpoint.get("config") or checkpoint.get("model_args")
20
+ # and pass into DeBERTaLSTMClassifier(**model_args)
21
+
22
+ # --------- Load model/tokenizer once (global) ----------
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
+
26
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
27
+ checkpoint = torch.load(ckpt_path, map_location=device)
28
+
29
+ # If you saved hyperparams in the checkpoint, use them:
30
+ model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
31
+ model = DeBERTaLSTMClassifier(**model_args)
32
+ model.load_state_dict(checkpoint["model_state_dict"])
33
+ model.to(device).eval()
34
+
35
+ # --------- Inference function ----------
36
+ def predict_fn(text: str):
37
+ if not text or not text.strip():
38
+ return {"error": "Please enter a URL or text."}
39
+
40
+ # Tokenize
41
+ inputs = tokenizer(
42
+ text,
43
+ return_tensors="pt",
44
+ truncation=True,
45
+ padding=True, # single example -> becomes [1, seq_len]
46
+ max_length=256 # adjust as used during training
47
+ )
48
+ # DeBERTa typically doesn't use token_type_ids
49
+ inputs.pop("token_type_ids", None)
50
+ # Move to device
51
+ inputs = {k: v.to(device) for k, v in inputs.items()}
52
+
53
+ with torch.no_grad():
54
+ logits = model(**inputs) # your model.forward should accept (input_ids, attention_mask)
55
+ probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
56
+
57
+ # Build label->prob mapping for Gradio Label output
58
+ # If LABELS length doesn't match logits dim, just return raw list
59
+ if len(LABELS) == len(probs):
60
+ return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
61
+ else:
62
+ return {f"class_{i}": float(p) for i, p in enumerate(probs)}
63
+
64
+ # --------- Gradio UI ----------
65
+ demo = gr.Interface(
66
+ fn=predict_fn,
67
+ inputs=gr.Textbox(label="URL or text", placeholder="e.g., http://suspicious-site.example"),
68
+ outputs=gr.Label(label="Prediction"),
69
+ title="Phishing Detector (DeBERTa + LSTM)",
70
+ description="Enter a URL/text. The model outputs class probabilities.",
71
+ examples=[
72
+ ["http://rendmoiunserviceeee.com"],
73
+ ["https://www.google.com"],
74
+ ["https://mail-secure-login-verify.example/path?token=..."]
75
+ ]
76
+ )
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch()
model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class DeBERTaLSTMClassifier(nn.Module):
6
+ def __init__(self, hidden_dim=128, num_labels=2):
7
+ super().__init__()
8
+
9
+ self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
10
+ for param in self.deberta.parameters():
11
+ param.requires_grad = False # freeze DeBERTa (as we don't have enough resources, we will not train DeBERTa in this model)
12
+
13
+ self.lstm = nn.LSTM(
14
+ input_size=self.deberta.config.hidden_size,
15
+ hidden_size=hidden_dim,
16
+ batch_first=True,
17
+ bidirectional=True
18
+ )
19
+
20
+ self.fc = nn.Linear(hidden_dim * 2, num_labels)
21
+
22
+ def forward(self, input_ids, attention_mask):
23
+ with torch.no_grad():
24
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
25
+
26
+ lstm_out, _ = self.lstm(outputs.last_hidden_state) # shape: [batch, seq_len, hidden*2]
27
+ final_hidden = lstm_out[:, -1, :] # last token output
28
+ logits = self.fc(final_hidden)
29
+ return logits
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers==4.41.2
3
+ huggingface_hub==0.24.5
4
+ safetensors
5
+ gradio==4.39.0