cdpearlman Cursor commited on
Commit
ddd91a5
·
1 Parent(s): 689669f

Migrate from google-generativeai to google-genai SDK

Browse files

Co-authored-by: Cursor <cursoragent@cursor.com>

requirements.txt CHANGED
@@ -18,4 +18,4 @@ numpy>=1.24.0
18
  pytest>=7.0.0
19
 
20
  # AI Chatbot dependencies
21
- google-generativeai>=0.8.0
 
18
  pytest>=7.0.0
19
 
20
  # AI Chatbot dependencies
21
+ google-genai>=1.0.0
tests/test_gemini_connection.py CHANGED
@@ -4,8 +4,7 @@ Tests for Gemini API connection.
4
  Verifies that the API key is configured correctly and can connect
5
  to the Gemini API without consuming generation tokens.
6
 
7
- Note: These tests avoid importing utils.gemini_client where possible
8
- to prevent slow tensorflow/jax imports from the google-generativeai package.
9
  """
10
 
11
  import os
@@ -33,13 +32,13 @@ class TestGeminiConnection:
33
  Test API connectivity by listing available models.
34
  This verifies the API key is valid without consuming generation tokens.
35
  """
36
- import google.generativeai as genai
37
 
38
  api_key = os.environ.get("GEMINI_API_KEY")
39
- genai.configure(api_key=api_key)
40
 
41
  # List models - this is a read-only API call that validates the key
42
- models = list(genai.list_models())
43
 
44
  assert len(models) > 0, "No models returned - API key may be invalid"
45
 
@@ -51,12 +50,12 @@ class TestGeminiConnection:
51
  @pytest.mark.timeout(30)
52
  def test_flash_model_available(self):
53
  """Verify a Gemini Flash model (used by default) is available."""
54
- import google.generativeai as genai
55
 
56
  api_key = os.environ.get("GEMINI_API_KEY")
57
- genai.configure(api_key=api_key)
58
 
59
- models = list(genai.list_models())
60
  model_names = [m.name for m in models]
61
 
62
  # Check for flash model variants (our default is gemini-2.0-flash)
@@ -65,3 +64,21 @@ class TestGeminiConnection:
65
  f"No Gemini Flash models available. "
66
  f"Available models: {model_names[:10]}..."
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  Verifies that the API key is configured correctly and can connect
5
  to the Gemini API without consuming generation tokens.
6
 
7
+ Uses the new google-genai SDK.
 
8
  """
9
 
10
  import os
 
32
  Test API connectivity by listing available models.
33
  This verifies the API key is valid without consuming generation tokens.
34
  """
35
+ from google import genai
36
 
37
  api_key = os.environ.get("GEMINI_API_KEY")
38
+ client = genai.Client(api_key=api_key)
39
 
40
  # List models - this is a read-only API call that validates the key
41
+ models = list(client.models.list())
42
 
43
  assert len(models) > 0, "No models returned - API key may be invalid"
44
 
 
50
  @pytest.mark.timeout(30)
51
  def test_flash_model_available(self):
52
  """Verify a Gemini Flash model (used by default) is available."""
53
+ from google import genai
54
 
55
  api_key = os.environ.get("GEMINI_API_KEY")
56
+ client = genai.Client(api_key=api_key)
57
 
58
+ models = list(client.models.list())
59
  model_names = [m.name for m in models]
60
 
61
  # Check for flash model variants (our default is gemini-2.0-flash)
 
64
  f"No Gemini Flash models available. "
65
  f"Available models: {model_names[:10]}..."
66
  )
67
+
68
+ @pytest.mark.timeout(30)
69
+ def test_embedding_model_available(self):
70
+ """Verify the embedding model is available."""
71
+ from google import genai
72
+
73
+ api_key = os.environ.get("GEMINI_API_KEY")
74
+ client = genai.Client(api_key=api_key)
75
+
76
+ models = list(client.models.list())
77
+ model_names = [m.name for m in models]
78
+
79
+ # Check for embedding model (gemini-embedding-001)
80
+ has_embedding_model = any("embedding" in name.lower() for name in model_names)
81
+ assert has_embedding_model, (
82
+ f"No embedding models available. "
83
+ f"Available models: {model_names[:10]}..."
84
+ )
todo.md CHANGED
@@ -150,3 +150,16 @@
150
  - [x] Create `tests/test_gemini_connection.py` to verify API key connectivity
