| | """ |
| | Custom Gemini embedding class for LlamaIndex integration. |
| | Provides an alternative to OpenAI embeddings using Google's Gemini API. |
| | """ |
| | from typing import List |
| | from llama_index.core.embeddings import BaseEmbedding |
| | from google import genai |
| |
|
| |
|
| | class GeminiEmbedding(BaseEmbedding): |
| | """ |
| | Gemini embedding model integration for LlamaIndex. |
| | |
| | Uses Google's gemini-embedding-001 model for generating embeddings. |
| | This provides an alternative to OpenAI embeddings. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | api_key: str, |
| | model_name: str = "gemini-embedding-001", |
| | **kwargs |
| | ): |
| | """ |
| | Initialize Gemini embedding model. |
| | |
| | Args: |
| | api_key: Google API key for Gemini |
| | model_name: Model name (default: gemini-embedding-001) |
| | **kwargs: Additional arguments for BaseEmbedding |
| | """ |
| | super().__init__(**kwargs) |
| | |
| | self._client = genai.Client(api_key=api_key) |
| | self._model_name = model_name |
| |
|
| | def _get_query_embedding(self, query: str) -> List[float]: |
| | """ |
| | Get embedding for a query string. |
| | |
| | Args: |
| | query: Query text to embed |
| | |
| | Returns: |
| | List of floats representing the embedding vector |
| | """ |
| | try: |
| | result = self._client.models.embed_content( |
| | model=self._model_name, |
| | contents=query |
| | ) |
| | |
| | |
| | if hasattr(result, 'embeddings') and len(result.embeddings) > 0: |
| | embedding = result.embeddings[0] |
| | if hasattr(embedding, 'values'): |
| | return list(embedding.values) |
| |
|
| | raise ValueError("Unexpected response structure from Gemini embedding API") |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Error getting query embedding from Gemini: {str(e)}") |
| |
|
| | def _get_text_embedding(self, text: str) -> List[float]: |
| | """ |
| | Get embedding for a text string. |
| | |
| | Args: |
| | text: Text to embed |
| | |
| | Returns: |
| | List of floats representing the embedding vector |
| | """ |
| | try: |
| | result = self._client.models.embed_content( |
| | model=self._model_name, |
| | contents=text |
| | ) |
| | |
| | if hasattr(result, 'embeddings') and len(result.embeddings) > 0: |
| | embedding = result.embeddings[0] |
| | if hasattr(embedding, 'values'): |
| | return list(embedding.values) |
| |
|
| | raise ValueError("Unexpected response structure from Gemini embedding API") |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Error getting text embedding from Gemini: {str(e)}") |
| |
|
| | async def _aget_query_embedding(self, query: str) -> List[float]: |
| | """ |
| | Async version of _get_query_embedding. |
| | |
| | Note: Currently uses synchronous API as Gemini SDK doesn't have async support yet. |
| | |
| | Args: |
| | query: Query text to embed |
| | |
| | Returns: |
| | List of floats representing the embedding vector |
| | """ |
| | return self._get_query_embedding(query) |
| |
|
| | async def _aget_text_embedding(self, text: str) -> List[float]: |
| | """ |
| | Async version of _get_text_embedding. |
| | |
| | Note: Currently uses synchronous API as Gemini SDK doesn't have async support yet. |
| | |
| | Args: |
| | text: Text to embed |
| | |
| | Returns: |
| | List of floats representing the embedding vector |
| | """ |
| | return self._get_text_embedding(text) |
| |
|
| | def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: |
| | """ |
| | Get embeddings for a list of texts. |
| | |
| | Args: |
| | texts: List of texts to embed |
| | |
| | Returns: |
| | List of embedding vectors |
| | """ |
| | embeddings = [] |
| | for text in texts: |
| | embeddings.append(self._get_text_embedding(text)) |
| | return embeddings |
| |
|