File size: 11,822 Bytes
1d2b720
163685a
 
 
 
 
 
 
 
 
1d2b720
76d55bc
1d2b720
 
 
 
 
 
 
807243a
1d2b720
 
 
 
 
 
 
 
 
 
 
 
 
 
807243a
4cc5fe9
 
807243a
4cc5fe9
 
807243a
4cc5fe9
807243a
 
 
 
4cc5fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807243a
 
 
 
4cc5fe9
 
 
 
 
 
807243a
4cc5fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807243a
 
 
76d55bc
 
 
 
 
 
 
 
 
 
 
 
807243a
76d55bc
807243a
 
 
 
 
 
 
 
 
1d2b720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05412f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d2b720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05412f4
1d2b720
 
05412f4
1d2b720
 
 
 
05412f4
1d2b720
 
 
 
05412f4
 
 
1d2b720
 
 
05412f4
1d2b720
 
807243a
1d2b720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05412f4
 
1d2b720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76d55bc
 
1d2b720
 
 
 
 
 
 
 
 
1bfc9c7
1d2b720
 
 
 
 
 
 
 
faa18d1
 
1d2b720
faa18d1
 
1d2b720
76d55bc
1d2b720
 
 
 
 
 
1bfc9c7
1d2b720
 
 
 
 
faa18d1
 
 
 
 
1d2b720
76d55bc
1d2b720
ed1d9e5
 
1d2b720
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import os

# Force writable home and cache paths before other imports
os.environ.setdefault("HOME", "/app")
os.environ.setdefault("HF_HOME", "/app/hf_cache")
os.environ.setdefault("HF_HUB_CACHE", "/app/hf_cache")
os.environ.setdefault("XDG_CACHE_HOME", "/app/.cache")
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True)

import json
from typing import Dict, List, Generator

import gradio as gr
import requests
from dotenv import load_dotenv

load_dotenv()

# -------- Keys (multi-key support) --------
FINNHUB_KEYS_RAW = os.getenv("FINNHUB_KEYS", "")
FINNHUB_KEYS = [k.strip() for k in FINNHUB_KEYS_RAW.split("\n") if k.strip()] if FINNHUB_KEYS_RAW else []
FINNHUB_API_KEY = os.getenv("FINNHUB_API_KEY", "")
if FINNHUB_API_KEY and FINNHUB_API_KEY.strip():
	FINNHUB_KEYS = FINNHUB_KEYS or [FINNHUB_API_KEY.strip()]

RAPIDAPI_KEYS_RAW = os.getenv("RAPIDAPI_KEYS", "")
RAPIDAPI_KEYS = [k.strip() for k in RAPIDAPI_KEYS_RAW.split("\n") if k.strip()] if RAPIDAPI_KEYS_RAW else []
RAPIDAPI_KEY = os.getenv("RAPIDAPI_KEY", "")
if RAPIDAPI_KEY and RAPIDAPI_KEY.strip():
	RAPIDAPI_KEYS = RAPIDAPI_KEYS or [RAPIDAPI_KEY.strip()]

RAPIDAPI_HOST = "alpha-vantage.p.rapidapi.com"

# -------- llama.cpp GGUF model --------
MODEL_REPO = "mradermacher/Fin-o1-8B-GGUF"
GGUF_OVERRIDE = os.getenv("GGUF_FILENAME", "").strip()
N_THREADS = int(os.getenv("LLAMA_CPP_THREADS", str(os.cpu_count() or 4)))
CTX_LEN = int(os.getenv("LLAMA_CPP_CTX", "3072"))  # CPU-friendly default
N_BATCH = int(os.getenv("LLAMA_CPP_BATCH", "128"))

from huggingface_hub import snapshot_download
from llama_cpp import Llama

_llm = None


def _pick_gguf_file(root_dir: str, override: str | None) -> str:
	import glob
	if override:
		path = os.path.join(root_dir, override)
		if os.path.isfile(path) and os.path.getsize(path) > 0:
			return path
		candidates = glob.glob(os.path.join(root_dir, "**", override), recursive=True)
		for c in candidates:
			if os.path.getsize(c) > 0:
				return c
	preferred: List[str] = [
		"Fin-o1-8B.Q4_K_M.gguf",  # explicit 8B file name first
		"Q4_K_M", "Q4_K_S", "Q4_0", "Q3_K_M", "Q3_K_S", "Q3_0", "Q2_K", "Q2_0",
	]
	import glob as _glob
	ggufs = _glob.glob(os.path.join(root_dir, "**", "*.gguf"), recursive=True)
	if not ggufs:
		raise FileNotFoundError("No .gguf files found in snapshot")
	for key in preferred:
		for f in ggufs:
			if key in os.path.basename(f):
				return f
	return ggufs[0]


