File size: 5,588 Bytes
36bd0e2
da1806a
 
 
9ccf981
36bd0e2
da1806a
36bd0e2
da1806a
 
 
 
 
36bd0e2
 
da1806a
36bd0e2
 
 
da1806a
fba6df5
da1806a
 
 
 
 
 
 
36bd0e2
da1806a
9ccf981
da1806a
fba6df5
9ccf981
fba6df5
da1806a
9ccf981
fba6df5
9ccf981
 
 
 
 
fba6df5
9ccf981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da1806a
 
 
 
 
fba6df5
36bd0e2
fba6df5
36bd0e2
da1806a
 
 
 
36bd0e2
 
9ccf981
da1806a
fba6df5
da1806a
 
 
 
 
 
 
 
9ccf981
da1806a
00e837f
da1806a
36bd0e2
da1806a
fba6df5
00e837f
 
36bd0e2
9ccf981
da1806a
 
 
36bd0e2
 
fba6df5
36bd0e2
fba6df5
da1806a
36bd0e2
00e837f
36bd0e2
fba6df5
da1806a
36bd0e2
fba6df5
36bd0e2
fba6df5
36bd0e2
00e837f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5896acb
 
 
 
 
 
 
3d04e63
 
 
 
 
 
 
 
 
5896acb
 
3d04e63
 
5896acb
3d04e63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from datasets import load_dataset, load_from_disk
import numpy as np
import os

app = FastAPI()

# ---------------------------------------------------------
# Logging
# ---------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------
# CORS
# ---------------------------------------------------------
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---------------------------------------------------------
# Dataset caching configuration
# ---------------------------------------------------------
DATASET_NAME = "kurry/sp500_earnings_transcripts"
CACHE_PATH = "/data/hf_dataset"   # persistent bucket mount
dataset_cache = None


def load_hf_dataset():
    """
    Loads the HF dataset with persistent caching.
    - If /data/hf_dataset exists → load from disk (fast, offline)
    - Else → download once, save to disk, then load
    """
    global dataset_cache

    if dataset_cache is not None:
        return dataset_cache

    if os.path.exists(CACHE_PATH):
        logger.info(f"Loading dataset from cache at {CACHE_PATH}")
        dataset_cache = load_from_disk(CACHE_PATH)
        logger.info(f"Loaded {len(dataset_cache)} rows from cached dataset")
        return dataset_cache

    logger.info(f"Downloading HF dataset: {DATASET_NAME}")
    ds = load_dataset(DATASET_NAME, split="train")

    logger.info(f"Saving dataset to cache at {CACHE_PATH}")
    ds.save_to_disk(CACHE_PATH)

    dataset_cache = ds
    logger.info(f"Dataset cached and loaded ({len(ds)} rows)")
    return ds


# ---------------------------------------------------------
# JSON-safe conversion
# ---------------------------------------------------------
def to_json_safe(obj):
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, (np.ndarray, list)):
        return [to_json_safe(x) for x in obj]
    if isinstance(obj, dict):
        return {k: to_json_safe(v) for k, v in obj.items()}
    return obj


# ---------------------------------------------------------
# Serve index.html
# ---------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def serve_index():
    if not os.path.exists("index.html"):
        return "<h1>index.html not found</h1>"
    with open("index.html", "r") as f:
        return f.read()


# ---------------------------------------------------------
# List all symbols
# ---------------------------------------------------------
@app.get("/tickers")
def get_tickers():
    ds = load_hf_dataset()
    symbols = sorted(set([s.upper() for s in ds["symbol"]]))
    return {"tickers": symbols}


# ---------------------------------------------------------
# Get transcript for a symbol
# ---------------------------------------------------------
@app.get("/transcript/{symbol}")
def get_transcript(symbol: str):
    ds = load_hf_dataset()
    symbol = symbol.upper()

    logger.info(f"Fetching transcript for: {symbol}")

    rows = [r for r in ds if r["symbol"].upper() == symbol]

    if not rows:
        raise HTTPException(status_code=404, detail=f"No transcript found for {symbol}")

    safe_rows = [to_json_safe(r) for r in rows]

    return {"symbol": symbol, "records": safe_rows}


# ---------------------------------------------------------
# Dataset info (size + columns)
# ---------------------------------------------------------
@app.get("/dataset-info")
def dataset_info():
    ds = load_hf_dataset()

    info = {
        "num_rows": len(ds),
        "columns": ds.column_names,
        "cache_path": CACHE_PATH,
    }

    return info


# ---------------------------------------------------------
# Dataset summary (high-level stats)
# ---------------------------------------------------------
@app.get("/dataset-summary")
def dataset_summary():
    ds = load_hf_dataset()

    symbols = set([s.upper() for s in ds["symbol"]])
    years = set(ds["year"])
    quarters = set(ds["quarter"])

    dates = [d for d in ds["date"] if d is not None]
    min_date = min(dates) if dates else None
    max_date = max(dates) if dates else None

    summary = {
        "total_rows": len(ds),
        "unique_symbols": len(symbols),
        "symbols_sample": sorted(list(symbols))[:20],
        "year_range": {
            "min_year": min(years),
            "max_year": max(years)
        },
        "quarters_present": sorted(list(quarters)),
        "date_range": {
            "min_date": min_date,
            "max_date": max_date
        },
        "company_count": len(set(ds["company_id"])),
    }

    return summary

@app.get("/check/{symbol}")
def check_symbol(symbol: str):
    ds = load_hf_dataset()
    symbol = symbol.upper()

    exists = any(r["symbol"].upper() == symbol for r in ds)

    if not exists:
        logger.warning(f"Symbol not found: {symbol}")
        return {
            "symbol": symbol,
            "exists": False,
            "message": f"Symbol '{symbol}' does not exist in the dataset."
        }

    logger.info(f"Symbol exists: {symbol}")
    return {
        "symbol": symbol,
        "exists": True,
        "message": f"Symbol '{symbol}' exists in the dataset."
    }