adityaverma977 commited on
Commit
d4d710c
·
1 Parent(s): 97270d8

Add Groq->HF fallback mapping and accept HUGGINGFACE_API_TOKEN

Browse files
Files changed (1) hide show
  1. backend/app/groq_client.py +49 -1
backend/app/groq_client.py CHANGED
@@ -9,7 +9,8 @@ from dotenv import load_dotenv
9
  load_dotenv()
10
 
11
  _GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
12
- _HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
 
13
  _groq_client = AsyncGroq(api_key=_GROQ_API_KEY) if _GROQ_API_KEY else None
14
  _HF_API_BASE = "https://api-inference.huggingface.co/models"
15
 
@@ -45,6 +46,13 @@ HF_MODELS = [
45
  "EleutherAI/gpt-j-6B",
46
  ]
47
 
 
 
 
 
 
 
 
48
 
49
  def is_ready():
50
  """Check if we have at least one backend available."""
@@ -171,6 +179,46 @@ Respond with ONLY valid JSON on a single line (no markdown, no code block):
171
  timeout=3.0
172
  )
173
  decision = json.loads(completion.choices[0].message.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  elif _is_hf_model(agent.model_name) and _HF_API_TOKEN:
175
  # Use HF Inference API for open-source models
176
  async with httpx.AsyncClient(timeout=10.0) as client:
 
9
  load_dotenv()
10
 
11
  _GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
12
+ # Accept either HF_API_TOKEN or HUGGINGFACE_API_TOKEN for compatibility
13
+ _HF_API_TOKEN = os.environ.get("HF_API_TOKEN") or os.environ.get("HUGGINGFACE_API_TOKEN")
14
  _groq_client = AsyncGroq(api_key=_GROQ_API_KEY) if _GROQ_API_KEY else None
15
  _HF_API_BASE = "https://api-inference.huggingface.co/models"
16
 
 
46
  "EleutherAI/gpt-j-6B",
47
  ]
48
 
49
+ # Mapping from premium Groq models to reasonable HF fallback model IDs
50
+ # Used when Groq is unavailable but a HF token exists.
51
+ GROQ_TO_HF_FALLBACK = {
52
+ "mixtral-8x7b-32768": "mistralai/Mistral-7B-Instruct-v0.2",
53
+ "llama2-70b-4096": "meta-llama/Llama-2-13b-chat-hf",
54
+ }
55
+
56
 
57
  def is_ready():
58
  """Check if we have at least one backend available."""
 
179
  timeout=3.0
180
  )
181
  decision = json.loads(completion.choices[0].message.content)
182
+ # If the agent requested a premium Groq model but Groq client is not configured,
183
+ # try to route the decision to a HF fallback model when possible.
184
+ elif _is_groq_model(agent.model_name) and not _groq_client and _HF_API_TOKEN:
185
+ fallback_model = GROQ_TO_HF_FALLBACK.get(agent.model_name)
186
+ if not fallback_model:
187
+ return _fallback_escape(agent, fire)
188
+
189
+ async with httpx.AsyncClient(timeout=10.0) as client:
190
+ response = await client.post(
191
+ f"{_HF_API_BASE}/{fallback_model}",
192
+ headers={"Authorization": f"Bearer {_HF_API_TOKEN}"},
193
+ json={
194
+ "inputs": system_prompt,
195
+ "parameters": {
196
+ "max_new_tokens": 200,
197
+ "temperature": 0.7,
198
+ "top_p": 0.9,
199
+ }
200
+ }
201
+ )
202
+ response.raise_for_status()
203
+ data = response.json()
204
+
205
+ if isinstance(data, list) and len(data) > 0:
206
+ text = data[0].get("generated_text", "")
207
+ else:
208
+ text = data.get("generated_text", "")
209
+
210
+ text = text[len(system_prompt):].strip() if text.startswith(system_prompt) else text
211
+
212
+ try:
213
+ json_start = text.find('{')
214
+ json_end = text.rfind('}') + 1
215
+ if json_start >= 0 and json_end > json_start:
216
+ json_str = text[json_start:json_end]
217
+ decision = json.loads(json_str)
218
+ else:
219
+ decision = {}
220
+ except json.JSONDecodeError:
221
+ decision = {}
222
  elif _is_hf_model(agent.model_name) and _HF_API_TOKEN:
223
  # Use HF Inference API for open-source models
224
  async with httpx.AsyncClient(timeout=10.0) as client: