ThejasRao commited on
Commit
ecb9d4e
·
1 Parent(s): ad7e56c

Fix: Readme

Browse files
.env DELETED
@@ -1,2 +0,0 @@
1
- MONGO_URI=mongodb+srv://Agripredict:TjXSvMhOis49qH8E@cluster0.gek7n.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0
2
-
 
 
 
src/agri_predict/__pycache__/config.cpython-312.pyc CHANGED
Binary files a/src/agri_predict/__pycache__/config.cpython-312.pyc and b/src/agri_predict/__pycache__/config.cpython-312.pyc differ
 
src/agri_predict/__pycache__/data.cpython-312.pyc CHANGED
Binary files a/src/agri_predict/__pycache__/data.cpython-312.pyc and b/src/agri_predict/__pycache__/data.cpython-312.pyc differ
 
src/agri_predict/__pycache__/models.cpython-312.pyc CHANGED
Binary files a/src/agri_predict/__pycache__/models.cpython-312.pyc and b/src/agri_predict/__pycache__/models.cpython-312.pyc differ
 
src/agri_predict/__pycache__/plotting.cpython-312.pyc CHANGED
Binary files a/src/agri_predict/__pycache__/plotting.cpython-312.pyc and b/src/agri_predict/__pycache__/plotting.cpython-312.pyc differ
 
src/agri_predict/config.py CHANGED
@@ -19,6 +19,15 @@ def get_collections():
19
 
20
  client = MongoClient(mongo_uri, tlsCAFile=certifi.where())
21
  db = client["AgriPredict"]
 
 
 
 
 
 
 
 
 
22
  return {
23
  "collection": db["WhiteSesame"],
24
  "best_params_collection": db["BestParams"],
 
19
 
20
  client = MongoClient(mongo_uri, tlsCAFile=certifi.where())
21
  db = client["AgriPredict"]
22
+
23
+ # Create indexes for better query performance (safe to call multiple times)
24
+ try:
25
+ db["WhiteSesame"].create_index([("Reported Date", -1)])
26
+ db["WhiteSesame"].create_index([("State Name", 1), ("Reported Date", -1)])
27
+ db["WhiteSesame"].create_index([("Market Name", 1), ("Reported Date", -1)])
28
+ except Exception:
29
+ pass # Indexes might already exist
30
+
31
  return {
32
  "collection": db["WhiteSesame"],
33
  "best_params_collection": db["BestParams"],
src/agri_predict/data.py CHANGED
@@ -7,7 +7,17 @@ from .scraper import api_client
7
 
8
  def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
9
  df = df[['Reported Date', 'Modal Price (Rs./Quintal)']].copy()
10
- df['Reported Date'] = pd.to_datetime(df['Reported Date'])
 
 
 
 
 
 
 
 
 
 
11
  df = df.groupby('Reported Date', as_index=False).mean()
12
  full_date_range = pd.date_range(df['Reported Date'].min(), df['Reported Date'].max())
13
  df = df.set_index('Reported Date').reindex(full_date_range).rename_axis('Reported Date').reset_index()
@@ -19,14 +29,32 @@ def fetch_and_process_data(query_filter: dict):
19
  cols = get_collections()
20
  collection = cols['collection']
21
  try:
22
- cursor = collection.find(query_filter)
 
23
  data = list(cursor)
24
- if data:
25
- df = pd.DataFrame(data)
26
- df = preprocess_data(df)
27
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return None
29
- except Exception:
30
  return None
31
 
32
 
 
7
 
8
  def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
9
  df = df[['Reported Date', 'Modal Price (Rs./Quintal)']].copy()
10
+ # Ensure datetime and numeric types (may already be converted)
11
+ if not pd.api.types.is_datetime64_any_dtype(df['Reported Date']):
12
+ df['Reported Date'] = pd.to_datetime(df['Reported Date'])
13
+ df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
14
+
15
+ # Drop any rows with NaT dates or NaN prices
16
+ df = df.dropna(subset=['Reported Date', 'Modal Price (Rs./Quintal)'])
17
+
18
+ if df.empty:
19
+ return df
20
+
21
  df = df.groupby('Reported Date', as_index=False).mean()
22
  full_date_range = pd.date_range(df['Reported Date'].min(), df['Reported Date'].max())
23
  df = df.set_index('Reported Date').reindex(full_date_range).rename_axis('Reported Date').reset_index()
 
29
  cols = get_collections()
30
  collection = cols['collection']
31
  try:
32
+ # Fetch all fields - MongoDB handles this efficiently
33
+ cursor = collection.find(query_filter).sort('Reported Date', 1)
34
  data = list(cursor)
35
+ if not data:
36
+ return None
37
+
38
+ df = pd.DataFrame(data)
39
+
40
+ # Check if required columns exist
41
+ if 'Reported Date' not in df.columns or 'Modal Price (Rs./Quintal)' not in df.columns:
42
+ import streamlit as st
43
+ st.error(f"Missing required columns. Available: {df.columns.tolist()}")
44
+ return None
45
+
46
+ # Ensure proper data types before preprocessing
47
+ df['Reported Date'] = pd.to_datetime(df['Reported Date'])
48
+ df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
49
+ df = preprocess_data(df)
50
+
51
+ if df is None or df.empty:
52
+ return None
53
+ return df
54
+ except Exception as e:
55
+ import streamlit as st
56
+ st.error(f"Error fetching data: {str(e)}")
57
  return None
 
58
  return None
59
 
60
 
src/agri_predict/models.py CHANGED
@@ -97,7 +97,7 @@ def _train_and_evaluate_generic(df, feature_fn, split_date, progress_bar):
97
  line=dict(color=color, dash=dash)
98
  ))
99
  fig.update_layout(title="Train, Test, and Predicted Data", xaxis_title="Date", yaxis_title="Modal Price (Rs./Quintal)", template="plotly_white")
100
- st.plotly_chart(fig, width='stretch')
101
 
102
  return best_params
103
 
 
97
  line=dict(color=color, dash=dash)
98
  ))
99
  fig.update_layout(title="Train, Test, and Predicted Data", xaxis_title="Date", yaxis_title="Modal Price (Rs./Quintal)", template="plotly_white")
100
+ st.plotly_chart(fig, use_container_width=True)
101
 
102
  return best_params
103
 
src/agri_predict/plotting.py CHANGED
@@ -25,7 +25,7 @@ def plot_data(original_df, future_df, last_date, model, days):
25
  data = plot_df[plot_df['Type'] == plot_type]
26
  fig.add_trace(go.Scatter(x=data['Reported Date'], y=data['Modal Price (Rs./Quintal)'], mode='lines', name=f"{plot_type} Data", line=dict(color=color, dash=dash)))
27
  fig.update_layout(title="Actual vs Forecasted Modal Price (Rs./Quintal)", xaxis_title="Date", yaxis_title="Modal Price (Rs./Quintal)", template="plotly_white")
28
- st.plotly_chart(fig, width='stretch')
29
 
30
 
31
  def download_button(future_df: pd.DataFrame, key: str):
@@ -87,7 +87,7 @@ def display_statistics(df):
87
  st.subheader("📆 This Day in Previous Years")
88
  st.markdown("<p class='highlight'>This table shows the modal price and total arrivals for this exact day across previous years. It provides a historical perspective to compare against current market conditions. This section examines historical data for the same day in previous years. By analyzing trends for this specific day, you can identify seasonal patterns, supply-demand changes, or any deviations that might warrant closer attention.</p>", unsafe_allow_html=True)
89
  today = latest_date
90
- previous_years_data = national_data[national_data['Reported Date'].dt.dayofyear == today.dayofyear]
91
 
92
  if not previous_years_data.empty:
93
  previous_years_data['Year'] = previous_years_data['Reported Date'].dt.year.astype(str)
@@ -121,7 +121,7 @@ def display_statistics(df):
121
  st.subheader("📈 Largest Daily Price Changes (Past Year)")
122
  st.markdown("<p class='highlight'>This analysis identifies the most significant daily price changes in the past year. These fluctuations can highlight periods of market volatility, potentially caused by external factors like weather, policy changes, or supply chain disruptions.</p>", unsafe_allow_html=True)
123
  one_year_ago = latest_date - pd.DateOffset(years=1)
124
- recent_data = national_data[national_data['Reported Date'] >= one_year_ago]
125
  recent_data['Daily Change (%)'] = recent_data['Modal Price (Rs./Quintal)'].pct_change() * 100
126
  largest_changes = recent_data[['Reported Date', 'Modal Price (Rs./Quintal)', 'Daily Change (%)']].nlargest(5, 'Daily Change (%)')