151
  - [x] Tests verify: API key is set, can list models, flash model available
152
  - Note: On Hugging Face Spaces, set `GEMINI_API_KEY` in Repository Secrets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  - [x] Create `tests/test_gemini_connection.py` to verify API key connectivity
151
  - [x] Tests verify: API key is set, can list models, flash model available
152
  - Note: On Hugging Face Spaces, set `GEMINI_API_KEY` in Repository Secrets
153
+
154
+ ## Completed: Migrate to New Google GenAI SDK
155
+
156
+ - [x] Update `requirements.txt`: `google-generativeai` → `google-genai>=1.0.0`
157
+ - [x] Rewrite `utils/gemini_client.py` using new centralized Client architecture
158
+ - New import: `from google import genai` and `from google.genai import types`
159
+ - Client-based API: `client = genai.Client(api_key=...)`
160
+ - Chat via: `client.chats.create(model=..., config=..., history=...)`
161
+ - Embeddings via: `client.models.embed_content(model=..., contents=..., config=...)`
162
+ - [x] Update embedding model: `models/text-embedding-004` → `gemini-embedding-001`
163
+ - [x] Update `tests/test_gemini_connection.py` to use new SDK
164
+ - [x] All 4 connection tests pass
165
+ - [x] Verified: embeddings work (3072 dimensions), chat generation works
utils/gemini_client.py CHANGED
@@ -3,16 +3,19 @@ Gemini API Client
3
 
4
  Wrapper for Google Gemini API providing text generation and embedding capabilities
5
  for the AI chatbot feature.
 
 
6
  """
7
 
8
  import os
9
  from typing import List, Dict, Optional
10
- import google.generativeai as genai
 
11
 
12
 
13
  # Default model configuration
14
  DEFAULT_GENERATION_MODEL = "gemini-2.0-flash"
15
- DEFAULT_EMBEDDING_MODEL = "models/text-embedding-004"
16
 
17
  # System prompt for the chatbot
18
  SYSTEM_PROMPT = """You are a helpful AI assistant integrated into a Transformer Explanation Dashboard.
@@ -45,23 +48,19 @@ class GeminiClient:
45
  """
46
  self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
47
  self._initialized = False
48
- self._generation_model = None
49
- self._embedding_model = None
50
 
51
  if self.api_key:
52
  self._initialize()
53
 
54
  def _initialize(self):
55
- """Initialize the Gemini API with the API key."""
56
  if not self.api_key:
57
  return
58
 
59
  try:
60
- genai.configure(api_key=self.api_key)
61
- self._generation_model = genai.GenerativeModel(
62
- model_name=DEFAULT_GENERATION_MODEL,
63
- system_instruction=SYSTEM_PROMPT
64
- )
65
  self._initialized = True
66
  except Exception as e:
67
  print(f"Error initializing Gemini client: {e}")
@@ -98,26 +97,32 @@ class GeminiClient:
98
  # Build the full prompt with context
99
  full_message = self._build_prompt(user_message, rag_context, dashboard_context)
100
 
101
- # Convert chat history to Gemini format
102
  history = []
103
  if chat_history:
104
  for msg in chat_history[-10:]: # Keep last 10 messages for context
105
  role = "user" if msg.get("role") == "user" else "model"
106
  history.append({
107
  "role": role,
108
- "parts": [msg.get("content", "")]
109
  })
110
 
111
- # Create chat session and send message
112
- chat = self._generation_model.start_chat(history=history)
113
- response = chat.send_message(full_message)
 
 
 
 
 
 
114
 
115
  return response.text
116
 
117
  except Exception as e:
118
  error_msg = str(e)
119
  if "quota" in error_msg.lower() or "rate" in error_msg.lower():
120
- return "The AI service is currently rate limited. Please try again in a moment."
121
  elif "invalid" in error_msg.lower() and "key" in error_msg.lower():
122
  return "Invalid API key. Please check your GEMINI_API_KEY configuration."
123
  else:
@@ -187,12 +192,15 @@ class GeminiClient:
187
  return None
188
 
189
  try:
190
- result = genai.embed_content(
191
  model=DEFAULT_EMBEDDING_MODEL,
192
- content=text,
193
- task_type="retrieval_document"
 
 
194
  )
195
- return result['embedding']
 
196
  except Exception as e:
197
  print(f"Embedding error: {e}")
198
  return None
@@ -211,12 +219,15 @@ class GeminiClient:
211
  return None
212
 
213
  try:
214
- result = genai.embed_content(
215
  model=DEFAULT_EMBEDDING_MODEL,
216
- content=query,
217
- task_type="retrieval_query"
 
 
218
  )
219
- return result['embedding']
 
220
  except Exception as e:
221
  print(f"Query embedding error: {e}")
222
  return None
 
3
 
4
  Wrapper for Google Gemini API providing text generation and embedding capabilities
5
  for the AI chatbot feature.
6
+
7
+ Uses the new google-genai SDK (migrated from deprecated google-generativeai).
8
  """
