Chatbot / backend_hf_api.py
dzezzefezfz's picture
Update backend_hf_api.py
467f028 verified
import os
import json
from typing import Iterator, List, Tuple, Any, Optional
import requests
_HF_ERR = None
try:
from huggingface_hub import InferenceClient
except Exception as e: # noqa: BLE001
_HF_ERR = e
InferenceClient = None # type: ignore
def get_hf_token() -> Optional[str]:
"""Prefer HF_TOKEN; accept HUGGINGFACEHUB_API_TOKEN as fallback."""
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
def is_hf_api_available() -> bool:
return bool(get_hf_token())
def _suggest_repo(bad_repo: str) -> str:
# why: common Nemotron typo rescue
if "nemotron" in bad_repo.lower():
return "NVIDIA/Nemotron-3-8B-Instruct"
return "mistralai/Mistral-7B-Instruct-v0.2"
class HFInferenceBackend:
"""
Robust HF Serverless client:
- Preflight: verify repo exists (fast) to avoid long blocking errors.
- Try text_generation streaming via huggingface_hub.
- If provider says 'conversational' only, call HTTP conversational and chunk output.
"""
def __init__(self, model_name: str):
token = get_hf_token()
if not token:
raise RuntimeError("HF_TOKEN not set")
self.model = model_name.strip()
self.token = token
self.client = InferenceClient(model=self.model, token=token) if InferenceClient else None
# ---------- Preflight ----------
def _preflight(self) -> tuple[bool, Optional[str]]:
"""Returns (exists, pipeline_tag_or_None)."""
url = f"https://huggingface.co/api/models/{self.model}"
headers = {"Authorization": f"Bearer {self.token}"}
try:
r = requests.get(url, headers=headers, timeout=8)
if r.status_code == 404:
return False, None
if r.ok:
data = r.json()
# 'pipeline_tag' when known; otherwise None
return True, data.get("pipeline_tag")
return True, None
except Exception:
# If API unreachable, don't block the chat; proceed and catch later.
return True, None
# ---------- Prompt Builders ----------
def _build_tg_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str:
parts = [f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n"]
for u, a in history:
if u:
parts.append(f"[USER]\n{u}\n[/USER]\n")
if a:
parts.append(f"[ASSISTANT]\n{a}\n[/ASSISTANT]\n")
parts.append(f"[USER]\n{user_msg}\n[/USER]\n[ASSISTANT]\n")
return "".join(parts)
def _build_conv_inputs(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> dict:
past_user_inputs: List[str] = []
generated_responses: List[str] = []
for u, a in history:
past_user_inputs.append(u or "")
generated_responses.append(a or "")
current = f"{system_prompt}\n\n{user_msg}".strip()
return {
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
"text": current,
}
# ---------- Event helper ----------
def _extract_text_from_event(self, event: Any) -> str:
if isinstance(event, str):
return event
token = getattr(event, "token", None)
if token is not None:
return getattr(token, "text", "") or ""
return getattr(event, "generated_text", "") or ""
# ---------- Streaming text-generation ----------
def _stream_text_generation(
self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int
) -> Iterator[str]:
if not self.client:
raise RuntimeError("huggingface_hub not installed")
prompt = self._build_tg_prompt(system_prompt, history, user_msg)
stream = self.client.text_generation(
prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=0.95,
repetition_penalty=1.05,
do_sample=temperature > 0,
stream=True,
return_full_text=False,
)
buf: List[str] = []
for event in stream:
delta = self._extract_text_from_event(event)
if delta:
buf.append(delta)
yield "".join(buf)
# ---------- Conversational via raw HTTP (non-stream; chunked) ----------
def _call_conversational_http(
self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int
) -> Iterator[str]:
url = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
"Authorization": f"Bearer {self.token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = {
"inputs": self._build_conv_inputs(system_prompt, history, user_msg),
"parameters": {"temperature": float(temperature), "max_new_tokens": int(max_new_tokens)},
}
try:
resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=40)
except Exception as e:
yield f"[error] network: {type(e).__name__}: {e}"
return
if resp.status_code == 503:
yield "[info] Model is loading on the provider. Please try again shortly."
return
try:
resp.raise_for_status()
except Exception:
yield f"[error] provider: HTTP {resp.status_code}: {resp.text[:500]}"
return
data = resp.json()
text = ""
if isinstance(data, dict):
text = data.get("generated_text") or ""
if not text:
conv = data.get("conversation") or {}
gen = conv.get("generated_responses") or []
if gen:
text = gen[-1] or ""
elif isinstance(data, list) and data:
item = data[-1]
if isinstance(item, dict):
text = item.get("generated_text") or ""
if not text:
text = json.dumps(data)
buf: List[str] = []
for i in range(0, len(text), 48):
buf.append(text[i : i + 48])
yield "".join(buf)
# ---------- Public ----------
def generate_stream(
self,
system_prompt: str,
history: List[Tuple[str, str]],
user_msg: str,
temperature: float,
max_new_tokens: int,
) -> Iterator[str]:
exists, pipeline_tag = self._preflight()
if not exists:
suggestion = _suggest_repo(self.model)
yield f"[error] Model repository not found: {self.model}. Try: `{suggestion}`"
return
try:
# If API says conversational, skip straight to conversational fallback.
if (pipeline_tag or "").lower() == "conversational":
yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens)
return
yield from self._stream_text_generation(system_prompt, history, user_msg, temperature, max_new_tokens)
except Exception as e:
msg = str(e).lower()
if "supported task: conversational" in msg or "conversational" in msg:
yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens)
else:
yield f"[error] text_generation: {type(e).__name__}: {e}"