JayLacoma commited on
Commit
7d07255
·
verified ·
1 Parent(s): edeb44c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +595 -0
app.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yfinance as yf
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ import gradio as gr
6
+ import io
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ from datetime import datetime
10
+ import plotly.express as px
11
+ import warnings
12
+ import timesfm
13
+ from prophet import Prophet
14
+
15
+ class StockDataFetcher:
16
+ """Handles fetching and preprocessing stock data"""
17
+
18
+ @staticmethod
19
+ def fetch_stock_data(ticker, start_date, end_date):
20
+ """Fetch and preprocess stock data"""
21
+ stock_data = yf.download(ticker, start=start_date, end=end_date)
22
+
23
+ # Handle MultiIndex columns if present
24
+ if isinstance(stock_data.columns, pd.MultiIndex):
25
+ stock_data.columns = stock_data.columns.droplevel(level=1)
26
+
27
+ # Standardize column names
28
+ stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
29
+
30
+ return stock_data
31
+
32
+ # Function for TimesFM forecasting
33
+ def timesfm_forecast(ticker, start_date, end_date):
34
+ try:
35
+ # Fetch historical data using the StockDataFetcher class
36
+ stock_data = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date)
37
+
38
+ # Reset index to have 'Date' as a column
39
+ stock_data.reset_index(inplace=True)
40
+
41
+ # Select relevant columns and rename them
42
+ df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
43
+
44
+ # Ensure the dates are in datetime format
45
+ df['ds'] = pd.to_datetime(df['ds'])
46
+
47
+ # Add a unique identifier for the time series
48
+ df['unique_id'] = ticker
49
+
50
+ # Initialize the TimesFM model
51
+ tfm = timesfm.TimesFm(
52
+ hparams=timesfm.TimesFmHparams(
53
+ backend="pytorch",
54
+ per_core_batch_size=32,
55
+ horizon_len=30, # Predicting the next 30 days
56
+ input_patch_len=32,
57
+ output_patch_len=128,
58
+ num_layers=50,
59
+ model_dims=1280,
60
+ use_positional_embedding=False,
61
+ ),
62
+ checkpoint=timesfm.TimesFmCheckpoint(
63
+ huggingface_repo_id="google/timesfm-2.0-500m-pytorch"
64
+ ),
65
+ )
66
+
67
+ # Forecast using the prepared DataFrame
68
+ forecast_df = tfm.forecast_on_df(
69
+ inputs=df,
70
+ freq="D", # Daily frequency
71
+ value_name="y",
72
+ num_jobs=-1,
73
+ )
74
+
75
+ # Ensure forecast_df has the correct columns
76
+ forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True)
77
+
78
+ # Create an interactive plot with Plotly
79
+ fig = go.Figure()
80
+
81
+ # Add Actual Prices Line
82
+ fig.add_trace(go.Scatter(x=df["ds"], y=df["y"],
83
+ mode="lines", name="Actual Prices",
84
+ line=dict(color="#00FFFF", width=2))) # Brighter cyan
85
+
86
+ # Add Forecasted Prices Line
87
+ fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"],
88
+ mode="lines", name="Forecasted Prices",
89
+ line=dict(color="#FF00FF", width=2, dash="dash"))) # Brighter magenta
90
+
91
+ # Layout Customization
92
+ fig.update_layout(
93
+ title=f"{ticker} Stock Price Forecast (TimesFM)",
94
+ xaxis_title="Date",
95
+ yaxis_title="Price",
96
+ template="plotly_dark", # Dark Theme
97
+ hovermode="x unified", # Show all values on hover
98
+ legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1),
99
+ plot_bgcolor="#111111", # Slightly lighter than black for contrast
100
+ paper_bgcolor="#111111",
101
+ font=dict(color="white", size=12),
102
+ margin=dict(l=40, r=40, t=50, b=40),
103
+ )
104
+
105
+ # Add grid lines for better readability
106
+ fig.update_xaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)")
107
+ fig.update_yaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)")
108
+
109
+ return fig # Return the Plotly figure for Gradio
110
+
111
+ except Exception as e:
112
+ return f"Error: {str(e)}"
113
+
114
+ # Function for Prophet forecasting
115
+ def prophet_forecast(ticker, start_date, end_date):
116
+ try:
117
+ # Download stock market data using the StockDataFetcher class
118
+ df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date)
119
+
120
+ # Reset the index to get 'Date' back as a column
121
+ df_plot = df.reset_index()
122
+
123
+ # Prepare the data for Prophet
124
+ df1 = df_plot[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
125
+
126
+ # Fit the model
127
+ m = Prophet()
128
+ m.fit(df1)
129
+
130
+ # Create future dataframe and make predictions
131
+ future = m.make_future_dataframe(periods=30, freq='D')
132
+ forecast = m.predict(future)
133
+
134
+ # Plotting stock closing prices with trend
135
+ fig1 = go.Figure()
136
+
137
+ # Add actual closing prices
138
+ fig1.add_trace(go.Scatter(
139
+ x=df1['ds'],
140
+ y=df1['y'],
141
+ mode='lines',
142
+ name='Actual Price',
143
+ line=dict(color='#36D7B7', width=2)
144
+ ))
145
+
146
+ # Add trend component
147
+ fig1.add_trace(go.Scatter(
148
+ x=forecast['ds'],
149
+ y=forecast['trend'],
150
+ mode='lines',
151
+ name='Trend',
152
+ line=dict(color='#FF6B6B', width=2)
153
+ ))
154
+
155
+ fig1.update_layout(
156
+ title=f'{ticker} Price and Trend',
157
+ plot_bgcolor='#111111',
158
+ paper_bgcolor='#111111',
159
+ font=dict(color='white', size=12),
160
+ margin=dict(l=40, r=40, t=50, b=40),
161
+ xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"),
162
+ yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"),
163
+ legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1)
164
+ )
165
+
166
+ # Plotting forecast with confidence interval
167
+ forecast_40 = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(40)
168
+ fig2 = go.Figure()
169
+
170
+ # Add forecast line
171
+ fig2.add_trace(go.Scatter(
172
+ x=forecast_40['ds'],
173
+ y=forecast_40['yhat'],
174
+ mode='lines',
175
+ name='Forecast',
176
+ line=dict(color='#FF6B6B', width=2)
177
+ ))
178
+
179
+ # Add confidence interval
180
+ fig2.add_trace(go.Scatter(
181
+ x=forecast_40["ds"].tolist() + forecast_40["ds"].tolist()[::-1],
182
+ y=forecast_40["yhat_upper"].tolist() + forecast_40["yhat_lower"].tolist()[::-1],
183
+ fill="toself",
184
+ fillcolor="rgba(78, 205, 196, 0.2)",
185
+ line=dict(color="rgba(255,255,255,0)"),
186
+ name="Confidence Interval"
187
+ ))
188
+
189
+ fig2.update_layout(
190
+ title=f'{ticker} 30 Days Forecast (Prophet)',
191
+ plot_bgcolor='#111111',
192
+ paper_bgcolor='#111111',
193
+ font=dict(color='white', size=12),
194
+ margin=dict(l=40, r=40, t=50, b=40),
195
+ xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"),
196
+ yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"),
197
+ legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1)
198
+ )
199
+
200
+ # Create components figure
201
+ components_fig = go.Figure()
202
+
203
+ # Add components if they exist in the forecast
204
+ if 'yearly' in forecast.columns:
205
+ yearly_pattern = forecast.iloc[-365:] if len(forecast) > 365 else forecast
206
+ components_fig.add_trace(go.Scatter(
207
+ x=yearly_pattern['ds'],
208
+ y=yearly_pattern['yearly'],
209
+ mode='lines',
210
+ name='Yearly Pattern',
211
+ line=dict(color='#4ECDC4', width=2)
212
+ ))
213
+
214
+
215
+ components_fig.update_layout(
216
+ title=f'{ticker} Forecast Components',
217
+ xaxis_title='Date',
218
+ yaxis_title='Value',
219
+ plot_bgcolor='#111111',
220
+ paper_bgcolor='#111111',
221
+ font=dict(color='white', size=12),
222
+ legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1),
223
+ margin=dict(l=40, r=40, t=50, b=40),
224
+ xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"),
225
+ yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)")
226
+ )
227
+
228
+ # For backwards compatibility, still create the matplotlib figure
229
+ try:
230
+ plt.style.use('dark_background')
231
+ fig, ax = plt.subplots(figsize=(10, 8), facecolor='#111111')
232
+
233
+ plt.rcParams.update({
234
+ 'text.color': 'white',
235
+ 'axes.labelcolor': 'white',
236
+ 'axes.edgecolor': 'white',
237
+ 'xtick.color': 'white',
238
+ 'ytick.color': 'white',
239
+ 'grid.color': 'gray',
240
+ 'figure.facecolor': '#111111',
241
+ 'axes.facecolor': '#111111',
242
+ 'savefig.facecolor': '#111111',
243
+ })
244
+
245
+ m.plot_components(forecast, ax=ax)
246
+
247
+ for ax in plt.gcf().get_axes():
248
+ ax.set_facecolor('#111111')
249
+ for spine in ax.spines.values():
250
+ spine.set_color('white')
251
+ ax.tick_params(colors='white')
252
+ ax.title.set_color('white')
253
+ for line in ax.get_lines():
254
+ if line.get_color() == 'b':
255
+ line.set_color('#C678DD')
256
+ else:
257
+ line.set_color('#FF6B6B')
258
+
259
+ plt.tight_layout()
260
+
261
+ buf = io.BytesIO()
262
+ plt.savefig(buf, format='png', facecolor='#111111')
263
+ buf.seek(0)
264
+ plt.close(fig)
265
+
266
+ img = Image.open(buf)
267
+
268
+ return fig1, fig2, components_fig
269
+ except Exception as e:
270
+ print(f"Error with Matplotlib components: {e}")
271
+ return fig1, fig2, components_fig
272
+
273
+ except Exception as e:
274
+ return f"Error: {str(e)}", f"Error: {str(e)}", None
275
+
276
+ # Functions for technical analysis
277
+ def calculate_sma(df, window):
278
+ return df['Close'].rolling(window=window).mean()
279
+
280
+ def calculate_ema(df, window):
281
+ return df['Close'].ewm(span=window, adjust=False).mean()
282
+
283
+ def calculate_macd(df):
284
+ short_ema = df['Close'].ewm(span=12, adjust=False).mean()
285
+ long_ema = df['Close'].ewm(span=26, adjust=False).mean()
286
+ macd = short_ema - long_ema
287
+ signal = macd.ewm(span=9, adjust=False).mean()
288
+ return macd, signal
289
+
290
+ def calculate_rsi(df):
291
+ delta = df['Close'].diff()
292
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
293
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
294
+ rs = gain / loss
295
+ rsi = 100 - (100 / (1 + rs))
296
+ return rsi
297
+
298
+ def calculate_bollinger_bands(df):
299
+ middle_bb = df['Close'].rolling(window=20).mean()
300
+ upper_bb = middle_bb + 2 * df['Close'].rolling(window=20).std()
301
+ lower_bb = middle_bb - 2 * df['Close'].rolling(window=20).std()
302
+ return middle_bb, upper_bb, lower_bb
303
+
304
+ def calculate_stochastic_oscillator(df):
305
+ lowest_low = df['Low'].rolling(window=14).min()
306
+ highest_high = df['High'].rolling(window=14).max()
307
+ slowk = ((df['Close'] - lowest_low) / (highest_high - lowest_low)) * 100
308
+ slowd = slowk.rolling(window=3).mean()
309
+ return slowk, slowd
310
+
311
+ def calculate_cmf(df, window=20):
312
+ mfv = ((df['Close'] - df['Low']) - (df['High'] - df['Close'])) / (df['High'] - df['Low']) * df['Volume']
313
+ cmf = mfv.rolling(window=window).sum() / df['Volume'].rolling(window=window).sum()
314
+ return cmf
315
+
316
+ def calculate_cci(df, window=20):
317
+ """Calculate Commodity Channel Index (CCI)."""
318
+ typical_price = (df['High'] + df['Low'] + df['Close']) / 3
319
+ sma = typical_price.rolling(window=window).mean()
320
+ mean_deviation = (typical_price - sma).abs().rolling(window=window).mean()
321
+ cci = (typical_price - sma) / (0.015 * mean_deviation)
322
+ return cci
323
+
324
+ def generate_trading_signals(df):
325
+ # Calculate various indicators
326
+ df['SMA_30'] = calculate_sma(df, 30)
327
+ df['SMA_100'] = calculate_sma(df, 100)
328
+ df['EMA_12'] = calculate_ema(df, 12)
329
+ df['EMA_26'] = calculate_ema(df, 26)
330
+ df['RSI'] = calculate_rsi(df)
331
+ df['MiddleBB'], df['UpperBB'], df['LowerBB'] = calculate_bollinger_bands(df)
332
+ df['SlowK'], df['SlowD'] = calculate_stochastic_oscillator(df)
333
+ df['CMF'] = calculate_cmf(df)
334
+ df['CCI'] = calculate_cci(df)
335
+
336
+ # Generate trading signals
337
+ df['SMA_Signal'] = np.where(df['SMA_30'] > df['SMA_100'], 1, 0)
338
+
339
+ macd, signal = calculate_macd(df)
340
+ df['MACD_Signal'] = np.select([(macd > signal) & (macd.shift(1) <= signal.shift(1)),
341
+ (macd < signal) & (macd.shift(1) >= signal.shift(1))],[1, -1], default=0)
342
+
343
+ df['RSI_Signal'] = np.where(df['RSI'] < 20, 1, 0)
344
+ df['RSI_Signal'] = np.where(df['RSI'] > 90, -1, df['RSI_Signal'])
345
+
346
+ df['BB_Signal'] = np.where(df['Close'] < df['LowerBB'], 0, 0)
347
+ df['BB_Signal'] = np.where(df['Close'] > df['UpperBB'], -1, df['BB_Signal'])
348
+
349
+ df['Stochastic_Signal'] = np.where((df['SlowK'] < 10) & (df['SlowD'] < 15), 1, 0)
350
+ df['Stochastic_Signal'] = np.where((df['SlowK'] > 90) & (df['SlowD'] > 85), -1, df['Stochastic_Signal'])
351
+
352
+ df['CMF_Signal'] = np.where(df['CMF'] > 0.3, -1, np.where(df['CMF'] < -0.3, 1, 0))
353
+
354
+ df['CCI_Signal'] = np.where(df['CCI'] < -180, 1, 0)
355
+ df['CCI_Signal'] = np.where(df['CCI'] > 150, -1, df['CCI_Signal'])
356
+
357
+ # Combined signal for stronger buy/sell points
358
+ df['Combined_Signal'] = df[['RSI_Signal', 'BB_Signal',
359
+ 'Stochastic_Signal', 'CMF_Signal',
360
+ 'CCI_Signal']].sum(axis=1)
361
+
362
+ return df
363
+
364
+ def plot_combined_signals(df, ticker):
365
+ # Create a figure
366
+ fig = go.Figure()
367
+
368
+ # Add closing price trace
369
+ fig.add_trace(go.Scatter(
370
+ x=df.index, y=df['Close'],
371
+ mode='lines',
372
+ name='Closing Price',
373
+ line=dict(color='#36D7B7', width=2) # Brighter pink
374
+ ))
375
+
376
+ # Add buy signals
377
+ buy_signals = df[df['Combined_Signal'] >= 3]
378
+ fig.add_trace(go.Scatter(
379
+ x=buy_signals.index, y=buy_signals['Close'],
380
+ mode='markers',
381
+ marker=dict(symbol='triangle-up', size=12, color='green'),
382
+ name='Buy Signal'
383
+ ))
384
+
385
+ # Add sell signals
386
+ sell_signals = df[df['Combined_Signal'] <= -4]
387
+ fig.add_trace(go.Scatter(
388
+ x=sell_signals.index, y=sell_signals['Close'],
389
+ mode='markers',
390
+ marker=dict(symbol='triangle-down', size=12, color='red'),
391
+ name='Sell Signal'
392
+ ))
393
+
394
+ # Combined signal trace
395
+ fig.add_trace(go.Scatter(
396
+ x=df.index, y=df['Combined_Signal'],
397
+ mode='lines',
398
+ name='Combined Signal',
399
+ line=dict(color='#36A2EB', width=1), # Blue
400
+ yaxis='y2'
401
+ ))
402
+
403
+ # Update layout
404
+ fig.update_layout(
405
+ title=f'{ticker}: Stock Price and Combined Trading Signal (Last 120 Days)',
406
+ xaxis=dict(
407
+ title='Date',
408
+ showgrid=True,
409
+ gridcolor="rgba(255,255,255,0.1)"
410
+ ),
411
+ yaxis=dict(
412
+ title='Price',
413
+ side='left',
414
+ showgrid=True,
415
+ gridcolor="rgba(255,255,255,0.1)"
416
+ ),
417
+ yaxis2=dict(
418
+ title='Combined Signal',
419
+ overlaying='y',
420
+ side='right',
421
+ showgrid=False
422
+ ),
423
+ plot_bgcolor='#111111',
424
+ paper_bgcolor='#111111',
425
+ font=dict(color='white', size=12),
426
+ legend=dict(
427
+ bgcolor="rgba(0,0,0,0.8)",
428
+ bordercolor="white",
429
+ borderwidth=1
430
+ ),
431
+ margin=dict(l=40, r=40, t=50, b=40)
432
+ )
433
+
434
+ return fig
435
+
436
+ def plot_individual_signals(df, ticker):
437
+ # Create a figure
438
+ fig = go.Figure()
439
+
440
+ # Add closing price line
441
+ fig.add_trace(go.Scatter(
442
+ x=df.index, y=df['Close'],
443
+ mode='lines',
444
+ name='Closing Price',
445
+ line=dict(color='#D4B2FF', width=2) # Brighter pink #D4B2FF
446
+ ))
447
+
448
+ # Define colors for different signals
449
+
450
+
451
+
452
+
453
+ signal_colors = {
454
+ 'RSI_Signal': {'buy': '#36D7B7', 'sell': 'red'}, # Light purple / Pale butter
455
+ 'BB_Signal': {'buy': '#36D7B7', 'sell': 'red'}, # Purple / Chiffon yellow
456
+ 'Stochastic_Signal': {'buy': '#36D7B7', 'sell': 'red'}, # Purple / Corn silk
457
+ 'CMF_Signal': {'buy': '#36D7B7', 'sell': 'red'}, # Deep purple / Lemon chiffon
458
+ 'CCI_Signal': {'buy': '#36D7B7', 'sell': 'red'} # Dark purple / Soft maize
459
+ }
460
+
461
+
462
+
463
+ # Add buy/sell signals for each indicator
464
+ signal_names = ['RSI_Signal', 'BB_Signal',
465
+ 'Stochastic_Signal', 'CMF_Signal',
466
+ 'CCI_Signal']
467
+
468
+ for signal in signal_names:
469
+ buy_signals = df[df[signal] == 1]
470
+ sell_signals = df[df[signal] == -1]
471
+
472
+ fig.add_trace(go.Scatter(
473
+ x=buy_signals.index, y=buy_signals['Close'],
474
+ mode='markers',
475
+ marker=dict(
476
+ symbol='triangle-up',
477
+ size=12,
478
+ color=signal_colors[signal]['buy']
479
+ ),
480
+ name=f'{signal} Buy Signal'
481
+ ))
482
+
483
+ fig.add_trace(go.Scatter(
484
+ x=sell_signals.index, y=sell_signals['Close'],
485
+ mode='markers',
486
+ marker=dict(
487
+ symbol='triangle-down',
488
+ size=12,
489
+ color=signal_colors[signal]['sell']
490
+ ),
491
+ name=f'{signal} Sell Signal'
492
+ ))
493
+
494
+ fig.update_layout(
495
+ title=f'{ticker}: Individual Trading Signals',
496
+ xaxis=dict(
497
+ title='Date',
498
+ showgrid=True,
499
+ gridcolor="rgba(255,255,255,0.1)"
500
+ ),
501
+ yaxis=dict(
502
+ title='Price',
503
+ side='left',
504
+ showgrid=True,
505
+ gridcolor="rgba(255,255,255,0.1)"
506
+ ),
507
+ plot_bgcolor='#111111',
508
+ paper_bgcolor='#111111',
509
+ font=dict(color='white', size=12),
510
+ legend=dict(
511
+ bgcolor="rgba(0,0,0,0.8)",
512
+ bordercolor="white",
513
+ borderwidth=1
514
+ ),
515
+ margin=dict(l=40, r=40, t=50, b=40)
516
+ )
517
+
518
+ return fig
519
+
520
+ def technical_analysis(ticker, start_date, end_date):
521
+ try:
522
+ # Download stock data using the StockDataFetcher class
523
+ df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date)
524
+
525
+ # Generate signals
526
+ df = generate_trading_signals(df)
527
+
528
+ # Last 120 days for plotting
529
+ df_last_120 = df.tail(120)
530
+
531
+ # Plot combined signals
532
+ fig_signals = plot_combined_signals(df_last_120, ticker)
533
+
534
+ # Plot individual signals
535
+ fig_individual_signals = plot_individual_signals(df_last_120, ticker)
536
+
537
+ return fig_signals, fig_individual_signals
538
+
539
+ except Exception as e:
540
+ return f"Error: {str(e)}", f"Error: {str(e)}"
541
+
542
+ # Create Gradio interface
543
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
544
+ gr.Markdown("# Advanced Stock Analysis & Forecasting App")
545
+ gr.Markdown("Enter a stock ticker, start date, and end date to analyze and forecast stock prices.")
546
+
547
+ with gr.Row():
548
+ ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA")
549
+ start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01")
550
+ end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2026-01-01")
551
+
552
+ # Create tabs for different analysis types
553
+ with gr.Tabs() as tabs:
554
+
555
+ with gr.TabItem("Technical Analysis"):
556
+ analysis_button = gr.Button("Generate Technical Analysis")
557
+
558
+ individual_signals = gr.Plot(label="Individual Trading Signals")
559
+ combined_signals = gr.Plot(label="Combined Trading Signals")
560
+
561
+ # Connect button to function
562
+ analysis_button.click(
563
+ technical_analysis,
564
+ inputs=[ticker_input, start_date_input, end_date_input],
565
+ outputs=[combined_signals, individual_signals]
566
+ )
567
+
568
+ with gr.TabItem("TimesFM Forecast"):
569
+ timesfm_button = gr.Button("Generate TimesFM Forecast")
570
+ timesfm_plot = gr.Plot(label="TimesFM Stock Price Forecast")
571
+
572
+ # Connect button to function
573
+ timesfm_button.click(
574
+ timesfm_forecast,
575
+ inputs=[ticker_input, start_date_input, end_date_input],
576
+ outputs=timesfm_plot
577
+ )
578
+
579
+ with gr.TabItem("Prophet Forecast"):
580
+ prophet_button = gr.Button("Generate Prophet Forecast")
581
+ prophet_recent_plot = gr.Plot(label="Recent Stock Prices")
582
+ prophet_forecast_plot = gr.Plot(label="Prophet 30-Day Forecast")
583
+ prophet_components = gr.Plot(label="Forecast Components") # Changed from gr.Image to gr.Plot
584
+
585
+ # Connect button to function
586
+ prophet_button.click(
587
+ prophet_forecast,
588
+ inputs=[ticker_input, start_date_input, end_date_input],
589
+ outputs=[prophet_recent_plot, prophet_forecast_plot, prophet_components]
590
+ )
591
+
592
+
593
+
594
+ # Launch the app
595
+ demo.launch()