File size: 17,918 Bytes
3029a46
 
 
 
 
 
 
 
ad7e56c
 
3029a46
 
 
 
 
 
 
 
 
 
 
 
 
 
ad7e56c
 
 
 
3029a46
 
 
 
ecb9d4e
c699128
601baad
ecb9d4e
 
 
601baad
 
 
 
 
 
 
 
 
 
 
 
 
ecb9d4e
 
 
 
c699128
ecb9d4e
 
 
 
 
 
601baad
ecb9d4e
 
 
 
601baad
 
 
 
 
 
 
 
 
 
 
 
 
ecb9d4e
 
 
3029a46
 
0910ab6
 
 
 
 
 
 
 
 
 
3029a46
 
 
 
 
 
 
 
 
f3e88b4
3029a46
f3e88b4
3029a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c699128
ecb9d4e
 
98ef576
ecb9d4e
 
 
01e0bb9
 
 
 
 
 
87a2608
 
 
 
 
 
 
 
 
43f3f8c
 
87a2608
3029a46
01e0bb9
3029a46
87a2608
3029a46
87a2608
 
 
 
3029a46
 
87a2608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecb9d4e
3029a46
87a2608
3029a46
 
87a2608
 
 
 
 
 
 
3029a46
ecb9d4e
87a2608
 
3029a46
 
87a2608
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
 
 
 
 
 
 
 
 
 
 
 
ad7e56c
3029a46
 
ecb9d4e
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
 
ecb9d4e
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
 
 
ad7e56c
3029a46
 
ecb9d4e
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
 
ecb9d4e
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
ad7e56c
3029a46
 
ecb9d4e
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
 
ecb9d4e
 
 
 
 
 
 
3029a46
ecb9d4e
3029a46
 
ecb9d4e
 
 
 
 
 
 
3029a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e88b4
3029a46
f3e88b4
3029a46
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
"""Streamlit entrypoint for AgriPredict (refactored).

Run with: `streamlit run streamlit_app.py` from project root.
"""
import streamlit as st
from datetime import datetime, timedelta
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import os
from dotenv import load_dotenv

from src.agri_predict import (
    fetch_and_process_data,
    fetch_and_store_data,
    preprocess_data,
    train_and_forecast,
    forecast,
    collection_to_dataframe,
    get_dataframe_from_collection,
)
from src.agri_predict.constants import state_market_dict
from src.agri_predict.utils import authenticate_user
from src.agri_predict.config import get_collections

# Load environment variables
load_dotenv()
IS_PROD = os.getenv("PROD", "False").lower() == "true"


st.set_page_config(layout="wide")


def load_all_data(_collection):
    """Load all data from MongoDB collection."""
    data = list(_collection.find({}))
    if not data:
        return pd.DataFrame()
    df = pd.DataFrame(data)
    
    # Drop MongoDB _id field
    if '_id' in df.columns:
        df = df.drop(columns=['_id'])
    
    # Convert data types
    if 'Reported Date' in df.columns:
        df['Reported Date'] = pd.to_datetime(df['Reported Date'])
    if 'Modal Price (Rs./Quintal)' in df.columns:
        df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
    if 'Arrivals (Tonnes)' in df.columns:
        df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce')
    
    return df


def get_filtered_data(_collection, state=None, market=None, days=30):
    """Get filtered data based on parameters."""
    query_filter = {"Reported Date": {"$gte": datetime.now() - timedelta(days=days)}}
    if state and state != 'India':
        query_filter["State Name"] = state
    if market:
        query_filter["Market Name"] = market
    
    data = list(_collection.find(query_filter))
    if not data:
        return pd.DataFrame()
    
    df = pd.DataFrame(data)
    
    # Drop MongoDB _id field
    if '_id' in df.columns:
        df = df.drop(columns=['_id'])
    
    # Convert data types
    if 'Reported Date' in df.columns:
        df['Reported Date'] = pd.to_datetime(df['Reported Date'])
    if 'Modal Price (Rs./Quintal)' in df.columns:
        df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
    if 'Arrivals (Tonnes)' in df.columns:
        df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce')
    
    return df


