soupstick commited on
Commit
0c36357
·
1 Parent(s): 32115de

Fix HF API ERROR

Browse files
Files changed (4) hide show
  1. .env.example +3 -2
  2. api/inference.py +27 -17
  3. api/providers.py +3 -4
  4. app.py +5 -5
.env.example CHANGED
@@ -5,7 +5,8 @@ AGENT_API_URL=http://localhost:7861
5
  DBT_PROFILES_DIR=./dbt_project/profiles
6
 
7
  # LLM Provider Configuration
8
- LLM_PROVIDER=hf
 
9
  HF_TOKEN=YOUR_TOKEN
10
  HF_ROUTER_MODEL=Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai
11
 
@@ -14,7 +15,7 @@ LLM_MODEL_GEN=Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai
14
  LLM_MODEL_REV=Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai
15
 
16
  # Optional Backend API URL (for existing functionality)
17
- API_URL=
18
 
19
  # Cache Configuration
20
  TRANSFORMERS_CACHE=/tmp/cache/transformers
 
5
  DBT_PROFILES_DIR=./dbt_project/profiles
6
 
7
  # LLM Provider Configuration
8
+ LLM_PROVIDER=hf_router
9
+ API_URL=https://router.huggingface.co/v1/chat/completions
10
  HF_TOKEN=YOUR_TOKEN
11
  HF_ROUTER_MODEL=Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai
12
 
 
15
  LLM_MODEL_REV=Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai
16
 
17
  # Optional Backend API URL (for existing functionality)
18
+ # API_URL can be used to point to a compatible backend, but also is used by the HF Router client above.
19
 
20
  # Cache Configuration
21
  TRANSFORMERS_CACHE=/tmp/cache/transformers
api/inference.py CHANGED
@@ -1,25 +1,35 @@
1
- import os
2
- import requests
3
- from typing import Optional
4
 
5
- API_URL = "https://router.huggingface.co/v1/chat/completions"
 
6
 
7
 
8
- def _call_llm(prompt: str, max_tokens: int = 512, temperature: float = 0.2, model: Optional[str] = None) -> str:
9
- hf_token = os.getenv("HF_TOKEN")
10
- if not hf_token:
11
- raise RuntimeError("Set HF_TOKEN in env")
12
- headers = {"Authorization": f"Bearer {hf_token}"}
13
- payload = {
14
- "model": model or os.getenv("HF_ROUTER_MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai"),
 
15
  "messages": [{"role": "user", "content": prompt}],
16
  "max_tokens": max_tokens,
17
- "temperature": temperature,
18
- "stream": False,
19
  }
20
- resp = requests.post(API_URL, headers=headers, json=payload, timeout=60)
 
21
  if resp.status_code != 200:
22
- print("HF Router error:", resp.text)
23
- resp.raise_for_status()
24
- return resp.json()["choices"][0]["message"]["content"]
 
 
 
 
25
 
 
 
 
 
 
 
1
+ import os, requests, json
 
 
2
 
3
+ ROUTER_URL = os.getenv("API_URL", "https://router.huggingface.co/v1/chat/completions")
4
+ DEFAULT_MODEL = os.getenv("HF_ROUTER_MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct:fireworks-ai")
5
 
6
 
7
+ def _call_llm(prompt: str, max_tokens: int = 512, temperature: float = 0.2, model: str | None = None) -> str:
8
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
9
+ if not token:
10
+ raise RuntimeError("Set HF_TOKEN (or HUGGINGFACEHUB_API_TOKEN) in env")
11
+
12
+ headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
13
+ body = {
14
+ "model": model or DEFAULT_MODEL,
15
  "messages": [{"role": "user", "content": prompt}],
16
  "max_tokens": max_tokens,
17
+ "temperature": float(temperature),
18
+ "stream": False
19
  }
20
+
21
+ resp = requests.post(ROUTER_URL, headers=headers, json=body, timeout=60)
22
  if resp.status_code != 200:
23
+ # Surface router/provider error clearly
24
+ try:
25
+ failing_model = body.get("model")
26
+ except Exception:
27
+ failing_model = None
28
+ print(f"HF Router error {resp.status_code} for model={failing_model}: {resp.text}")
29
+ raise RuntimeError(f"HF Router error {resp.status_code}: {resp.text}")
30
 
31
+ data = resp.json()
32
+ try:
33
+ return data["choices"][0]["message"]["content"]
34
+ except Exception:
35
+ raise RuntimeError(f"Unexpected HF Router response: {json.dumps(data)[:800]}")
api/providers.py CHANGED
@@ -42,9 +42,8 @@ def extract_json(text: str):
42
  # ---------- Unified provider call via HF Router ----------
