nijatmammadov commited on
Commit
783b3c4
·
verified ·
1 Parent(s): b5e41ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -4,14 +4,14 @@ from fastapi import FastAPI
4
  from transformers import AutoModel, BertTokenizerFast
5
  from pydantic import BaseModel
6
  from model import BERT_Arch
7
- from preprocess_data import remove_html,remove_links
8
  import gradio as gr
9
 
 
10
  class TextRequest(BaseModel):
11
  text: str
12
 
13
  # Download model from Google Drive
14
- #link:https://drive.google.com/drive/folders/102UPd446eHCCENR58EC3UxnJfcYkBa8U?usp=sharing
15
  model_url = "https://drive.google.com/uc?id=16ZWVa0d2V0T3s11Oq86rLOTA6bOR0DnR"
16
  model_path = "model.pth"
17
  gdown.download(model_url, model_path, quiet=False)
@@ -24,7 +24,7 @@ for param in bert.parameters():
24
  # Set device
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
- # Load custom model
28
  model = BERT_Arch(bert)
29
  model.load_state_dict(torch.load(model_path, map_location=device))
30
  model.to(device)
@@ -42,36 +42,37 @@ def home():
42
 
43
  @app.post("/predict/")
44
  def predict(request: TextRequest):
45
- try:
46
- text = request.text.strip()
47
 
48
- # Preprocess text
 
 
 
49
  text = remove_html(text)
50
  text = remove_links(text)
51
 
52
- # Tokenize input text
53
  tokens = tokenizer(
54
- text, return_tensors="pt", truncation=True, padding="max_length", max_length=512
 
55
  )
56
-
57
  input_ids = tokens["input_ids"].to(device)
58
  attention_mask = tokens["attention_mask"].to(device)
59
 
60
- # Perform inference
61
  with torch.no_grad():
62
  output = model(input_ids, attention_mask)
63
 
64
  prediction = torch.argmax(output.cpu(), dim=1).item()
65
-
66
- return {"prediction": "Phishing" if prediction == 1 else "Not Phishing"}
67
 
68
  except Exception as e:
69
- return {"error": str(e)}
70
- def greet(name):
71
- return "Hello " + name + "!"
72
  gr.Interface(
73
- fn=greet,
74
- inputs="text",
75
- outputs="text",
76
- allow_flagging="never"
77
- ).launch()
 
 
4
  from transformers import AutoModel, BertTokenizerFast
5
  from pydantic import BaseModel
6
  from model import BERT_Arch
7
+ from preprocess_data import remove_html, remove_links
8
  import gradio as gr
9
 
10
+ # Define input model structure
11
  class TextRequest(BaseModel):
12
  text: str
13
 
14
  # Download model from Google Drive
 
15
  model_url = "https://drive.google.com/uc?id=16ZWVa0d2V0T3s11Oq86rLOTA6bOR0DnR"
16
  model_path = "model.pth"
17
  gdown.download(model_url, model_path, quiet=False)
 
24
  # Set device
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
+ # Load your custom BERT-based model
28
  model = BERT_Arch(bert)
29
  model.load_state_dict(torch.load(model_path, map_location=device))
30
  model.to(device)
 
42
 
43
  @app.post("/predict/")
44
  def predict(request: TextRequest):
45
+ return {"prediction": classify_text(request.text)}
 
46
 
47
+ # Function to classify text
48
+ def classify_text(text: str) -> str:
49
+ try:
50
+ text = text.strip()
51
  text = remove_html(text)
52
  text = remove_links(text)
53
 
 
54
  tokens = tokenizer(
55
+ text, return_tensors="pt", truncation=True,
56
+ padding="max_length", max_length=512
57
  )
58
+
59
  input_ids = tokens["input_ids"].to(device)
60
  attention_mask = tokens["attention_mask"].to(device)
61
 
 
62
  with torch.no_grad():
63
  output = model(input_ids, attention_mask)
64
 
65
  prediction = torch.argmax(output.cpu(), dim=1).item()
66
+ return "Phishing" if prediction == 1 else "Not Phishing"
 
67
 
68
  except Exception as e:
69
+ return f"Error: {str(e)}"
70
+
71
+ # Gradio UI
72
  gr.Interface(
73
+ fn=classify_text,
74
+ inputs=gr.Textbox(label="Enter website content or email text"),
75
+ outputs=gr.Label(label="Prediction"),
76
+ title="Phishing Text Detector",
77
+ description="Website text to check if it's phishing."
78
+ ).launch()