bertspamclassification / src /streamlit_app.py
anshu9749's picture
Update src/streamlit_app.py
453de19 verified
import streamlit as st
import tensorflow as tf
# Ensure tensorflow_hub is installed
try:
import tensorflow_hub as hub
except ImportError:
hub = None
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
import numpy as np
# Cache BERT layers and model creation for efficiency
def load_bert_layers():
if hub is None:
st.error("Missing dependency: tensorflow_hub. Run `pip install tensorflow-hub` in your environment.")
st.stop()
bert_preprocess = hub.KerasLayer(
"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3", name='preprocess')
bert_encoder = hub.KerasLayer(
"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4", trainable=True, name='BERT_encoder')
return bert_preprocess, bert_encoder
@st.cache(allow_output_mutation=True)
def build_model(bert_preprocess, bert_encoder):
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
x = bert_preprocess(text_input)
outputs = bert_encoder(x)
x = tf.keras.layers.Dropout(0.1)(outputs['pooled_output'])
x = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=text_input, outputs=x)
model.compile(
loss='binary_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
metrics=['accuracy']
)
return model
# Main Streamlit app
def main():
st.title("Spam Detection with BERT and Streamlit")
st.write("Upload an SMS dataset (CSV with 'Category' and 'Message' columns) to train the classifier.")
# File uploader
uploaded_file = st.file_uploader("Choose a CSV file", type='csv')
if uploaded_file is not None:
df = pd.read_csv(uploaded_file, encoding='latin-1')
if 'Category' in df.columns and 'Message' in df.columns:
df = df[['Category', 'Message']].dropna()
df.columns = ['label', 'text']
df['spam'] = df['label'].apply(lambda x: 1 if str(x).lower()=='spam' else 0)
if st.checkbox('Show data sample'):
st.dataframe(df.head(5))
# Prepare data
X = df['text'].astype(str)
y = df['spam']
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, test_size=0.2, random_state=42)
# Load BERT and model
bert_preprocess, bert_encoder = load_bert_layers()
model = build_model(bert_preprocess, bert_encoder)
# Train model
if st.button('Train Model'):
with st.spinner('Training model...'):
model.fit(X_train, y_train, epochs=3, batch_size=32, verbose=1)
# Evaluate
preds = (model.predict(X_test) > 0.5).astype(int).flatten()
metrics = {
'accuracy': accuracy_score(y_test, preds),
'precision': precision_score(y_test, preds),
'recall': recall_score(y_test, preds)
}
st.success("Model trained!")
st.write("**Accuracy:**", metrics['accuracy'])
st.write("**Precision:**", metrics['precision'])
st.write("**Recall:**", metrics['recall'])
cm = confusion_matrix(y_test, preds)
st.write("**Confusion Matrix:**")
st.write(cm)
# Prediction section
st.subheader("Classify a New Message")
user_input = st.text_area("Enter your message:")
if st.button('Predict') and user_input:
prob = model.predict([user_input])[0][0]
label = 'Spam' if prob > 0.5 else 'Ham'
st.write(f"**Prediction:** {label}")
st.write(f"**Probability:** {prob:.4f}")
else:
st.error("CSV must contain 'Category' and 'Message' columns.")
else:
st.info('Awaiting CSV file upload.')
if __name__ == '__main__':
main()