dashboard / src /datasource.py
Arrechenash's picture
Support more timeframes in charts
4fef3fe
import os
import duckdb
import requests_cache
import streamlit as st
from alpaca.data.historical import StockHistoricalDataClient
from alpaca.data.requests import StockBarsRequest
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
from dotenv import load_dotenv
load_dotenv()
ALPACA_API_KEY = os.getenv("ALPACA_API_KEY")
ALPACA_SECRET_KEY = os.getenv("ALPACA_SECRET_KEY")
if not ALPACA_API_KEY or not ALPACA_SECRET_KEY:
st.error(
"API keys missing. Set ALPACA_API_KEY and ALPACA_SECRET_KEY in your .env or as secrets/environment variables."
)
st.stop()
requests_cache.install_cache("alpaca_api_cache", expire_after=120)
def get_dataset_path(dataset_name="alpaca"):
if os.getenv("APP_ENV") == "development":
return f"{dataset_name}.parquet"
else:
return f"hf://datasets/Arrechenash/stocks/{dataset_name}.parquet"
@st.cache_data
def load_symbols(dataset):
return (
duckdb.query(
f"SELECT DISTINCT symbol FROM read_parquet('{dataset}') ORDER BY symbol"
)
.to_df()["symbol"]
.tolist()
)
@st.cache_data
def get_data(dataset, filters=None):
query = f"SELECT * FROM read_parquet('{dataset}')"
if filters:
query += " WHERE " + " AND ".join(filters)
return duckdb.query(query + " ORDER BY date DESC").to_df()
@st.cache_resource
def get_client():
return StockHistoricalDataClient(ALPACA_API_KEY, ALPACA_SECRET_KEY)
def parse_timeframe(timeframe_str):
"""Parse timeframe string and return TimeFrame object"""
timeframe_mapping = {
"1m": TimeFrame(amount=1, unit=TimeFrameUnit.Minute),
"5m": TimeFrame(amount=5, unit=TimeFrameUnit.Minute),
"15m": TimeFrame(amount=15, unit=TimeFrameUnit.Minute),
"30m": TimeFrame(amount=30, unit=TimeFrameUnit.Minute),
"1h": TimeFrame(amount=1, unit=TimeFrameUnit.Hour),
"1d": TimeFrame(amount=1, unit=TimeFrameUnit.Day),
}
return timeframe_mapping.get(
timeframe_str, TimeFrame(amount=1, unit=TimeFrameUnit.Minute)
)
def get_stock_bars(symbol_or_symbols, date_start, date_end, interval="1min"):
req = StockBarsRequest(
symbol_or_symbols=symbol_or_symbols,
timeframe=parse_timeframe(interval),
start=str(date_start),
end=str(date_end),
)
return get_client().get_stock_bars(req).df