Spaces:
Sleeping
Sleeping
| import os, json, requests | |
| from typing import Any, Dict, Optional, List | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletion | |
| from openai._exceptions import NotFoundError, BadRequestError, AuthenticationError | |
| # ---- Fireworks native (defaults match your .env) ---- | |
| FIREWORKS_MODEL = os.getenv("FIREWORKS_MODEL_ID", "accounts/fireworks/models/gpt-oss-20b") | |
| FIREWORKS_URL_CHAT = os.getenv("FIREWORKS_URL_CHAT", "https://api.fireworks.ai/inference/v1/chat/completions") | |
| FIREWORKS_URL_COMP = os.getenv("FIREWORKS_URL_COMP", "https://api.fireworks.ai/inference/v1/completions") | |
| def _norm_content(val) -> Optional[str]: | |
| if val is None: return None | |
| if isinstance(val, str): return val.strip() | |
| if isinstance(val, list): # sometimes providers return list segments | |
| parts=[] | |
| for seg in val: | |
| if isinstance(seg, str): parts.append(seg) | |
| elif isinstance(seg, dict): | |
| txt = seg.get("text") or seg.get("content") or seg.get("value") | |
| if isinstance(txt, str): parts.append(txt) | |
| return ("\n".join(p for p in parts if p.strip())) or None | |
| return None | |
| def _extract_text_from_choices(data: Dict[str, Any]) -> Optional[str]: | |
| """ | |
| Robustly extract assistant text from Fireworks/OpenAI-style responses. | |
| Tries (in order): | |
| - choices[0].message.content (str or list of segments) | |
| - choices[0].text (completions-style) | |
| - choices[*] concatenate any available text fields (best-effort) | |
| """ | |
| choices: List[Dict[str, Any]] = data.get("choices") or [] | |
| if not choices: | |
| return None | |
| # Helper to normalize content which may be str or list of dicts/segments | |
| def _norm_content(val: Any) -> Optional[str]: | |
| if val is None: | |
| return None | |
| if isinstance(val, str): | |
| return val.strip() | |
| if isinstance(val, list): | |
| parts: List[str] = [] | |
| for seg in val: | |
| if isinstance(seg, str): | |
| parts.append(seg) | |
| elif isinstance(seg, dict): | |
| # Common keys used by various providers | |
| txt = seg.get("text") or seg.get("content") or seg.get("value") | |
| if isinstance(txt, str): | |
| parts.append(txt) | |
| return ("\n".join(p for p in parts if p.strip())) or None | |
| # Unknown structure | |
| return None | |
| # 1) chat style | |
| msg = choices[0].get("message") | |
| if isinstance(msg, dict): | |
| content = _norm_content(msg.get("content")) | |
| if content: | |
| return content | |
| # 2) completions style | |
| text = choices[0].get("text") | |
| if isinstance(text, str) and text.strip(): | |
| return text.strip() | |
| # 3) best-effort scan | |
| buf: List[str] = [] | |
| for ch in choices: | |
| if isinstance(ch, dict): | |
| # message.content path | |
| if isinstance(ch.get("message"), dict): | |
| c = _norm_content(ch["message"].get("content")) | |
| if c: | |
| buf.append(c) | |
| # text path | |
| t = ch.get("text") | |
| if isinstance(t, str) and t.strip(): | |
| buf.append(t.strip()) | |
| # delta.content (streaming shards) | |
| delta = ch.get("delta") | |
| if isinstance(delta, dict): | |
| dc = _norm_content(delta.get("content")) | |
| if dc: | |
| buf.append(dc) | |
| if buf: | |
| return "\n".join(buf).strip() | |
| return None | |
| def _fireworks_client() -> OpenAI: | |
| key = os.getenv("FIREWORKS_API_KEY") | |
| if not key: | |
| raise RuntimeError("Set FIREWORKS_API_KEY for Fireworks provider") | |
| return OpenAI(api_key=key, base_url="https://api.fireworks.ai/inference/v1") | |
| def _post_fireworks_chat(model: str, prompt: str, max_tokens: int, temperature: float, key: str) -> requests.Response: | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": int(max_tokens), | |
| "temperature": float(temperature), | |
| "stream": False, | |
| } | |
| return requests.post(FIREWORKS_URL_CHAT, headers=headers, json=body, timeout=60) | |
| def _post_fireworks_completions(model: str, prompt: str, max_tokens: int, temperature: float, key: str) -> requests.Response: | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "prompt": prompt, | |
| "max_tokens": int(max_tokens), | |
| "temperature": float(temperature), | |
| "stream": False, | |
| } | |
| return requests.post(FIREWORKS_URL_COMP, headers=headers, json=body, timeout=60) | |
| def _call_llm( | |
| prompt: str, | |
| max_tokens: int = 512, | |
| temperature: float = 0.2, | |
| model: Optional[str] = None, | |
| provider: Optional[str] = None, | |
| ) -> str: | |
| effective_provider = (provider or os.getenv("LLM_PROVIDER") or "fireworks").lower() | |
| mdl = model or FIREWORKS_MODEL | |
| if effective_provider == "fireworks": | |
| client = _fireworks_client() | |
| # 1) try chat | |
| try: | |
| resp: ChatCompletion = client.chat.completions.create( | |
| model=mdl, | |
| messages=[{"role":"user","content":prompt}], | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| ) | |
| txt = _norm_content(resp.choices[0].message.content) | |
| if txt: return txt | |
| except (BadRequestError, NotFoundError) as e: | |
| # fall through to completions | |
| pass | |
| except AuthenticationError as e: | |
| raise RuntimeError(f"Fireworks auth error: {e}") from e | |
| except Exception as e: | |
| # if chat endpoint hiccups, try completions too | |
| pass | |
| # 2) completions fallback | |
| try: | |
| comp = client.completions.create( | |
| model=mdl, | |
| prompt=prompt, | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| ) | |
| # OpenAI SDK for completions returns .choices[0].text | |
| text = getattr(comp.choices[0], "text", None) | |
| if isinstance(text, str) and text.strip(): | |
| return text.strip() | |
| # last resort: jsonify & parse | |
| data = comp.model_dump() | |
| txt = _extract_text_from_choices(data) | |
| if txt: return txt | |
| raise RuntimeError(f"Fireworks completions returned no text: {json.dumps(data)[:800]}") | |
| except Exception as e: | |
| raise RuntimeError(f"Fireworks completions error: {e} (model='{mdl}')") | |
| elif effective_provider in ("hf_router", "hf"): | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not token: | |
| raise RuntimeError("Set HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN) for hf_router provider") | |
| headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model or HF_ROUTER_MODEL, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": int(max_tokens), | |
| "temperature": float(temperature), | |
| "stream": False, | |
| } | |
| resp = requests.post(HF_ROUTER_URL, headers=headers, json=body, timeout=60) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f"HF Router error {resp.status_code}: {resp.text}") | |
| data = resp.json() | |
| txt = _extract_text_from_choices(data) | |
| if txt: return txt | |
| raise RuntimeError(f"HF Router returned 200 but no text: {json.dumps(data)[:800]}") | |
| else: | |
| raise RuntimeError(f"Unsupported LLM_PROVIDER={effective_provider}") | |
| # Back-compat shim for old callers | |
| def _router_call(prompt: str, max_tokens: int = 512, temperature: float = 0.2, model: Optional[str] = None) -> str: | |
| return _call_llm(prompt, max_tokens=max_tokens, temperature=temperature, model=model) |