127
  largest_changes['Reported Date'] = largest_changes['Reported Date'].dt.date
@@ -147,7 +147,7 @@ def display_statistics(df):
147
  national_data['Lag (14 Days)'] = national_data['Modal Price (Rs./Quintal)'].shift(14)
148
  national_data['Reported Date'] = national_data['Reported Date'].dt.date
149
  national_data = national_data.sort_values(by='Reported Date', ascending=False)
150
- st.dataframe(national_data.head(14).reset_index(drop=True), width='stretch', height=525)
151
 
152
  editable_spreadsheet()
153
 
 
25
  data = plot_df[plot_df['Type'] == plot_type]
26
  fig.add_trace(go.Scatter(x=data['Reported Date'], y=data['Modal Price (Rs./Quintal)'], mode='lines', name=f"{plot_type} Data", line=dict(color=color, dash=dash)))
27
  fig.update_layout(title="Actual vs Forecasted Modal Price (Rs./Quintal)", xaxis_title="Date", yaxis_title="Modal Price (Rs./Quintal)", template="plotly_white")
28
+ st.plotly_chart(fig, use_container_width=True)
29
 
30
 
31
  def download_button(future_df: pd.DataFrame, key: str):
 
87
  st.subheader("📆 This Day in Previous Years")
88
  st.markdown("<p class='highlight'>This table shows the modal price and total arrivals for this exact day across previous years. It provides a historical perspective to compare against current market conditions. This section examines historical data for the same day in previous years. By analyzing trends for this specific day, you can identify seasonal patterns, supply-demand changes, or any deviations that might warrant closer attention.</p>", unsafe_allow_html=True)
89
  today = latest_date
90
+ previous_years_data = national_data[national_data['Reported Date'].dt.dayofyear == today.dayofyear].copy()
91
 
92
  if not previous_years_data.empty:
93
  previous_years_data['Year'] = previous_years_data['Reported Date'].dt.year.astype(str)
 
121
  st.subheader("📈 Largest Daily Price Changes (Past Year)")
122
  st.markdown("<p class='highlight'>This analysis identifies the most significant daily price changes in the past year. These fluctuations can highlight periods of market volatility, potentially caused by external factors like weather, policy changes, or supply chain disruptions.</p>", unsafe_allow_html=True)
123
  one_year_ago = latest_date - pd.DateOffset(years=1)
124
+ recent_data = national_data[national_data['Reported Date'] >= one_year_ago].copy()
125
  recent_data['Daily Change (%)'] = recent_data['Modal Price (Rs./Quintal)'].pct_change() * 100
126
  largest_changes = recent_data[['Reported Date', 'Modal Price (Rs./Quintal)', 'Daily Change (%)']].nlargest(5, 'Daily Change (%)')
127
  largest_changes['Reported Date'] = largest_changes['Reported Date'].dt.date
 
147
  national_data['Lag (14 Days)'] = national_data['Modal Price (Rs./Quintal)'].shift(14)
148
  national_data['Reported Date'] = national_data['Reported Date'].dt.date
149
  national_data = national_data.sort_values(by='Reported Date', ascending=False)
150
+ st.dataframe(national_data.head(14).reset_index(drop=True), use_container_width=True, height=525)
151
 
152
  editable_spreadsheet()
153
 
