Paperbag commited on
Commit
09be979
·
1 Parent(s): 47b5c71

feat: implement LLM provider client with automated fallback and Gemini model support

Browse files
__pycache__/agent.cpython-39.pyc CHANGED
Binary files a/__pycache__/agent.cpython-39.pyc and b/__pycache__/agent.cpython-39.pyc differ
 
llm/client.py CHANGED
@@ -4,7 +4,7 @@ from typing import List
4
  from langchain_core.messages import AIMessage
5
  from llm.providers import PROVIDERS
6
 
7
- PROVIDER_ORDER = os.getenv("LLM_PROVIDER_ORDER", "gemini_gemma, gemini, groq").split(",")
8
 
9
  _degraded_providers = {}
10
 
@@ -50,14 +50,14 @@ def invoke_llm(messages: List, tools: List, fallback_count: int = 0) -> AIMessag
50
  except Exception as e:
51
  error_msg = str(e).lower()
52
 
53
- if "rate limit" in error_msg or "429" in error_msg:
54
  print(f"{provider_name} rate limit hit. Waiting before retry...")
55
  import time
56
  wait_time = 10 * (fallback_count + 1)
57
  time.sleep(wait_time)
58
-
59
- print(f"{provider_name} failed: {e}. Marking as degraded.")
60
- _degraded_providers[provider_name] = True
61
 
62
  remaining = [n for n in PROVIDER_ORDER if n not in _degraded_providers]
63
  if remaining:
 
4
  from langchain_core.messages import AIMessage
5
  from llm.providers import PROVIDERS
6
 
7
+ PROVIDER_ORDER = os.getenv("LLM_PROVIDER_ORDER", "groq, gemini, gemini_gemma").split(",")
8
 
9
  _degraded_providers = {}
10
 
 
50
  except Exception as e:
51
  error_msg = str(e).lower()
52
 
53
+ if "rate limit" in error_msg or "429" in error_msg or "quota" in error_msg:
54
  print(f"{provider_name} rate limit hit. Waiting before retry...")
55
  import time
56
  wait_time = 10 * (fallback_count + 1)
57
  time.sleep(wait_time)
58
+ _degraded_providers[provider_name] = True
59
+ else:
60
+ print(f"{provider_name} error: {e}. Trying next provider.")
61
 
62
  remaining = [n for n in PROVIDER_ORDER if n not in _degraded_providers]
63
  if remaining:
llm/providers/gemini.py CHANGED
@@ -1,9 +1,14 @@
 
 
1
  from langchain_google_genai import ChatGoogleGenerativeAI
2
 
 
 
 
3
 
4
  def invoke(messages, tools, model_name: str = "gemini-2.0-flash"):
5
  """Invoke Gemini models (free tier)."""
6
- model = ChatGoogleGenerativeAI(model=model_name, temperature=0)
7
  model_with_tools = model.bind_tools(tools)
8
  return model_with_tools.invoke(messages)
9
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
 
5
+ load_dotenv()
6
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
7
+
8
 
9
  def invoke(messages, tools, model_name: str = "gemini-2.0-flash"):
10
  """Invoke Gemini models (free tier)."""
11
+ model = ChatGoogleGenerativeAI(model=model_name, temperature=0, google_api_key=GOOGLE_API_KEY)
12
  model_with_tools = model.bind_tools(tools)
13
  return model_with_tools.invoke(messages)
14
 
llm/providers/gemini_gemma.py CHANGED
@@ -1,9 +1,14 @@
 
 
1
  from langchain_google_genai import ChatGoogleGenerativeAI
2
 
 
 
 
3
 
4
  def invoke(messages, tools, model_name: str = "gemma-2-27b-it"):
5
  """Invoke Google Gemma models (free tier)."""
6
- model = ChatGoogleGenerativeAI(model=model_name, temperature=0)
7
  model_with_tools = model.bind_tools(tools)
8
  return model_with_tools.invoke(messages)
9
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
 
5
+ load_dotenv()
6
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
7
+
8
 
9
  def invoke(messages, tools, model_name: str = "gemma-2-27b-it"):
10
  """Invoke Google Gemma models (free tier)."""
11
+ model = ChatGoogleGenerativeAI(model=model_name, temperature=0, google_api_key=GOOGLE_API_KEY)
12
  model_with_tools = model.bind_tools(tools)
13
  return model_with_tools.invoke(messages)
14