Asish Karthikeya Gogineni commited on
Commit
a1acb59
·
1 Parent(s): 3508757

feat: Add Gemini model fallback logic - tries multiple models in sequence

Browse files
Files changed (1) hide show
  1. code_chatbot/rag.py +44 -11
code_chatbot/rag.py CHANGED
@@ -118,18 +118,51 @@ class ChatEngine:
118
  if not os.getenv("GOOGLE_API_KEY"):
119
  raise ValueError("Google API Key is required for Gemini")
120
 
121
- # Use model name without prefix - langchain handles it
122
- model_name = self.model_name or "gemini-2.5-flash"
123
- # Remove models/ prefix if present (langchain adds it)
124
- if model_name.startswith("models/"):
125
- model_name = model_name.replace("models/", "")
 
 
 
 
 
126
 
127
- return ChatGoogleGenerativeAI(
128
- model=model_name,
129
- google_api_key=api_key,
130
- temperature=0.2, # Low temp for agents
131
- convert_system_message_to_human=True
132
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  elif self.provider == "groq":
134
  if not api_key:
135
  if not os.getenv("GROQ_API_KEY"):
 
118
  if not os.getenv("GOOGLE_API_KEY"):
119
  raise ValueError("Google API Key is required for Gemini")
120
 
121
+ # Fallback list of Gemini models to try in order
122
+ GEMINI_MODELS_TO_TRY = [
123
+ "gemini-2.5-flash",
124
+ "gemini-2.5-pro",
125
+ "gemini-2.0-flash",
126
+ "gemini-2.0-flash-lite",
127
+ "gemini-1.5-flash",
128
+ "gemini-1.5-pro",
129
+ "gemini-pro",
130
+ ]
131
 
132
+ # If user specified a model, try it first
133
+ if self.model_name:
134
+ model_name = self.model_name
135
+ if model_name.startswith("models/"):
136
+ model_name = model_name.replace("models/", "")
137
+ if model_name not in GEMINI_MODELS_TO_TRY:
138
+ GEMINI_MODELS_TO_TRY.insert(0, model_name)
139
+ else:
140
+ # Move specified model to front
141
+ GEMINI_MODELS_TO_TRY.remove(model_name)
142
+ GEMINI_MODELS_TO_TRY.insert(0, model_name)
143
+
144
+ # Try each model until one works
145
+ last_error = None
146
+ for model_name in GEMINI_MODELS_TO_TRY:
147
+ try:
148
+ logger.info(f"Attempting to use Gemini model: {model_name}")
149
+ llm = ChatGoogleGenerativeAI(
150
+ model=model_name,
151
+ google_api_key=api_key,
152
+ temperature=0.2,
153
+ convert_system_message_to_human=True
154
+ )
155
+ # Test the model with a simple call
156
+ llm.invoke("test")
157
+ logger.info(f"Successfully initialized Gemini model: {model_name}")
158
+ return llm
159
+ except Exception as e:
160
+ logger.warning(f"Model {model_name} failed: {str(e)[:100]}")
161
+ last_error = e
162
+ continue
163
+
164
+ # If all models failed, raise the last error
165
+ raise ValueError(f"All Gemini models failed. Last error: {last_error}")
166
  elif self.provider == "groq":
167
  if not api_key:
168
  if not os.getenv("GROQ_API_KEY"):