def load_model():
	global _llm
	if _llm is not None:
		return _llm
	repo_dir = snapshot_download(
		repo_id=MODEL_REPO,
		allow_patterns=["*.gguf"],
		cache_dir=os.getenv("HF_HOME", "/app/hf_cache"),
		local_files_only=False,
		resume_download=True,
	)
	try:
		model_path = _pick_gguf_file(repo_dir, GGUF_OVERRIDE or None)
	except Exception as e:
		raise RuntimeError(f"GGUF not found: {e}")
	try:
		_llm = Llama(
			model_path=model_path,
			n_ctx=CTX_LEN,
			n_threads=N_THREADS,
			n_batch=N_BATCH,
			use_mlock=False,
			use_mmap=True,
			verbose=False,
		)
	except Exception as e:
		raise RuntimeError(f"Failed to load GGUF: {e}. Set GGUF_FILENAME to an available 8B file if needed.")
	return _llm


def generate_response_stream(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> Generator[str, None, None]:
	yield "Initializing model..."
	llm = load_model()
	yield "Model loaded. Generating..."
	accum = ""
	for chunk in llm(prompt=prompt, max_tokens=max_new_tokens, temperature=temperature, stream=True):
		text = chunk.get("choices", [{}])[0].get("text", "")
		if text:
			accum += text
			yield accum


def generate_response(prompt: str, temperature: float = 0.2, max_new_tokens: int = 384) -> str:
	# non-streaming fallback
	llm = load_model()
	res = llm(
		prompt=prompt,
		max_tokens=max_new_tokens,
		temperature=temperature,
		stop=["</s>", "<|eot_id|>"],
	)
	return res.get("choices", [{}])[0].get("text", "")


# -------- Robust requests session with retry --------
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

def create_session() -> requests.Session:
	session = requests.Session()
	retry_strategy = Retry(
		total=3,
		backoff_factor=1.0,
		status_forcelist=[429, 500, 502, 503, 504],
	)
	adapter = HTTPAdapter(max_retries=retry_strategy)
	session.mount("http://", adapter)
	session.mount("https://", adapter)
	return session

http = create_session()

# -------- Helpers for mock candles --------

def _create_mock_candles(symbol: str, count: int = 60) -> Dict:
	import time as tmod
	import random
	now = int(tmod.time())
	day = 86400
	t, o, h, l, c, v = [], [], [], [], [], []
	price = 100.0 + (hash(symbol) % 50)
	for i in range(count, 0, -1):
		ts = now - i * day
		op = max(1.0, price + random.uniform(-2, 2))
		hi = op + random.uniform(0, 2)
		lo = max(0.5, op - random.uniform(0, 2))
		cl = max(0.5, lo + random.uniform(0, (hi - lo) or 1))
		vol = abs(int(random.gauss(1_000_000, 250_000)))
		t.append(ts); o.append(op); h.append(hi); l.append(lo); c.append(cl); v.append(vol)
		price = cl
	return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "mock"}

# -------- Data helpers (Finnhub with fallback to Alpha Vantage) --------

def fetch_finnhub_candles(symbol: str, resolution: str = "D", count: int = 60) -> Dict:
	"""Try Finnhub first cycling keys; on 401/403 or exhaustion, raise to caller."""
	if not FINNHUB_KEYS:
		raise ValueError("Missing FINNHUB_KEYS/FINNHUB_API_KEY")
	import time as _time
	end = int(__import__("time").time())
	start = end - count * 86400
	last_error: Exception | None = None
	for api_key in FINNHUB_KEYS:
		url = (
			f"https://finnhub.io/api/v1/stock/candle?symbol={symbol}"
			f"&resolution={resolution}&from={start}&to={end}&token={api_key}"
		)
		try:
			r = http.get(url, timeout=30)
			if r.status_code in (401, 403):
				last_error = requests.HTTPError(f"Finnhub auth error {r.status_code}")
				continue
			r.raise_for_status()
			data = r.json()
			if data.get("s") == "ok":
				data["source"] = "finnhub"
				return data
			last_error = RuntimeError(f"Finnhub returned status: {data.get('s')}")
		except Exception as e:
			last_error = e
			continue
		finally:
			_time.sleep(0.3)
	raise last_error or RuntimeError("Finnhub candles failed")


def fetch_alpha_vantage_series_daily(symbol: str, outputsize: str = "compact", count: int = 60) -> Dict:
	"""Fallback: Alpha Vantage TIME_SERIES_DAILY via RapidAPI, format like Finnhub candles."""
	if not RAPIDAPI_KEYS:
		return _create_mock_candles(symbol, count)
	import time as _time
	for api_key in RAPIDAPI_KEYS:
		try:
			url = f"https://{RAPIDAPI_HOST}/query"
			headers = {"X-RapidAPI-Key": api_key, "X-RapidAPI-Host": RAPIDAPI_HOST}
			params = {"function": "TIME_SERIES_DAILY", "symbol": symbol, "outputsize": outputsize}
			r = http.get(url, headers=headers, params=params, timeout=30)
			r.raise_for_status()
			data = r.json()
			if isinstance(data, dict) and any(k in data for k in ("Note", "Error Message", "Information")):
				# rate limit or error; try next key
				continue
			series = data.get("Time Series (Daily)") or {}
			if not series:
				continue
			dates = sorted(series.keys())[-count:]
			import time as tmod
			t, o, h, l, c, v = [], [], [], [], [], []
			for d in dates:
				row = series[d]
				try:
					op_v = float(row.get("1. open"))
					h_v = float(row.get("2. high"))
					l_v = float(row.get("3. low"))
					c_v = float(row.get("4. close"))
					v_v = float(row.get("5. volume"))
					ts = int(tmod.mktime(tmod.strptime(d, "%Y-%m-%d")))
				except Exception:
					continue
				t.append(ts); o.append(op_v); h.append(h_v); l.append(l_v); c.append(c_v); v.append(v_v)
			return {"s": "ok", "t": t, "o": o, "h": h, "l": l, "c": c, "v": v, "source": "alpha_vantage"}
		except Exception:
			continue
		finally:
			_time.sleep(0.5)
	# If all keys failed or no data, return mock to keep UI responsive
	return _create_mock_candles(symbol, count)