st.markdown("""
    <style>
        .main { 
            max-width: 1200px; 
            margin: 0 auto;
            padding: 2rem;
        }
        .block-container {
            max-width: 1200px;
            padding-left: 5rem;
            padding-right: 5rem;
        }
        h1 { color: #4CAF50; font-family: 'Arial Black', sans-serif; }
        .stButton>button { background-color: #4CAF50; color: white; }
    </style>
""", unsafe_allow_html=True)

if 'authenticated' not in st.session_state:
    st.session_state.authenticated = False

if st.session_state.authenticated:
    # Get collections after authentication
    try:
        cols = get_collections()
    except Exception as exc:
        st.error(f"Configuration error: {exc}")
        st.stop()
    
    collection = cols['collection']
    impExp = cols['impExp']
    
    st.title("🌾 AgriPredict Dashboard")
    if st.button("Get Live Data Feed"):
        fetch_and_store_data()

    view_mode = st.radio("View Mode", ["Statistics", "Plots", "Predictions", "Exim"], horizontal=True, label_visibility="collapsed")

    if view_mode == "Plots":
        st.sidebar.header("Filters")
        selected_period = st.sidebar.selectbox("Select Time Period", ["2 Weeks", "1 Month", "3 Months", "1 Year", "5 Years"], index=1)
        period_mapping = {"2 Weeks": 14, "1 Month": 30, "3 Months": 90, "1 Year": 365, "2 Years": 730, "5 Years": 1825}
        st.session_state.selected_period = period_mapping[selected_period]

        state_options = list(state_market_dict.keys()) + ['India']
        selected_state = st.sidebar.selectbox("Select", state_options)

        market_wise = False
        if selected_state != 'India':
            market_wise = st.sidebar.checkbox("Market Wise Analysis")
            if market_wise:
                markets = state_market_dict.get(selected_state, [])
                selected_market = st.sidebar.selectbox("Select Market", markets)
                query_filter = {"State Name": selected_state, "Market Name": selected_market}
            else:
                query_filter = {"State Name": selected_state}
        else:
            query_filter = {}

        data_type = st.sidebar.radio("Select Data Type", ["Price", "Volume", "Both"])

        if st.sidebar.button("✨ Let's go!"):
            try:
                # Load data
                state_param = selected_state if selected_state != 'India' else None
                market_param = selected_market if market_wise else None
                
                df = get_filtered_data(collection, state_param, market_param, st.session_state.selected_period)
                
                if not df.empty:
                    # Group by date and aggregate
                    df_grouped = df.groupby('Reported Date', as_index=False).agg({
                        'Arrivals (Tonnes)': 'sum', 
                        'Modal Price (Rs./Quintal)': 'mean'
                    })
                    
                    # Create complete date range and fill gaps
                    date_range = pd.date_range(
                        start=df_grouped['Reported Date'].min(),
                        end=df_grouped['Reported Date'].max(),
                        freq='D'
                    )
                    df_grouped = df_grouped.set_index('Reported Date').reindex(date_range).rename_axis('Reported Date').reset_index()
                    
                    # Fill missing values using the working method
                    df_grouped['Arrivals (Tonnes)'] = df_grouped['Arrivals (Tonnes)'].ffill().bfill()
                    df_grouped['Modal Price (Rs./Quintal)'] = df_grouped['Modal Price (Rs./Quintal)'].ffill().bfill()
                    
                    st.subheader(f"πŸ“ˆ Trends for {selected_state} ({'Market: ' + selected_market if market_wise else 'State'})")
                    
                    if data_type == "Both":
                        # Min-Max Scaling
                        scaler = MinMaxScaler()
                        df_grouped[['Scaled Price', 'Scaled Arrivals']] = scaler.fit_transform(
                            df_grouped[['Modal Price (Rs./Quintal)', 'Arrivals (Tonnes)']]
                        )

                        import plotly.graph_objects as go
                        fig = go.Figure()

                        fig.add_trace(go.Scatter(
                            x=df_grouped['Reported Date'],
                            y=df_grouped['Scaled Price'],
                            mode='lines',
                            name='Scaled Price',
                            line=dict(width=1, color='green'),
                            text=df_grouped['Modal Price (Rs./Quintal)'],
                            hovertemplate='Date: %{x}<br>Scaled Price: %{y:.2f}<br>Actual Price: %{text:.2f}<extra></extra>'
                        ))

                        fig.add_trace(go.Scatter(
                            x=df_grouped['Reported Date'],
                            y=df_grouped['Scaled Arrivals'],
                            mode='lines',
                            name='Scaled Arrivals',
                            line=dict(width=1, color='blue'),
                            text=df_grouped['Arrivals (Tonnes)'],
                            hovertemplate='Date: %{x}<br>Scaled Arrivals: %{y:.2f}<br>Actual Arrivals: %{text:.2f}<extra></extra>'
                        ))

                        fig.update_layout(
                            title="Price and Arrivals Trend",
                            xaxis_title='Date',
                            yaxis_title='Scaled Values',
                            template='plotly_white'
                        )
                        st.plotly_chart(fig, use_container_width=True)
                    elif data_type == "Price":
                        # Plot Modal Price
                        import plotly.graph_objects as go
                        fig = go.Figure()
                        fig.add_trace(go.Scatter(
                            x=df_grouped['Reported Date'],
                            y=df_grouped['Modal Price (Rs./Quintal)'],
                            mode='lines',
                            name='Modal Price',
                            line=dict(width=1, color='green')
                        ))
                        fig.update_layout(title="Modal Price Trend", xaxis_title='Date', yaxis_title='Price (/Quintall)', template='plotly_white')
                        st.plotly_chart(fig, use_container_width=True)
                    elif data_type == "Volume":
                        # Plot Arrivals (Tonnes)
                        import plotly.graph_objects as go
                        fig = go.Figure()
                        fig.add_trace(go.Scatter(
                            x=df_grouped['Reported Date'],
                            y=df_grouped['Arrivals (Tonnes)'],
                            mode='lines',
                            name='Arrivals',
                            line=dict(width=1, color='blue')
                        ))
                        fig.update_layout(title="Arrivals Trend", xaxis_title='Date', yaxis_title='Volume (in Tonnes)', template='plotly_white')
                        st.plotly_chart(fig, use_container_width=True)
                else:
                    st.warning("⚠️ No data found for the selected filters.")
            except Exception as e:
                st.error(f"❌ Error fetching data: {e}")

    elif view_mode == "Predictions":
        st.subheader("πŸ“Š Model Analysis")
        sub_option = st.radio("Select one of the following", ["India", "States", "Market"], horizontal=True)
        sub_timeline = st.radio("Select one of the following horizons", ["14 days", "1 month", "3 month"], horizontal=True)
        if sub_option == "States":
            states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"]
            selected_state = st.selectbox("Select State for Model Training", states)
            filter_key = f"state_{selected_state}"
            if not IS_PROD and st.button("Train and Forecast"):
                query_filter = {"State Name": selected_state}
                df = fetch_and_process_data(query_filter)
                if df is not None:
                    if sub_timeline == "14 days":
                        train_and_forecast(df, filter_key, 14)
                    elif sub_timeline == "1 month":
                        train_and_forecast(df, filter_key, 30)
                    else:
                        train_and_forecast(df, filter_key, 90)
                else:
                    st.error("❌ No data available for the selected state.")
            if st.button("Forecast"):
                query_filter = {"State Name": selected_state}
                df = fetch_and_process_data(query_filter)
                if df is not None:
                    if sub_timeline == "14 days":
                        forecast(df, filter_key, 14)
                    elif sub_timeline == "1 month":
                        forecast(df, filter_key, 30)
                    else:
                        forecast(df, filter_key, 90)
                else:
                    st.error("❌ No data available for the selected state.")
        elif sub_option == "Market":
            market_options = ["Rajkot", "Gondal", "Kalburgi", "Amreli"]
            selected_market = st.selectbox("Select Market for Model Training", market_options)
            filter_key = f"market_{selected_market}"
            if not IS_PROD and st.button("Train and Forecast"):
                query_filter = {"Market Name": selected_market}
                df = fetch_and_process_data(query_filter)
                if df is not None:
                    if sub_timeline == "14 days":
                        train_and_forecast(df, filter_key, 14)
                    elif sub_timeline == "1 month":
                        train_and_forecast(df, filter_key, 30)
                    else:
                        train_and_forecast(df, filter_key, 90)
                else:
                    st.error("❌ No data available for the selected market.")
            elif st.button("Forecast"):
                query_filter = {"Market Name": selected_market}
                df = fetch_and_process_data(query_filter)
                if df is not None:
                    if sub_timeline == "14 days":
                        forecast(df, filter_key, 14)
                    elif sub_timeline == "1 month":
                        forecast(df, filter_key, 30)
                    else:
                        forecast(df, filter_key, 90)
                else:
                    st.error("❌ No data available for the selected market.")
        elif sub_option == "India":
            df = collection_to_dataframe(impExp)
            if not IS_PROD and st.button("Train and Forecast"):
                query_filter = {}
                df = fetch_and_process_data(query_filter)
                if df is not None:
                    if sub_timeline == "14 days":
                        train_and_forecast(df, "India", 14)
                    elif sub_timeline == "1 month":
                        train_and_forecast(df, "India", 30)
                    else:
                        train_and_forecast(df, "India", 90)
                else:
                    st.error("❌ No data available for forecasting.")
            if st.button("Forecast"):
                query_filter = {}
                df = fetch_and_process_data(query_filter)
                if df is not None:
                    if sub_timeline == "14 days":
                        forecast(df, "India", 14)
                    elif sub_timeline == "1 month":
                        forecast(df, "India", 30)
                    else:
                        forecast(df, "India", 90)
                else:
                    st.error("❌ No data available for forecasting.")

    elif view_mode == "Statistics":
        # Use cached data loading
        df = load_all_data(collection)
        if not df.empty:
            from src.agri_predict.plotting import display_statistics
            display_statistics(df)
        else:
            st.warning("No data available to display statistics.")
    elif view_mode == "Exim":
        df = collection_to_dataframe(impExp)
        plot_option = st.radio("Select the data to visualize:", ["Import Price", "Import Quantity", "Export Price", "Export Quantity"], horizontal=True)
        time_period = st.selectbox("Select time period:", ["1 Month", "6 Months", "1 Year", "2 Years"])
        df["Reported Date"] = pd.to_datetime(df["Reported Date"], format="%Y-%m-%d")
        if time_period == "1 Month":
            start_date = pd.Timestamp.now() - pd.DateOffset(months=1)
        elif time_period == "6 Months":
            start_date = pd.Timestamp.now() - pd.DateOffset(months=6)
        elif time_period == "1 Year":
            start_date = pd.Timestamp.now() - pd.DateOffset(years=1)
        else:
            start_date = pd.Timestamp.now() - pd.DateOffset(years=2)
        filtered_df = df[df["Reported Date"] >= start_date]
        if plot_option == "Import Price":
            grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_IMPORT"].mean().rename(columns={"VALUE_IMPORT": "Average Import Price"})
            y_axis_label = "Average Import Price (Rs.)"
        elif plot_option == "Import Quantity":
            grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Import Quantity"})
            y_axis_label = "Total Import Quantity (Tonnes)"
        elif plot_option == "Export Price":
            grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_EXPORT"].mean().rename(columns={"VALUE_EXPORT": "Average Export Price"})
            y_axis_label = "Average Export Price (Rs.)"
        else:
            grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Export Quantity"})
            y_axis_label = "Total Export Quantity (Tonnes)"
        import plotly.express as px
        fig = px.line(grouped_df, x="Reported Date", y=grouped_df.columns[1], title=f"{plot_option} Over Time", labels={"Reported Date": "Date", grouped_df.columns[1]: y_axis_label})
        st.plotly_chart(fig)

else:
    with st.form("login_form"):
        st.subheader("Please log in")
        username = st.text_input("Username")
        password = st.text_input("Password", type="password")
        login_button = st.form_submit_button("Login")
        if login_button:
            # Get collections for authentication
            try:
                cols = get_collections()
                users_collection = cols['users_collection']
            except Exception as exc:
                st.error(f"Database connection error: {exc}")
                st.stop()
            
            if authenticate_user(username, password, users_collection):
                st.session_state.authenticated = True
                st.session_state['username'] = username
                st.write("Login successful!")
                st.rerun()
            else:
                st.error("Invalid username or password")