Navya-Sree commited on
Commit
e706da8
·
verified ·
1 Parent(s): dc5f24a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from datetime import datetime, timedelta
9
+ import yaml
10
+ import os
11
+ import sys
12
+
13
+ # Add src to path
14
+ sys.path.append('src')
15
+
16
+ from src.data_processing.processor import AdvancedDataProcessor
17
+ from src.modeling.advanced_models import AdvancedModelTrainer
18
+ from src.agents.genai_integration import ForecastingAIAssistant
19
+
20
+ # Page configuration
21
+ st.set_page_config(
22
+ page_title="Advanced Forecasting",
23
+ page_icon="📈",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded"
26
+ )
27
+
28
+ # Custom CSS
29
+ st.markdown("""
30
+ <style>
31
+ .main-header {font-size: 3rem; color: #1f77b4;}
32
+ .section-header {font-size: 2rem; color: #ff7f0e; margin-top: 2rem;}
33
+ .highlight {background-color: #f7f7f7; padding: 15px; border-radius: 5px; margin: 10px 0;}
34
+ </style>
35
+ """, unsafe_allow_html=True)
36
+
37
+ # Load configuration
38
+ @st.cache_resource
39
+ def load_config():
40
+ with open('config/config.yaml', 'r') as f:
41
+ return yaml.safe_load(f)
42
+
43
+ config = load_config()
44
+
45
+ # Initialize components
46
+ @st.cache_resource
47
+ def init_components():
48
+ processor = AdvancedDataProcessor(config['data_processing'])
49
+ trainer = AdvancedModelTrainer(config['modeling'])
50
+
51
+ # Check for OpenAI API key
52
+ openai_key = os.getenv('OPENAI_API_KEY')
53
+ ai_assistant = ForecastingAIAssistant(openai_key) if openai_key else None
54
+
55
+ return processor, trainer, ai_assistant
56
+
57
+ processor, trainer, ai_assistant = init_components()
58
+
59
+ # App title
60
+ st.markdown('<h1 class="main-header">Advanced Time Series Forecasting</h1>', unsafe_allow_html=True)
61
+ st.write("""
62
+ A comprehensive forecasting system with advanced features including deep learning models,
63
+ automated feature engineering, and AI-powered insights.
64
+ """)
65
+
66
+ # Sidebar
67
+ st.sidebar.title("Configuration")
68
+ st.sidebar.header("Data Input")
69
+
70
+ # Data input options
71
+ data_option = st.sidebar.radio(
72
+ "Choose data source:",
73
+ ["Use example data", "Upload your own data"]
74
+ )
75
+
76
+ df = None
77
+ if data_option == "Use example data":
78
+ st.sidebar.info("Using example sales data")
79
+ df = pd.read_csv('assets/example_data.csv')
80
+ df['date'] = pd.to_datetime(df['date'])
81
+ else:
82
+ uploaded_file = st.sidebar.file_uploader(
83
+ "Upload your time series data (CSV)",
84
+ type=['csv']
85
+ )
86
+ if uploaded_file is not None:
87
+ df = pd.read_csv(uploaded_file)
88
+ date_col = st.sidebar.selectbox("Select date column", df.columns)
89
+ value_col = st.sidebar.selectbox("Select value column", df.columns)
90
+ df[date_col] = pd.to_datetime(df[date_col])
91
+ df = df.rename(columns={date_col: 'date', value_col: 'value'})
92
+
93
+ # Main content
94
+ if df is not None:
95
+ # Display data info
96
+ st.markdown('<h2 class="section-header">Data Overview</h2>', unsafe_allow_html=True)
97
+
98
+ col1, col2, col3, col4 = st.columns(4)
99
+ col1.metric("Total Records", len(df))
100
+ col2.metric("Date Range", f"{df['date'].min().date()} to {df['date'].max().date()}")
101
+ col3.metric("Average Value", f"{df['value'].mean():.2f}")
102
+ col4.metric("Data Frequency", "Daily")
103
+
104
+ # Data preview
105
+ st.dataframe(df.head(10))
106
+
107
+ # Plot raw data
108
+ st.markdown('<h2 class="section-header">Data Visualization</h2>', unsafe_allow_html=True)
109
+
110
+ fig = go.Figure()
111
+ fig.add_trace(go.Scatter(x=df['date'], y=df['value'], mode='lines', name='Value'))
112
+ fig.update_layout(
113
+ title='Time Series Data',
114
+ xaxis_title='Date',
115
+ yaxis_title='Value',
116
+ height=500
117
+ )
118
+ st.plotly_chart(fig, use_container_width=True)
119
+
120
+ # Feature engineering
121
+ st.markdown('<h2 class="section-header">Feature Engineering</h2>', unsafe_allow_html=True)
122
+
123
+ if st.button("Generate Features"):
124
+ with st.spinner("Creating advanced features..."):
125
+ df_engineered = processor.engineer_features(df, 'date', 'value')
126
+
127
+ st.success(f"Created {len(processor.feature_columns)} features!")
128
+
129
+ # Show feature importance (simplified)
130
+ st.write("Top 10 features by correlation with target:")
131
+ correlations = df_engineered.corr()['value'].abs().sort_values(ascending=False)
132
+ top_features = correlations[1:11] # Exclude the target itself
133
+
134
+ fig, ax = plt.subplots(figsize=(10, 6))
135
+ top_features.plot(kind='bar', ax=ax)
136
+ ax.set_title('Top Feature Correlations with Target')
137
+ ax.set_ylabel('Absolute Correlation')
138
+ st.pyplot(fig)
139
+
140
+ # Prepare data for modeling
141
+ X, y = processor.create_sequences(
142
+ df_engineered, 'value', processor.feature_columns, 30, 7
143
+ )
144
+
145
+ st.session_state.X = X
146
+ st.session_state.y = y
147
+ st.session_state.df_engineered = df_engineered
148
+
149
+ # Model training
150
+ if 'X' in st.session_state:
151
+ st.markdown('<h2 class="section-header">Model Training</h2>', unsafe_allow_html=True)
152
+
153
+ model_option = st.selectbox(
154
+ "Select model type:",
155
+ ["LSTM", "Prophet", "ARIMA", "Ensemble"]
156
+ )
157
+
158
+ if st.button("Train Model"):
159
+ with st.spinner(f"Training {model_option} model..."):
160
+ if model_option == "LSTM":
161
+ model = trainer.train_lstm(
162
+ st.session_state.X[:-100],
163
+ st.session_state.y[:-100],
164
+ st.session_state.X[-100:],
165
+ st.session_state.y[-100:]
166
+ )
167
+ elif model_option == "Prophet":
168
+ model = trainer.train_prophet(df, 'date', 'value')
169
+ elif model_option == "ARIMA":
170
+ model = trainer.train_auto_arima(df['value'])
171
+ else:
172
+ st.warning("Ensemble model not implemented in this demo")
173
+ model = None
174
+
175
+ if model:
176
+ st.session_state.model = model
177
+ st.session_state.model_type = model_option.lower()
178
+ st.success(f"{model_option} model trained successfully!")
179
+
180
+ # Forecasting
181
+ if 'model' in st.session_state:
182
+ st.markdown('<h2 class="section-header">Forecasting</h2>', unsafe_allow_html=True)
183
+
184
+ forecast_days = st.slider("Forecast horizon (days)", 7, 90, 30)
185
+
186
+ if st.button("Generate Forecast"):
187
+ with st.spinner("Generating forecast..."):
188
+ # For demo purposes, we'll create a simple forecast
189
+ last_values = df['value'].values[-30:]
190
+ forecast = np.array([last_values.mean()] * forecast_days)
191
+
192
+ # Add some randomness to simulate a forecast
193
+ np.random.seed(42)
194
+ noise = np.random.normal(0, df['value'].std() * 0.1, forecast_days)
195
+ trend = np.linspace(0, forecast_days * 0.01, forecast_days)
196
+ forecast = forecast + noise + trend
197
+
198
+ # Create forecast dates
199
+ last_date = df['date'].max()
200
+ forecast_dates = [last_date + timedelta(days=i) for i in range(1, forecast_days+1)]
201
+
202
+ # Plot forecast
203
+ fig = go.Figure()
204
+ fig.add_trace(go.Scatter(
205
+ x=df['date'],
206
+ y=df['value'],
207
+ mode='lines',
208
+ name='Historical Data'
209
+ ))
210
+ fig.add_trace(go.Scatter(
211
+ x=forecast_dates,
212
+ y=forecast,
213
+ mode='lines',
214
+ name='Forecast',
215
+ line=dict(dash='dash')
216
+ ))
217
+
218
+ # Add confidence interval
219
+ upper_bound = forecast + df['value'].std() * 0.5
220
+ lower_bound = forecast - df['value'].std() * 0.5
221
+
222
+ fig.add_trace(go.Scatter(
223
+ x=forecast_dates + forecast_dates[::-1],
224
+ y=np.concatenate([upper_bound, lower_bound[::-1]]),
225
+ fill='toself',
226
+ fillcolor='rgba(0,100,80,0.2)',
227
+ line=dict(color='rgba(255,255,255,0)'),
228
+ name='Confidence Interval'
229
+ ))
230
+
231
+ fig.update_layout(
232
+ title=f'{forecast_days}-Day Forecast',
233
+ xaxis_title='Date',
234
+ yaxis_title='Value',
235
+ height=500
236
+ )
237
+
238
+ st.plotly_chart(fig, use_container_width=True)
239
+
240
+ # Display forecast values
241
+ forecast_df = pd.DataFrame({
242
+ 'Date': forecast_dates,
243
+ 'Forecast': forecast,
244
+ 'Lower Bound': lower_bound,
245
+ 'Upper Bound': upper_bound
246
+ })
247
+
248
+ st.dataframe(forecast_df)
249
+
250
+ # AI Insights
251
+ if ai_assistant and 'model' in st.session_state:
252
+ st.markdown('<h2 class="section-header">AI-Powered Insights</h2>', unsafe_allow_html=True)
253
+
254
+ if st.button("Generate AI Insights"):
255
+ with st.spinner("Generating AI insights..."):
256
+ # Prepare data for AI analysis
257
+ data_summary = {
258
+ 'period': f"{df['date'].min().date()} to {df['date'].max().date()}",
259
+ 'data_points': len(df),
260
+ 'mean': df['value'].mean(),
261
+ 'std': df['value'].std(),
262
+ 'trend': 'upward' if df['value'].iloc[-1] > df['value'].iloc[0] else 'downward'
263
+ }
264
+
265
+ # Generate interpretation
266
+ interpretation = ai_assistant.generate_forecast_interpretation(
267
+ data_summary,
268
+ {'model_type': st.session_state.model_type},
269
+ {'rmse': 0.05, 'mae': 0.03} # Placeholder metrics
270
+ )
271
+
272
+ st.markdown('<div class="highlight">', unsafe_allow_html=True)
273
+ st.write("### AI Interpretation")
274
+ st.write(interpretation)
275
+ st.markdown('</div>', unsafe_allow_html=True)
276
+
277
+ # Generate recommendations
278
+ recommendations = ai_assistant.generate_business_recommendations(
279
+ "Time series forecasting for business planning",
280
+ {'forecast_horizon': 30, 'confidence': 0.8},
281
+ df['value']
282
+ )
283
+
284
+ st.markdown('<div class="highlight">', unsafe_allow_html=True)
285
+ st.write("### AI Recommendations")
286
+ st.write(recommendations)
287
+ st.markdown('</div>', unsafe_allow_html=True)
288
+
289
+ else:
290
+ st.info("Please load data to get started. Use the sidebar to upload a file or use example data.")
291
+
292
+ # Footer
293
+ st.markdown("---")
294
+ st.markdown("""
295
+ <div style="text-align: center;">
296
+ <p>Advanced Time Series Forecasting System | Built with Streamlit</p>
297
+ </div>
298
+ """, unsafe_allow_html=True)