soupstick's picture
Update api/inference.py
2df7beb verified
import os, json, requests
from typing import Any, Dict, Optional, List
from openai import OpenAI
from openai.types.chat import ChatCompletion
from openai._exceptions import NotFoundError, BadRequestError, AuthenticationError
# ---- Fireworks native (defaults match your .env) ----
FIREWORKS_MODEL = os.getenv("FIREWORKS_MODEL_ID", "accounts/fireworks/models/gpt-oss-20b")
FIREWORKS_URL_CHAT = os.getenv("FIREWORKS_URL_CHAT", "https://api.fireworks.ai/inference/v1/chat/completions")
FIREWORKS_URL_COMP = os.getenv("FIREWORKS_URL_COMP", "https://api.fireworks.ai/inference/v1/completions")
def _norm_content(val) -> Optional[str]:
if val is None: return None
if isinstance(val, str): return val.strip()
if isinstance(val, list): # sometimes providers return list segments
parts=[]
for seg in val:
if isinstance(seg, str): parts.append(seg)
elif isinstance(seg, dict):
txt = seg.get("text") or seg.get("content") or seg.get("value")
if isinstance(txt, str): parts.append(txt)
return ("\n".join(p for p in parts if p.strip())) or None
return None
def _extract_text_from_choices(data: Dict[str, Any]) -> Optional[str]:
"""
Robustly extract assistant text from Fireworks/OpenAI-style responses.
Tries (in order):
- choices[0].message.content (str or list of segments)
- choices[0].text (completions-style)
- choices[*] concatenate any available text fields (best-effort)
"""
choices: List[Dict[str, Any]] = data.get("choices") or []
if not choices:
return None
# Helper to normalize content which may be str or list of dicts/segments
def _norm_content(val: Any) -> Optional[str]:
if val is None:
return None
if isinstance(val, str):
return val.strip()
if isinstance(val, list):
parts: List[str] = []
for seg in val:
if isinstance(seg, str):
parts.append(seg)
elif isinstance(seg, dict):
# Common keys used by various providers
txt = seg.get("text") or seg.get("content") or seg.get("value")
if isinstance(txt, str):
parts.append(txt)
return ("\n".join(p for p in parts if p.strip())) or None
# Unknown structure
return None
# 1) chat style
msg = choices[0].get("message")
if isinstance(msg, dict):
content = _norm_content(msg.get("content"))
if content:
return content
# 2) completions style
text = choices[0].get("text")
if isinstance(text, str) and text.strip():
return text.strip()
# 3) best-effort scan
buf: List[str] = []
for ch in choices:
if isinstance(ch, dict):
# message.content path
if isinstance(ch.get("message"), dict):
c = _norm_content(ch["message"].get("content"))
if c:
buf.append(c)
# text path
t = ch.get("text")
if isinstance(t, str) and t.strip():
buf.append(t.strip())
# delta.content (streaming shards)
delta = ch.get("delta")
if isinstance(delta, dict):
dc = _norm_content(delta.get("content"))
if dc:
buf.append(dc)
if buf:
return "\n".join(buf).strip()
return None
def _fireworks_client() -> OpenAI:
key = os.getenv("FIREWORKS_API_KEY")
if not key:
raise RuntimeError("Set FIREWORKS_API_KEY for Fireworks provider")
return OpenAI(api_key=key, base_url="https://api.fireworks.ai/inference/v1")
def _post_fireworks_chat(model: str, prompt: str, max_tokens: int, temperature: float, key: str) -> requests.Response:
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
body = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": int(max_tokens),
"temperature": float(temperature),
"stream": False,
}
return requests.post(FIREWORKS_URL_CHAT, headers=headers, json=body, timeout=60)
def _post_fireworks_completions(model: str, prompt: str, max_tokens: int, temperature: float, key: str) -> requests.Response:
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
body = {
"model": model,
"prompt": prompt,
"max_tokens": int(max_tokens),
"temperature": float(temperature),
"stream": False,
}
return requests.post(FIREWORKS_URL_COMP, headers=headers, json=body, timeout=60)
def _call_llm(
prompt: str,
max_tokens: int = 512,
temperature: float = 0.2,
model: Optional[str] = None,
provider: Optional[str] = None,
) -> str:
effective_provider = (provider or os.getenv("LLM_PROVIDER") or "fireworks").lower()
mdl = model or FIREWORKS_MODEL
if effective_provider == "fireworks":
client = _fireworks_client()
# 1) try chat
try:
resp: ChatCompletion = client.chat.completions.create(
model=mdl,
messages=[{"role":"user","content":prompt}],
max_tokens=int(max_tokens),
temperature=float(temperature),
)
txt = _norm_content(resp.choices[0].message.content)
if txt: return txt
except (BadRequestError, NotFoundError) as e:
# fall through to completions
pass
except AuthenticationError as e:
raise RuntimeError(f"Fireworks auth error: {e}") from e
except Exception as e:
# if chat endpoint hiccups, try completions too
pass
# 2) completions fallback
try:
comp = client.completions.create(
model=mdl,
prompt=prompt,
max_tokens=int(max_tokens),
temperature=float(temperature),
)
# OpenAI SDK for completions returns .choices[0].text
text = getattr(comp.choices[0], "text", None)
if isinstance(text, str) and text.strip():
return text.strip()
# last resort: jsonify & parse
data = comp.model_dump()
txt = _extract_text_from_choices(data)
if txt: return txt
raise RuntimeError(f"Fireworks completions returned no text: {json.dumps(data)[:800]}")
except Exception as e:
raise RuntimeError(f"Fireworks completions error: {e} (model='{mdl}')")
elif effective_provider in ("hf_router", "hf"):
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not token:
raise RuntimeError("Set HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN) for hf_router provider")
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
body = {
"model": model or HF_ROUTER_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": int(max_tokens),
"temperature": float(temperature),
"stream": False,
}
resp = requests.post(HF_ROUTER_URL, headers=headers, json=body, timeout=60)
if resp.status_code != 200:
raise RuntimeError(f"HF Router error {resp.status_code}: {resp.text}")
data = resp.json()
txt = _extract_text_from_choices(data)
if txt: return txt
raise RuntimeError(f"HF Router returned 200 but no text: {json.dumps(data)[:800]}")
else:
raise RuntimeError(f"Unsupported LLM_PROVIDER={effective_provider}")
# Back-compat shim for old callers
def _router_call(prompt: str, max_tokens: int = 512, temperature: float = 0.2, model: Optional[str] = None) -> str:
return _call_llm(prompt, max_tokens=max_tokens, temperature=temperature, model=model)