9
 
10
  import os
11
  from typing import List, Dict, Optional
12
+ from google import genai
13
+ from google.genai import types
14
 
15
 
16
  # Default model configuration
17
  DEFAULT_GENERATION_MODEL = "gemini-2.0-flash"
18
+ DEFAULT_EMBEDDING_MODEL = "gemini-embedding-001"
19
 
20
  # System prompt for the chatbot
21
  SYSTEM_PROMPT = """You are a helpful AI assistant integrated into a Transformer Explanation Dashboard.
 
48
  """
49
  self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
50
  self._initialized = False
51
+ self._client = None
 
52
 
53
  if self.api_key:
54
  self._initialize()
55
 
56
  def _initialize(self):
57
+ """Initialize the Gemini API client."""
58
  if not self.api_key:
59
  return
60
 
61
  try:
62
+ # Create the centralized client object (new SDK architecture)
63
+ self._client = genai.Client(api_key=self.api_key)
 
 
 
64
  self._initialized = True
65
  except Exception as e:
66
  print(f"Error initializing Gemini client: {e}")
 
97
  # Build the full prompt with context
98
  full_message = self._build_prompt(user_message, rag_context, dashboard_context)
99
 
100
+ # Convert chat history to new SDK format
101
  history = []
102
  if chat_history:
103
  for msg in chat_history[-10:]: # Keep last 10 messages for context
104
  role = "user" if msg.get("role") == "user" else "model"
105
  history.append({
106
  "role": role,
107
+ "parts": [{"text": msg.get("content", "")}]
108
  })
109
 
110
+ # Create chat session with system instruction and send message
111
+ chat = self._client.chats.create(
112
+ model=DEFAULT_GENERATION_MODEL,
113
+ config=types.GenerateContentConfig(
114
+ system_instruction=SYSTEM_PROMPT,
115
+ ),
116
+ history=history
117
+ )
118
+ response = chat.send_message(message=full_message)
119
 
120
  return response.text
121
 
122
  except Exception as e:
123
  error_msg = str(e)
124
  if "quota" in error_msg.lower() or "rate" in error_msg.lower():
125
+ return f"The AI service is currently rate limited. Please try again in a moment. {error_msg}"
126
  elif "invalid" in error_msg.lower() and "key" in error_msg.lower():
127
  return "Invalid API key. Please check your GEMINI_API_KEY configuration."
128
  else:
 
192
  return None
193
 
194
  try:
195
+ result = self._client.models.embed_content(
196
  model=DEFAULT_EMBEDDING_MODEL,
197
+ contents=text,
198
+ config=types.EmbedContentConfig(
199
+ task_type="RETRIEVAL_DOCUMENT"
200
+ )
201
  )
202
+ # New SDK returns embeddings as a list, get the first one
203
+ return result.embeddings[0].values
204
  except Exception as e:
205
  print(f"Embedding error: {e}")
206
  return None
 
219
  return None
220
 
221
  try:
222
+ result = self._client.models.embed_content(
223
  model=DEFAULT_EMBEDDING_MODEL,
224
+ contents=query,
225
+ config=types.EmbedContentConfig(
226
+ task_type="RETRIEVAL_QUERY"
227
+ )
228
  )
229
+ # New SDK returns embeddings as a list, get the first one
230
+ return result.embeddings[0].values
231
  except Exception as e:
232
  print(f"Query embedding error: {e}")
233
  return None