""" Custom Gemini embedding class for LlamaIndex integration. Provides an alternative to OpenAI embeddings using Google's Gemini API. """ from typing import List from llama_index.core.embeddings import BaseEmbedding from google import genai class GeminiEmbedding(BaseEmbedding): """ Gemini embedding model integration for LlamaIndex. Uses Google's gemini-embedding-001 model for generating embeddings. This provides an alternative to OpenAI embeddings. """ def __init__( self, api_key: str, model_name: str = "gemini-embedding-001", **kwargs ): """ Initialize Gemini embedding model. Args: api_key: Google API key for Gemini model_name: Model name (default: gemini-embedding-001) **kwargs: Additional arguments for BaseEmbedding """ super().__init__(**kwargs) # Use private attribute to store client (Pydantic compatibility) self._client = genai.Client(api_key=api_key) self._model_name = model_name def _get_query_embedding(self, query: str) -> List[float]: """ Get embedding for a query string. Args: query: Query text to embed Returns: List of floats representing the embedding vector """ try: result = self._client.models.embed_content( model=self._model_name, contents=query ) # Extract embedding values from the response # The response structure is: result.embeddings[0].values if hasattr(result, 'embeddings') and len(result.embeddings) > 0: embedding = result.embeddings[0] if hasattr(embedding, 'values'): return list(embedding.values) raise ValueError("Unexpected response structure from Gemini embedding API") except Exception as e: raise RuntimeError(f"Error getting query embedding from Gemini: {str(e)}") def _get_text_embedding(self, text: str) -> List[float]: """ Get embedding for a text string. Args: text: Text to embed Returns: List of floats representing the embedding vector """ try: result = self._client.models.embed_content( model=self._model_name, contents=text ) # Extract embedding values from the response if hasattr(result, 'embeddings') and len(result.embeddings) > 0: embedding = result.embeddings[0] if hasattr(embedding, 'values'): return list(embedding.values) raise ValueError("Unexpected response structure from Gemini embedding API") except Exception as e: raise RuntimeError(f"Error getting text embedding from Gemini: {str(e)}") async def _aget_query_embedding(self, query: str) -> List[float]: """ Async version of _get_query_embedding. Note: Currently uses synchronous API as Gemini SDK doesn't have async support yet. Args: query: Query text to embed Returns: List of floats representing the embedding vector """ return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: """ Async version of _get_text_embedding. Note: Currently uses synchronous API as Gemini SDK doesn't have async support yet. Args: text: Text to embed Returns: List of floats representing the embedding vector """ return self._get_text_embedding(text) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """ Get embeddings for a list of texts. Args: texts: List of texts to embed Returns: List of embedding vectors """ embeddings = [] for text in texts: embeddings.append(self._get_text_embedding(text)) return embeddings