Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import numpy as np
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
@@ -9,14 +9,84 @@ import gradio as gr
|
|
| 9 |
from cachetools import cached, TTLCache
|
| 10 |
import cProfile
|
| 11 |
import pstats
|
| 12 |
-
import timesfm # Import the TimeSFM module
|
| 13 |
|
| 14 |
# Global fontsize variable
|
| 15 |
FONT_SIZE = 32
|
| 16 |
|
| 17 |
# Company ticker mapping
|
| 18 |
COMPANY_TICKERS = {
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
|
| 22 |
# Cache with 1-day TTL
|
|
@@ -56,29 +126,6 @@ def plot_to_image(plt, title, market_cap):
|
|
| 56 |
buf.seek(0)
|
| 57 |
return Image.open(buf)
|
| 58 |
|
| 59 |
-
def generate_forecast(data, horizon=30):
|
| 60 |
-
"""Generate a 30-day close price forecast using TimeSFM."""
|
| 61 |
-
tfm = timesfm.TimesFm(
|
| 62 |
-
context_len=128,
|
| 63 |
-
horizon_len=horizon,
|
| 64 |
-
input_patch_len=32,
|
| 65 |
-
output_patch_len=128,
|
| 66 |
-
num_layers=20,
|
| 67 |
-
model_dims=1280,
|
| 68 |
-
backend='cpu', # or 'gpu'
|
| 69 |
-
)
|
| 70 |
-
tfm.load_from_checkpoint('path_to_checkpoint')
|
| 71 |
-
|
| 72 |
-
# Prepare data for forecasting
|
| 73 |
-
close_prices = data['Close'].values[-128:] # Use the last 128 days for context
|
| 74 |
-
forecast, lower_bound, upper_bound = tfm.predict(close_prices)
|
| 75 |
-
|
| 76 |
-
# Generate future dates
|
| 77 |
-
last_date = data.index[-1]
|
| 78 |
-
future_dates = [last_date + np.timedelta64(i, 'D') for i in range(1, horizon + 1)]
|
| 79 |
-
|
| 80 |
-
return future_dates, forecast, lower_bound, upper_bound
|
| 81 |
-
|
| 82 |
def plot_indicator(data, company_name, ticker, indicator, market_cap):
|
| 83 |
"""Plot selected technical indicator for a single company."""
|
| 84 |
plt.figure(figsize=(16, 10))
|
|
@@ -98,12 +145,6 @@ def plot_indicator(data, company_name, ticker, indicator, market_cap):
|
|
| 98 |
plt.plot(data.index, signal, label='Signal Line')
|
| 99 |
plt.bar(data.index, macd - signal, label='MACD Histogram')
|
| 100 |
plt.ylabel('MACD', fontsize=FONT_SIZE)
|
| 101 |
-
elif indicator == "Forecast":
|
| 102 |
-
future_dates, forecast, lower_bound, upper_bound = generate_forecast(data)
|
| 103 |
-
plt.plot(data.index, data['Close'], label='Historical Close Price')
|
| 104 |
-
plt.plot(future_dates, forecast, label='Forecasted Close Price')
|
| 105 |
-
plt.fill_between(future_dates, lower_bound, upper_bound, color='gray', alpha=0.3, label='Confidence Interval')
|
| 106 |
-
plt.ylabel('Price', fontsize=FONT_SIZE)
|
| 107 |
|
| 108 |
return plot_to_image(plt, f'{company_name} ({ticker}) {indicator}', market_cap)
|
| 109 |
|
|
@@ -137,13 +178,13 @@ def plot_indicators(company_names, indicator_types):
|
|
| 137 |
|
| 138 |
def select_all_indicators(select_all):
|
| 139 |
"""Select or deselect all indicators based on the select_all flag."""
|
| 140 |
-
indicators = ["SMA", "MACD"
|
| 141 |
return indicators if select_all else []
|
| 142 |
|
| 143 |
def launch_gradio_app():
|
| 144 |
"""Launch the Gradio app for interactive plotting."""
|
| 145 |
company_choices = list(COMPANY_TICKERS.keys())
|
| 146 |
-
indicators = ["SMA", "MACD"
|
| 147 |
|
| 148 |
def fetch_and_plot(company_names, indicator_types):
|
| 149 |
images, error_message, total_market_cap = plot_indicators(company_names, indicator_types)
|
|
|
|
| 1 |
+
mport yfinance as yf
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import numpy as np
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
| 9 |
from cachetools import cached, TTLCache
|
| 10 |
import cProfile
|
| 11 |
import pstats
|
|
|
|
| 12 |
|
| 13 |
# Global fontsize variable
|
| 14 |
FONT_SIZE = 32
|
| 15 |
|
| 16 |
# Company ticker mapping
|
| 17 |
COMPANY_TICKERS = {
|
| 18 |
+
'Google': 'GOOGL',
|
| 19 |
+
'NVIDIA': 'NVDA',
|
| 20 |
+
'Microsoft': 'MSFT',
|
| 21 |
+
'Twilio': 'TWLO',
|
| 22 |
+
'Meta': 'META',
|
| 23 |
+
'Workday': 'WDAY',
|
| 24 |
+
'SAP': 'SAP',
|
| 25 |
+
'Apple': 'AAPL',
|
| 26 |
+
'Tesla': 'TSLA',
|
| 27 |
+
'Amazon': 'AMZN',
|
| 28 |
+
'Oracle': 'ORCL',
|
| 29 |
+
'Uber': 'UBER',
|
| 30 |
+
'ADP': 'ADP',
|
| 31 |
+
'Adobe':'ADBE',
|
| 32 |
+
'Cadence Design Systems': 'CDNS',
|
| 33 |
+
'Salesforce':'CRM',
|
| 34 |
+
'Constellation Software': 'CNSWF',
|
| 35 |
+
'Palo Alto': 'PANW',
|
| 36 |
+
'Autodesk': 'ADSK',
|
| 37 |
+
'Intuit': 'INTU',
|
| 38 |
+
'Confluent': 'CFLT',
|
| 39 |
+
'CrowdStrike': 'CRWD',
|
| 40 |
+
'UIPath':'PATH',
|
| 41 |
+
'Synopsys': 'SNPS',
|
| 42 |
+
'Palantir': 'PLTR',
|
| 43 |
+
'Mongodb': 'MDB',
|
| 44 |
+
'Reddit':'RDDT',
|
| 45 |
+
'Cloudflare': 'NET',
|
| 46 |
+
'DoorDash':'DASH',
|
| 47 |
+
'Datadog': 'DDOG',
|
| 48 |
+
'Duolingo': 'DUOL',
|
| 49 |
+
'HubSpot': 'HUBS',
|
| 50 |
+
'Trade Desk': 'TTD',
|
| 51 |
+
'Samsara': 'IOT',
|
| 52 |
+
'ServiceNow': 'NOW',
|
| 53 |
+
'ANSYS': 'ANSS',
|
| 54 |
+
'Zeta Global Holdings': 'ZETA',
|
| 55 |
+
'Veeva Systems' : 'VEEV',
|
| 56 |
+
'Box Inc': 'BOX',
|
| 57 |
+
'Airbnb': 'ABNB',
|
| 58 |
+
'AppFolio': 'APPF',
|
| 59 |
+
'Fortinet': 'FTNT',
|
| 60 |
+
'Snowflake': 'SNOW',
|
| 61 |
+
'Zscaler': 'ZS',
|
| 62 |
+
'Okta': 'OKTA',
|
| 63 |
+
'Docusign':'DOCU',
|
| 64 |
+
'Elastic NV': 'ESTC',
|
| 65 |
+
'NetApp': 'NTAP',
|
| 66 |
+
'Guidewire': 'GWRE',
|
| 67 |
+
'Monday.com': 'MNDY',
|
| 68 |
+
'Atlassian': 'TEAM',
|
| 69 |
+
'Shopify': 'SHOP',
|
| 70 |
+
'HashiCorp': 'HCP',
|
| 71 |
+
'Qualys': 'QLYS',
|
| 72 |
+
'Gitlab': 'GTLB',
|
| 73 |
+
'JFrog': 'FROG',
|
| 74 |
+
'Procore': 'PCOR',
|
| 75 |
+
'C3.ai': 'AI',
|
| 76 |
+
'Dynatrace': 'DT',
|
| 77 |
+
'Rubrik': 'RBRK',
|
| 78 |
+
'nCino': 'NCNO',
|
| 79 |
+
'SentinelOne': 'S',
|
| 80 |
+
'Klaviyo': 'KVYO',
|
| 81 |
+
'Braze': 'BRZE',
|
| 82 |
+
'Q2': 'QTWO',
|
| 83 |
+
'Tenable': 'TENB',
|
| 84 |
+
'DigitalOcean': 'DOCN',
|
| 85 |
+
'Workiva': 'WK',
|
| 86 |
+
'Smartsheet': 'SMAR',
|
| 87 |
+
'Unity Software': 'U',
|
| 88 |
+
'Squarespace': 'SQSP',
|
| 89 |
+
'Wix.com': 'WIX'
|
| 90 |
}
|
| 91 |
|
| 92 |
# Cache with 1-day TTL
|
|
|
|
| 126 |
buf.seek(0)
|
| 127 |
return Image.open(buf)
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def plot_indicator(data, company_name, ticker, indicator, market_cap):
|
| 130 |
"""Plot selected technical indicator for a single company."""
|
| 131 |
plt.figure(figsize=(16, 10))
|
|
|
|
| 145 |
plt.plot(data.index, signal, label='Signal Line')
|
| 146 |
plt.bar(data.index, macd - signal, label='MACD Histogram')
|
| 147 |
plt.ylabel('MACD', fontsize=FONT_SIZE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
return plot_to_image(plt, f'{company_name} ({ticker}) {indicator}', market_cap)
|
| 150 |
|
|
|
|
| 178 |
|
| 179 |
def select_all_indicators(select_all):
|
| 180 |
"""Select or deselect all indicators based on the select_all flag."""
|
| 181 |
+
indicators = ["SMA", "MACD"]
|
| 182 |
return indicators if select_all else []
|
| 183 |
|
| 184 |
def launch_gradio_app():
|
| 185 |
"""Launch the Gradio app for interactive plotting."""
|
| 186 |
company_choices = list(COMPANY_TICKERS.keys())
|
| 187 |
+
indicators = ["SMA", "MACD"]
|
| 188 |
|
| 189 |
def fetch_and_plot(company_names, indicator_types):
|
| 190 |
images, error_message, total_market_cap = plot_indicators(company_names, indicator_types)
|