3v324v23's picture
app
eabb121
import streamlit as st
import joblib
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
# Set Streamlit page config
st.set_page_config(page_title="SMS Spam Detector", page_icon="πŸ“©", layout="wide")
# Custom CSS for centering and styling
st.markdown("""
<style>
.centered-container {
display: flex;
justify-content: center;
align-items: center;
flex-direction: column;
text-align: center;
width: 80%;
}
.padded-container {
padding: 20px;
}
.big-dataset {
font-size: 12px;
max-width: 100%;
margin: auto;
}
.stDataFrame {
display: flex;
justify-content: center;
align-items: center;
}
img {
max-width: 150px;
height: 600px;
}
</style>
""", unsafe_allow_html=True)
# Title
st.title("πŸ“© SMS Spam Detector")
# Load dataset
@st.cache_data
def load_data():
dataset_path = "spam.csv"
df = pd.read_csv(dataset_path, encoding='latin-1')[['v1', 'v2']]
df.columns = ['label', 'message']
df['label'] = df['label'].map({'ham': 0, 'spam': 1})
return df
df = load_data()
# Train and save model
@st.cache_resource
def train_and_save_model():
X_train, X_test, y_train, y_test = train_test_split(df['message'], df['label'], test_size=0.2, random_state=42)
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)
svm_model = SVC(kernel='linear')
svm_model.fit(X_train_tfidf, y_train)
y_pred = svm_model.predict(X_test_tfidf)
accuracy = accuracy_score(y_test, y_pred)
joblib.dump(svm_model, "svm_sms_spam.pkl")
joblib.dump(vectorizer, "vectorizer.pkl")
return svm_model, vectorizer, accuracy
svm_model, vectorizer, accuracy = train_and_save_model()
# Create tabs
tab1, tab2, tab3 = st.tabs(["πŸ“Š Data Overview", "πŸ“ˆ Data Visualization", "πŸ” Spam Detector"])
# Tab 1: Data Overview
with tab1:
st.subheader("Dataset Overview")
st.markdown('<div class="centered-container">', unsafe_allow_html=True)
st.markdown('<div style="display: flex; justify-content: center;">', unsafe_allow_html=True)
st.dataframe(df, height=300, width=1000)
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Smaller class distribution title
st.subheader("Class Distribution")
fig, ax = plt.subplots(figsize=(2, 2)) # Smaller figure size
sns.countplot(
x=df['label'].map({0: 'Not Spam', 1: 'Spam'}),
palette='coolwarm',
ax=ax,
width=0.2
)
ax.set_title("Distribution of Spam vs. Not Spam Messages", fontsize=8) # Smaller title
ax.set_xlabel("Message Type", fontsize=5) # Smaller x-axis label
ax.set_ylabel("Count", fontsize=5) # Smaller y-axis label
ax.tick_params(axis='both', labelsize=5) # Smaller tick labels
st.pyplot(fig)
st.markdown(f"### πŸ“Š Model Accuracy: **{accuracy * 100:.2f}%**")
# Tab 2: Data Visualization
with tab2:
st.subheader("Data Visualizations")
# Confusion Matrix
st.markdown("### Confusion Matrix")
X_train, X_test, y_train, y_test = train_test_split(df['message'], df['label'], test_size=0.2, random_state=42)
X_test_tfidf = vectorizer.transform(X_test)
y_pred = svm_model.predict(X_test_tfidf)
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(5, 3))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Not Spam', 'Spam'], yticklabels=['Not Spam', 'Spam'])
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual")
ax.set_title("Confusion Matrix")
st.pyplot(fig)
# Heatmap
st.markdown("### Heatmap of Feature Correlations")
df['message_length'] = df['message'].apply(len)
correlation_matrix = df[['message_length', 'label']].corr()
fig, ax = plt.subplots(figsize=(5, 3))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', ax=ax)
ax.set_title("Feature Correlation Heatmap")
st.pyplot(fig)
st.markdown('</div>', unsafe_allow_html=True)
# Tab 3: Spam Detector
with tab3:
st.subheader("Check SMS Message")
st.write("Enter an SMS message below to check if it's spam or not.")
user_input = st.text_area("Enter SMS Message:")
if st.button("Check Message"):
if user_input:
input_features = vectorizer.transform([user_input])
prediction = svm_model.predict(input_features)
if prediction[0] == 1:
st.error("🚨 This message is Spam!")
else:
st.success("βœ… This message is NOT Spam!")
else:
st.warning("Please enter a message before checking.")