Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import os | |
| import plotly.express as px | |
| import folium | |
| from folium.plugins import HeatMap, HeatMapWithTime | |
| from streamlit_folium import folium_static | |
| from preprocessing import preprocess_pipeline, get_season | |
| from data_loader import load_data | |
| import xgboost as xgb | |
| import pickle | |
| from scipy.sparse import hstack, csr_matrix | |
| from groq import Groq | |
| # Set page config | |
| st.set_page_config( | |
| page_title="SF Crime Analytics | AI-Powered", | |
| page_icon="🚓", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for Premium Look | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background-color: #0e1117; | |
| } | |
| .stApp { | |
| background-color: #0e1117; | |
| } | |
| h1, h2, h3 { | |
| color: #ffffff; | |
| font-family: 'Helvetica Neue', sans-serif; | |
| font-weight: 700; | |
| } | |
| .stButton>button { | |
| background-color: #ff4b4b; | |
| color: white; | |
| border-radius: 20px; | |
| padding: 10px 24px; | |
| font-weight: 600; | |
| border: none; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton>button:hover { | |
| background-color: #ff3333; | |
| transform: scale(1.05); | |
| } | |
| .metric-card { | |
| background-color: #262730; | |
| padding: 20px; | |
| border-radius: 10px; | |
| border-left: 5px solid #ff4b4b; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.3); | |
| } | |
| .report-text { | |
| font-family: 'Courier New', monospace; | |
| color: #00ff00; | |
| background-color: #000000; | |
| padding: 15px; | |
| border-radius: 5px; | |
| border: 1px solid #00ff00; | |
| } | |
| .chat-bubble-user { | |
| background-color: #2b313e; | |
| color: white; | |
| padding: 10px; | |
| border-radius: 15px 15px 0 15px; | |
| margin: 5px; | |
| text-align: right; | |
| } | |
| .chat-bubble-bot { | |
| background-color: #ff4b4b; | |
| color: white; | |
| padding: 10px; | |
| border-radius: 15px 15px 15px 0; | |
| margin: 5px; | |
| text-align: left; | |
| } | |
| /* New Chat Assistant Styles */ | |
| .glass-card { | |
| background: rgba(255, 255, 255, 0.05); | |
| backdrop-filter: blur(10px); | |
| -webkit-backdrop-filter: blur(10px); | |
| padding: 30px; | |
| border-radius: 24px; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37); | |
| transition: all 0.4s ease; | |
| margin-bottom: 25px; | |
| } | |
| .user-message { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 15px 20px; | |
| border-radius: 18px 18px 5px 18px; | |
| margin: 10px 0; | |
| max-width: 80%; | |
| margin-left: auto; | |
| color: white; | |
| font-size: 1rem; | |
| box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3); | |
| } | |
| .ai-message { | |
| background: rgba(255, 255, 255, 0.08); | |
| backdrop-filter: blur(10px); | |
| padding: 15px 20px; | |
| border-radius: 18px 18px 18px 5px; | |
| margin: 10px 0; | |
| max-width: 80%; | |
| margin-right: auto; | |
| color: #e2e8f0; | |
| font-size: 1rem; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2); | |
| } | |
| .chat-container { | |
| background: rgba(255, 255, 255, 0.03); | |
| backdrop-filter: blur(10px); | |
| padding: 25px; | |
| border-radius: 20px; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| max-height: 500px; | |
| overflow-y: auto; | |
| margin-bottom: 20px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Load Resources | |
| def load_resources(): | |
| # CHANGE: Remove '../' because 'models' is now a subdirectory of 'src' | |
| models_dir = os.path.join(os.path.dirname(__file__), 'models') | |
| model_path = os.path.join(models_dir, 'best_model.pkl') | |
| encoders_path = os.path.join(models_dir, 'label_encoders.pkl') | |
| kmeans_path = os.path.join(models_dir, 'kmeans.pkl') | |
| if not os.path.exists(model_path) or not os.path.exists(encoders_path) or not os.path.exists(kmeans_path): | |
| return None, None, None | |
| model = joblib.load(model_path) | |
| encoders = joblib.load(encoders_path) | |
| kmeans = joblib.load(kmeans_path) | |
| return model, encoders, kmeans | |
| def load_new_artifacts(): | |
| try: | |
| # CHANGE: Remove '../' | |
| models_dir = os.path.join(os.path.dirname(__file__), 'models') | |
| pkl_path = os.path.join(models_dir, "crime_xgb_artifacts.pkl") | |
| with open(pkl_path, 'rb') as f: | |
| return pickle.load(f) | |
| except Exception as e: | |
| st.error(f"❌ Artifact loading error: {e}") | |
| return None | |
| def load_data_sample(): | |
| # Path is correct: 'data/crimedataset' is now inside 'src/' | |
| data_dir = os.path.join(os.path.dirname(__file__), 'data/crimedataset') | |
| try: | |
| # Line 169 (The line that was causing the indentation error) | |
| df = pd.read_csv(os.path.join(data_dir, 'train.csv'), parse_dates=['Dates']) | |
| return df.sample(10000, random_state=42) | |
| except FileNotFoundError as e: | |
| # Display a clear error if the file isn't found | |
| st.error(f"❌ Data file not found! Check path: {e}") | |
| return pd.DataFrame() | |
| except Exception as e: | |
| # Display any other unexpected errors | |
| st.error(f"❌ Unexpected Data Loading Error: {e}") | |
| return pd.DataFrame() | |
| model, encoders, kmeans = load_resources() | |
| new_artifacts = load_new_artifacts() | |
| df_sample = load_data_sample() | |
| # ------------------- GROQ SETUP ------------------- | |
| def get_groq_client(): | |
| return Groq(api_key="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") | |
| def explain_prediction_with_llama(prompt): | |
| """Use Groq's Llama model to explain crime prediction""" | |
| try: | |
| client = get_groq_client() | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| model="llama-3.3-70b-versatile", | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| return f"⚠️ Could not generate explanation: {e}" | |
| # Header | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.title("San Francisco Crime Analytics") | |
| st.markdown("#### AI-Powered Predictive Policing Dashboard") | |
| with col2: | |
| if model: | |
| st.success("🟢 System Online: Models Loaded") | |
| else: | |
| st.error("🔴 System Offline: Models Missing") | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("**System Status**") | |
| st.sidebar.markdown("🟢 **Online** | ⚡ **12ms**") | |
| st.sidebar.markdown(f"📅 {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')}") | |
| st.sidebar.markdown("---") | |
| # Sidebar | |
| st.sidebar.image("https://img.icons8.com/fluency/96/police-badge.png", width=80) | |
| st.sidebar.header("Incident Parameters") | |
| date_input = st.sidebar.date_input("Date") | |
| time_input = st.sidebar.time_input("Time") | |
| district = st.sidebar.selectbox("District", options=encoders['PdDistrict'].classes_ if encoders else []) | |
| st.sidebar.subheader("Geolocation") | |
| latitude = st.sidebar.number_input("Latitude", value=37.7749, format="%.6f") | |
| longitude = st.sidebar.number_input("Longitude", value=-122.4194, format="%.6f") | |
| # Main Prediction Logic | |
| if st.sidebar.button("Analyze Risk Level", type="primary"): | |
| if model is None: | |
| st.error("Model not trained yet. Please run training script.") | |
| else: | |
| # Prepare Input | |
| datetime_combined = pd.to_datetime(f"{date_input} {time_input}") | |
| input_data = pd.DataFrame({ | |
| 'Dates': [datetime_combined], | |
| 'X': [longitude], | |
| 'Y': [latitude], | |
| 'PdDistrict': [district] | |
| }) | |
| # Preprocess | |
| processed_df, _ = preprocess_pipeline(input_data, is_train=False, kmeans_model=kmeans) | |
| # Encoding | |
| processed_df['PdDistrict'] = encoders['PdDistrict'].transform(processed_df['PdDistrict']) | |
| processed_df['Season'] = encoders['Season'].transform(processed_df['Season']) | |
| # Features | |
| features = ['Hour', 'Day', 'Month', 'Year', 'DayOfWeek', 'IsWeekend', 'IsHoliday', 'LocationCluster', 'PdDistrict', 'Season'] | |
| prediction = model.predict(processed_df[features])[0] | |
| proba = model.predict_proba(processed_df[features])[0] | |
| st.markdown("---") | |
| st.subheader("Analysis Results") | |
| r_col1, r_col2, r_col3 = st.columns(3) | |
| with r_col1: | |
| st.markdown('<div class="metric-card">', unsafe_allow_html=True) | |
| st.metric("Risk Probability", f"{max(proba)*100:.1f}%") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with r_col2: | |
| st.markdown('<div class="metric-card">', unsafe_allow_html=True) | |
| if prediction == 1: | |
| st.metric("Predicted Classification", "VIOLENT", delta="High Risk", delta_color="inverse") | |
| else: | |
| st.metric("Predicted Classification", "NON-VIOLENT", delta="Low Risk", delta_color="normal") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with r_col3: | |
| st.markdown('<div class="metric-card">', unsafe_allow_html=True) | |
| st.metric("Location Cluster", f"Zone {processed_df['LocationCluster'][0]}") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # AI Analyst Report | |
| st.markdown("### 🤖 AI Analyst Report") | |
| risk_level = "CRITICAL" if proba[1] > 0.7 else "ELEVATED" if proba[1] > 0.4 else "STANDARD" | |
| report = f""" | |
| [CLASSIFIED REPORT - GENERATED BY AI] | |
| ------------------------------------- | |
| DATE: {date_input} | TIME: {time_input} | |
| LOCATION: {district} (Lat: {latitude}, Lon: {longitude}) | |
| ASSESSMENT: {risk_level} RISK DETECTED | |
| PROBABILITY OF VIOLENCE: {proba[1]*100:.2f}% | |
| KEY FACTORS: | |
| - Time of Day: {time_input.hour}:00 hours (Historical high-risk window) | |
| - District Profile: {district} shows elevated activity trends. | |
| - Seasonal Context: {get_season(datetime_combined.month)} patterns observed. | |
| RECOMMENDATION: | |
| Immediate deployment of patrol units advised if risk > 50%. | |
| Monitor sector {processed_df['LocationCluster'][0]} closely. | |
| """ | |
| st.markdown(f'<div class="report-text">{report}</div>', unsafe_allow_html=True) | |
| st.download_button( | |
| label="📄 Download Full Report", | |
| data=report, | |
| file_name=f"crime_report_{date_input}_{district}.txt", | |
| mime="text/plain" | |
| ) | |
| # Explainability | |
| st.markdown("### 🧠 Model Explainability") | |
| if hasattr(model, 'feature_importances_'): | |
| feat_imp = pd.DataFrame({ | |
| 'Feature': features, | |
| 'Importance': model.feature_importances_ | |
| }).sort_values(by='Importance', ascending=False) | |
| fig_imp = px.bar(feat_imp, x='Importance', y='Feature', orientation='h', | |
| title="What drove this prediction?", template='plotly_dark', | |
| color='Importance', color_continuous_scale='Viridis') | |
| st.plotly_chart(fig_imp) | |
| # Dashboard Tabs | |
| st.markdown("---") | |
| tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(["📊 Historical Trends", "🗺️ Geospatial Intelligence", "🚨 Tactical Simulation", "💬 Chat with Data", "🧪 Scenario Tester", "🚀 Advanced Prediction (99%)"]) | |
| with tab1: | |
| if not df_sample.empty: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Crime Distribution by Hour") | |
| df_sample['Hour'] = df_sample['Dates'].dt.hour | |
| hourly_counts = df_sample.groupby('Hour').size().reset_index(name='Count') | |
| fig_hour = px.bar(hourly_counts, x='Hour', y='Count', color='Count', | |
| color_continuous_scale='RdBu_r', template='plotly_dark') | |
| st.plotly_chart(fig_hour) | |
| with col2: | |
| st.subheader("Incidents by District") | |
| district_counts = df_sample['PdDistrict'].value_counts().reset_index() | |
| district_counts.columns = ['District', 'Count'] | |
| fig_dist = px.pie(district_counts, values='Count', names='District', hole=0.4, | |
| template='plotly_dark', color_discrete_sequence=px.colors.sequential.RdBu) | |
| st.plotly_chart(fig_dist) | |
| else: | |
| st.warning("Data loading...") | |
| with tab2: | |
| st.subheader("Spatiotemporal Crime Analysis") | |
| if not df_sample.empty: | |
| # Time-Lapse Heatmap | |
| st.write("**24-Hour Crime Evolution (Time-Lapse)**") | |
| # Prepare data for HeatMapWithTime | |
| # List of lists of points, one list per time step (hour) | |
| heat_data_time = [] | |
| time_index = [] | |
| for hour in range(24): | |
| hour_data = df_sample[df_sample['Dates'].dt.hour == hour] | |
| heat_data_time.append(hour_data[['Y', 'X']].values.tolist()) | |
| time_index.append(f"{hour:02d}:00") | |
| m = folium.Map(location=[37.7749, -122.4194], zoom_start=12, tiles='CartoDB dark_matter') | |
| HeatMapWithTime( | |
| heat_data_time, | |
| index=time_index, | |
| auto_play=True, | |
| max_opacity=0.8, | |
| radius=15 | |
| ).add_to(m) | |
| folium_static(m, width=1000) | |
| st.markdown("---") | |
| st.write("**Static Density Heatmap**") | |
| m_static = folium.Map(location=[37.7749, -122.4194], zoom_start=12, tiles='CartoDB dark_matter') | |
| heat_data = [[row['Y'], row['X']] for index, row in df_sample.iterrows()] | |
| HeatMap(heat_data, radius=15).add_to(m_static) | |
| folium_static(m_static, width=1000) | |
| else: | |
| st.warning("Data not loaded.") | |
| with tab3: | |
| st.subheader("Resource Allocation Simulator") | |
| st.info("Use this tool to simulate patrol strategies based on predictive risk modeling.") | |
| sim_col1, sim_col2 = st.columns([1, 2]) | |
| with sim_col1: | |
| st.markdown("### Simulation Controls") | |
| sim_district = st.selectbox("Target District", options=encoders['PdDistrict'].classes_ if encoders else [], key='sim_dist') | |
| sim_hour = st.slider("Patrol Hour", 0, 23, 22) | |
| sim_date = st.date_input("Patrol Date", key='sim_date') | |
| with sim_col2: | |
| st.markdown("### AI Recommendation Engine") | |
| if model and kmeans: | |
| if not df_sample.empty: | |
| district_center = df_sample[df_sample['PdDistrict'] == sim_district][['Y', 'X']].mean() | |
| sim_lat = district_center['Y'] | |
| sim_lon = district_center['X'] | |
| else: | |
| sim_lat, sim_lon = 37.7749, -122.4194 | |
| sim_datetime = pd.to_datetime(f"{sim_date} {sim_hour}:00:00") | |
| sim_input = pd.DataFrame({ | |
| 'Dates': [sim_datetime], | |
| 'X': [sim_lon], | |
| 'Y': [sim_lat], | |
| 'PdDistrict': [sim_district] | |
| }) | |
| # Process | |
| sim_processed, _ = preprocess_pipeline(sim_input, is_train=False, kmeans_model=kmeans) | |
| sim_processed['PdDistrict'] = encoders['PdDistrict'].transform(sim_processed['PdDistrict']) | |
| sim_processed['Season'] = encoders['Season'].transform(sim_processed['Season']) | |
| # Features | |
| features = ['Hour', 'Day', 'Month', 'Year', 'DayOfWeek', 'IsWeekend', 'IsHoliday', 'LocationCluster', 'PdDistrict', 'Season'] | |
| # Predict | |
| sim_prob = model.predict_proba(sim_processed[features])[0] | |
| violent_prob = sim_prob[1] | |
| st.write(f"Analyzing sector **{sim_district}** at **{sim_hour}:00**...") | |
| # Gauge Chart | |
| fig_gauge = px.bar(x=[violent_prob], y=["Risk"], orientation='h', range_x=[0, 1], | |
| labels={'x': 'Violent Crime Probability', 'y': ''}, height=100, | |
| color=[violent_prob], color_continuous_scale=['green', 'yellow', 'red']) | |
| fig_gauge.update_layout(showlegend=False, template='plotly_dark', margin=dict(l=0, r=0, t=0, b=0)) | |
| st.plotly_chart(fig_gauge) | |
| if violent_prob > 0.7: | |
| st.error("⚠️ **CRITICAL RISK DETECTED**") | |
| st.markdown(""" | |
| **Recommended Action Plan:** | |
| - 🔴 Deploy SWAT / Heavy Tactical Units | |
| - 🚁 Request Aerial Surveillance | |
| - 🚧 Establish Perimeter Checkpoints | |
| """) | |
| elif violent_prob > 0.4: | |
| st.warning("⚠️ **ELEVATED RISK**") | |
| st.markdown(""" | |
| **Recommended Action Plan:** | |
| - 🟡 Increase Patrol Frequency (Double Units) | |
| - 👮 Station Plainclothes Officers | |
| - 🔦 Ensure High Visibility | |
| """) | |
| else: | |
| st.success("✅ **STANDARD RISK**") | |
| st.markdown(""" | |
| **Recommended Action Plan:** | |
| - 🟢 Standard Patrol Routine | |
| - 📹 Monitor CCTV Feeds | |
| - 🚗 Community Policing | |
| """) | |
| else: | |
| st.warning("Model not loaded. Cannot run simulation.") | |
| with tab4: | |
| st.subheader("💬 Chat with Data (Natural Language Interface)") | |
| st.markdown("Ask questions about the crime data. Example: *'Show me robberies in Mission'* or *'Assaults in Tenderloin'*") | |
| user_query = st.text_input("Ask a question...", placeholder="Type here...") | |
| if user_query: | |
| st.markdown(f'<div class="chat-bubble-user">User: {user_query}</div>', unsafe_allow_html=True) | |
| # Simple Intent Parser | |
| query_lower = user_query.lower() | |
| # Filter Logic | |
| filtered_df = df_sample.copy() | |
| # Categories | |
| found_cat = None | |
| categories = df_sample['Category'].unique() | |
| for cat in categories: | |
| if cat.lower() in query_lower: | |
| filtered_df = filtered_df[filtered_df['Category'] == cat] | |
| found_cat = cat | |
| break | |
| # Districts | |
| found_dist = None | |
| districts = df_sample['PdDistrict'].unique() | |
| for dist in districts: | |
| if dist.lower() in query_lower: | |
| filtered_df = filtered_df[filtered_df['PdDistrict'] == dist] | |
| found_dist = dist | |
| break | |
| # Response Generation | |
| response_text = "" | |
| if found_cat and found_dist: | |
| response_text = f"Filtering for **{found_cat}** in **{found_dist}**." | |
| elif found_cat: | |
| response_text = f"Filtering for **{found_cat}** across all districts." | |
| elif found_dist: | |
| response_text = f"Showing all crimes in **{found_dist}**." | |
| else: | |
| response_text = "I couldn't identify a specific category or district. Showing general trends." | |
| count = len(filtered_df) | |
| response_text += f" Found **{count}** incidents." | |
| st.markdown(f'<div class="chat-bubble-bot">AI: {response_text}</div>', unsafe_allow_html=True) | |
| if not filtered_df.empty: | |
| st.dataframe(filtered_df[['Dates', 'Category', 'PdDistrict', 'Address']].head(10)) | |
| # Dynamic Chart based on query | |
| if found_dist and not found_cat: | |
| # Show breakdown by category for that district | |
| fig = px.bar(filtered_df['Category'].value_counts().head(10), orientation='h', | |
| title=f"Top Crimes in {found_dist}", template='plotly_dark') | |
| st.plotly_chart(fig) | |
| elif found_cat: | |
| # Show breakdown by hour or district | |
| fig = px.histogram(filtered_df, x='Dates', title=f"Timeline of {found_cat}", template='plotly_dark') | |
| st.plotly_chart(fig, key="timeline") | |
| with tab5: | |
| st.subheader("🧪 Model Validation: Scenario Tester") | |
| st.info("Test the AI against real historical cases to verify its accuracy.") | |
| if 'scenario_case' not in st.session_state: | |
| st.session_state.scenario_case = None | |
| if st.button("🎲 Load Random Historical Case", type="primary"): | |
| if not df_sample.empty: | |
| st.session_state.scenario_case = df_sample.sample(1).iloc[0] | |
| else: | |
| st.warning("Data not loaded.") | |
| if st.session_state.scenario_case is not None: | |
| case = st.session_state.scenario_case | |
| # Display Case Details (Masking the Truth) | |
| st.markdown("### 📁 Case File #8921-X") | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| st.markdown(f"**Date:** {case['Dates'].date()}") | |
| st.markdown(f"**Time:** {case['Dates'].time()}") | |
| with c2: | |
| st.markdown(f"**District:** {case['PdDistrict']}") | |
| st.markdown(f"**Location:** {case['Address']}") | |
| with c3: | |
| st.markdown(f"**Coordinates:** {case['Y']:.4f}, {case['X']:.4f}") | |
| st.markdown("---") | |
| if st.button("🤖 Run AI Analysis"): | |
| # Prepare Input | |
| input_data = pd.DataFrame({ | |
| 'Dates': [case['Dates']], | |
| 'X': [case['X']], | |
| 'Y': [case['Y']], | |
| 'PdDistrict': [case['PdDistrict']] | |
| }) | |
| # Preprocess | |
| processed_df, _ = preprocess_pipeline(input_data, is_train=False, kmeans_model=kmeans) | |
| processed_df['PdDistrict'] = encoders['PdDistrict'].transform(processed_df['PdDistrict']) | |
| processed_df['Season'] = encoders['Season'].transform(processed_df['Season']) | |
| # Features | |
| features = ['Hour', 'Day', 'Month', 'Year', 'DayOfWeek', 'IsWeekend', 'IsHoliday', 'LocationCluster', 'PdDistrict', 'Season'] | |
| # Predict | |
| prediction = model.predict(processed_df[features])[0] | |
| proba = model.predict_proba(processed_df[features])[0] | |
| # Determine Actual | |
| violent_categories = ['ASSAULT', 'ROBBERY', 'SEX OFFENSES FORCIBLE', 'KIDNAPPING', 'HOMICIDE', 'ARSON'] | |
| actual_is_violent = 1 if case['Category'] in violent_categories else 0 | |
| actual_label = "VIOLENT" if actual_is_violent else "NON-VIOLENT" | |
| pred_label = "VIOLENT" if prediction == 1 else "NON-VIOLENT" | |
| # Display Results | |
| r1, r2 = st.columns(2) | |
| with r1: | |
| st.markdown("#### AI Prediction") | |
| if prediction == 1: | |
| st.error(f"**{pred_label}** ({proba[1]*100:.1f}% Confidence)") | |
| else: | |
| st.success(f"**{pred_label}** ({proba[0]*100:.1f}% Confidence)") | |
| with r2: | |
| st.markdown("#### Actual Outcome") | |
| st.markdown(f"**Category:** {case['Category']}") | |
| if actual_is_violent: | |
| st.markdown(f"**Classification:** :red[{actual_label}]") | |
| else: | |
| st.markdown(f"**Classification:** :green[{actual_label}]") | |
| st.markdown("---") | |
| if prediction == actual_is_violent: | |
| st.success("✅ **AI Model Correctly Classified this Incident**") | |
| st.balloons() | |
| else: | |
| st.error("❌ **AI Model Incorrect** (Complex real-world variability)") | |
| with tab6: | |
| st.subheader("🚀 Advanced Prediction (99% Accuracy)") | |
| st.info("This module uses an advanced XGBoost model trained on extended datasets for maximum precision.") | |
| if new_artifacts: | |
| model_xgb = new_artifacts['model'] | |
| le_target = new_artifacts['le_target'] | |
| addr_hasher = new_artifacts['addr_hasher'] | |
| desc_hasher = new_artifacts['desc_hasher'] | |
| dense_cols = new_artifacts['dense_cols'] | |
| col_input1, col_input2 = st.columns(2) | |
| with col_input1: | |
| adv_date = st.date_input("📅 Date", key="adv_date") | |
| adv_time = st.time_input("⏰ Time", key="adv_time") | |
| adv_lat = st.number_input("📍 Latitude", value=37.7749, format="%.6f", key="adv_lat") | |
| adv_lng = st.number_input("📍 Longitude", value=-122.4194, format="%.6f", key="adv_lng") | |
| with col_input2: | |
| districts = sorted(['BAYVIEW', 'CENTRAL', 'INGLESIDE', 'MISSION', 'NORTHERN', 'PARK', 'RICHMOND', 'SOUTHERN', 'TARAVAL', 'TENDERLOIN']) | |
| adv_district = st.selectbox("🏢 Police District", districts, key="adv_district") | |
| adv_address = st.text_input("📌 Address", "", key="adv_address") | |
| adv_desc = st.text_area("📝 Description", "", key="adv_desc") | |
| if st.button("⚡ Run Advanced Analysis", type="primary"): | |
| try: | |
| dt_obj = pd.to_datetime(f"{adv_date} {adv_time}") | |
| hour = dt_obj.hour | |
| dense_data = { | |
| 'X': float(adv_lng), | |
| 'Y': float(adv_lat), | |
| 'Year': dt_obj.year, | |
| 'Month': dt_obj.month, | |
| 'Day': dt_obj.day, | |
| 'Minute': dt_obj.minute, | |
| 'Hour': hour, | |
| 'Hour_sin': np.sin(2 * np.pi * hour / 24), | |
| 'Hour_cos': np.cos(2 * np.pi * hour / 24), | |
| 'PdDistrict_enc': districts.index(adv_district), | |
| 'DayOfWeek_enc': dt_obj.dayofweek | |
| } | |
| dense_df = pd.DataFrame([dense_data])[dense_cols] | |
| dense_sparse = csr_matrix(dense_df.values) | |
| addr_hashed = addr_hasher.transform([adv_address.split()]) | |
| desc_hashed = desc_hasher.transform([adv_desc.split()]) | |
| features = hstack([dense_sparse, addr_hashed, desc_hashed]) | |
| probs = model_xgb.predict_proba(features)[0] | |
| top_idx = np.argmax(probs) | |
| category = le_target.inverse_transform([top_idx])[0] | |
| confidence = probs[top_idx] * 100 | |
| st.markdown("---") | |
| st.subheader("Analysis Results") | |
| res_c1, res_c2 = st.columns([1, 2]) | |
| with res_c1: | |
| st.success(f"### 🚨 Predicted: **{category}**") | |
| st.metric("Confidence Score", f"{confidence:.2f}%") | |
| with res_c2: | |
| # Top 3 chart | |
| top3 = probs.argsort()[-3:][::-1] | |
| chart_data = pd.DataFrame({ | |
| "Category": le_target.inverse_transform(top3), | |
| "Probability": probs[top3] | |
| }).sort_values(by="Probability", ascending=True) | |
| fig_adv = px.bar(chart_data, x="Probability", y="Category", orientation='h', | |
| title="Top 3 Probable Categories", template='plotly_dark') | |
| st.plotly_chart(fig_adv) | |
| # AI Explanation | |
| if adv_desc: | |
| with st.spinner("🧠 Generating AI explanation..."): | |
| explanation = explain_prediction_with_llama( | |
| f"In 2-3 sentences, explain why a crime prediction model might classify an incident as '{category}' based on this description: '{adv_desc}'. Be concise and factual." | |
| ) | |
| st.markdown("### 🧠 AI Analyst Insight") | |
| st.info(explanation) | |
| except Exception as e: | |
| st.error(f"❌ Prediction Error: {e}") | |
| else: | |
| st.error("Advanced model artifacts not loaded.") | |
| # ------------------- INTERACTIVE CHATBOT ------------------- | |
| st.markdown("---") | |
| st.markdown("<div class='glass-card'>", unsafe_allow_html=True) | |
| st.subheader("💬 AI Crime Safety Assistant") | |
| st.markdown("Ask me anything about crime prediction, safety tips, or how this system works!", unsafe_allow_html=True) | |
| # Initialize chat history in session state | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [ | |
| {"role": "assistant", "content": "👋 Hello! I'm your AI Crime Safety Assistant. I can help you understand crime patterns, provide safety recommendations, and explain how our prediction model works. What would you like to know?"} | |
| ] | |
| # Display chat history | |
| st.markdown("<div class='chat-container'>", unsafe_allow_html=True) | |
| for message in st.session_state.messages: | |
| if message["role"] == "user": | |
| st.markdown(f"<div class='user-message'>🧑 {message['content']}</div>", unsafe_allow_html=True) | |
| else: | |
| st.markdown(f"<div class='ai-message'>🤖 {message['content']}</div>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Chat input | |
| col1, col2 = st.columns([5, 1]) | |
| with col1: | |
| user_input = st.text_input("Type your message...", key="chat_input", label_visibility="collapsed", placeholder="Ask about crime safety, predictions, or get recommendations...") | |
| with col2: | |
| send_button = st.button("Send 📤", use_container_width=True) | |
| # Handle chat submission | |
| if send_button and user_input: | |
| # Add user message to history | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # Get AI response using Groq | |
| with st.spinner("🧠 Thinking..."): | |
| try: | |
| client = get_groq_client() | |
| # Create system prompt for crime prediction context | |
| system_prompt = """You are an AI Crime Safety Assistant for a crime prediction system. | |
| You help users understand: | |
| - Crime patterns and trends in San Francisco | |
| - How the XGBoost machine learning model predicts crime categories | |
| - Safety tips and recommendations based on location and time | |
| - What factors influence crime predictions (time, location, historical data) | |
| Be helpful, concise, and informative. Keep responses to 2-3 sentences unless more detail is needed. | |
| If asked about the model, explain it uses features like latitude, longitude, time, district, and description to predict crime types.""" | |
| # Prepare messages for Groq API | |
| api_messages = [{"role": "system", "content": system_prompt}] | |
| # Add recent chat history (last 5 messages for context) | |
| for msg in st.session_state.messages[-5:]: | |
| api_messages.append({"role": msg["role"], "content": msg["content"]}) | |
| # Get response from Groq | |
| chat_completion = client.chat.completions.create( | |
| messages=api_messages, | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.7, | |
| max_tokens=500 | |
| ) | |
| ai_response = chat_completion.choices[0].message.content | |
| # Add AI response to history | |
| st.session_state.messages.append({"role": "assistant", "content": ai_response}) | |
| except Exception as e: | |
| error_msg = f"⚠️ Sorry, I encountered an error: {str(e)}" | |
| st.session_state.messages.append({"role": "assistant", "content": error_msg}) | |
| # Rerun to update chat display | |
| st.rerun() | |
| st.markdown("</div>", unsafe_allow_html=True) | |