kaurm43 commited on
Commit
7b735c8
·
verified ·
1 Parent(s): e6dcf3c

Update PolyAgent/gradio_interface.py

Browse files
Files changed (1) hide show
  1. PolyAgent/gradio_interface.py +38 -24
PolyAgent/gradio_interface.py CHANGED
@@ -1201,7 +1201,9 @@ def gpt_only_answer(state: Dict[str, Any], prompt: str) -> str:
1201
  # ----------------------------- Other LLMs (Hugging Face Inference) ----------------------------- #
1202
  def llm_only_answer(state: Dict[str, Any], model_name: str, prompt: str) -> str:
1203
  """
1204
- LLM-only responses using Hugging Face Inference API for non-GPT models.
 
 
1205
  """
1206
  ensure_orch(state)
1207
 
@@ -1210,7 +1212,7 @@ def llm_only_answer(state: Dict[str, Any], model_name: str, prompt: str) -> str:
1210
 
1211
  HF_TOKEN = (os.getenv("HF_TOKEN") or "").strip()
1212
  if not HF_TOKEN:
1213
- return pretty_json({"ok": False, "error": "HF_TOKEN is not set. Add HF_TOKEN=hf_... to your .env or env vars."})
1214
 
1215
  HF_MODEL_MAP = {
1216
  "mixtral-8x22b-instruct": "mistralai/Mixtral-8x22B-Instruct-v0.1",
@@ -1228,18 +1230,22 @@ def llm_only_answer(state: Dict[str, Any], model_name: str, prompt: str) -> str:
1228
  if not model_id:
1229
  return pretty_json({"ok": False, "error": f"Unsupported model selection: {m}", "supported": list(HF_MODEL_MAP.keys())})
1230
 
1231
- client = InferenceClient(model=model_id, token=HF_TOKEN)
 
 
 
 
 
 
 
1232
 
1233
  system = (
1234
  "You are a polymer R&D assistant. Answer directly and clearly. "
1235
  "Do not call tools or run web searches. If you are uncertain, state uncertainty."
1236
  )
1237
 
1238
- # A simple instruct-style prompt that works for text-generation endpoints
1239
- flat_prompt = f"{system}\n\nUser:\n{p}\n\nAssistant:\n"
1240
-
1241
  try:
1242
- # Try chat endpoint first (works only if the provider exposes the model as chat)
1243
  resp = client.chat_completion(
1244
  messages=[
1245
  {"role": "system", "content": system},
@@ -1249,25 +1255,33 @@ def llm_only_answer(state: Dict[str, Any], model_name: str, prompt: str) -> str:
1249
  temperature=0.7,
1250
  )
1251
  return resp.choices[0].message.content or ""
1252
-
1253
- except Exception as e:
1254
- msg = str(e)
1255
-
1256
- # If provider says it's not a chat model, fall back to text generation.
1257
- if ("not a chat model" in msg.lower()) or ("model_not_supported" in msg.lower()):
1258
- try:
1259
- out = client.text_generation(
1260
- flat_prompt,
1261
- max_new_tokens=900,
1262
- temperature=0.7,
1263
- do_sample=True,
1264
- return_full_text=False,
1265
  )
1266
- return out if isinstance(out, str) else str(out)
1267
- except Exception as e2:
1268
- return pretty_json({"ok": False, "error": str(e2), "model_id": model_id, "mode": "text_generation"})
1269
 
1270
- return pretty_json({"ok": False, "error": msg, "model_id": model_id, "mode": "chat_completion"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
 
1272
 
1273
  def build_ui() -> gr.Blocks:
 
1201
  # ----------------------------- Other LLMs (Hugging Face Inference) ----------------------------- #
1202
  def llm_only_answer(state: Dict[str, Any], model_name: str, prompt: str) -> str:
1203
  """
1204
+ LLM-only responses via huggingface_hub.InferenceClient.
1205
+ - Forces provider to avoid unwanted auto-routing (e.g., fireworks-ai).
1206
+ - Tries chat_completion first; if model/provider doesn't support chat, falls back to text_generation.
1207
  """
1208
  ensure_orch(state)
1209
 
 
1212
 
1213
  HF_TOKEN = (os.getenv("HF_TOKEN") or "").strip()
1214
  if not HF_TOKEN:
1215
+ return pretty_json({"ok": False, "error": "HF_TOKEN is not set. Add HF_TOKEN=hf_... to Space Secrets."})
1216
 
1217
  HF_MODEL_MAP = {
1218
  "mixtral-8x22b-instruct": "mistralai/Mixtral-8x22B-Instruct-v0.1",
 
1230
  if not model_id:
1231
  return pretty_json({"ok": False, "error": f"Unsupported model selection: {m}", "supported": list(HF_MODEL_MAP.keys())})
1232
 
1233
+ # IMPORTANT: force provider so HF doesn't auto-route you to a provider that lacks the needed task
1234
+ provider = (os.getenv("HF_PROVIDER") or "hf-inference").strip()
1235
+
1236
+ client = InferenceClient(
1237
+ provider=provider,
1238
+ model=model_id,
1239
+ api_key=HF_TOKEN, # api_key works for both HF token + provider keys
1240
+ )
1241
 
1242
  system = (
1243
  "You are a polymer R&D assistant. Answer directly and clearly. "
1244
  "Do not call tools or run web searches. If you are uncertain, state uncertainty."
1245
  )
1246
 
1247
+ # 1) Try chat (conversational)
 
 
1248
  try:
 
1249
  resp = client.chat_completion(
1250
  messages=[
1251
  {"role": "system", "content": system},
 
1255
  temperature=0.7,
1256
  )
1257
  return resp.choices[0].message.content or ""
1258
+ except Exception as e_chat:
1259
+ # 2) Fallback to plain text-generation (works on hf-inference; many providers don't support it)
1260
+ try:
1261
+ if provider != "hf-inference":
1262
+ # text_generation is not universally supported across providers
1263
+ raise RuntimeError(
1264
+ f"Chat failed and provider='{provider}' may not support text_generation. "
1265
+ f"Set HF_PROVIDER=hf-inference (recommended) or choose a compatible model/provider."
 
 
 
 
 
1266
  )
 
 
 
1267
 
1268
+ # A simple prompt wrapper for non-chat models / non-chat endpoints
1269
+ wrapped = f"{system}\n\nUser: {p}\nAssistant:"
1270
+ out = client.text_generation(
1271
+ wrapped,
1272
+ max_new_tokens=900,
1273
+ temperature=0.7,
1274
+ do_sample=True,
1275
+ return_full_text=False,
1276
+ )
1277
+ return out if isinstance(out, str) else str(out)
1278
+ except Exception as e_gen:
1279
+ return pretty_json({
1280
+ "ok": False,
1281
+ "error": f"chat_completion failed: {e_chat}; text_generation failed: {e_gen}",
1282
+ "model_id": model_id,
1283
+ "provider": provider,
1284
+ })
1285
 
1286
 
1287
  def build_ui() -> gr.Blocks: