File size: 4,094 Bytes
8d35ea4
453de19
 
 
 
 
 
91550db
453de19
 
 
91550db
453de19
 
 
 
 
 
 
 
 
 
91550db
453de19
 
 
 
 
 
 
 
 
 
 
 
91550db
453de19
91550db
453de19
 
 
 
91550db
453de19
 
 
 
 
 
 
 
91550db
453de19
 
91550db
453de19
 
 
 
 
91550db
453de19
 
 
91550db
453de19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91550db
453de19
 
 
 
 
 
 
 
 
 
91550db
453de19
91550db
453de19
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import tensorflow as tf
# Ensure tensorflow_hub is installed
try:
    import tensorflow_hub as hub
except ImportError:
    hub = None
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
import numpy as np

# Cache BERT layers and model creation for efficiency
def load_bert_layers():
    if hub is None:
        st.error("Missing dependency: tensorflow_hub. Run `pip install tensorflow-hub` in your environment.")
        st.stop()
    bert_preprocess = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3", name='preprocess')
    bert_encoder = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4", trainable=True, name='BERT_encoder')
    return bert_preprocess, bert_encoder

@st.cache(allow_output_mutation=True)
def build_model(bert_preprocess, bert_encoder):
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    x = bert_preprocess(text_input)
    outputs = bert_encoder(x)
    x = tf.keras.layers.Dropout(0.1)(outputs['pooled_output'])
    x = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
    model = tf.keras.Model(inputs=text_input, outputs=x)
    model.compile(
        loss='binary_crossentropy',
        optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
        metrics=['accuracy']
    )
    return model

# Main Streamlit app
def main():
    st.title("Spam Detection with BERT and Streamlit")
    st.write("Upload an SMS dataset (CSV with 'Category' and 'Message' columns) to train the classifier.")

    # File uploader
    uploaded_file = st.file_uploader("Choose a CSV file", type='csv')
    if uploaded_file is not None:
        df = pd.read_csv(uploaded_file, encoding='latin-1')
        if 'Category' in df.columns and 'Message' in df.columns:
            df = df[['Category', 'Message']].dropna()
            df.columns = ['label', 'text']
            df['spam'] = df['label'].apply(lambda x: 1 if str(x).lower()=='spam' else 0)

            if st.checkbox('Show data sample'):
                st.dataframe(df.head(5))

            # Prepare data
            X = df['text'].astype(str)
            y = df['spam']
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, stratify=y, test_size=0.2, random_state=42)

            # Load BERT and model
            bert_preprocess, bert_encoder = load_bert_layers()
            model = build_model(bert_preprocess, bert_encoder)

            # Train model
            if st.button('Train Model'):
                with st.spinner('Training model...'):
                    model.fit(X_train, y_train, epochs=3, batch_size=32, verbose=1)
                # Evaluate
                preds = (model.predict(X_test) > 0.5).astype(int).flatten()
                metrics = {
                    'accuracy': accuracy_score(y_test, preds),
                    'precision': precision_score(y_test, preds),
                    'recall': recall_score(y_test, preds)
                }
                st.success("Model trained!")
                st.write("**Accuracy:**", metrics['accuracy'])
                st.write("**Precision:**", metrics['precision'])
                st.write("**Recall:**", metrics['recall'])
                cm = confusion_matrix(y_test, preds)
                st.write("**Confusion Matrix:**")
                st.write(cm)

            # Prediction section
            st.subheader("Classify a New Message")
            user_input = st.text_area("Enter your message:")
            if st.button('Predict') and user_input:
                prob = model.predict([user_input])[0][0]
                label = 'Spam' if prob > 0.5 else 'Ham'
                st.write(f"**Prediction:** {label}")
                st.write(f"**Probability:** {prob:.4f}")
        else:
            st.error("CSV must contain 'Category' and 'Message' columns.")
    else:
        st.info('Awaiting CSV file upload.')

if __name__ == '__main__':
    main()