Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tensorflow as tf | |
| # Ensure tensorflow_hub is installed | |
| try: | |
| import tensorflow_hub as hub | |
| except ImportError: | |
| hub = None | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix | |
| import numpy as np | |
| # Cache BERT layers and model creation for efficiency | |
| def load_bert_layers(): | |
| if hub is None: | |
| st.error("Missing dependency: tensorflow_hub. Run `pip install tensorflow-hub` in your environment.") | |
| st.stop() | |
| bert_preprocess = hub.KerasLayer( | |
| "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3", name='preprocess') | |
| bert_encoder = hub.KerasLayer( | |
| "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4", trainable=True, name='BERT_encoder') | |
| return bert_preprocess, bert_encoder | |
| def build_model(bert_preprocess, bert_encoder): | |
| text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text') | |
| x = bert_preprocess(text_input) | |
| outputs = bert_encoder(x) | |
| x = tf.keras.layers.Dropout(0.1)(outputs['pooled_output']) | |
| x = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) | |
| model = tf.keras.Model(inputs=text_input, outputs=x) | |
| model.compile( | |
| loss='binary_crossentropy', | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5), | |
| metrics=['accuracy'] | |
| ) | |
| return model | |
| # Main Streamlit app | |
| def main(): | |
| st.title("Spam Detection with BERT and Streamlit") | |
| st.write("Upload an SMS dataset (CSV with 'Category' and 'Message' columns) to train the classifier.") | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose a CSV file", type='csv') | |
| if uploaded_file is not None: | |
| df = pd.read_csv(uploaded_file, encoding='latin-1') | |
| if 'Category' in df.columns and 'Message' in df.columns: | |
| df = df[['Category', 'Message']].dropna() | |
| df.columns = ['label', 'text'] | |
| df['spam'] = df['label'].apply(lambda x: 1 if str(x).lower()=='spam' else 0) | |
| if st.checkbox('Show data sample'): | |
| st.dataframe(df.head(5)) | |
| # Prepare data | |
| X = df['text'].astype(str) | |
| y = df['spam'] | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, stratify=y, test_size=0.2, random_state=42) | |
| # Load BERT and model | |
| bert_preprocess, bert_encoder = load_bert_layers() | |
| model = build_model(bert_preprocess, bert_encoder) | |
| # Train model | |
| if st.button('Train Model'): | |
| with st.spinner('Training model...'): | |
| model.fit(X_train, y_train, epochs=3, batch_size=32, verbose=1) | |
| # Evaluate | |
| preds = (model.predict(X_test) > 0.5).astype(int).flatten() | |
| metrics = { | |
| 'accuracy': accuracy_score(y_test, preds), | |
| 'precision': precision_score(y_test, preds), | |
| 'recall': recall_score(y_test, preds) | |
| } | |
| st.success("Model trained!") | |
| st.write("**Accuracy:**", metrics['accuracy']) | |
| st.write("**Precision:**", metrics['precision']) | |
| st.write("**Recall:**", metrics['recall']) | |
| cm = confusion_matrix(y_test, preds) | |
| st.write("**Confusion Matrix:**") | |
| st.write(cm) | |
| # Prediction section | |
| st.subheader("Classify a New Message") | |
| user_input = st.text_area("Enter your message:") | |
| if st.button('Predict') and user_input: | |
| prob = model.predict([user_input])[0][0] | |
| label = 'Spam' if prob > 0.5 else 'Ham' | |
| st.write(f"**Prediction:** {label}") | |
| st.write(f"**Probability:** {prob:.4f}") | |
| else: | |
| st.error("CSV must contain 'Category' and 'Message' columns.") | |
| else: | |
| st.info('Awaiting CSV file upload.') | |
| if __name__ == '__main__': | |
| main() | |