SubhaL commited on
Commit
5286d26
·
verified ·
1 Parent(s): 9fac93d

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -27
app.py CHANGED
@@ -1,22 +1,22 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import JSONResponse
3
  import gradio as gr
4
  from transformers import pipeline, AutoTokenizer
5
- import uvicorn
6
 
7
  # Load model and tokenizer
8
  model_name = "ealvaradob/bert-finetuned-phishing"
9
  classifier = pipeline("text-classification", model=model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
11
  MAX_TOKENS = 512
12
 
13
- # Functions
14
  def count_tokens(text):
15
  return len(tokenizer.encode(text, truncation=False))
16
 
17
  def chunk_text(text, max_tokens=MAX_TOKENS):
18
  words = text.split()
19
- chunks, current_chunk, current_length = [], [], 0
 
 
 
20
  for word in words:
21
  word_length = len(tokenizer.encode(word, add_special_tokens=False))
22
  if current_length + word_length > max_tokens:
@@ -26,53 +26,53 @@ def chunk_text(text, max_tokens=MAX_TOKENS):
26
  else:
27
  current_chunk.append(word)
28
  current_length += word_length
 
29
  if current_chunk:
30
  chunks.append(" ".join(current_chunk))
 
31
  return chunks
32
 
33
  def process_chunks(chunks):
34
- phishing_count, legitimate_count, total_score = 0, 0, 0
 
 
 
35
  for chunk in chunks:
36
  result = classifier(chunk)[0]
37
- label, score = result['label'].lower(), result['score']
 
38
  total_score += score
 
39
  if label == "phishing":
40
  phishing_count += 1
41
  else:
42
  legitimate_count += 1
 
43
  final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate"
44
  average_confidence = total_score / len(chunks)
45
- return {"label": final_label, "confidence": round(average_confidence, 4)}
 
46
 
47
  def detect_phishing(input_text):
48
- if count_tokens(input_text) <= MAX_TOKENS:
 
 
49
  result = classifier(input_text)[0]
50
  label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate"
51
- return {"label": label, "confidence": round(result['score'], 4)}
52
  else:
53
  chunks = chunk_text(input_text)
54
  return process_chunks(chunks)
55
 
56
- # FastAPI app
57
- api = FastAPI()
58
-
59
- @api.post("/predict")
60
- async def predict(request: Request):
61
- data = await request.json()
62
- input_text = data.get("text", "")
63
- if not input_text:
64
- return JSONResponse({"error": "No text provided."}, status_code=400)
65
- result = detect_phishing(input_text)
66
- return JSONResponse(result)
67
-
68
- # Gradio interface (optional)
69
  demo = gr.Interface(
70
- fn=lambda x: f"{detect_phishing(x)['label']} ({detect_phishing(x)['confidence']*100:.2f}%)",
71
- inputs=gr.Textbox(lines=6, label="Paste Email Text"),
72
  outputs="text",
73
  title="Phishing Email Detector",
74
- description="Detects whether an email is Phishing or Legitimate using BERT."
75
  )
76
 
77
- demo.launch(server_name="0.0.0.0", server_port=7860, inline=False, share=False)
 
78
 
 
 
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoTokenizer
 
3
 
4
  # Load model and tokenizer
5
  model_name = "ealvaradob/bert-finetuned-phishing"
6
  classifier = pipeline("text-classification", model=model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
  MAX_TOKENS = 512
10
 
 
11
  def count_tokens(text):
12
  return len(tokenizer.encode(text, truncation=False))
13
 
14
  def chunk_text(text, max_tokens=MAX_TOKENS):
15
  words = text.split()
16
+ chunks = []
17
+ current_chunk = []
18
+ current_length = 0
19
+
20
  for word in words:
21
  word_length = len(tokenizer.encode(word, add_special_tokens=False))
22
  if current_length + word_length > max_tokens:
 
26
  else:
27
  current_chunk.append(word)
28
  current_length += word_length
29
+
30
  if current_chunk:
31
  chunks.append(" ".join(current_chunk))
32
+
33
  return chunks
34
 
35
  def process_chunks(chunks):
36
+ phishing_count = 0
37
+ legitimate_count = 0
38
+ total_score = 0
39
+
40
  for chunk in chunks:
41
  result = classifier(chunk)[0]
42
+ label = result['label'].lower()
43
+ score = result['score']
44
  total_score += score
45
+
46
  if label == "phishing":
47
  phishing_count += 1
48
  else:
49
  legitimate_count += 1
50
+
51
  final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate"
52
  average_confidence = total_score / len(chunks)
53
+
54
+ return f"Prediction: {final_label}\nAverage Confidence: {average_confidence:.2%}"
55
 
56
  def detect_phishing(input_text):
57
+ token_count = count_tokens(input_text)
58
+
59
+ if token_count <= MAX_TOKENS:
60
  result = classifier(input_text)[0]
61
  label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate"
62
+ return f"Prediction: {label}\nConfidence: {result['score']:.2%}"
63
  else:
64
  chunks = chunk_text(input_text)
65
  return process_chunks(chunks)
66
 
67
+ # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
68
  demo = gr.Interface(
69
+ fn=detect_phishing,
70
+ inputs=gr.Textbox(lines=8, placeholder="Paste email content here..."),
71
  outputs="text",
72
  title="Phishing Email Detector",
73
+ description="Uses a fine-tuned BERT model to classify whether the email is phishing or legitimate. Handles long emails by chunking."
74
  )
75
 
76
+ demo.launch()
77
+
78