LeonceNsh commited on
Commit
92d2642
·
verified ·
1 Parent(s): 0880026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -34
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import yfinance as yf
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -9,14 +9,84 @@ import gradio as gr
9
  from cachetools import cached, TTLCache
10
  import cProfile
11
  import pstats
12
- import timesfm # Import the TimeSFM module
13
 
14
  # Global fontsize variable
15
  FONT_SIZE = 32
16
 
17
  # Company ticker mapping
18
  COMPANY_TICKERS = {
19
- # [same as provided]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  }
21
 
22
  # Cache with 1-day TTL
@@ -56,29 +126,6 @@ def plot_to_image(plt, title, market_cap):
56
  buf.seek(0)
57
  return Image.open(buf)
58
 
59
- def generate_forecast(data, horizon=30):
60
- """Generate a 30-day close price forecast using TimeSFM."""
61
- tfm = timesfm.TimesFm(
62
- context_len=128,
63
- horizon_len=horizon,
64
- input_patch_len=32,
65
- output_patch_len=128,
66
- num_layers=20,
67
- model_dims=1280,
68
- backend='cpu', # or 'gpu'
69
- )
70
- tfm.load_from_checkpoint('path_to_checkpoint')
71
-
72
- # Prepare data for forecasting
73
- close_prices = data['Close'].values[-128:] # Use the last 128 days for context
74
- forecast, lower_bound, upper_bound = tfm.predict(close_prices)
75
-
76
- # Generate future dates
77
- last_date = data.index[-1]
78
- future_dates = [last_date + np.timedelta64(i, 'D') for i in range(1, horizon + 1)]
79
-
80
- return future_dates, forecast, lower_bound, upper_bound
81
-
82
  def plot_indicator(data, company_name, ticker, indicator, market_cap):
83
  """Plot selected technical indicator for a single company."""
84
  plt.figure(figsize=(16, 10))
@@ -98,12 +145,6 @@ def plot_indicator(data, company_name, ticker, indicator, market_cap):
98
  plt.plot(data.index, signal, label='Signal Line')
99
  plt.bar(data.index, macd - signal, label='MACD Histogram')
100
  plt.ylabel('MACD', fontsize=FONT_SIZE)
101
- elif indicator == "Forecast":
102
- future_dates, forecast, lower_bound, upper_bound = generate_forecast(data)
103
- plt.plot(data.index, data['Close'], label='Historical Close Price')
104
- plt.plot(future_dates, forecast, label='Forecasted Close Price')
105
- plt.fill_between(future_dates, lower_bound, upper_bound, color='gray', alpha=0.3, label='Confidence Interval')
106
- plt.ylabel('Price', fontsize=FONT_SIZE)
107
 
108
  return plot_to_image(plt, f'{company_name} ({ticker}) {indicator}', market_cap)
109
 
@@ -137,13 +178,13 @@ def plot_indicators(company_names, indicator_types):
137
 
138
  def select_all_indicators(select_all):
139
  """Select or deselect all indicators based on the select_all flag."""
140
- indicators = ["SMA", "MACD", "Forecast"]
141
  return indicators if select_all else []
142
 
143
  def launch_gradio_app():
144
  """Launch the Gradio app for interactive plotting."""
145
  company_choices = list(COMPANY_TICKERS.keys())
146
- indicators = ["SMA", "MACD", "Forecast"]
147
 
148
  def fetch_and_plot(company_names, indicator_types):
149
  images, error_message, total_market_cap = plot_indicators(company_names, indicator_types)
 
1
+ mport yfinance as yf
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
9
  from cachetools import cached, TTLCache
10
  import cProfile
11
  import pstats
 
12
 
13
  # Global fontsize variable
14
  FONT_SIZE = 32
15
 
16
  # Company ticker mapping
17
  COMPANY_TICKERS = {
18
+ 'Google': 'GOOGL',
19
+ 'NVIDIA': 'NVDA',
20
+ 'Microsoft': 'MSFT',
21
+ 'Twilio': 'TWLO',
22
+ 'Meta': 'META',
23
+ 'Workday': 'WDAY',
24
+ 'SAP': 'SAP',
25
+ 'Apple': 'AAPL',
26
+ 'Tesla': 'TSLA',
27
+ 'Amazon': 'AMZN',
28
+ 'Oracle': 'ORCL',
29
+ 'Uber': 'UBER',
30
+ 'ADP': 'ADP',
31
+ 'Adobe':'ADBE',
32
+ 'Cadence Design Systems': 'CDNS',
33
+ 'Salesforce':'CRM',
34
+ 'Constellation Software': 'CNSWF',
35
+ 'Palo Alto': 'PANW',
36
+ 'Autodesk': 'ADSK',
37
+ 'Intuit': 'INTU',
38
+ 'Confluent': 'CFLT',
39
+ 'CrowdStrike': 'CRWD',
40
+ 'UIPath':'PATH',
41
+ 'Synopsys': 'SNPS',
42
+ 'Palantir': 'PLTR',
43
+ 'Mongodb': 'MDB',
44
+ 'Reddit':'RDDT',
45
+ 'Cloudflare': 'NET',
46
+ 'DoorDash':'DASH',
47
+ 'Datadog': 'DDOG',
48
+ 'Duolingo': 'DUOL',
49
+ 'HubSpot': 'HUBS',
50
+ 'Trade Desk': 'TTD',
51
+ 'Samsara': 'IOT',
52
+ 'ServiceNow': 'NOW',
53
+ 'ANSYS': 'ANSS',
54
+ 'Zeta Global Holdings': 'ZETA',
55
+ 'Veeva Systems' : 'VEEV',
56
+ 'Box Inc': 'BOX',
57
+ 'Airbnb': 'ABNB',
58
+ 'AppFolio': 'APPF',
59
+ 'Fortinet': 'FTNT',
60
+ 'Snowflake': 'SNOW',
61
+ 'Zscaler': 'ZS',
62
+ 'Okta': 'OKTA',
63
+ 'Docusign':'DOCU',
64
+ 'Elastic NV': 'ESTC',
65
+ 'NetApp': 'NTAP',
66
+ 'Guidewire': 'GWRE',
67
+ 'Monday.com': 'MNDY',
68
+ 'Atlassian': 'TEAM',
69
+ 'Shopify': 'SHOP',
70
+ 'HashiCorp': 'HCP',
71
+ 'Qualys': 'QLYS',
72
+ 'Gitlab': 'GTLB',
73
+ 'JFrog': 'FROG',
74
+ 'Procore': 'PCOR',
75
+ 'C3.ai': 'AI',
76
+ 'Dynatrace': 'DT',
77
+ 'Rubrik': 'RBRK',
78
+ 'nCino': 'NCNO',
79
+ 'SentinelOne': 'S',
80
+ 'Klaviyo': 'KVYO',
81
+ 'Braze': 'BRZE',
82
+ 'Q2': 'QTWO',
83
+ 'Tenable': 'TENB',
84
+ 'DigitalOcean': 'DOCN',
85
+ 'Workiva': 'WK',
86
+ 'Smartsheet': 'SMAR',
87
+ 'Unity Software': 'U',
88
+ 'Squarespace': 'SQSP',
89
+ 'Wix.com': 'WIX'
90
  }
91
 
92
  # Cache with 1-day TTL
 
126
  buf.seek(0)
127
  return Image.open(buf)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def plot_indicator(data, company_name, ticker, indicator, market_cap):
130
  """Plot selected technical indicator for a single company."""
131
  plt.figure(figsize=(16, 10))
 
145
  plt.plot(data.index, signal, label='Signal Line')
146
  plt.bar(data.index, macd - signal, label='MACD Histogram')
147
  plt.ylabel('MACD', fontsize=FONT_SIZE)
 
 
 
 
 
 
148
 
149
  return plot_to_image(plt, f'{company_name} ({ticker}) {indicator}', market_cap)
150
 
 
178
 
179
  def select_all_indicators(select_all):
180
  """Select or deselect all indicators based on the select_all flag."""
181
+ indicators = ["SMA", "MACD"]
182
  return indicators if select_all else []
183
 
184
  def launch_gradio_app():
185
  """Launch the Gradio app for interactive plotting."""
186
  company_choices = list(COMPANY_TICKERS.keys())
187
+ indicators = ["SMA", "MACD"]
188
 
189
  def fetch_and_plot(company_names, indicator_types):
190
  images, error_message, total_market_cap = plot_indicators(company_names, indicator_types)