kaurm43 commited on
Commit
491816e
·
verified ·
1 Parent(s): 455f399

Update PolyAgent/gradio_interface.py

Browse files
Files changed (1) hide show
  1. PolyAgent/gradio_interface.py +15 -26
PolyAgent/gradio_interface.py CHANGED
@@ -1240,34 +1240,23 @@ def llm_only_answer(state: Dict[str, Any], model_name: str, prompt: str) -> str:
1240
  }
1241
  )
1242
 
1243
- client = InferenceClient(model=model_id, token=HF_TOKEN)
1244
 
1245
  try:
1246
- if model_id.startswith("mistralai/"):
1247
- # Mixtral: use text-generation, not chat
1248
- prompt_text = (
1249
- "You are a polymer R&D assistant. Answer directly and clearly.\n\n"
1250
- f"User: {p}\nAssistant:"
1251
- )
1252
- resp = client.text_generation(
1253
- prompt_text,
1254
- max_new_tokens=900,
1255
- temperature=0.7,
1256
- top_p=0.95,
1257
- return_full_text=False,
1258
- )
1259
- return resp
1260
- else:
1261
- # Llama: chat endpoint works
1262
- resp = client.chat_completion(
1263
- messages=[
1264
- {"role": "system", "content": "You are a polymer R&D assistant..."},
1265
- {"role": "user", "content": p},
1266
- ],
1267
- max_tokens=900,
1268
- temperature=0.7,
1269
- )
1270
- return resp.choices[0].message.content or ""
1271
  return resp.choices[0].message.content or ""
1272
  except Exception as e:
1273
  return pretty_json({"ok": False, "error": str(e), "model_id": model_id})
 
1240
  }
1241
  )
1242
 
1243
+ client = InferenceClient(model=model_id, token=HF_TOKEN, provider="fireworks-ai")
1244
 
1245
  try:
1246
+ resp = client.chat_completion(
1247
+ messages=[
1248
+ {
1249
+ "role": "system",
1250
+ "content": (
1251
+ "You are a polymer R&D assistant. Answer directly and clearly. "
1252
+ "Do not call tools or run web searches. If you are uncertain, state uncertainty."
1253
+ ),
1254
+ },
1255
+ {"role": "user", "content": p},
1256
+ ],
1257
+ max_tokens=900,
1258
+ temperature=0.7,
1259
+ )
 
 
 
 
 
 
 
 
 
 
 
1260
  return resp.choices[0].message.content or ""
1261
  except Exception as e:
1262
  return pretty_json({"ok": False, "error": str(e), "model_id": model_id})