tch-ai / app.py
Kasher13's picture
fix(ai): add 26B as Gemini fallback before llama-cpp
51503b5 verified
"""
TwoCentsHustler AI Space.
Primary backend: Google Gemini API (Gemma 4 31B via AI Studio).
- Gemma 4 31B is a thinking model β€” response has a thought part + final answer part.
- 500 errors are transient (Google-side); retried with exponential backoff (3x).
- Falls through to llama-cpp on persistent failure.
Fallback backend: local llama-cpp (gemma-4-E4B GGUF) β€” slow (20-40s).
POST /api/ai { "operation": "analyze"|"summarize"|"cluster", "payload": {...} }
"""
import os
import json
import re
import time
import gradio as gr
from fastapi import Request
from fastapi.responses import JSONResponse
# ── Gemini setup (optional β€” only when GEMINI_API_KEY is set) ─────────────────
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemma-4-31b-it")
GEMINI_FALLBACK_MODEL = os.environ.get("GEMINI_FALLBACK_MODEL", "gemma-4-26b-a4b-it")
GEMINI_RETRIES = 3
_gemini_primary = None
_gemini_fallback = None
if GEMINI_API_KEY:
try:
import google.generativeai as genai
genai.configure(api_key=GEMINI_API_KEY)
_gemini_primary = genai.GenerativeModel(GEMINI_MODEL)
_gemini_fallback = genai.GenerativeModel(GEMINI_FALLBACK_MODEL)
print(f"Gemini backend ready: primary={GEMINI_MODEL}, fallback={GEMINI_FALLBACK_MODEL}")
except Exception as e:
print(f"Gemini init failed ({e}), will use llama-cpp")
_gemini_primary = _gemini_fallback = None
# ── llama-cpp setup (always loaded as fallback) ───────────────────────────────
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
REPO_ID = os.environ.get("GGUF_REPO", "unsloth/gemma-4-E4B-it-GGUF")
GGUF_FILE = os.environ.get("GGUF_FILE", "gemma-4-E4B-it-Q4_K_M.gguf")
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
N_CTX = int(os.environ.get("N_CTX", "4096"))
N_THREADS = int(os.environ.get("N_THREADS", "2"))
print(f"Downloading {REPO_ID}/{GGUF_FILE} …")
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=GGUF_FILE,
token=HF_TOKEN,
)
print(f"Loading llama-cpp model from {model_path} …")
_llm = Llama(
model_path=model_path,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_gpu_layers=0,
verbose=False,
)
print("llama-cpp model ready.")
# ── Inference ─────────────────────────────────────────────────────────────────
def _call_gemini_model(model, model_name: str, prompt: str) -> str:
"""Call one Gemini model with exponential backoff on 5xx errors.
Both Gemma 4 models are thinking models β€” response.candidates[0].content.parts
contains a thought part (thought=True) then the final answer (thought=False).
Extract only the non-thought text so JSON extraction matches the answer, not
the reasoning chain.
"""
from google.generativeai.types import GenerationConfig
import google.api_core.exceptions as gapi_exc
for attempt in range(GEMINI_RETRIES):
try:
response = model.generate_content(
prompt,
generation_config=GenerationConfig(temperature=0.0),
)
parts = response.candidates[0].content.parts
answer_text = "".join(
p.text for p in parts if not getattr(p, "thought", False)
)
return answer_text or response.text
except Exception as e:
is_5xx = (
isinstance(e, gapi_exc.InternalServerError)
or isinstance(e, gapi_exc.ServiceUnavailable)
or "500" in str(e)
or "503" in str(e)
)
if is_5xx and attempt < GEMINI_RETRIES - 1:
wait = 2 ** attempt
print(f"{model_name} 5xx (attempt {attempt+1}/{GEMINI_RETRIES}), retry in {wait}s: {e}")
time.sleep(wait)
else:
raise
def _generate_gemini(prompt: str) -> str:
"""Try primary (31B), fall back to Gemini fallback (26B) on persistent 5xx."""
try:
return _call_gemini_model(_gemini_primary, GEMINI_MODEL, prompt)
except Exception as e:
print(f"{GEMINI_MODEL} exhausted retries ({e}), trying {GEMINI_FALLBACK_MODEL}")
return _call_gemini_model(_gemini_fallback, GEMINI_FALLBACK_MODEL, prompt)
def _generate_llama(prompt: str) -> str:
result = _llm.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=1024,
temperature=0.0,
response_format={"type": "json_object"},
)
return result["choices"][0]["message"]["content"]
def _generate(prompt: str) -> str:
"""Try Gemini chain (31B β†’ 26B) first; fall back to llama-cpp on total failure."""
if _gemini_primary is not None:
try:
return _generate_gemini(prompt)
except Exception as e:
print(f"Gemini chain exhausted ({e}), falling back to llama-cpp")
return _generate_llama(prompt)
# ── Prompt builders (mirrors lib/ai/prompts.ts) ───────────────────────────────
_ANALYSIS_SCHEMA = """\
Respond ONLY with valid JSON:
{
"sentiment": "positive"|"negative"|"neutral"|"mixed",
"sentimentScore": integer -100..100,
"marketRelevance": integer 0..100,
"impactReasoning": string <=200 chars,
"impactOverride": "HIGH"|"MEDIUM"|"LOW",
"entities": [{"entityType":"ticker"|"company"|"person"|"place"|"commodity"|"currency"|"central_bank","value":string,"normalized":string|null,"confidence":integer 0..100}]
}
Impact classification criteria:
HIGH β€” market-moving: central bank decisions, major earnings beats/misses, geopolitical crisis,
large corporate bankruptcies/acquisitions, systemic financial events
MEDIUM β€” earnings reports, regulatory updates, M&A rumours, sector policy changes, key executive moves
LOW β€” routine company updates, minor personnel changes, informational/educational content,
press releases with no immediate market consequence"""
def _build_analysis_prompt(p: dict) -> str:
lines = [
"You are a financial news analyst. Analyze the article and output structured JSON.",
"",
f"ARTICLE CATEGORY: {p.get('category', 'unknown')}",
f"HEADLINE: {p.get('headline', '')}",
]
if p.get("summary"):
lines.append(f"SUMMARY: {p['summary']}")
lines += [
"",
"Extract: market sentiment, market relevance (0-100), impact level, impact reasoning,",
"and all named entities. Prefer normalized ticker symbols (e.g. 'AAPL') in normalized field.",
"impactOverride is REQUIRED β€” always classify as HIGH, MEDIUM, or LOW.",
"",
_ANALYSIS_SCHEMA,
]
return "\n".join(lines)
def _build_summary_prompt(p: dict) -> str:
items = p.get("items", [])
max_bullets = p.get("maxBullets", 6)
scope = p.get("scope", "daily")
article_lines = "\n".join(
f"{i+1}. [{it.get('category','?')}|{it.get('impact','?')}|{it.get('publishedAt','')}] "
f"{it.get('headline','')}"
+ (f" β€” {it.get('summary','')[:200]}" if it.get("summary") else "")
for i, it in enumerate(items[:60])
)
return "\n".join([
f"You are writing a {scope} market brief for active traders.",
f"Synthesize the following {len(items)} articles into a concise brief.",
"",
article_lines,
"",
f'Output JSON: {{"content": string (markdown <=400 words), "highlights": string[] (<={max_bullets} bullets each <=120 chars)}}',
])
def _fmt_entity(e: dict) -> str:
return f"{e.get('entityType','?')}:{e.get('normalized') or e.get('value','?')}"
def _build_cluster_prompt(p: dict) -> str:
items = p.get("items", [])
article_lines = "\n".join(
f"{i+1}. [id:{it.get('id','?')}|{it.get('category','?')}] {it.get('headline','')} "
f"(entities: {', '.join(_fmt_entity(e) for e in it.get('entities', [])) or 'none'})"
for i, it in enumerate(items[:40])
)
return "\n".join([
"Cluster these financial news articles into market events.",
"Group into 0..N events where each is a coherent story thread.",
"Skip articles that don't belong to any multi-article event.",
"",
article_lines,
"",
'Output JSON: [{"title":string<=80,"description":string|null,"category":"MACRO"|"STOCKS"|"CRYPTO"|"FOREX"|"COMMODITIES","itemIds":string[]>=2,"keyEntities":string[],"relevanceScores":{itemId:0..100}}]',
])
# ── JSON extractor ────────────────────────────────────────────────────────────
def _extract_json(text: str):
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```$", "", text)
try:
return json.loads(text)
except json.JSONDecodeError:
pass
candidates = [(text.find("{"), "}"), (text.find("["), "]")]
candidates = [(i, c) for i, c in candidates if i != -1]
if candidates:
first = min(candidates, key=lambda x: x[0])[0]
last = max(text.rfind("}"), text.rfind("]"))
if last > first:
return json.loads(text[first : last + 1])
raise ValueError(f"No JSON found: {text[:200]}")
# ── Dispatcher ────────────────────────────────────────────────────────────────
def _dispatch(operation: str, payload: dict):
if operation == "analyze":
prompt = _build_analysis_prompt(payload)
elif operation == "summarize":
prompt = _build_summary_prompt(payload)
elif operation == "cluster":
prompt = _build_cluster_prompt(payload)
else:
raise ValueError(f"Unknown operation: {operation!r}")
return _extract_json(_generate(prompt))
# ── Gradio UI ─────────────────────────────────────────────────────────────────
_backend_label = f"Gemini ({GEMINI_MODEL} β†’ {GEMINI_FALLBACK_MODEL})" if _gemini_primary else f"llama-cpp ({GGUF_FILE})"
with gr.Blocks(title="TwoCentsHustler AI") as demo:
gr.Markdown(
f"## TwoCentsHustler AI\n"
f"Primary: `{_backend_label}` Β· Fallback: `{GGUF_FILE}` (llama-cpp)"
)
with gr.Row():
op = gr.Dropdown(["analyze", "summarize", "cluster"], value="analyze", label="Operation")
payload_box = gr.Code(
value='{"headline":"Fed raises rates by 25bps","category":"MACRO"}',
language="json",
label="Payload",
)
out = gr.JSON(label="Result")
btn = gr.Button("Run")
def _gradio_run(operation: str, payload_str: str):
try:
return _dispatch(operation, json.loads(payload_str or "{}"))
except Exception as e:
return {"error": str(e)}
btn.click(_gradio_run, inputs=[op, payload_box], outputs=out)
# ── REST route ────────────────────────────────────────────────────────────────
app = demo.app
@app.post("/api/ai")
async def ai_endpoint(request: Request):
try:
body = await request.json()
result = _dispatch(body.get("operation", ""), body.get("payload", {}))
return JSONResponse(content=result)
except ValueError as exc:
return JSONResponse(content={"error": str(exc)}, status_code=400)
except Exception as exc:
return JSONResponse(content={"error": str(exc)}, status_code=500)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)