Spaces:
Sleeping
Sleeping
File size: 5,588 Bytes
36bd0e2 da1806a 9ccf981 36bd0e2 da1806a 36bd0e2 da1806a 36bd0e2 da1806a 36bd0e2 da1806a fba6df5 da1806a 36bd0e2 da1806a 9ccf981 da1806a fba6df5 9ccf981 fba6df5 da1806a 9ccf981 fba6df5 9ccf981 fba6df5 9ccf981 da1806a fba6df5 36bd0e2 fba6df5 36bd0e2 da1806a 36bd0e2 9ccf981 da1806a fba6df5 da1806a 9ccf981 da1806a 00e837f da1806a 36bd0e2 da1806a fba6df5 00e837f 36bd0e2 9ccf981 da1806a 36bd0e2 fba6df5 36bd0e2 fba6df5 da1806a 36bd0e2 00e837f 36bd0e2 fba6df5 da1806a 36bd0e2 fba6df5 36bd0e2 fba6df5 36bd0e2 00e837f 5896acb 3d04e63 5896acb 3d04e63 5896acb 3d04e63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from datasets import load_dataset, load_from_disk
import numpy as np
import os
app = FastAPI()
# ---------------------------------------------------------
# Logging
# ---------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------
# CORS
# ---------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------
# Dataset caching configuration
# ---------------------------------------------------------
DATASET_NAME = "kurry/sp500_earnings_transcripts"
CACHE_PATH = "/data/hf_dataset" # persistent bucket mount
dataset_cache = None
def load_hf_dataset():
"""
Loads the HF dataset with persistent caching.
- If /data/hf_dataset exists → load from disk (fast, offline)
- Else → download once, save to disk, then load
"""
global dataset_cache
if dataset_cache is not None:
return dataset_cache
if os.path.exists(CACHE_PATH):
logger.info(f"Loading dataset from cache at {CACHE_PATH}")
dataset_cache = load_from_disk(CACHE_PATH)
logger.info(f"Loaded {len(dataset_cache)} rows from cached dataset")
return dataset_cache
logger.info(f"Downloading HF dataset: {DATASET_NAME}")
ds = load_dataset(DATASET_NAME, split="train")
logger.info(f"Saving dataset to cache at {CACHE_PATH}")
ds.save_to_disk(CACHE_PATH)
dataset_cache = ds
logger.info(f"Dataset cached and loaded ({len(ds)} rows)")
return ds
# ---------------------------------------------------------
# JSON-safe conversion
# ---------------------------------------------------------
def to_json_safe(obj):
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, (np.ndarray, list)):
return [to_json_safe(x) for x in obj]
if isinstance(obj, dict):
return {k: to_json_safe(v) for k, v in obj.items()}
return obj
# ---------------------------------------------------------
# Serve index.html
# ---------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def serve_index():
if not os.path.exists("index.html"):
return "<h1>index.html not found</h1>"
with open("index.html", "r") as f:
return f.read()
# ---------------------------------------------------------
# List all symbols
# ---------------------------------------------------------
@app.get("/tickers")
def get_tickers():
ds = load_hf_dataset()
symbols = sorted(set([s.upper() for s in ds["symbol"]]))
return {"tickers": symbols}
# ---------------------------------------------------------
# Get transcript for a symbol
# ---------------------------------------------------------
@app.get("/transcript/{symbol}")
def get_transcript(symbol: str):
ds = load_hf_dataset()
symbol = symbol.upper()
logger.info(f"Fetching transcript for: {symbol}")
rows = [r for r in ds if r["symbol"].upper() == symbol]
if not rows:
raise HTTPException(status_code=404, detail=f"No transcript found for {symbol}")
safe_rows = [to_json_safe(r) for r in rows]
return {"symbol": symbol, "records": safe_rows}
# ---------------------------------------------------------
# Dataset info (size + columns)
# ---------------------------------------------------------
@app.get("/dataset-info")
def dataset_info():
ds = load_hf_dataset()
info = {
"num_rows": len(ds),
"columns": ds.column_names,
"cache_path": CACHE_PATH,
}
return info
# ---------------------------------------------------------
# Dataset summary (high-level stats)
# ---------------------------------------------------------
@app.get("/dataset-summary")
def dataset_summary():
ds = load_hf_dataset()
symbols = set([s.upper() for s in ds["symbol"]])
years = set(ds["year"])
quarters = set(ds["quarter"])
dates = [d for d in ds["date"] if d is not None]
min_date = min(dates) if dates else None
max_date = max(dates) if dates else None
summary = {
"total_rows": len(ds),
"unique_symbols": len(symbols),
"symbols_sample": sorted(list(symbols))[:20],
"year_range": {
"min_year": min(years),
"max_year": max(years)
},
"quarters_present": sorted(list(quarters)),
"date_range": {
"min_date": min_date,
"max_date": max_date
},
"company_count": len(set(ds["company_id"])),
}
return summary
@app.get("/check/{symbol}")
def check_symbol(symbol: str):
ds = load_hf_dataset()
symbol = symbol.upper()
exists = any(r["symbol"].upper() == symbol for r in ds)
if not exists:
logger.warning(f"Symbol not found: {symbol}")
return {
"symbol": symbol,
"exists": False,
"message": f"Symbol '{symbol}' does not exist in the dataset."
}
logger.info(f"Symbol exists: {symbol}")
return {
"symbol": symbol,
"exists": True,
"message": f"Symbol '{symbol}' exists in the dataset."
}
|