File size: 2,367 Bytes
89dd6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816b723
 
 
 
 
 
 
89dd6ec
b7478c5
89dd6ec
 
b7478c5
89dd6ec
 
 
 
 
 
 
b7478c5
 
89dd6ec
 
 
 
 
 
 
 
 
 
4fef3fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89dd6ec
 
4fef3fe
89dd6ec
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os

import duckdb
import requests_cache
import streamlit as st
from alpaca.data.historical import StockHistoricalDataClient
from alpaca.data.requests import StockBarsRequest
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
from dotenv import load_dotenv

load_dotenv()

ALPACA_API_KEY = os.getenv("ALPACA_API_KEY")
ALPACA_SECRET_KEY = os.getenv("ALPACA_SECRET_KEY")

if not ALPACA_API_KEY or not ALPACA_SECRET_KEY:
    st.error(
        "API keys missing. Set ALPACA_API_KEY and ALPACA_SECRET_KEY in your .env or as secrets/environment variables."
    )
    st.stop()

requests_cache.install_cache("alpaca_api_cache", expire_after=120)


def get_dataset_path(dataset_name="alpaca"):
    if os.getenv("APP_ENV") == "development":
        return f"{dataset_name}.parquet"
    else:
        return f"hf://datasets/Arrechenash/stocks/{dataset_name}.parquet"


@st.cache_data
def load_symbols(dataset):
    return (
        duckdb.query(
            f"SELECT DISTINCT symbol FROM read_parquet('{dataset}') ORDER BY symbol"
        )
        .to_df()["symbol"]
        .tolist()
    )


@st.cache_data
def get_data(dataset, filters=None):
    query = f"SELECT * FROM read_parquet('{dataset}')"
    if filters:
        query += " WHERE " + " AND ".join(filters)
    return duckdb.query(query + " ORDER BY date DESC").to_df()


@st.cache_resource
def get_client():
    return StockHistoricalDataClient(ALPACA_API_KEY, ALPACA_SECRET_KEY)


def parse_timeframe(timeframe_str):
    """Parse timeframe string and return TimeFrame object"""
    timeframe_mapping = {
        "1m": TimeFrame(amount=1, unit=TimeFrameUnit.Minute),
        "5m": TimeFrame(amount=5, unit=TimeFrameUnit.Minute),
        "15m": TimeFrame(amount=15, unit=TimeFrameUnit.Minute),
        "30m": TimeFrame(amount=30, unit=TimeFrameUnit.Minute),
        "1h": TimeFrame(amount=1, unit=TimeFrameUnit.Hour),
        "1d": TimeFrame(amount=1, unit=TimeFrameUnit.Day),
    }
    return timeframe_mapping.get(
        timeframe_str, TimeFrame(amount=1, unit=TimeFrameUnit.Minute)
    )


def get_stock_bars(symbol_or_symbols, date_start, date_end, interval="1min"):
    req = StockBarsRequest(
        symbol_or_symbols=symbol_or_symbols,
        timeframe=parse_timeframe(interval),
        start=str(date_start),
        end=str(date_end),
    )
    return get_client().get_stock_bars(req).df