import os import json from typing import Iterator, List, Tuple, Any, Optional import requests _HF_ERR = None try: from huggingface_hub import InferenceClient except Exception as e: # noqa: BLE001 _HF_ERR = e InferenceClient = None # type: ignore def get_hf_token() -> Optional[str]: """Prefer HF_TOKEN; accept HUGGINGFACEHUB_API_TOKEN as fallback.""" return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") def is_hf_api_available() -> bool: return bool(get_hf_token()) def _suggest_repo(bad_repo: str) -> str: # why: common Nemotron typo rescue if "nemotron" in bad_repo.lower(): return "NVIDIA/Nemotron-3-8B-Instruct" return "mistralai/Mistral-7B-Instruct-v0.2" class HFInferenceBackend: """ Robust HF Serverless client: - Preflight: verify repo exists (fast) to avoid long blocking errors. - Try text_generation streaming via huggingface_hub. - If provider says 'conversational' only, call HTTP conversational and chunk output. """ def __init__(self, model_name: str): token = get_hf_token() if not token: raise RuntimeError("HF_TOKEN not set") self.model = model_name.strip() self.token = token self.client = InferenceClient(model=self.model, token=token) if InferenceClient else None # ---------- Preflight ---------- def _preflight(self) -> tuple[bool, Optional[str]]: """Returns (exists, pipeline_tag_or_None).""" url = f"https://huggingface.co/api/models/{self.model}" headers = {"Authorization": f"Bearer {self.token}"} try: r = requests.get(url, headers=headers, timeout=8) if r.status_code == 404: return False, None if r.ok: data = r.json() # 'pipeline_tag' when known; otherwise None return True, data.get("pipeline_tag") return True, None except Exception: # If API unreachable, don't block the chat; proceed and catch later. return True, None # ---------- Prompt Builders ---------- def _build_tg_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str: parts = [f"[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n"] for u, a in history: if u: parts.append(f"[USER]\n{u}\n[/USER]\n") if a: parts.append(f"[ASSISTANT]\n{a}\n[/ASSISTANT]\n") parts.append(f"[USER]\n{user_msg}\n[/USER]\n[ASSISTANT]\n") return "".join(parts) def _build_conv_inputs(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> dict: past_user_inputs: List[str] = [] generated_responses: List[str] = [] for u, a in history: past_user_inputs.append(u or "") generated_responses.append(a or "") current = f"{system_prompt}\n\n{user_msg}".strip() return { "past_user_inputs": past_user_inputs, "generated_responses": generated_responses, "text": current, } # ---------- Event helper ---------- def _extract_text_from_event(self, event: Any) -> str: if isinstance(event, str): return event token = getattr(event, "token", None) if token is not None: return getattr(token, "text", "") or "" return getattr(event, "generated_text", "") or "" # ---------- Streaming text-generation ---------- def _stream_text_generation( self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int ) -> Iterator[str]: if not self.client: raise RuntimeError("huggingface_hub not installed") prompt = self._build_tg_prompt(system_prompt, history, user_msg) stream = self.client.text_generation( prompt, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=0.95, repetition_penalty=1.05, do_sample=temperature > 0, stream=True, return_full_text=False, ) buf: List[str] = [] for event in stream: delta = self._extract_text_from_event(event) if delta: buf.append(delta) yield "".join(buf) # ---------- Conversational via raw HTTP (non-stream; chunked) ---------- def _call_conversational_http( self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int ) -> Iterator[str]: url = f"https://api-inference.huggingface.co/models/{self.model}" headers = { "Authorization": f"Bearer {self.token}", "Accept": "application/json", "Content-Type": "application/json", } payload = { "inputs": self._build_conv_inputs(system_prompt, history, user_msg), "parameters": {"temperature": float(temperature), "max_new_tokens": int(max_new_tokens)}, } try: resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=40) except Exception as e: yield f"[error] network: {type(e).__name__}: {e}" return if resp.status_code == 503: yield "[info] Model is loading on the provider. Please try again shortly." return try: resp.raise_for_status() except Exception: yield f"[error] provider: HTTP {resp.status_code}: {resp.text[:500]}" return data = resp.json() text = "" if isinstance(data, dict): text = data.get("generated_text") or "" if not text: conv = data.get("conversation") or {} gen = conv.get("generated_responses") or [] if gen: text = gen[-1] or "" elif isinstance(data, list) and data: item = data[-1] if isinstance(item, dict): text = item.get("generated_text") or "" if not text: text = json.dumps(data) buf: List[str] = [] for i in range(0, len(text), 48): buf.append(text[i : i + 48]) yield "".join(buf) # ---------- Public ---------- def generate_stream( self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int, ) -> Iterator[str]: exists, pipeline_tag = self._preflight() if not exists: suggestion = _suggest_repo(self.model) yield f"[error] Model repository not found: {self.model}. Try: `{suggestion}`" return try: # If API says conversational, skip straight to conversational fallback. if (pipeline_tag or "").lower() == "conversational": yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens) return yield from self._stream_text_generation(system_prompt, history, user_msg, temperature, max_new_tokens) except Exception as e: msg = str(e).lower() if "supported task: conversational" in msg or "conversational" in msg: yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens) else: yield f"[error] text_generation: {type(e).__name__}: {e}"