SubhaL commited on
Commit
ecf2409
·
verified ·
1 Parent(s): 27d4105

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -26
app.py CHANGED
@@ -1,22 +1,22 @@
 
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoTokenizer
 
3
 
4
  # Load model and tokenizer
5
  model_name = "ealvaradob/bert-finetuned-phishing"
6
  classifier = pipeline("text-classification", model=model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
-
9
  MAX_TOKENS = 512
10
 
 
11
  def count_tokens(text):
12
  return len(tokenizer.encode(text, truncation=False))
13
 
14
  def chunk_text(text, max_tokens=MAX_TOKENS):
15
  words = text.split()
16
- chunks = []
17
- current_chunk = []
18
- current_length = 0
19
-
20
  for word in words:
21
  word_length = len(tokenizer.encode(word, add_special_tokens=False))
22
  if current_length + word_length > max_tokens:
@@ -26,51 +26,53 @@ def chunk_text(text, max_tokens=MAX_TOKENS):
26
  else:
27
  current_chunk.append(word)
28
  current_length += word_length
29
-
30
  if current_chunk:
31
  chunks.append(" ".join(current_chunk))
32
-
33
  return chunks
34
 
35
  def process_chunks(chunks):
36
- phishing_count = 0
37
- legitimate_count = 0
38
- total_score = 0
39
-
40
  for chunk in chunks:
41
  result = classifier(chunk)[0]
42
- label = result['label'].lower()
43
- score = result['score']
44
  total_score += score
45
-
46
  if label == "phishing":
47
  phishing_count += 1
48
  else:
49
  legitimate_count += 1
50
-
51
  final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate"
52
  average_confidence = total_score / len(chunks)
53
-
54
- return f"Prediction: {final_label}\nAverage Confidence: {average_confidence:.2%}"
55
 
56
  def detect_phishing(input_text):
57
- token_count = count_tokens(input_text)
58
-
59
- if token_count <= MAX_TOKENS:
60
  result = classifier(input_text)[0]
61
  label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate"
62
- return f"Prediction: {label}\nConfidence: {result['score']:.2%}"
63
  else:
64
  chunks = chunk_text(input_text)
65
  return process_chunks(chunks)
66
 
67
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
68
  demo = gr.Interface(
69
- fn=detect_phishing,
70
- inputs=gr.Textbox(lines=8, placeholder="Paste email content here..."),
71
  outputs="text",
72
  title="Phishing Email Detector",
73
- description="Uses a fine-tuned BERT model to classify whether the email is phishing or legitimate. Handles long emails by chunking."
74
  )
75
 
76
- demo.launch()
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import JSONResponse
3
  import gradio as gr
4
  from transformers import pipeline, AutoTokenizer
5
+ import uvicorn
6
 
7
  # Load model and tokenizer
8
  model_name = "ealvaradob/bert-finetuned-phishing"
9
  classifier = pipeline("text-classification", model=model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
11
  MAX_TOKENS = 512
12
 
13
+ # Functions
14
  def count_tokens(text):
15
  return len(tokenizer.encode(text, truncation=False))
16
 
17
  def chunk_text(text, max_tokens=MAX_TOKENS):
18
  words = text.split()
19
+ chunks, current_chunk, current_length = [], [], 0
 
 
 
20
  for word in words:
21
  word_length = len(tokenizer.encode(word, add_special_tokens=False))
22
  if current_length + word_length > max_tokens:
 
26
  else:
27
  current_chunk.append(word)
28
  current_length += word_length
 
29
  if current_chunk:
30
  chunks.append(" ".join(current_chunk))
 
31
  return chunks
32
 
33
  def process_chunks(chunks):
34
+ phishing_count, legitimate_count, total_score = 0, 0, 0
 
 
 
35
  for chunk in chunks:
36
  result = classifier(chunk)[0]
37
+ label, score = result['label'].lower(), result['score']
 
38
  total_score += score
 
39
  if label == "phishing":
40
  phishing_count += 1
41
  else:
42
  legitimate_count += 1
 
43
  final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate"
44
  average_confidence = total_score / len(chunks)
45
+ return {"label": final_label, "confidence": round(average_confidence, 4)}
 
46
 
47
  def detect_phishing(input_text):
48
+ if count_tokens(input_text) <= MAX_TOKENS:
 
 
49
  result = classifier(input_text)[0]
50
  label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate"
51
+ return {"label": label, "confidence": round(result['score'], 4)}
52
  else:
53
  chunks = chunk_text(input_text)
54
  return process_chunks(chunks)
55
 
56
+ # FastAPI app
57
+ api = FastAPI()
58
+
59
+ @api.post("/predict")
60
+ async def predict(request: Request):
61
+ data = await request.json()
62
+ input_text = data.get("text", "")
63
+ if not input_text:
64
+ return JSONResponse({"error": "No text provided."}, status_code=400)
65
+ result = detect_phishing(input_text)
66
+ return JSONResponse(result)
67
+
68
+ # Gradio interface (optional)
69
  demo = gr.Interface(
70
+ fn=lambda x: f"{detect_phishing(x)['label']} ({detect_phishing(x)['confidence']*100:.2f}%)",
71
+ inputs=gr.Textbox(lines=6, label="Paste Email Text"),
72
  outputs="text",
73
  title="Phishing Email Detector",
74
+ description="Detects whether an email is Phishing or Legitimate using BERT."
75
  )
76
 
77
+ demo.launch(server_name="0.0.0.0", server_port=7860, inline=False, share=False)
78
+