"""Streamlit entrypoint for AgriPredict (refactored). Run with: `streamlit run streamlit_app.py` from project root. """ import streamlit as st from datetime import datetime, timedelta import pandas as pd from sklearn.preprocessing import MinMaxScaler import os from dotenv import load_dotenv from src.agri_predict import ( fetch_and_process_data, fetch_and_store_data, preprocess_data, train_and_forecast, forecast, collection_to_dataframe, get_dataframe_from_collection, ) from src.agri_predict.constants import state_market_dict from src.agri_predict.utils import authenticate_user from src.agri_predict.config import get_collections # Load environment variables load_dotenv() IS_PROD = os.getenv("PROD", "False").lower() == "true" st.set_page_config(layout="wide") def load_all_data(_collection): """Load all data from MongoDB collection.""" data = list(_collection.find({})) if not data: return pd.DataFrame() df = pd.DataFrame(data) # Drop MongoDB _id field if '_id' in df.columns: df = df.drop(columns=['_id']) # Convert data types if 'Reported Date' in df.columns: df['Reported Date'] = pd.to_datetime(df['Reported Date']) if 'Modal Price (Rs./Quintal)' in df.columns: df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce') if 'Arrivals (Tonnes)' in df.columns: df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce') return df def get_filtered_data(_collection, state=None, market=None, days=30): """Get filtered data based on parameters.""" query_filter = {"Reported Date": {"$gte": datetime.now() - timedelta(days=days)}} if state and state != 'India': query_filter["State Name"] = state if market: query_filter["Market Name"] = market data = list(_collection.find(query_filter)) if not data: return pd.DataFrame() df = pd.DataFrame(data) # Drop MongoDB _id field if '_id' in df.columns: df = df.drop(columns=['_id']) # Convert data types if 'Reported Date' in df.columns: df['Reported Date'] = pd.to_datetime(df['Reported Date']) if 'Modal Price (Rs./Quintal)' in df.columns: df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce') if 'Arrivals (Tonnes)' in df.columns: df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce') return df st.markdown(""" """, unsafe_allow_html=True) if 'authenticated' not in st.session_state: st.session_state.authenticated = False if st.session_state.authenticated: # Get collections after authentication try: cols = get_collections() except Exception as exc: st.error(f"Configuration error: {exc}") st.stop() collection = cols['collection'] impExp = cols['impExp'] st.title("🌾 AgriPredict Dashboard") if st.button("Get Live Data Feed"): fetch_and_store_data() view_mode = st.radio("View Mode", ["Statistics", "Plots", "Predictions", "Exim"], horizontal=True, label_visibility="collapsed") if view_mode == "Plots": st.sidebar.header("Filters") selected_period = st.sidebar.selectbox("Select Time Period", ["2 Weeks", "1 Month", "3 Months", "1 Year", "5 Years"], index=1) period_mapping = {"2 Weeks": 14, "1 Month": 30, "3 Months": 90, "1 Year": 365, "2 Years": 730, "5 Years": 1825} st.session_state.selected_period = period_mapping[selected_period] state_options = list(state_market_dict.keys()) + ['India'] selected_state = st.sidebar.selectbox("Select", state_options) market_wise = False if selected_state != 'India': market_wise = st.sidebar.checkbox("Market Wise Analysis") if market_wise: markets = state_market_dict.get(selected_state, []) selected_market = st.sidebar.selectbox("Select Market", markets) query_filter = {"State Name": selected_state, "Market Name": selected_market} else: query_filter = {"State Name": selected_state} else: query_filter = {} data_type = st.sidebar.radio("Select Data Type", ["Price", "Volume", "Both"]) if st.sidebar.button("✨ Let's go!"): try: # Load data state_param = selected_state if selected_state != 'India' else None market_param = selected_market if market_wise else None df = get_filtered_data(collection, state_param, market_param, st.session_state.selected_period) if not df.empty: # Group by date and aggregate df_grouped = df.groupby('Reported Date', as_index=False).agg({ 'Arrivals (Tonnes)': 'sum', 'Modal Price (Rs./Quintal)': 'mean' }) # Create complete date range and fill gaps date_range = pd.date_range( start=df_grouped['Reported Date'].min(), end=df_grouped['Reported Date'].max(), freq='D' ) df_grouped = df_grouped.set_index('Reported Date').reindex(date_range).rename_axis('Reported Date').reset_index() # Fill missing values using the working method df_grouped['Arrivals (Tonnes)'] = df_grouped['Arrivals (Tonnes)'].ffill().bfill() df_grouped['Modal Price (Rs./Quintal)'] = df_grouped['Modal Price (Rs./Quintal)'].ffill().bfill() st.subheader(f"📈 Trends for {selected_state} ({'Market: ' + selected_market if market_wise else 'State'})") if data_type == "Both": # Min-Max Scaling scaler = MinMaxScaler() df_grouped[['Scaled Price', 'Scaled Arrivals']] = scaler.fit_transform( df_grouped[['Modal Price (Rs./Quintal)', 'Arrivals (Tonnes)']] ) import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Scatter( x=df_grouped['Reported Date'], y=df_grouped['Scaled Price'], mode='lines', name='Scaled Price', line=dict(width=1, color='green'), text=df_grouped['Modal Price (Rs./Quintal)'], hovertemplate='Date: %{x}
Scaled Price: %{y:.2f}
Actual Price: %{text:.2f}' )) fig.add_trace(go.Scatter( x=df_grouped['Reported Date'], y=df_grouped['Scaled Arrivals'], mode='lines', name='Scaled Arrivals', line=dict(width=1, color='blue'), text=df_grouped['Arrivals (Tonnes)'], hovertemplate='Date: %{x}
Scaled Arrivals: %{y:.2f}
Actual Arrivals: %{text:.2f}' )) fig.update_layout( title="Price and Arrivals Trend", xaxis_title='Date', yaxis_title='Scaled Values', template='plotly_white' ) st.plotly_chart(fig, use_container_width=True) elif data_type == "Price": # Plot Modal Price import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Scatter( x=df_grouped['Reported Date'], y=df_grouped['Modal Price (Rs./Quintal)'], mode='lines', name='Modal Price', line=dict(width=1, color='green') )) fig.update_layout(title="Modal Price Trend", xaxis_title='Date', yaxis_title='Price (/Quintall)', template='plotly_white') st.plotly_chart(fig, use_container_width=True) elif data_type == "Volume": # Plot Arrivals (Tonnes) import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Scatter( x=df_grouped['Reported Date'], y=df_grouped['Arrivals (Tonnes)'], mode='lines', name='Arrivals', line=dict(width=1, color='blue') )) fig.update_layout(title="Arrivals Trend", xaxis_title='Date', yaxis_title='Volume (in Tonnes)', template='plotly_white') st.plotly_chart(fig, use_container_width=True) else: st.warning("⚠️ No data found for the selected filters.") except Exception as e: st.error(f"❌ Error fetching data: {e}") elif view_mode == "Predictions": st.subheader("📊 Model Analysis") sub_option = st.radio("Select one of the following", ["India", "States", "Market"], horizontal=True) sub_timeline = st.radio("Select one of the following horizons", ["14 days", "1 month", "3 month"], horizontal=True) if sub_option == "States": states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"] selected_state = st.selectbox("Select State for Model Training", states) filter_key = f"state_{selected_state}" if not IS_PROD and st.button("Train and Forecast"): query_filter = {"State Name": selected_state} df = fetch_and_process_data(query_filter) if df is not None: if sub_timeline == "14 days": train_and_forecast(df, filter_key, 14) elif sub_timeline == "1 month": train_and_forecast(df, filter_key, 30) else: train_and_forecast(df, filter_key, 90) else: st.error("❌ No data available for the selected state.") if st.button("Forecast"): query_filter = {"State Name": selected_state} df = fetch_and_process_data(query_filter) if df is not None: if sub_timeline == "14 days": forecast(df, filter_key, 14) elif sub_timeline == "1 month": forecast(df, filter_key, 30) else: forecast(df, filter_key, 90) else: st.error("❌ No data available for the selected state.") elif sub_option == "Market": market_options = ["Rajkot", "Gondal", "Kalburgi", "Amreli"] selected_market = st.selectbox("Select Market for Model Training", market_options) filter_key = f"market_{selected_market}" if not IS_PROD and st.button("Train and Forecast"): query_filter = {"Market Name": selected_market} df = fetch_and_process_data(query_filter) if df is not None: if sub_timeline == "14 days": train_and_forecast(df, filter_key, 14) elif sub_timeline == "1 month": train_and_forecast(df, filter_key, 30) else: train_and_forecast(df, filter_key, 90) else: st.error("❌ No data available for the selected market.") elif st.button("Forecast"): query_filter = {"Market Name": selected_market} df = fetch_and_process_data(query_filter) if df is not None: if sub_timeline == "14 days": forecast(df, filter_key, 14) elif sub_timeline == "1 month": forecast(df, filter_key, 30) else: forecast(df, filter_key, 90) else: st.error("❌ No data available for the selected market.") elif sub_option == "India": df = collection_to_dataframe(impExp) if not IS_PROD and st.button("Train and Forecast"): query_filter = {} df = fetch_and_process_data(query_filter) if df is not None: if sub_timeline == "14 days": train_and_forecast(df, "India", 14) elif sub_timeline == "1 month": train_and_forecast(df, "India", 30) else: train_and_forecast(df, "India", 90) else: st.error("❌ No data available for forecasting.") if st.button("Forecast"): query_filter = {} df = fetch_and_process_data(query_filter) if df is not None: if sub_timeline == "14 days": forecast(df, "India", 14) elif sub_timeline == "1 month": forecast(df, "India", 30) else: forecast(df, "India", 90) else: st.error("❌ No data available for forecasting.") elif view_mode == "Statistics": # Use cached data loading df = load_all_data(collection) if not df.empty: from src.agri_predict.plotting import display_statistics display_statistics(df) else: st.warning("No data available to display statistics.") elif view_mode == "Exim": df = collection_to_dataframe(impExp) plot_option = st.radio("Select the data to visualize:", ["Import Price", "Import Quantity", "Export Price", "Export Quantity"], horizontal=True) time_period = st.selectbox("Select time period:", ["1 Month", "6 Months", "1 Year", "2 Years"]) df["Reported Date"] = pd.to_datetime(df["Reported Date"], format="%Y-%m-%d") if time_period == "1 Month": start_date = pd.Timestamp.now() - pd.DateOffset(months=1) elif time_period == "6 Months": start_date = pd.Timestamp.now() - pd.DateOffset(months=6) elif time_period == "1 Year": start_date = pd.Timestamp.now() - pd.DateOffset(years=1) else: start_date = pd.Timestamp.now() - pd.DateOffset(years=2) filtered_df = df[df["Reported Date"] >= start_date] if plot_option == "Import Price": grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_IMPORT"].mean().rename(columns={"VALUE_IMPORT": "Average Import Price"}) y_axis_label = "Average Import Price (Rs.)" elif plot_option == "Import Quantity": grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Import Quantity"}) y_axis_label = "Total Import Quantity (Tonnes)" elif plot_option == "Export Price": grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_EXPORT"].mean().rename(columns={"VALUE_EXPORT": "Average Export Price"}) y_axis_label = "Average Export Price (Rs.)" else: grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Export Quantity"}) y_axis_label = "Total Export Quantity (Tonnes)" import plotly.express as px fig = px.line(grouped_df, x="Reported Date", y=grouped_df.columns[1], title=f"{plot_option} Over Time", labels={"Reported Date": "Date", grouped_df.columns[1]: y_axis_label}) st.plotly_chart(fig) else: with st.form("login_form"): st.subheader("Please log in") username = st.text_input("Username") password = st.text_input("Password", type="password") login_button = st.form_submit_button("Login") if login_button: # Get collections for authentication try: cols = get_collections() users_collection = cols['users_collection'] except Exception as exc: st.error(f"Database connection error: {exc}") st.stop() if authenticate_user(username, password, users_collection): st.session_state.authenticated = True st.session_state['username'] = username st.write("Login successful!") st.rerun() else: st.error("Invalid username or password")