43
 
44
  def llm_call(kind: str, prompt: str) -> str:
45
- """Unified HF Router call. kind: 'gen' or 'rev' chooses env model."""
46
- model = os.getenv("LLM_MODEL_GEN") if kind == "gen" else os.getenv("LLM_MODEL_REV")
47
  if not model:
48
- raise RuntimeError(f"Missing model id for kind={kind}. Set LLM_MODEL_GEN/REV.")
49
  return _router_call(prompt, max_tokens=256, temperature=0.0, model=model)
50
-
 
42
  # ---------- Unified provider call via HF Router ----------
43
 
44
  def llm_call(kind: str, prompt: str) -> str:
45
+ """Unified HF Router call. Model is taken from HF_ROUTER_MODEL env."""
46
+ model = os.getenv("HF_ROUTER_MODEL")
47
  if not model:
48
+ raise RuntimeError("Set HF_ROUTER_MODEL in env for HF Router.")
49
  return _router_call(prompt, max_tokens=256, temperature=0.0, model=model)
 
app.py CHANGED
@@ -119,8 +119,8 @@ def _provider_has_creds(provider: str) -> bool:
119
 
120
 
121
  def _call_llm(provider: str, model: str, prompt: str) -> str:
122
- # Delegate to unified HF Router call; keep signature for compatibility
123
- return _router_call(prompt, max_tokens=400, temperature=0.0, model=model)
124
 
125
 
126
  def _extract_sql_from_text(text: str) -> str:
@@ -171,7 +171,7 @@ def _gen_sql(question: str, schema: str, provider: str, model: str, api_url: str
171
  return _enforce_limits(candidate)
172
  except Exception:
173
  st.warning(REMOTE_ERROR_HINT)
174
- llm_output = _router_call(prompt, max_tokens=400, temperature=0.0, model=model)
175
  return _enforce_limits(_extract_sql_from_text(llm_output))
176
 
177
 
@@ -227,7 +227,7 @@ def _review_sql(question: str, sql: str, schema: str, provider: str, model: str)
227
  "Return JSON with keys reasoning, ok (true/false), fixed_sql."
228
  )
229
  try:
230
- llm_response = _router_call(prompt, max_tokens=400, temperature=0.0, model=model)
231
  parsed = _extract_json(llm_response)
232
  if parsed:
233
  return parsed
@@ -268,7 +268,7 @@ def _suggest_chart(df: pd.DataFrame, provider: str, model: str) -> Optional[Dict
268
  column_info = ", ".join(f"{col} ({df[col].dtype})" for col in df.columns)
269
  prompt = PROMPT_DASHBOARD.format(column_info=column_info)
270
  try:
271
- raw = _router_call(prompt, max_tokens=400, temperature=0.0, model=model)
272
  parsed = _extract_json(raw)
273
  if isinstance(parsed, dict):
274
  chart_type = parsed.get("chart_type")
 
119
 
120
 
121
  def _call_llm(provider: str, model: str, prompt: str) -> str:
122
+ # Delegate to unified HF Router call; always use HF_ROUTER_MODEL with the router
123
+ return _router_call(prompt, max_tokens=400, temperature=0.0, model=os.getenv('HF_ROUTER_MODEL'))
124
 
125
 
126
  def _extract_sql_from_text(text: str) -> str:
 
171
  return _enforce_limits(candidate)
172
  except Exception:
173
  st.warning(REMOTE_ERROR_HINT)
174
+ llm_output = _router_call(prompt, max_tokens=400, temperature=0.0, model=os.getenv('HF_ROUTER_MODEL'))
175
  return _enforce_limits(_extract_sql_from_text(llm_output))
176
 
177
 
 
227
  "Return JSON with keys reasoning, ok (true/false), fixed_sql."
228
  )
229
  try:
230
+ llm_response = _router_call(prompt, max_tokens=400, temperature=0.0, model=os.getenv('HF_ROUTER_MODEL'))
231
  parsed = _extract_json(llm_response)
232
  if parsed:
233
  return parsed
 
268
  column_info = ", ".join(f"{col} ({df[col].dtype})" for col in df.columns)
269
  prompt = PROMPT_DASHBOARD.format(column_info=column_info)
270
  try:
271
+ raw = _router_call(prompt, max_tokens=400, temperature=0.0, model=os.getenv('HF_ROUTER_MODEL'))
272
  parsed = _extract_json(raw)
273
  if isinstance(parsed, dict):
274
  chart_type = parsed.get("chart_type")