Spaces:
Build error
Build error
| import streamlit as st | |
| import joblib | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.svm import SVC | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, confusion_matrix | |
| # Set Streamlit page config | |
| st.set_page_config(page_title="SMS Spam Detector", page_icon="π©", layout="wide") | |
| # Custom CSS for centering and styling | |
| st.markdown(""" | |
| <style> | |
| .centered-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| flex-direction: column; | |
| text-align: center; | |
| width: 80%; | |
| } | |
| .padded-container { | |
| padding: 20px; | |
| } | |
| .big-dataset { | |
| font-size: 12px; | |
| max-width: 100%; | |
| margin: auto; | |
| } | |
| .stDataFrame { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| img { | |
| max-width: 150px; | |
| height: 600px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Title | |
| st.title("π© SMS Spam Detector") | |
| # Load dataset | |
| def load_data(): | |
| dataset_path = "spam.csv" | |
| df = pd.read_csv(dataset_path, encoding='latin-1')[['v1', 'v2']] | |
| df.columns = ['label', 'message'] | |
| df['label'] = df['label'].map({'ham': 0, 'spam': 1}) | |
| return df | |
| df = load_data() | |
| # Train and save model | |
| def train_and_save_model(): | |
| X_train, X_test, y_train, y_test = train_test_split(df['message'], df['label'], test_size=0.2, random_state=42) | |
| vectorizer = TfidfVectorizer(stop_words='english', max_features=5000) | |
| X_train_tfidf = vectorizer.fit_transform(X_train) | |
| X_test_tfidf = vectorizer.transform(X_test) | |
| svm_model = SVC(kernel='linear') | |
| svm_model.fit(X_train_tfidf, y_train) | |
| y_pred = svm_model.predict(X_test_tfidf) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| joblib.dump(svm_model, "svm_sms_spam.pkl") | |
| joblib.dump(vectorizer, "vectorizer.pkl") | |
| return svm_model, vectorizer, accuracy | |
| svm_model, vectorizer, accuracy = train_and_save_model() | |
| # Create tabs | |
| tab1, tab2, tab3 = st.tabs(["π Data Overview", "π Data Visualization", "π Spam Detector"]) | |
| # Tab 1: Data Overview | |
| with tab1: | |
| st.subheader("Dataset Overview") | |
| st.markdown('<div class="centered-container">', unsafe_allow_html=True) | |
| st.markdown('<div style="display: flex; justify-content: center;">', unsafe_allow_html=True) | |
| st.dataframe(df, height=300, width=1000) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Smaller class distribution title | |
| st.subheader("Class Distribution") | |
| fig, ax = plt.subplots(figsize=(2, 2)) # Smaller figure size | |
| sns.countplot( | |
| x=df['label'].map({0: 'Not Spam', 1: 'Spam'}), | |
| palette='coolwarm', | |
| ax=ax, | |
| width=0.2 | |
| ) | |
| ax.set_title("Distribution of Spam vs. Not Spam Messages", fontsize=8) # Smaller title | |
| ax.set_xlabel("Message Type", fontsize=5) # Smaller x-axis label | |
| ax.set_ylabel("Count", fontsize=5) # Smaller y-axis label | |
| ax.tick_params(axis='both', labelsize=5) # Smaller tick labels | |
| st.pyplot(fig) | |
| st.markdown(f"### π Model Accuracy: **{accuracy * 100:.2f}%**") | |
| # Tab 2: Data Visualization | |
| with tab2: | |
| st.subheader("Data Visualizations") | |
| # Confusion Matrix | |
| st.markdown("### Confusion Matrix") | |
| X_train, X_test, y_train, y_test = train_test_split(df['message'], df['label'], test_size=0.2, random_state=42) | |
| X_test_tfidf = vectorizer.transform(X_test) | |
| y_pred = svm_model.predict(X_test_tfidf) | |
| cm = confusion_matrix(y_test, y_pred) | |
| fig, ax = plt.subplots(figsize=(5, 3)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Not Spam', 'Spam'], yticklabels=['Not Spam', 'Spam']) | |
| ax.set_xlabel("Predicted") | |
| ax.set_ylabel("Actual") | |
| ax.set_title("Confusion Matrix") | |
| st.pyplot(fig) | |
| # Heatmap | |
| st.markdown("### Heatmap of Feature Correlations") | |
| df['message_length'] = df['message'].apply(len) | |
| correlation_matrix = df[['message_length', 'label']].corr() | |
| fig, ax = plt.subplots(figsize=(5, 3)) | |
| sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', ax=ax) | |
| ax.set_title("Feature Correlation Heatmap") | |
| st.pyplot(fig) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Tab 3: Spam Detector | |
| with tab3: | |
| st.subheader("Check SMS Message") | |
| st.write("Enter an SMS message below to check if it's spam or not.") | |
| user_input = st.text_area("Enter SMS Message:") | |
| if st.button("Check Message"): | |
| if user_input: | |
| input_features = vectorizer.transform([user_input]) | |
| prediction = svm_model.predict(input_features) | |
| if prediction[0] == 1: | |
| st.error("π¨ This message is Spam!") | |
| else: | |
| st.success("β This message is NOT Spam!") | |
| else: | |
| st.warning("Please enter a message before checking.") |