BaoKhuong's picture
Upload app.py
faa18d1 verified
import os
# Force writable home and cache paths before other imports
os.environ.setdefault("HOME", "/app")
os.environ.setdefault("HF_HOME", "/app/hf_cache")
os.environ.setdefault("HF_HUB_CACHE", "/app/hf_cache")
os.environ.setdefault("XDG_CACHE_HOME", "/app/.cache")
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True)
import json
from typing import Dict, List, Generator
import gradio as gr
import requests
from dotenv import load_dotenv
load_dotenv()
# -------- Keys (multi-key support) --------
FINNHUB_KEYS_RAW = os.getenv("FINNHUB_KEYS", "")
FINNHUB_KEYS = [k.strip() for k in FINNHUB_KEYS_RAW.split("\n") if k.strip()] if FINNHUB_KEYS_RAW else []
FINNHUB_API_KEY = os.getenv("FINNHUB_API_KEY", "")
if FINNHUB_API_KEY and FINNHUB_API_KEY.strip():
FINNHUB_KEYS = FINNHUB_KEYS or [FINNHUB_API_KEY.strip()]
RAPIDAPI_KEYS_RAW = os.getenv("RAPIDAPI_KEYS", "")
RAPIDAPI_KEYS = [k.strip() for k in RAPIDAPI_KEYS_RAW.split("\n") if k.strip()] if RAPIDAPI_KEYS_RAW else []
RAPIDAPI_KEY = os.getenv("RAPIDAPI_KEY", "")
if RAPIDAPI_KEY and RAPIDAPI_KEY.strip():
RAPIDAPI_KEYS = RAPIDAPI_KEYS or [RAPIDAPI_KEY.strip()]
RAPIDAPI_HOST = "alpha-vantage.p.rapidapi.com"
# -------- llama.cpp GGUF model --------
MODEL_REPO = "mradermacher/Fin-o1-8B-GGUF"
GGUF_OVERRIDE = os.getenv("GGUF_FILENAME", "").strip()
N_THREADS = int(os.getenv("LLAMA_CPP_THREADS", str(os.cpu_count() or 4)))
CTX_LEN = int(os.getenv("LLAMA_CPP_CTX", "3072")) # CPU-friendly default
N_BATCH = int(os.getenv("LLAMA_CPP_BATCH", "128"))
from huggingface_hub import snapshot_download
from llama_cpp import Llama
_llm = None
def _pick_gguf_file(root_dir: str, override: str | None) -> str:
import glob
if override:
path = os.path.join(root_dir, override)
if os.path.isfile(path) and os.path.getsize(path) > 0:
return path
candidates = glob.glob(os.path.join(root_dir, "**", override), recursive=True)
for c in candidates:
if os.path.getsize(c) > 0:
return c
preferred: List[str] = [
"Fin-o1-8B.Q4_K_M.gguf", # explicit 8B file name first
"Q4_K_M", "Q4_K_S", "Q4_0", "Q3_K_M", "Q3_K_S", "Q3_0", "Q2_K", "Q2_0",
]
import glob as _glob
ggufs = _glob.glob(os.path.join(root_dir, "**", "*.gguf"), recursive=True)
if not ggufs:
raise FileNotFoundError("No .gguf files found in snapshot")
for key in preferred:
for f in ggufs:
if key in os.path.basename(f):
return f
return ggufs[0]
def load_model():
global _llm
if _llm is not None:
return _llm
repo_dir = snapshot_download(
repo_id=MODEL_REPO,
allow_patterns=["*.gguf"],
cache_dir=os.getenv("HF_HOME", "/app/hf_cache"),
local_files_only=False,
resume_download=True,
)
try:
model_path = _pick_gguf_file(repo_dir, GGUF_OVERRIDE or None)
except Exception as e:
raise RuntimeError(f"GGUF not found: {e}")
try:
_llm = Llama(
model_path=model_path,
n_ctx=CTX_LEN,
n_threads=N_THREADS,
n_batch=N_BATCH,
use_mlock=False,
use_mmap=True,
verbose=False,
)
except Exception as e:
raise RuntimeError(f"Failed to load GGUF: {e}. Set GGUF_FILENAME to an available 8B file if needed.")
return _llm
def generate_response_stream(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> Generator[str, None, None]:
yield "Initializing model..."
llm = load_model()
yield "Model loaded. Generating..."
accum = ""
for chunk in llm(prompt=prompt, max_tokens=max_new_tokens, temperature=temperature, stream=True):
text = chunk.get("choices", [{}])[0].get("text", "")
if text:
accum += text
yield accum
def generate_response(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> str:
# non-streaming fallback
llm = load_model()
res = llm(
prompt=prompt,
max_tokens=max_new_tokens,
temperature=temperature,
stop=["</s>", "<|eot_id|>"],
)
return res.get("choices", [{}])[0].get("text", "")
# -------- Robust requests session with retry --------
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
def create_session() -> requests.Session:
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1.0,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
http = create_session()
# -------- Helpers for mock candles --------
def _create_mock_candles(symbol: str, count: int = 60) -> Dict:
import time as tmod
import random
now = int(tmod.time())
day = 86400
t, o, h, l, c, v = [], [], [], [], [], []
price = 100.0 + (hash(symbol) % 50)
for i in range(count, 0, -1):
ts = now - i * day
op = max(1.0, price + random.uniform(-2, 2))
hi = op + random.uniform(0, 2)
lo = max(0.5, op - random.uniform(0, 2))
cl = max(0.5, lo + random.uniform(0, (hi - lo) or 1))
vol = abs(int(random.gauss(1_000_000, 250_000)))
t.append(ts); o.append(op); h.append(hi); l.append(lo); c.append(cl); v.append(vol)
price = cl
return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "mock"}
# -------- Data helpers (Finnhub with fallback to Alpha Vantage) --------
def fetch_finnhub_candles(symbol: str, resolution: str = "D", count: int = 60) -> Dict:
"""Try Finnhub first cycling keys; on 401/403 or exhaustion, raise to caller."""
if not FINNHUB_KEYS:
raise ValueError("Missing FINNHUB_KEYS/FINNHUB_API_KEY")
import time as _time
end = int(__import__("time").time())
start = end - count * 86400
last_error: Exception | None = None
for api_key in FINNHUB_KEYS:
url = (
f"https://finnhub.io/api/v1/stock/candle?symbol={symbol}"
f"&resolution={resolution}&from={start}&to={end}&token={api_key}"
)
try:
r = http.get(url, timeout=30)
if r.status_code in (401, 403):
last_error = requests.HTTPError(f"Finnhub auth error {r.status_code}")
continue
r.raise_for_status()
data = r.json()
if data.get("s") == "ok":
data["source"] = "finnhub"
return data
last_error = RuntimeError(f"Finnhub returned status: {data.get('s')}")
except Exception as e:
last_error = e
continue
finally:
_time.sleep(0.3)
raise last_error or RuntimeError("Finnhub candles failed")
def fetch_alpha_vantage_series_daily(symbol: str, outputsize: str = "compact", count: int = 60) -> Dict:
"""Fallback: Alpha Vantage TIME_SERIES_DAILY via RapidAPI, format like Finnhub candles."""
if not RAPIDAPI_KEYS:
return _create_mock_candles(symbol, count)
import time as _time
for api_key in RAPIDAPI_KEYS:
try:
url = f"https://{RAPIDAPI_HOST}/query"
headers = {"X-RapidAPI-Key": api_key, "X-RapidAPI-Host": RAPIDAPI_HOST}
params = {"function": "TIME_SERIES_DAILY", "symbol": symbol, "outputsize": outputsize}
r = http.get(url, headers=headers, params=params, timeout=30)
r.raise_for_status()
data = r.json()
if isinstance(data, dict) and any(k in data for k in ("Note", "Error Message", "Information")):
# rate limit or error; try next key
continue
series = data.get("Time Series (Daily)") or {}
if not series:
continue
dates = sorted(series.keys())[-count:]
import time as tmod
t, o, h, l, c, v = [], [], [], [], [], []
for d in dates:
row = series[d]
try:
op_v = float(row.get("1. open"))
h_v = float(row.get("2. high"))
l_v = float(row.get("3. low"))
c_v = float(row.get("4. close"))
v_v = float(row.get("5. volume"))
ts = int(tmod.mktime(tmod.strptime(d, "%Y-%m-%d")))
except Exception:
continue
t.append(ts); o.append(op_v); h.append(h_v); l.append(l_v); c.append(c_v); v.append(v_v)
return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "alpha_vantage"}
except Exception:
continue
finally:
_time.sleep(0.5)
# If all keys failed or no data, return mock to keep UI responsive
return _create_mock_candles(symbol, count)
def fetch_alpha_vantage_overview(symbol: str) -> Dict:
if not RAPIDAPI_KEYS:
raise ValueError("Missing RAPIDAPI_KEYS/RAPIDAPI_KEY")
for api_key in RAPIDAPI_KEYS:
try:
url = f"https://{RAPIDAPI_HOST}/query"
headers = {"x-rapidapi-key": api_key, "x-rapidapi-host": RAPIDAPI_HOST}
params = {"function": "OVERVIEW", "symbol": symbol}
r = http.get(url, headers=headers, params=params, timeout=30)
r.raise_for_status()
data = r.json()
if data:
return data
except Exception:
continue
raise RuntimeError("Alpha Vantage OVERVIEW failed")
# -------- Prompts --------
def build_price_prediction_prompt(symbol: str, candles: Dict) -> str:
context = json.dumps(candles)[:10000]
source = candles.get("source", "finnhub")
return (
f"You are a financial analyst agent. Analyze recent OHLCV candles for {symbol} (source: {source}) and provide a short-term price prediction. "
f"Explain key drivers in bullet points and give a 1-2 sentence forecast.\n\nData JSON: {context}\n\n"
)
def build_equity_research_prompt(symbol: str, overview: Dict) -> str:
context = json.dumps(overview)[:10000]
return (
"You are an equity research analyst. Using the fundamentals overview, write a concise equity research note including: "
"Business summary, recent performance, profitability, leverage, valuation multiples, key risks, and an investment view (Buy/Hold/Sell) with rationale.\n\n"
f"Ticker: {symbol}\nFundamentals JSON: {context}\n"
)
# -------- Gradio UI --------
def ui_app():
with gr.Blocks(title="Fin-o1-8B Tools") as demo:
gr.Markdown("""# Fin-o1-8B Tools
Two tabs: Price Prediction (Finnhub with Alpha Vantage fallback) and Equity Research (Alpha Vantage via RapidAPI).""")
with gr.Tab("Price Prediction"):
symbol = gr.Textbox(label="Ticker (e.g., AAPL)", value="AAPL")
resolution = gr.Dropdown(["D", "60", "30", "15", "5"], value="D", label="Resolution")
count = gr.Slider(20, 160, value=60, step=5, label="Num candles")
temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
max_new = gr.Slider(64, 768, value=384, step=16, label="Max new tokens")
btn = gr.Button("Predict")
out = gr.Textbox(lines=30, show_copy_button=True)
def on_predict(sym, res, cnt, temperature, max_tokens):
try:
candles = fetch_finnhub_candles(sym, res, int(cnt))
except Exception:
try:
candles = fetch_alpha_vantage_series_daily(sym, outputsize="compact")
except Exception as e2:
yield f"Error fetching candles: {e2}"
return
prompt = build_price_prediction_prompt(sym, candles)
for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)):
yield text
btn.click(on_predict, inputs=[symbol, resolution, count, temp, max_new], outputs=out, show_progress=True)
with gr.Tab("Equity Research Report"):
symbol2 = gr.Textbox(label="Ticker (e.g., MSFT)", value="MSFT")
temp2 = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
max_new2 = gr.Slider(64, 768, value=384, step=16, label="Max new tokens")
btn2 = gr.Button("Generate Report")
out2 = gr.Textbox(lines=30, show_copy_button=True)
def on_report(sym, temperature, max_tokens):
try:
overview = fetch_alpha_vantage_overview(sym)
except Exception as e:
yield f"Error fetching fundamentals: {e}"
return
prompt = build_equity_research_prompt(sym, overview)
for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)):
yield text
btn2.click(on_report, inputs=[symbol2, temp2, max_new2], outputs=out2, show_progress=True)
# Enable queue with default settings for current Gradio version
demo.queue()
return demo
if __name__ == "__main__":
app = ui_app()
app.launch(server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")))