def fetch_alpha_vantage_overview(symbol: str) -> Dict:
	if not RAPIDAPI_KEYS:
		raise ValueError("Missing RAPIDAPI_KEYS/RAPIDAPI_KEY")
	for api_key in RAPIDAPI_KEYS:
		try:
			url = f"https://{RAPIDAPI_HOST}/query"
			headers = {"x-rapidapi-key": api_key, "x-rapidapi-host": RAPIDAPI_HOST}
			params = {"function": "OVERVIEW", "symbol": symbol}
			r = http.get(url, headers=headers, params=params, timeout=30)
			r.raise_for_status()
			data = r.json()
			if data:
				return data
		except Exception:
			continue
	raise RuntimeError("Alpha Vantage OVERVIEW failed")


# -------- Prompts --------

def build_price_prediction_prompt(symbol: str, candles: Dict) -> str:
	context = json.dumps(candles)[:10000]
	source = candles.get("source", "finnhub")
	return (
		f"You are a financial analyst agent. Analyze recent OHLCV candles for {symbol} (source: {source}) and provide a short-term price prediction. "
		f"Explain key drivers in bullet points and give a 1-2 sentence forecast.\n\nData JSON: {context}\n\n"
	)


def build_equity_research_prompt(symbol: str, overview: Dict) -> str:
	context = json.dumps(overview)[:10000]
	return (
		"You are an equity research analyst. Using the fundamentals overview, write a concise equity research note including: "
		"Business summary, recent performance, profitability, leverage, valuation multiples, key risks, and an investment view (Buy/Hold/Sell) with rationale.\n\n"
		f"Ticker: {symbol}\nFundamentals JSON: {context}\n"
	)


# -------- Gradio UI --------

def ui_app():
	with gr.Blocks(title="Fin-o1-8B Tools") as demo:
		gr.Markdown("""# Fin-o1-8B Tools
Two tabs: Price Prediction (Finnhub with Alpha Vantage fallback) and Equity Research (Alpha Vantage via RapidAPI).""")

		with gr.Tab("Price Prediction"):
			symbol = gr.Textbox(label="Ticker (e.g., AAPL)", value="AAPL")
			resolution = gr.Dropdown(["D", "60", "30", "15", "5"], value="D", label="Resolution")
			count = gr.Slider(20, 160, value=60, step=5, label="Num candles")
			temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
			max_new = gr.Slider(64, 768, value=384, step=16, label="Max new tokens")
			btn = gr.Button("Predict")
			out = gr.Textbox(lines=30, show_copy_button=True)

			def on_predict(sym, res, cnt, temperature, max_tokens):
				try:
					candles = fetch_finnhub_candles(sym, res, int(cnt))
				except Exception:
					try:
						candles = fetch_alpha_vantage_series_daily(sym, outputsize="compact")
					except Exception as e2:
						yield f"Error fetching candles: {e2}"
						return
				prompt = build_price_prediction_prompt(sym, candles)
				for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)):
					yield text

			btn.click(on_predict, inputs=[symbol, resolution, count, temp, max_new], outputs=out, show_progress=True)

		with gr.Tab("Equity Research Report"):
			symbol2 = gr.Textbox(label="Ticker (e.g., MSFT)", value="MSFT")
			temp2 = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
			max_new2 = gr.Slider(64, 768, value=384, step=16, label="Max new tokens")
			btn2 = gr.Button("Generate Report")
			out2 = gr.Textbox(lines=30, show_copy_button=True)

			def on_report(sym, temperature, max_tokens):
				try:
					overview = fetch_alpha_vantage_overview(sym)
				except Exception as e:
					yield f"Error fetching fundamentals: {e}"
					return
				prompt = build_equity_research_prompt(sym, overview)
				for text in generate_response_stream(prompt, temperature=temperature, max_new_tokens=int(max_tokens)):
					yield text

			btn2.click(on_report, inputs=[symbol2, temp2, max_new2], outputs=out2, show_progress=True)

		# Enable queue with default settings for current Gradio version
		demo.queue()
	return demo


if __name__ == "__main__":
	app = ui_app()
	app.launch(server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")))