Spaces:
Sleeping
Sleeping
| import os | |
| # Force writable home and cache paths before other imports | |
| os.environ.setdefault("HOME", "/app") | |
| os.environ.setdefault("HF_HOME", "/app/hf_cache") | |
| os.environ.setdefault("HF_HUB_CACHE", "/app/hf_cache") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/app/.cache") | |
| os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
| os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True) | |
| import json | |
| from typing import Dict, List, Generator | |
| import gradio as gr | |
| import requests | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # -------- Keys (multi-key support) -------- | |
| FINNHUB_KEYS_RAW = os.getenv("FINNHUB_KEYS", "") | |
| FINNHUB_KEYS = [k.strip() for k in FINNHUB_KEYS_RAW.split("\n") if k.strip()] if FINNHUB_KEYS_RAW else [] | |
| FINNHUB_API_KEY = os.getenv("FINNHUB_API_KEY", "") | |
| if FINNHUB_API_KEY and FINNHUB_API_KEY.strip(): | |
| FINNHUB_KEYS = FINNHUB_KEYS or [FINNHUB_API_KEY.strip()] | |
| RAPIDAPI_KEYS_RAW = os.getenv("RAPIDAPI_KEYS", "") | |
| RAPIDAPI_KEYS = [k.strip() for k in RAPIDAPI_KEYS_RAW.split("\n") if k.strip()] if RAPIDAPI_KEYS_RAW else [] | |
| RAPIDAPI_KEY = os.getenv("RAPIDAPI_KEY", "") | |
| if RAPIDAPI_KEY and RAPIDAPI_KEY.strip(): | |
| RAPIDAPI_KEYS = RAPIDAPI_KEYS or [RAPIDAPI_KEY.strip()] | |
| RAPIDAPI_HOST = "alpha-vantage.p.rapidapi.com" | |
| # -------- llama.cpp GGUF model -------- | |
| MODEL_REPO = "mradermacher/Fin-o1-8B-GGUF" | |
| GGUF_OVERRIDE = os.getenv("GGUF_FILENAME", "").strip() | |
| N_THREADS = int(os.getenv("LLAMA_CPP_THREADS", str(os.cpu_count() or 4))) | |
| CTX_LEN = int(os.getenv("LLAMA_CPP_CTX", "3072")) # CPU-friendly default | |
| N_BATCH = int(os.getenv("LLAMA_CPP_BATCH", "128")) | |
| from huggingface_hub import snapshot_download | |
| from llama_cpp import Llama | |
| _llm = None | |
| def _pick_gguf_file(root_dir: str, override: str | None) -> str: | |
| import glob | |
| if override: | |
| path = os.path.join(root_dir, override) | |
| if os.path.isfile(path) and os.path.getsize(path) > 0: | |
| return path | |
| candidates = glob.glob(os.path.join(root_dir, "**", override), recursive=True) | |
| for c in candidates: | |
| if os.path.getsize(c) > 0: | |
| return c | |
| preferred: List[str] = [ | |
| "Fin-o1-8B.Q4_K_M.gguf", # explicit 8B file name first | |
| "Q4_K_M", "Q4_K_S", "Q4_0", "Q3_K_M", "Q3_K_S", "Q3_0", "Q2_K", "Q2_0", | |
| ] | |
| import glob as _glob | |
| ggufs = _glob.glob(os.path.join(root_dir, "**", "*.gguf"), recursive=True) | |
| if not ggufs: | |
| raise FileNotFoundError("No .gguf files found in snapshot") | |
| for key in preferred: | |
| for f in ggufs: | |
| if key in os.path.basename(f): | |
| return f | |
| return ggufs[0] | |
| def load_model(): | |
| global _llm | |
| if _llm is not None: | |
| return _llm | |
| repo_dir = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| allow_patterns=["*.gguf"], | |
| cache_dir=os.getenv("HF_HOME", "/app/hf_cache"), | |
| local_files_only=False, | |
| resume_download=True, | |
| ) | |
| try: | |
| model_path = _pick_gguf_file(repo_dir, GGUF_OVERRIDE or None) | |
| except Exception as e: | |
| raise RuntimeError(f"GGUF not found: {e}") | |
| try: | |
| _llm = Llama( | |
| model_path=model_path, | |
| n_ctx=CTX_LEN, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| use_mlock=False, | |
| use_mmap=True, | |
| verbose=False, | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load GGUF: {e}. Set GGUF_FILENAME to an available 8B file if needed.") | |
| return _llm | |
| def generate_response_stream(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> Generator[str, None, None]: | |
| yield "Initializing model..." | |
| llm = load_model() | |
| yield "Model loaded. Generating..." | |
| accum = "" | |
| for chunk in llm(prompt=prompt, max_tokens=max_new_tokens, temperature=temperature, stream=True): | |
| text = chunk.get("choices", [{}])[0].get("text", "") | |
| if text: | |
| accum += text | |
| yield accum | |
| def generate_response(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> str: | |
| # non-streaming fallback | |
| llm = load_model() | |
| res = llm( | |
| prompt=prompt, | |
| max_tokens=max_new_tokens, | |
| temperature=temperature, | |
| stop=["</s>", "<|eot_id|>"], | |
| ) | |
| return res.get("choices", [{}])[0].get("text", "") | |
| # -------- Robust requests session with retry -------- | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| def create_session() -> requests.Session: | |
| session = requests.Session() | |
| retry_strategy = Retry( | |
| total=3, | |
| backoff_factor=1.0, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| ) | |
| adapter = HTTPAdapter(max_retries=retry_strategy) | |
| session.mount("http://", adapter) | |
| session.mount("https://", adapter) | |
| return session | |
| http = create_session() | |
| # -------- Helpers for mock candles -------- | |
| def _create_mock_candles(symbol: str, count: int = 60) -> Dict: | |
| import time as tmod | |
| import random | |
| now = int(tmod.time()) | |
| day = 86400 | |
| t, o, h, l, c, v = [], [], [], [], [], [] | |
| price = 100.0 + (hash(symbol) % 50) | |
| for i in range(count, 0, -1): | |
| ts = now - i * day | |
| op = max(1.0, price + random.uniform(-2, 2)) | |
| hi = op + random.uniform(0, 2) | |
| lo = max(0.5, op - random.uniform(0, 2)) | |
| cl = max(0.5, lo + random.uniform(0, (hi - lo) or 1)) | |
| vol = abs(int(random.gauss(1_000_000, 250_000))) | |
| t.append(ts); o.append(op); h.append(hi); l.append(lo); c.append(cl); v.append(vol) | |
| price = cl | |
| return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "mock"} | |
| # -------- Data helpers (Finnhub with fallback to Alpha Vantage) -------- | |
| def fetch_finnhub_candles(symbol: str, resolution: str = "D", count: int = 60) -> Dict: | |
| """Try Finnhub first cycling keys; on 401/403 or exhaustion, raise to caller.""" | |
| if not FINNHUB_KEYS: | |
| raise ValueError("Missing FINNHUB_KEYS/FINNHUB_API_KEY") | |
| import time as _time | |
| end = int(__import__("time").time()) | |
| start = end - count * 86400 | |
| last_error: Exception | None = None | |
| for api_key in FINNHUB_KEYS: | |
| url = ( | |
| f"https://finnhub.io/api/v1/stock/candle?symbol={symbol}" | |
| f"&resolution={resolution}&from={start}&to={end}&token={api_key}" | |
| ) | |
| try: | |
| r = http.get(url, timeout=30) | |
| if r.status_code in (401, 403): | |
| last_error = requests.HTTPError(f"Finnhub auth error {r.status_code}") | |
| continue | |
| r.raise_for_status() | |
| data = r.json() | |
| if data.get("s") == "ok": | |
| data["source"] = "finnhub" | |
| return data | |
| last_error = RuntimeError(f"Finnhub returned status: {data.get('s')}") | |
| except Exception as e: | |
| last_error = e | |
| continue | |
| finally: | |
| _time.sleep(0.3) | |
| raise last_error or RuntimeError("Finnhub candles failed") | |
| def fetch_alpha_vantage_series_daily(symbol: str, outputsize: str = "compact", count: int = 60) -> Dict: | |
| """Fallback: Alpha Vantage TIME_SERIES_DAILY via RapidAPI, format like Finnhub candles.""" | |
| if not RAPIDAPI_KEYS: | |
| return _create_mock_candles(symbol, count) | |
| import time as _time | |
| for api_key in RAPIDAPI_KEYS: | |
| try: | |
| url = f"https://{RAPIDAPI_HOST}/query" | |
| headers = {"X-RapidAPI-Key": api_key, "X-RapidAPI-Host": RAPIDAPI_HOST} | |
| params = {"function": "TIME_SERIES_DAILY", "symbol": symbol, "outputsize": outputsize} | |
| r = http.get(url, headers=headers, params=params, timeout=30) | |
| r.raise_for_status() | |
| data = r.json() | |
| if isinstance(data, dict) and any(k in data for k in ("Note", "Error Message", "Information")): | |
| # rate limit or error; try next key | |
| continue | |
| series = data.get("Time Series (Daily)") or {} | |
| if not series: | |
| continue | |
| dates = sorted(series.keys())[-count:] | |
| import time as tmod | |
| t, o, h, l, c, v = [], [], [], [], [], [] | |
| for d in dates: | |
| row = series[d] | |
| try: | |
| op_v = float(row.get("1. open")) | |
| h_v = float(row.get("2. high")) | |
| l_v = float(row.get("3. low")) | |
| c_v = float(row.get("4. close")) | |
| v_v = float(row.get("5. volume")) | |
| ts = int(tmod.mktime(tmod.strptime(d, "%Y-%m-%d"))) | |
| except Exception: | |
| continue | |
| t.append(ts); o.append(op_v); h.append(h_v); l.append(l_v); c.append(c_v); v.append(v_v) | |
| return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "alpha_vantage"} | |
| except Exception: | |
| continue | |
| finally: | |
| _time.sleep(0.5) | |
| # If all keys failed or no data, return mock to keep UI responsive | |
| return _create_mock_candles(symbol, count) | |
| def fetch_alpha_vantage_overview(symbol: str) -> Dict: | |
| if not RAPIDAPI_KEYS: | |
| raise ValueError("Missing RAPIDAPI_KEYS/RAPIDAPI_KEY") | |
| for api_key in RAPIDAPI_KEYS: | |
| try: | |
| url = f"https://{RAPIDAPI_HOST}/query" | |
| headers = {"x-rapidapi-key": api_key, "x-rapidapi-host": RAPIDAPI_HOST} | |
| params = {"function": "OVERVIEW", "symbol": symbol} | |
| r = http.get(url, headers=headers, params=params, timeout=30) | |
| r.raise_for_status() | |
| data = r.json() | |
| if data: | |
| return data | |
| except Exception: | |
| continue | |
| raise RuntimeError("Alpha Vantage OVERVIEW failed") | |
| # -------- Prompts -------- | |
| def build_price_prediction_prompt(symbol: str, candles: Dict) -> str: | |
| context = json.dumps(candles)[:10000] | |
| source = candles.get("source", "finnhub") | |
| return ( | |
| f"You are a financial analyst agent. Analyze recent OHLCV candles for {symbol} (source: {source}) and provide a short-term price prediction. " | |
| f"Explain key drivers in bullet points and give a 1-2 sentence forecast.\n\nData JSON: {context}\n\n" | |
| ) | |
| def build_equity_research_prompt(symbol: str, overview: Dict) -> str: | |
| context = json.dumps(overview)[:10000] | |
| return ( | |
| "You are an equity research analyst. Using the fundamentals overview, write a concise equity research note including: " | |
| "Business summary, recent performance, profitability, leverage, valuation multiples, key risks, and an investment view (Buy/Hold/Sell) with rationale.\n\n" | |
| f"Ticker: {symbol}\nFundamentals JSON: {context}\n" | |
| ) | |
| # -------- Gradio UI -------- | |
| def ui_app(): | |
| with gr.Blocks(title="Fin-o1-8B Tools") as demo: | |
| gr.Markdown("""# Fin-o1-8B Tools | |
| Two tabs: Price Prediction (Finnhub with Alpha Vantage fallback) and Equity Research (Alpha Vantage via RapidAPI).""") | |
| with gr.Tab("Price Prediction"): | |
| symbol = gr.Textbox(label="Ticker (e.g., AAPL)", value="AAPL") | |
| resolution = gr.Dropdown(["D", "60", "30", "15", "5"], value="D", label="Resolution") | |
| count = gr.Slider(20, 160, value=60, step=5, label="Num candles") | |
| temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") | |
| max_new = gr.Slider(64, 768, value=384, step=16, label="Max new tokens") | |
| btn = gr.Button("Predict") | |
| out = gr.Textbox(lines=30, show_copy_button=True) | |
| def on_predict(sym, res, cnt, temperature, max_tokens): | |
| try: | |
| candles = fetch_finnhub_candles(sym, res, int(cnt)) | |
| except Exception: | |
| try: | |
| candles = fetch_alpha_vantage_series_daily(sym, outputsize="compact") | |
| except Exception as e2: | |
| yield f"Error fetching candles: {e2}" | |
| return | |
| prompt = build_price_prediction_prompt(sym, candles) | |
| for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)): | |
| yield text | |
| btn.click(on_predict, inputs=[symbol, resolution, count, temp, max_new], outputs=out, show_progress=True) | |
| with gr.Tab("Equity Research Report"): | |
| symbol2 = gr.Textbox(label="Ticker (e.g., MSFT)", value="MSFT") | |
| temp2 = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature") | |
| max_new2 = gr.Slider(64, 768, value=384, step=16, label="Max new tokens") | |
| btn2 = gr.Button("Generate Report") | |
| out2 = gr.Textbox(lines=30, show_copy_button=True) | |
| def on_report(sym, temperature, max_tokens): | |
| try: | |
| overview = fetch_alpha_vantage_overview(sym) | |
| except Exception as e: | |
| yield f"Error fetching fundamentals: {e}" | |
| return | |
| prompt = build_equity_research_prompt(sym, overview) | |
| for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)): | |
| yield text | |
| btn2.click(on_report, inputs=[symbol2, temp2, max_new2], outputs=out2, show_progress=True) | |
| # Enable queue with default settings for current Gradio version | |
| demo.queue() | |
| return demo | |
| if __name__ == "__main__": | |
| app = ui_app() | |
| app.launch(server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860"))) | |