Spaces:
Running
Running
Asish Karthikeya Gogineni commited on
Commit ·
a1acb59
1
Parent(s): 3508757
feat: Add Gemini model fallback logic - tries multiple models in sequence
Browse files- code_chatbot/rag.py +44 -11
code_chatbot/rag.py
CHANGED
|
@@ -118,18 +118,51 @@ class ChatEngine:
|
|
| 118 |
if not os.getenv("GOOGLE_API_KEY"):
|
| 119 |
raise ValueError("Google API Key is required for Gemini")
|
| 120 |
|
| 121 |
-
#
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
elif self.provider == "groq":
|
| 134 |
if not api_key:
|
| 135 |
if not os.getenv("GROQ_API_KEY"):
|
|
|
|
| 118 |
if not os.getenv("GOOGLE_API_KEY"):
|
| 119 |
raise ValueError("Google API Key is required for Gemini")
|
| 120 |
|
| 121 |
+
# Fallback list of Gemini models to try in order
|
| 122 |
+
GEMINI_MODELS_TO_TRY = [
|
| 123 |
+
"gemini-2.5-flash",
|
| 124 |
+
"gemini-2.5-pro",
|
| 125 |
+
"gemini-2.0-flash",
|
| 126 |
+
"gemini-2.0-flash-lite",
|
| 127 |
+
"gemini-1.5-flash",
|
| 128 |
+
"gemini-1.5-pro",
|
| 129 |
+
"gemini-pro",
|
| 130 |
+
]
|
| 131 |
|
| 132 |
+
# If user specified a model, try it first
|
| 133 |
+
if self.model_name:
|
| 134 |
+
model_name = self.model_name
|
| 135 |
+
if model_name.startswith("models/"):
|
| 136 |
+
model_name = model_name.replace("models/", "")
|
| 137 |
+
if model_name not in GEMINI_MODELS_TO_TRY:
|
| 138 |
+
GEMINI_MODELS_TO_TRY.insert(0, model_name)
|
| 139 |
+
else:
|
| 140 |
+
# Move specified model to front
|
| 141 |
+
GEMINI_MODELS_TO_TRY.remove(model_name)
|
| 142 |
+
GEMINI_MODELS_TO_TRY.insert(0, model_name)
|
| 143 |
+
|
| 144 |
+
# Try each model until one works
|
| 145 |
+
last_error = None
|
| 146 |
+
for model_name in GEMINI_MODELS_TO_TRY:
|
| 147 |
+
try:
|
| 148 |
+
logger.info(f"Attempting to use Gemini model: {model_name}")
|
| 149 |
+
llm = ChatGoogleGenerativeAI(
|
| 150 |
+
model=model_name,
|
| 151 |
+
google_api_key=api_key,
|
| 152 |
+
temperature=0.2,
|
| 153 |
+
convert_system_message_to_human=True
|
| 154 |
+
)
|
| 155 |
+
# Test the model with a simple call
|
| 156 |
+
llm.invoke("test")
|
| 157 |
+
logger.info(f"Successfully initialized Gemini model: {model_name}")
|
| 158 |
+
return llm
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.warning(f"Model {model_name} failed: {str(e)[:100]}")
|
| 161 |
+
last_error = e
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
# If all models failed, raise the last error
|
| 165 |
+
raise ValueError(f"All Gemini models failed. Last error: {last_error}")
|
| 166 |
elif self.provider == "groq":
|
| 167 |
if not api_key:
|
| 168 |
if not os.getenv("GROQ_API_KEY"):
|