import os # Force writable home and cache paths before other imports os.environ.setdefault("HOME", "/app") os.environ.setdefault("HF_HOME", "/app/hf_cache") os.environ.setdefault("HF_HUB_CACHE", "/app/hf_cache") os.environ.setdefault("XDG_CACHE_HOME", "/app/.cache") os.makedirs(os.environ["HF_HOME"], exist_ok=True) os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True) import json from typing import Dict, List, Generator import gradio as gr import requests from dotenv import load_dotenv load_dotenv() # -------- Keys (multi-key support) -------- FINNHUB_KEYS_RAW = os.getenv("FINNHUB_KEYS", "") FINNHUB_KEYS = [k.strip() for k in FINNHUB_KEYS_RAW.split("\n") if k.strip()] if FINNHUB_KEYS_RAW else [] FINNHUB_API_KEY = os.getenv("FINNHUB_API_KEY", "") if FINNHUB_API_KEY and FINNHUB_API_KEY.strip(): FINNHUB_KEYS = FINNHUB_KEYS or [FINNHUB_API_KEY.strip()] RAPIDAPI_KEYS_RAW = os.getenv("RAPIDAPI_KEYS", "") RAPIDAPI_KEYS = [k.strip() for k in RAPIDAPI_KEYS_RAW.split("\n") if k.strip()] if RAPIDAPI_KEYS_RAW else [] RAPIDAPI_KEY = os.getenv("RAPIDAPI_KEY", "") if RAPIDAPI_KEY and RAPIDAPI_KEY.strip(): RAPIDAPI_KEYS = RAPIDAPI_KEYS or [RAPIDAPI_KEY.strip()] RAPIDAPI_HOST = "alpha-vantage.p.rapidapi.com" # -------- llama.cpp GGUF model -------- MODEL_REPO = "mradermacher/Fin-o1-8B-GGUF" GGUF_OVERRIDE = os.getenv("GGUF_FILENAME", "").strip() N_THREADS = int(os.getenv("LLAMA_CPP_THREADS", str(os.cpu_count() or 4))) CTX_LEN = int(os.getenv("LLAMA_CPP_CTX", "3072")) # CPU-friendly default N_BATCH = int(os.getenv("LLAMA_CPP_BATCH", "128")) from huggingface_hub import snapshot_download from llama_cpp import Llama _llm = None def _pick_gguf_file(root_dir: str, override: str | None) -> str: import glob if override: path = os.path.join(root_dir, override) if os.path.isfile(path) and os.path.getsize(path) > 0: return path candidates = glob.glob(os.path.join(root_dir, "**", override), recursive=True) for c in candidates: if os.path.getsize(c) > 0: return c preferred: List[str] = [ "Fin-o1-8B.Q4_K_M.gguf", # explicit 8B file name first "Q4_K_M", "Q4_K_S", "Q4_0", "Q3_K_M", "Q3_K_S", "Q3_0", "Q2_K", "Q2_0", ] import glob as _glob ggufs = _glob.glob(os.path.join(root_dir, "**", "*.gguf"), recursive=True) if not ggufs: raise FileNotFoundError("No .gguf files found in snapshot") for key in preferred: for f in ggufs: if key in os.path.basename(f): return f return ggufs[0] def load_model(): global _llm if _llm is not None: return _llm repo_dir = snapshot_download( repo_id=MODEL_REPO, allow_patterns=["*.gguf"], cache_dir=os.getenv("HF_HOME", "/app/hf_cache"), local_files_only=False, resume_download=True, ) try: model_path = _pick_gguf_file(repo_dir, GGUF_OVERRIDE or None) except Exception as e: raise RuntimeError(f"GGUF not found: {e}") try: _llm = Llama( model_path=model_path, n_ctx=CTX_LEN, n_threads=N_THREADS, n_batch=N_BATCH, use_mlock=False, use_mmap=True, verbose=False, ) except Exception as e: raise RuntimeError(f"Failed to load GGUF: {e}. Set GGUF_FILENAME to an available 8B file if needed.") return _llm def generate_response_stream(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> Generator[str, None, None]: yield "Initializing model..." llm = load_model() yield "Model loaded. Generating..." accum = "" for chunk in llm(prompt=prompt, max_tokens=max_new_tokens, temperature=temperature, stream=True): text = chunk.get("choices", [{}])[0].get("text", "") if text: accum += text yield accum def generate_response(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> str: # non-streaming fallback llm = load_model() res = llm( prompt=prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["", "<|eot_id|>"], ) return res.get("choices", [{}])[0].get("text", "") # -------- Robust requests session with retry -------- from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry def create_session() -> requests.Session: session = requests.Session() retry_strategy = Retry( total=3, backoff_factor=1.0, status_forcelist=[429, 500, 502, 503, 504], ) adapter = HTTPAdapter(max_retries=retry_strategy) session.mount("http://", adapter) session.mount("https://", adapter) return session http = create_session() # -------- Helpers for mock candles -------- def _create_mock_candles(symbol: str, count: int = 60) -> Dict: import time as tmod import random now = int(tmod.time()) day = 86400 t, o, h, l, c, v = [], [], [], [], [], [] price = 100.0 + (hash(symbol) % 50) for i in range(count, 0, -1): ts = now - i * day op = max(1.0, price + random.uniform(-2, 2)) hi = op + random.uniform(0, 2) lo = max(0.5, op - random.uniform(0, 2)) cl = max(0.5, lo + random.uniform(0, (hi - lo) or 1)) vol = abs(int(random.gauss(1_000_000, 250_000))) t.append(ts); o.append(op); h.append(hi); l.append(lo); c.append(cl); v.append(vol) price = cl return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "mock"} # -------- Data helpers (Finnhub with fallback to Alpha Vantage) -------- def fetch_finnhub_candles(symbol: str, resolution: str = "D", count: int = 60) -> Dict: """Try Finnhub first cycling keys; on 401/403 or exhaustion, raise to caller.""" if not FINNHUB_KEYS: raise ValueError("Missing FINNHUB_KEYS/FINNHUB_API_KEY") import time as _time end = int(__import__("time").time()) start = end - count * 86400 last_error: Exception | None = None for api_key in FINNHUB_KEYS: url = ( f"https://finnhub.io/api/v1/stock/candle?symbol={symbol}" f"&resolution={resolution}&from={start}&to={end}&token={api_key}" ) try: r = http.get(url, timeout=30) if r.status_code in (401, 403): last_error = requests.HTTPError(f"Finnhub auth error {r.status_code}") continue r.raise_for_status() data = r.json() if data.get("s") == "ok": data["source"] = "finnhub" return data last_error = RuntimeError(f"Finnhub returned status: {data.get('s')}") except Exception as e: last_error = e continue finally: _time.sleep(0.3) raise last_error or RuntimeError("Finnhub candles failed") def fetch_alpha_vantage_series_daily(symbol: str, outputsize: str = "compact", count: int = 60) -> Dict: """Fallback: Alpha Vantage TIME_SERIES_DAILY via RapidAPI, format like Finnhub candles.""" if not RAPIDAPI_KEYS: return _create_mock_candles(symbol, count) import time as _time for api_key in RAPIDAPI_KEYS: try: url = f"https://{RAPIDAPI_HOST}/query" headers = {"X-RapidAPI-Key": api_key, "X-RapidAPI-Host": RAPIDAPI_HOST} params = {"function": "TIME_SERIES_DAILY", "symbol": symbol, "outputsize": outputsize} r = http.get(url, headers=headers, params=params, timeout=30) r.raise_for_status() data = r.json() if isinstance(data, dict) and any(k in data for k in ("Note", "Error Message", "Information")): # rate limit or error; try next key continue series = data.get("Time Series (Daily)") or {} if not series: continue dates = sorted(series.keys())[-count:] import time as tmod t, o, h, l, c, v = [], [], [], [], [], [] for d in dates: row = series[d] try: op_v = float(row.get("1. open")) h_v = float(row.get("2. high")) l_v = float(row.get("3. low")) c_v = float(row.get("4. close")) v_v = float(row.get("5. volume")) ts = int(tmod.mktime(tmod.strptime(d, "%Y-%m-%d"))) except Exception: continue t.append(ts); o.append(op_v); h.append(h_v); l.append(l_v); c.append(c_v); v.append(v_v) return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "alpha_vantage"} except Exception: continue finally: _time.sleep(0.5) # If all keys failed or no data, return mock to keep UI responsive return _create_mock_candles(symbol, count) def fetch_alpha_vantage_overview(symbol: str) -> Dict: if not RAPIDAPI_KEYS: raise ValueError("Missing RAPIDAPI_KEYS/RAPIDAPI_KEY") for api_key in RAPIDAPI_KEYS: try: url = f"https://{RAPIDAPI_HOST}/query" headers = {"x-rapidapi-key": api_key, "x-rapidapi-host": RAPIDAPI_HOST} params = {"function": "OVERVIEW", "symbol": symbol} r = http.get(url, headers=headers, params=params, timeout=30) r.raise_for_status() data = r.json() if data: return data except Exception: continue raise RuntimeError("Alpha Vantage OVERVIEW failed") # -------- Prompts -------- def build_price_prediction_prompt(symbol: str, candles: Dict) -> str: context = json.dumps(candles)[:10000] source = candles.get("source", "finnhub") return ( f"You are a financial analyst agent. Analyze recent OHLCV candles for {symbol} (source: {source}) and provide a short-term price prediction. " f"Explain key drivers in bullet points and give a 1-2 sentence forecast.\n\nData JSON: {context}\n\n" ) def build_equity_research_prompt(symbol: str, overview: Dict) -> str: context = json.dumps(overview)[:10000] return ( "You are an equity research analyst. Using the fundamentals overview, write a concise equity research note including: " "Business summary, recent performance, profitability, leverage, valuation multiples, key risks, and an investment view (Buy/Hold/Sell) with rationale.\n\n" f"Ticker: {symbol}\nFundamentals JSON: {context}\n" ) # -------- Gradio UI -------- def ui_app(): with gr.Blocks(title="Fin-o1-8B Tools") as demo: gr.Markdown("""# Fin-o1-8B Tools Two tabs: Price Prediction (Finnhub with Alpha Vantage fallback) and Equity Research (Alpha Vantage via RapidAPI).""") with gr.Tab("Price Prediction"): symbol = gr.Textbox(label="Ticker (e.g., AAPL)", value="AAPL") resolution = gr.Dropdown(["D", "60", "30", "15", "5"], value="D", label="Resolution") count = gr.Slider(20, 160, value=60, step=5, label="Num candles") temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") max_new = gr.Slider(64, 768, value=384, step=16, label="Max new tokens") btn = gr.Button("Predict") out = gr.Textbox(lines=30, show_copy_button=True) def on_predict(sym, res, cnt, temperature, max_tokens): try: candles = fetch_finnhub_candles(sym, res, int(cnt)) except Exception: try: candles = fetch_alpha_vantage_series_daily(sym, outputsize="compact") except Exception as e2: yield f"Error fetching candles: {e2}" return prompt = build_price_prediction_prompt(sym, candles) for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)): yield text btn.click(on_predict, inputs=[symbol, resolution, count, temp, max_new], outputs=out, show_progress=True) with gr.Tab("Equity Research Report"): symbol2 = gr.Textbox(label="Ticker (e.g., MSFT)", value="MSFT") temp2 = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") max_new2 = gr.Slider(64, 768, value=384, step=16, label="Max new tokens") btn2 = gr.Button("Generate Report") out2 = gr.Textbox(lines=30, show_copy_button=True) def on_report(sym, temperature, max_tokens): try: overview = fetch_alpha_vantage_overview(sym) except Exception as e: yield f"Error fetching fundamentals: {e}" return prompt = build_equity_research_prompt(sym, overview) for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)): yield text btn2.click(on_report, inputs=[symbol2, temp2, max_new2], outputs=out2, show_progress=True) # Enable queue with default settings for current Gradio version demo.queue() return demo if __name__ == "__main__": app = ui_app() app.launch(server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")))