import streamlit as st import tensorflow as tf # Ensure tensorflow_hub is installed try: import tensorflow_hub as hub except ImportError: hub = None import pandas as pd from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix import numpy as np # Cache BERT layers and model creation for efficiency def load_bert_layers(): if hub is None: st.error("Missing dependency: tensorflow_hub. Run `pip install tensorflow-hub` in your environment.") st.stop() bert_preprocess = hub.KerasLayer( "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3", name='preprocess') bert_encoder = hub.KerasLayer( "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4", trainable=True, name='BERT_encoder') return bert_preprocess, bert_encoder @st.cache(allow_output_mutation=True) def build_model(bert_preprocess, bert_encoder): text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text') x = bert_preprocess(text_input) outputs = bert_encoder(x) x = tf.keras.layers.Dropout(0.1)(outputs['pooled_output']) x = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) model = tf.keras.Model(inputs=text_input, outputs=x) model.compile( loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5), metrics=['accuracy'] ) return model # Main Streamlit app def main(): st.title("Spam Detection with BERT and Streamlit") st.write("Upload an SMS dataset (CSV with 'Category' and 'Message' columns) to train the classifier.") # File uploader uploaded_file = st.file_uploader("Choose a CSV file", type='csv') if uploaded_file is not None: df = pd.read_csv(uploaded_file, encoding='latin-1') if 'Category' in df.columns and 'Message' in df.columns: df = df[['Category', 'Message']].dropna() df.columns = ['label', 'text'] df['spam'] = df['label'].apply(lambda x: 1 if str(x).lower()=='spam' else 0) if st.checkbox('Show data sample'): st.dataframe(df.head(5)) # Prepare data X = df['text'].astype(str) y = df['spam'] X_train, X_test, y_train, y_test = train_test_split( X, y, stratify=y, test_size=0.2, random_state=42) # Load BERT and model bert_preprocess, bert_encoder = load_bert_layers() model = build_model(bert_preprocess, bert_encoder) # Train model if st.button('Train Model'): with st.spinner('Training model...'): model.fit(X_train, y_train, epochs=3, batch_size=32, verbose=1) # Evaluate preds = (model.predict(X_test) > 0.5).astype(int).flatten() metrics = { 'accuracy': accuracy_score(y_test, preds), 'precision': precision_score(y_test, preds), 'recall': recall_score(y_test, preds) } st.success("Model trained!") st.write("**Accuracy:**", metrics['accuracy']) st.write("**Precision:**", metrics['precision']) st.write("**Recall:**", metrics['recall']) cm = confusion_matrix(y_test, preds) st.write("**Confusion Matrix:**") st.write(cm) # Prediction section st.subheader("Classify a New Message") user_input = st.text_area("Enter your message:") if st.button('Predict') and user_input: prob = model.predict([user_input])[0][0] label = 'Spam' if prob > 0.5 else 'Ham' st.write(f"**Prediction:** {label}") st.write(f"**Probability:** {prob:.4f}") else: st.error("CSV must contain 'Category' and 'Message' columns.") else: st.info('Awaiting CSV file upload.') if __name__ == '__main__': main()