Davila Rostaing commited on
Commit
85c4814
·
1 Parent(s): c97c237

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
+ import torch.nn.functional as F
5
+
6
+ # 1. Load the model and tokenizer from your saved directory
7
+ # If uploading to Hugging Face Spaces, ensure your saved_model folder is uploaded too!
8
+ model_path = "./saved_model" # Update this path to where your model is saved
9
+
10
+ try:
11
+ tokenizer = DistilBertTokenizer.from_pretrained(model_path)
12
+ model = DistilBertForSequenceClassification.from_pretrained(model_path)
13
+ model.eval() # Set model to evaluation mode
14
+ except Exception as e:
15
+ # Fallback to base model if you are just testing the code without your fine-tuned weights
16
+ # (Remove this try-except block in your final version, it's just to prevent crashes if path is wrong)
17
+ print("Could not load fine-tuned model, loading base model for demonstration purposes...")
18
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
19
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
20
+ model.eval()
21
+
22
+ # 2. Define the prediction function
23
+ def predict_spam(message):
24
+ if not message.strip():
25
+ return {"Please enter a message": 1.0}
26
+
27
+ # Tokenize input
28
+ inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True, max_length=128)
29
+
30
+ # Get model prediction
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits
34
+
35
+ # Apply softmax to get probabilities
36
+ probabilities = F.softmax(logits, dim=1).squeeze()
37
+
38
+ # Our labels are: 0 -> Ham, 1 -> Spam
39
+ prob_ham = probabilities[0].item()
40
+ prob_spam = probabilities[1].item()
41
+
42
+ # Gradio expects a dictionary of {label: probability} for the Label component
43
+ return {"Ham (Legitimate)": prob_ham, "Spam (Malicious)": prob_spam}
44
+
45
+ # 3. Create the Gradio Interface
46
+ # We use a clean, modern interface
47
+ demo = gr.Interface(
48
+ fn=predict_spam,
49
+ inputs=gr.Textbox(
50
+ lines=4,
51
+ placeholder="Type an email or SMS message here...",
52
+ label="Message Content"
53
+ ),
54
+ outputs=gr.Label(num_top_classes=2, label="Prediction Confidence"),
55
+ title="🛡️ Spam Detection AI",
56
+ description="""
57
+ ### Deep Learning Project (2025)
58
+ This application uses a fine-tuned **DistilBERT** Transformer model to classify text messages as either Spam or Ham.
59
+ * Enter a message below and click Submit.
60
+ * **Examples of Spam:** 'URGENT! You have won a 1 week FREE membership. Call 087124006024'
61
+ * **Examples of Ham:** 'Hey, are we still meeting for lunch tomorrow?'
62
+ """,
63
+ examples=[
64
+ ["WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."],
65
+ ["Hey man, just checking in. Are we still on for the movie tonight at 8?"],
66
+ ["URGENT: Your bank account has been locked due to suspicious activity. Please click this link to verify your identity: http://secure-login-update.com"],
67
+ ["I'll be about 10 minutes late, stuck in traffic."]
68
+ ],
69
+ theme=gr.themes.Soft()
70
+ )
71
+
72
+ # 4. Launch the app
73
+ if __name__ == "__main__":
74
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
saved_model/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faf37237976cc13bad4520548a113c77e7529c3b4ff6b598648fbce1696499e3
3
+ size 663
saved_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c07d77d4bad466798bc928c36a02d91cadbef0c2047def7e445928f1825d959
3
+ size 267832560
saved_model/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b616fb2661c2ad4dbc6c0ee5c4c1caefd532989e8c399005744f5442e4116ed
3
+ size 535724875
saved_model/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c0f6bef14c378cf6dadcce9566f94c89a1de5a87f6dab300dc2cfb3620974e5
3
+ size 14455
saved_model/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fd612b91f9c87518d4d69e324f23842dc16c0f0824a97f416850db8d65d6175
3
+ size 1465
saved_model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66
3
+ size 711396
saved_model/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af5ecd1e2aa3d659df1ea773292e80eb41a5c3b4026126019e5c6dcdb7cbc044
3
+ size 328
saved_model/trainer_state.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4748a47ae13b6f9fd0182a6f65f4ee77883b846e370684a964475170d957f80
3
+ size 5955
saved_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab8b54e46b16083bb42102640d659199bfe96644c6fe4869e1135f8028f75f5c
3
+ size 5201