Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer | |
| # Load model and tokenizer | |
| model_name = "ealvaradob/bert-finetuned-phishing" | |
| classifier = pipeline("text-classification", model=model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| MAX_TOKENS = 512 | |
| def count_tokens(text): | |
| return len(tokenizer.encode(text, truncation=False)) | |
| def chunk_text(text, max_tokens=MAX_TOKENS): | |
| words = text.split() | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for word in words: | |
| word_length = len(tokenizer.encode(word, add_special_tokens=False)) | |
| if current_length + word_length > max_tokens: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = [word] | |
| current_length = word_length | |
| else: | |
| current_chunk.append(word) | |
| current_length += word_length | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| return chunks | |
| def process_chunks(chunks): | |
| phishing_count = 0 | |
| legitimate_count = 0 | |
| total_score = 0 | |
| for chunk in chunks: | |
| result = classifier(chunk)[0] | |
| label = result['label'].lower() | |
| score = result['score'] | |
| total_score += score | |
| if label == "phishing": | |
| phishing_count += 1 | |
| else: | |
| legitimate_count += 1 | |
| final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate" | |
| average_confidence = total_score / len(chunks) | |
| return f"Prediction: {final_label}\nAverage Confidence: {average_confidence:.2%}" | |
| def detect_phishing(input_text): | |
| token_count = count_tokens(input_text) | |
| if token_count <= MAX_TOKENS: | |
| result = classifier(input_text)[0] | |
| label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate" | |
| return f"Prediction: {label}\nConfidence: {result['score']:.2%}" | |
| else: | |
| chunks = chunk_text(input_text) | |
| return process_chunks(chunks) | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=detect_phishing, | |
| inputs=gr.Textbox(lines=8, placeholder="Paste email content here..."), | |
| outputs="text", | |
| title="Phishing Email Detector", | |
| description="Uses a fine-tuned BERT model to classify whether the email is phishing or legitimate. Handles long emails by chunking." | |
| ) | |
| demo.launch() | |