dzezzefezfz commited on
Commit
467f028
·
verified ·
1 Parent(s): b60854a

Update backend_hf_api.py

Browse files
Files changed (1) hide show
  1. backend_hf_api.py +44 -11
backend_hf_api.py CHANGED
@@ -21,24 +21,49 @@ def is_hf_api_available() -> bool:
21
  return bool(get_hf_token())
22
 
23
 
 
 
 
 
 
 
 
24
  class HFInferenceBackend:
25
  """
26
- Hugging Face Serverless client with safe fallback:
27
- - Try text_generation (stream).
28
- - If provider reports 'Supported task: conversational', call HTTP conversational endpoint and chunk output.
 
29
  """
30
 
31
  def __init__(self, model_name: str):
32
  token = get_hf_token()
33
  if not token:
34
  raise RuntimeError("HF_TOKEN not set")
35
- self.model = model_name
36
  self.token = token
37
  self.client = InferenceClient(model=self.model, token=token) if InferenceClient else None
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # ---------- Prompt Builders ----------
40
  def _build_tg_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str:
41
- # Generic instruct-style prompt; works widely including Nemotron chat variants
42
  parts = [f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n"]
43
  for u, a in history:
44
  if u:
@@ -94,7 +119,7 @@ class HFInferenceBackend:
94
  buf.append(delta)
95
  yield "".join(buf)
96
 
97
- # ---------- Conversational via raw HTTP (non-stream; chunked to UI) ----------
98
  def _call_conversational_http(
99
  self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int
100
  ) -> Iterator[str]:
@@ -108,9 +133,8 @@ class HFInferenceBackend:
108
  "inputs": self._build_conv_inputs(system_prompt, history, user_msg),
109
  "parameters": {"temperature": float(temperature), "max_new_tokens": int(max_new_tokens)},
110
  }
111
-
112
  try:
113
- resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=90)
114
  except Exception as e:
115
  yield f"[error] network: {type(e).__name__}: {e}"
116
  return
@@ -138,11 +162,9 @@ class HFInferenceBackend:
138
  item = data[-1]
139
  if isinstance(item, dict):
140
  text = item.get("generated_text") or ""
141
-
142
  if not text:
143
- text = json.dumps(data) # visibility fallback
144
 
145
- # Chunk to simulate streaming and keep UI responsive
146
  buf: List[str] = []
147
  for i in range(0, len(text), 48):
148
  buf.append(text[i : i + 48])
@@ -157,7 +179,18 @@ class HFInferenceBackend:
157
  temperature: float,
158
  max_new_tokens: int,
159
  ) -> Iterator[str]:
 
 
 
 
 
 
160
  try:
 
 
 
 
 
161
  yield from self._stream_text_generation(system_prompt, history, user_msg, temperature, max_new_tokens)
162
  except Exception as e:
163
  msg = str(e).lower()
 
21
  return bool(get_hf_token())
22
 
23
 
24
+ def _suggest_repo(bad_repo: str) -> str:
25
+ # why: common Nemotron typo rescue
26
+ if "nemotron" in bad_repo.lower():
27
+ return "NVIDIA/Nemotron-3-8B-Instruct"
28
+ return "mistralai/Mistral-7B-Instruct-v0.2"
29
+
30
+
31
  class HFInferenceBackend:
32
  """
33
+ Robust HF Serverless client:
34
+ - Preflight: verify repo exists (fast) to avoid long blocking errors.
35
+ - Try text_generation streaming via huggingface_hub.
36
+ - If provider says 'conversational' only, call HTTP conversational and chunk output.
37
  """
38
 
39
  def __init__(self, model_name: str):
40
  token = get_hf_token()
41
  if not token:
42
  raise RuntimeError("HF_TOKEN not set")
43
+ self.model = model_name.strip()
44
  self.token = token
45
  self.client = InferenceClient(model=self.model, token=token) if InferenceClient else None
46
 
47
+ # ---------- Preflight ----------
48
+ def _preflight(self) -> tuple[bool, Optional[str]]:
49
+ """Returns (exists, pipeline_tag_or_None)."""
50
+ url = f"https://huggingface.co/api/models/{self.model}"
51
+ headers = {"Authorization": f"Bearer {self.token}"}
52
+ try:
53
+ r = requests.get(url, headers=headers, timeout=8)
54
+ if r.status_code == 404:
55
+ return False, None
56
+ if r.ok:
57
+ data = r.json()
58
+ # 'pipeline_tag' when known; otherwise None
59
+ return True, data.get("pipeline_tag")
60
+ return True, None
61
+ except Exception:
62
+ # If API unreachable, don't block the chat; proceed and catch later.
63
+ return True, None
64
+
65
  # ---------- Prompt Builders ----------
66
  def _build_tg_prompt(self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str) -> str:
 
67
  parts = [f"<s>[SYSTEM]\n{system_prompt}\n[/SYSTEM]\n"]
68
  for u, a in history:
69
  if u:
 
119
  buf.append(delta)
120
  yield "".join(buf)
121
 
122
+ # ---------- Conversational via raw HTTP (non-stream; chunked) ----------
123
  def _call_conversational_http(
124
  self, system_prompt: str, history: List[Tuple[str, str]], user_msg: str, temperature: float, max_new_tokens: int
125
  ) -> Iterator[str]:
 
133
  "inputs": self._build_conv_inputs(system_prompt, history, user_msg),
134
  "parameters": {"temperature": float(temperature), "max_new_tokens": int(max_new_tokens)},
135
  }
 
136
  try:
137
+ resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=40)
138
  except Exception as e:
139
  yield f"[error] network: {type(e).__name__}: {e}"
140
  return
 
162
  item = data[-1]
163
  if isinstance(item, dict):
164
  text = item.get("generated_text") or ""
 
165
  if not text:
166
+ text = json.dumps(data)
167
 
 
168
  buf: List[str] = []
169
  for i in range(0, len(text), 48):
170
  buf.append(text[i : i + 48])
 
179
  temperature: float,
180
  max_new_tokens: int,
181
  ) -> Iterator[str]:
182
+ exists, pipeline_tag = self._preflight()
183
+ if not exists:
184
+ suggestion = _suggest_repo(self.model)
185
+ yield f"[error] Model repository not found: {self.model}. Try: `{suggestion}`"
186
+ return
187
+
188
  try:
189
+ # If API says conversational, skip straight to conversational fallback.
190
+ if (pipeline_tag or "").lower() == "conversational":
191
+ yield from self._call_conversational_http(system_prompt, history, user_msg, temperature, max_new_tokens)
192
+ return
193
+
194
  yield from self._stream_text_generation(system_prompt, history, user_msg, temperature, max_new_tokens)
195
  except Exception as e:
196
  msg = str(e).lower()