Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import Iterator, List, Tuple, Any, Optional | |
| import requests | |
| _HF_ERR = None | |
| try: | |
| from huggingface_hub import InferenceClient | |
| except Exception as e: # noqa: BLE001 | |
| _HF_ERR = e | |
| InferenceClient = None # type: ignore | |
| def get_hf_token() -> Optional[str]: | |
| """Prefer HF_TOKEN; accept HUGGINGFACEHUB_API_TOKEN as fallback.""" | |
| return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| def is_hf_api_available() -> bool: | |
| return bool(get_hf_token()) | |
| def _suggest_repo(bad_repo: str) -> str: | |
| # why: common Nemotron typo rescue | |
| if "nemotron" in bad_repo.lower(): | |
| return "NVIDIA/Nemotron-3-8B-Instruct" | |
| return "mistralai/Mistral-7B-Instruct-v0.2" | |
| class HFInferenceBackend: | |
| """ | |
| Robust HF Serverless client: | |
| - Preflight: verify repo exists (fast) to avoid long blocking errors. | |
| - Try text_generation streaming via huggingface_hub. | |
| - If provider says 'conversational' only, call HTTP conversational and chunk output. | |
| """ | |
| def __init__(self, model_name: str): | |
| token = get_hf_token() | |
| if not token: | |
| raise RuntimeError("HF_TOKEN not set") | |
| self.model = model_name.strip() | |
| self.token = token | |
| self.client = InferenceClient(model=self.model, token=token) if InferenceClient else None | |
| # ---------- Preflight ---------- | |
| def _preflight(self) -> tuple[bool, Optional[str]]: | |
| """Returns (exists, pipeline_tag_or_None).""" | |
| url = f"https://huggingface.co/api/models/{self.model}" | |
| headers = {"Authorization": f"Bearer {self.token}"} | |
| try: | |
| r = requests.get(url, headers=headers, timeout=8) | |
| if r.status_code == 404: | |
| return False, None | |
| if r.ok: | |
| data = r.json() | |
| # 'pipeline_tag' when known; otherwise None | |
| return True, data.get("pipeline_tag") | |
| return True, None | |
| except Exception: | |
| # If API unreachable, don't block the chat; proceed and catch later. | |
| return True, None | |
| # ---------- Prompt Builders ---------- | |
| def _build_tg_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str: | |
| parts = [f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n"] | |
| for u, a in history: | |
| if u: | |
| parts.append(f"[USER]\n{u}\n[/USER]\n") | |
| if a: | |
| parts.append(f"[ASSISTANT]\n{a}\n[/ASSISTANT]\n") | |
| parts.append(f"[USER]\n{user_msg}\n[/USER]\n[ASSISTANT]\n") | |
| return "".join(parts) | |
| def _build_conv_inputs(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> dict: | |
| past_user_inputs: List[str] = [] | |
| generated_responses: List[str] = [] | |
| for u, a in history: | |
| past_user_inputs.append(u or "") | |
| generated_responses.append(a or "") | |
| current = f"{system_prompt}\n\n{user_msg}".strip() | |
| return { | |
| "past_user_inputs": past_user_inputs, | |
| "generated_responses": generated_responses, | |
| "text": current, | |
| } | |
| # ---------- Event helper ---------- | |
| def _extract_text_from_event(self, event: Any) -> str: | |
| if isinstance(event, str): | |
| return event | |
| token = getattr(event, "token", None) | |
| if token is not None: | |
| return getattr(token, "text", "") or "" | |
| return getattr(event, "generated_text", "") or "" | |
| # ---------- Streaming text-generation ---------- | |
| def _stream_text_generation( | |
| self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int | |
| ) -> Iterator[str]: | |
| if not self.client: | |
| raise RuntimeError("huggingface_hub not installed") | |
| prompt = self._build_tg_prompt(system_prompt, history, user_msg) | |
| stream = self.client.text_generation( | |
| prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=0.95, | |
| repetition_penalty=1.05, | |
| do_sample=temperature > 0, | |
| stream=True, | |
| return_full_text=False, | |
| ) | |
| buf: List[str] = [] | |
| for event in stream: | |
| delta = self._extract_text_from_event(event) | |
| if delta: | |
| buf.append(delta) | |
| yield "".join(buf) | |
| # ---------- Conversational via raw HTTP (non-stream; chunked) ---------- | |
| def _call_conversational_http( | |
| self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int | |
| ) -> Iterator[str]: | |
| url = f"https://api-inference.huggingface.co/models/{self.model}" | |
| headers = { | |
| "Authorization": f"Bearer {self.token}", | |
| "Accept": "application/json", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "inputs": self._build_conv_inputs(system_prompt, history, user_msg), | |
| "parameters": {"temperature": float(temperature), "max_new_tokens": int(max_new_tokens)}, | |
| } | |
| try: | |
| resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=40) | |
| except Exception as e: | |
| yield f"[error] network: {type(e).__name__}: {e}" | |
| return | |
| if resp.status_code == 503: | |
| yield "[info] Model is loading on the provider. Please try again shortly." | |
| return | |
| try: | |
| resp.raise_for_status() | |
| except Exception: | |
| yield f"[error] provider: HTTP {resp.status_code}: {resp.text[:500]}" | |
| return | |
| data = resp.json() | |
| text = "" | |
| if isinstance(data, dict): | |
| text = data.get("generated_text") or "" | |
| if not text: | |
| conv = data.get("conversation") or {} | |
| gen = conv.get("generated_responses") or [] | |
| if gen: | |
| text = gen[-1] or "" | |
| elif isinstance(data, list) and data: | |
| item = data[-1] | |
| if isinstance(item, dict): | |
| text = item.get("generated_text") or "" | |
| if not text: | |
| text = json.dumps(data) | |
| buf: List[str] = [] | |
| for i in range(0, len(text), 48): | |
| buf.append(text[i : i + 48]) | |
| yield "".join(buf) | |
| # ---------- Public ---------- | |
| def generate_stream( | |
| self, | |
| system_prompt: str, | |
| history: List[Tuple[str, str]], | |
| user_msg: str, | |
| temperature: float, | |
| max_new_tokens: int, | |
| ) -> Iterator[str]: | |
| exists, pipeline_tag = self._preflight() | |
| if not exists: | |
| suggestion = _suggest_repo(self.model) | |
| yield f"[error] Model repository not found: {self.model}. Try: `{suggestion}`" | |
| return | |
| try: | |
| # If API says conversational, skip straight to conversational fallback. | |
| if (pipeline_tag or "").lower() == "conversational": | |
| yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens) | |
| return | |
| yield from self._stream_text_generation(system_prompt, history, user_msg, temperature, max_new_tokens) | |
| except Exception as e: | |
| msg = str(e).lower() | |
| if "supported task: conversational" in msg or "conversational" in msg: | |
| yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens) | |
| else: | |
| yield f"[error] text_generation: {type(e).__name__}: {e}" | |