streamlit_app.py CHANGED
@@ -36,6 +36,39 @@ def get_cached_collections():
36
  return get_collections()
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  st.markdown("""
40
  <style>
41
  .main {
@@ -95,15 +128,15 @@ if st.session_state.authenticated:
95
  query_filter = {}
96
 
97
  data_type = st.sidebar.radio("Select Data Type", ["Price", "Volume", "Both"])
98
- query_filter["Reported Date"] = {"$gte": datetime.now() - timedelta(days=st.session_state.selected_period)}
99
 
100
  if st.sidebar.button("✨ Let's go!"):
101
  try:
102
- cursor = collection.find(query_filter)
103
- data = list(cursor)
104
- if data:
105
- df = pd.DataFrame(data)
106
- df['Reported Date'] = pd.to_datetime(df['Reported Date'])
 
107
  df_grouped = df.groupby('Reported Date', as_index=False).agg({'Arrivals (Tonnes)': 'sum', 'Modal Price (Rs./Quintal)': 'mean'})
108
  date_range = pd.date_range(start=df_grouped['Reported Date'].min(), end=df_grouped['Reported Date'].max())
109
  df_grouped = df_grouped.set_index('Reported Date').reindex(date_range).rename_axis('Reported Date').reset_index()
@@ -118,19 +151,19 @@ if st.session_state.authenticated:
118
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Scaled Price'], mode='lines', name='Scaled Price', line=dict(width=1, color='green')))
119
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Scaled Arrivals'], mode='lines', name='Scaled Arrivals', line=dict(width=1, color='blue')))
120
  fig.update_layout(title="Price and Arrivals Trend", xaxis_title='Date', yaxis_title='Scaled Values', template='plotly_white')
121
- st.plotly_chart(fig, width='stretch')
122
  elif data_type == "Price":
123
  import plotly.graph_objects as go
124
  fig = go.Figure()
125
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Modal Price (Rs./Quintal)'], mode='lines', name='Modal Price', line=dict(width=1, color='green')))
126
  fig.update_layout(title="Modal Price Trend", xaxis_title='Date', yaxis_title='Price (/Quintall)', template='plotly_white')
127
- st.plotly_chart(fig, width='stretch')
128
  else:
129
  import plotly.graph_objects as go
130
  fig = go.Figure()
131
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Arrivals (Tonnes)'], mode='lines', name='Arrivals', line=dict(width=1, color='blue')))
132
  fig.update_layout(title="Arrivals Trend", xaxis_title='Date', yaxis_title='Volume (in Tonnes)', template='plotly_white')
133
- st.plotly_chart(fig, width='stretch')
134
  else:
135
  st.warning("⚠️ No data found for the selected filters.")
136
  except Exception as e:
@@ -147,21 +180,27 @@ if st.session_state.authenticated:
147
  if not IS_PROD and st.button("Train and Forecast"):
148
  query_filter = {"State Name": selected_state}
149
  df = fetch_and_process_data(query_filter)
150
- if sub_timeline == "14 days":
151
- train_and_forecast(df, filter_key, 14)
152
- elif sub_timeline == "1 month":
153
- train_and_forecast(df, filter_key, 30)
 
 
 
154
  else:
155
- train_and_forecast(df, filter_key, 90)
156
  if st.button("Forecast"):
157
  query_filter = {"State Name": selected_state}
158
  df = fetch_and_process_data(query_filter)
159
- if sub_timeline == "14 days":
160
- forecast(df, filter_key, 14)
161
- elif sub_timeline == "1 month":
162
- forecast(df, filter_key, 30)
 
 
 
163
  else:
164
- forecast(df, filter_key, 90)
165
  elif sub_option == "Market":
166
  market_options = ["Rajkot", "Gondal", "Kalburgi", "Amreli"]
167
  selected_market = st.selectbox("Select Market for Model Training", market_options)
@@ -169,47 +208,62 @@ if st.session_state.authenticated:
169
  if not IS_PROD and st.button("Train and Forecast"):
170
  query_filter = {"Market Name": selected_market}
171
  df = fetch_and_process_data(query_filter)
172
- if sub_timeline == "14 days":
173
- train_and_forecast(df, filter_key, 14)
174
- elif sub_timeline == "1 month":
175
- train_and_forecast(df, filter_key, 30)
 
 
 
176
  else:
177
- train_and_forecast(df, filter_key, 90)
178
  elif st.button("Forecast"):
179
  query_filter = {"Market Name": selected_market}
180
  df = fetch_and_process_data(query_filter)
181
- if sub_timeline == "14 days":
182
- forecast(df, filter_key, 14)
183
- elif sub_timeline == "1 month":
184
- forecast(df, filter_key, 30)
 
 
 
185
  else:
186
- forecast(df, filter_key, 90)
187
  elif sub_option == "India":
188
  df = collection_to_dataframe(impExp)
189
  if not IS_PROD and st.button("Train and Forecast"):
190
  query_filter = {}
191
  df = fetch_and_process_data(query_filter)
192
- if sub_timeline == "14 days":
193
- train_and_forecast(df, "India", 14)
194
- elif sub_timeline == "1 month":
195
- train_and_forecast(df, "India", 30)
 
 
 
196
  else:
197
- train_and_forecast(df, "India", 90)
198
  if st.button("Forecast"):
199
  query_filter = {}
200
  df = fetch_and_process_data(query_filter)
201
- if sub_timeline == "14 days":
202
- forecast(df, "India", 14)
203
- elif sub_timeline == "1 month":
204
- forecast(df, "India", 30)
 
 
 
205
  else:
206
- forecast(df, "India", 90)
207
 
208
  elif view_mode == "Statistics":
209
- document = collection.find_one()
210
- df = get_dataframe_from_collection(collection)
211
- from src.agri_predict.plotting import display_statistics
212
- display_statistics(df)
 
 
 
213
  elif view_mode == "Exim":
214
  df = collection_to_dataframe(impExp)
215
  plot_option = st.radio("Select the data to visualize:", ["Import Price", "Import Quantity", "Export Price", "Export Quantity"], horizontal=True)
 
36
  return get_collections()
37
 
38
 
39
+ @st.cache_data(ttl=300) # Cache for 5 minutes
40
+ def load_all_data(_collection):
41
+ """Load and cache all data from MongoDB collection."""
42
+ data = list(_collection.find({}, {'_id': 0, 'Reported Date': 1, 'Modal Price (Rs./Quintal)': 1, 'Arrivals (Tonnes)': 1, 'State Name': 1, 'Market Name': 1}))
43
+ if not data:
44
+ return pd.DataFrame()
45
+ df = pd.DataFrame(data)
46
+ df['Reported Date'] = pd.to_datetime(df['Reported Date'])
47
+ df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
48
+ df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce')
49
+ return df
50
+
51
+
52
+ @st.cache_data(ttl=300)
53
+ def get_filtered_data(_collection, state=None, market=None, days=30):
54
+ """Get filtered and cached data based on parameters."""
55
+ query_filter = {"Reported Date": {"$gte": datetime.now() - timedelta(days=days)}}
56
+ if state and state != 'India':
57
+ query_filter["State Name"] = state
58
+ if market:
59
+ query_filter["Market Name"] = market
60
+
61
+ data = list(_collection.find(query_filter, {'_id': 0}))
62
+ if not data:
63
+ return pd.DataFrame()
64
+
65
+ df = pd.DataFrame(data)
66
+ df['Reported Date'] = pd.to_datetime(df['Reported Date'])
67
+ df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
68
+ df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce')
69
+ return df
70
+
71
+
72
  st.markdown("""
73
  <style>
74
  .main {
 
128
  query_filter = {}
129
 
130
  data_type = st.sidebar.radio("Select Data Type", ["Price", "Volume", "Both"])
 
131
 
132
  if st.sidebar.button("✨ Let's go!"):
133
  try:
134
+ # Use cached data loading
135
+ state_param = selected_state if selected_state != 'India' else None
136
+ market_param = selected_market if market_wise else None
137
+ df = get_filtered_data(collection, state_param, market_param, st.session_state.selected_period)
138
+
139
+ if not df.empty:
140
  df_grouped = df.groupby('Reported Date', as_index=False).agg({'Arrivals (Tonnes)': 'sum', 'Modal Price (Rs./Quintal)': 'mean'})
141
  date_range = pd.date_range(start=df_grouped['Reported Date'].min(), end=df_grouped['Reported Date'].max())
142
  df_grouped = df_grouped.set_index('Reported Date').reindex(date_range).rename_axis('Reported Date').reset_index()
 
151
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Scaled Price'], mode='lines', name='Scaled Price', line=dict(width=1, color='green')))
152
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Scaled Arrivals'], mode='lines', name='Scaled Arrivals', line=dict(width=1, color='blue')))
153
  fig.update_layout(title="Price and Arrivals Trend", xaxis_title='Date', yaxis_title='Scaled Values', template='plotly_white')
154
+ st.plotly_chart(fig, use_container_width=True)
155
  elif data_type == "Price":
156
  import plotly.graph_objects as go
157
  fig = go.Figure()
158
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Modal Price (Rs./Quintal)'], mode='lines', name='Modal Price', line=dict(width=1, color='green')))
159
  fig.update_layout(title="Modal Price Trend", xaxis_title='Date', yaxis_title='Price (/Quintall)', template='plotly_white')
160
+ st.plotly_chart(fig, use_container_width=True)
161
  else:
162
  import plotly.graph_objects as go
163
  fig = go.Figure()
164
  fig.add_trace(go.Scatter(x=df_grouped['Reported Date'], y=df_grouped['Arrivals (Tonnes)'], mode='lines', name='Arrivals', line=dict(width=1, color='blue')))
165
  fig.update_layout(title="Arrivals Trend", xaxis_title='Date', yaxis_title='Volume (in Tonnes)', template='plotly_white')
166
+ st.plotly_chart(fig, use_container_width=True)
167
  else:
168
  st.warning("⚠️ No data found for the selected filters.")
169
  except Exception as e:
 
180
  if not IS_PROD and st.button("Train and Forecast"):
181
  query_filter = {"State Name": selected_state}
182
  df = fetch_and_process_data(query_filter)
183
+ if df is not None:
184
+ if sub_timeline == "14 days":
185
+ train_and_forecast(df, filter_key, 14)
186
+ elif sub_timeline == "1 month":
187
+ train_and_forecast(df, filter_key, 30)
188
+ else:
189
+ train_and_forecast(df, filter_key, 90)
190
  else:
191
+ st.error("❌ No data available for the selected state.")
192
  if st.button("Forecast"):
193
  query_filter = {"State Name": selected_state}
194
  df = fetch_and_process_data(query_filter)
195
+ if df is not None:
196
+ if sub_timeline == "14 days":
197
+ forecast(df, filter_key, 14)
198
+ elif sub_timeline == "1 month":
199
+ forecast(df, filter_key, 30)
200
+ else:
201
+ forecast(df, filter_key, 90)
202
  else:
203
+ st.error("❌ No data available for the selected state.")
204
  elif sub_option == "Market":
205
  market_options = ["Rajkot", "Gondal", "Kalburgi", "Amreli"]
206
  selected_market = st.selectbox("Select Market for Model Training", market_options)
 
208
  if not IS_PROD and st.button("Train and Forecast"):
209
  query_filter = {"Market Name": selected_market}
210
  df = fetch_and_process_data(query_filter)
211
+ if df is not None:
212
+ if sub_timeline == "14 days":
213
+ train_and_forecast(df, filter_key, 14)
214
+ elif sub_timeline == "1 month":
215
+ train_and_forecast(df, filter_key, 30)
216
+ else:
217
+ train_and_forecast(df, filter_key, 90)
218
  else:
219
+ st.error("❌ No data available for the selected market.")
220
  elif st.button("Forecast"):
221
  query_filter = {"Market Name": selected_market}
222
  df = fetch_and_process_data(query_filter)
223
+ if df is not None:
224
+ if sub_timeline == "14 days":
225
+ forecast(df, filter_key, 14)
226
+ elif sub_timeline == "1 month":
227
+ forecast(df, filter_key, 30)
228
+ else:
229
+ forecast(df, filter_key, 90)
230
  else:
231
+ st.error("❌ No data available for the selected market.")
232
  elif sub_option == "India":
233
  df = collection_to_dataframe(impExp)
234
  if not IS_PROD and st.button("Train and Forecast"):
235
  query_filter = {}
236
  df = fetch_and_process_data(query_filter)
237
+ if df is not None:
238
+ if sub_timeline == "14 days":
239
+ train_and_forecast(df, "India", 14)
240
+ elif sub_timeline == "1 month":
241
+ train_and_forecast(df, "India", 30)
242
+ else:
243
+ train_and_forecast(df, "India", 90)
244
  else:
245
+ st.error("❌ No data available for forecasting.")
246
  if st.button("Forecast"):
247
  query_filter = {}
248
  df = fetch_and_process_data(query_filter)
249
+ if df is not None:
250
+ if sub_timeline == "14 days":
251
+ forecast(df, "India", 14)
252
+ elif sub_timeline == "1 month":
253
+ forecast(df, "India", 30)
254
+ else:
255
+ forecast(df, "India", 90)
256
  else:
257
+ st.error("❌ No data available for forecasting.")
258
 
259
  elif view_mode == "Statistics":
260
+ # Use cached data loading
261
+ df = load_all_data(collection)
262
+ if not df.empty:
263
+ from src.agri_predict.plotting import display_statistics
264
+ display_statistics(df)
265
+ else:
266
+ st.warning("No data available to display statistics.")
267
  elif view_mode == "Exim":
268
  df = collection_to_dataframe(impExp)
269
  plot_option = st.radio("Select the data to visualize:", ["Import Price", "Import Quantity", "Export Price", "Export Quantity"], horizontal=True)