File size: 4,150 Bytes
461adca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
Custom Gemini embedding class for LlamaIndex integration.
Provides an alternative to OpenAI embeddings using Google's Gemini API.
"""
from typing import List
from llama_index.core.embeddings import BaseEmbedding
from google import genai
class GeminiEmbedding(BaseEmbedding):
"""
Gemini embedding model integration for LlamaIndex.
Uses Google's gemini-embedding-001 model for generating embeddings.
This provides an alternative to OpenAI embeddings.
"""
def __init__(
self,
api_key: str,
model_name: str = "gemini-embedding-001",
**kwargs
):
"""
Initialize Gemini embedding model.
Args:
api_key: Google API key for Gemini
model_name: Model name (default: gemini-embedding-001)
**kwargs: Additional arguments for BaseEmbedding
"""
super().__init__(**kwargs)
# Use private attribute to store client (Pydantic compatibility)
self._client = genai.Client(api_key=api_key)
self._model_name = model_name
def _get_query_embedding(self, query: str) -> List[float]:
"""
Get embedding for a query string.
Args:
query: Query text to embed
Returns:
List of floats representing the embedding vector
"""
try:
result = self._client.models.embed_content(
model=self._model_name,
contents=query
)
# Extract embedding values from the response
# The response structure is: result.embeddings[0].values
if hasattr(result, 'embeddings') and len(result.embeddings) > 0:
embedding = result.embeddings[0]
if hasattr(embedding, 'values'):
return list(embedding.values)
raise ValueError("Unexpected response structure from Gemini embedding API")
except Exception as e:
raise RuntimeError(f"Error getting query embedding from Gemini: {str(e)}")
def _get_text_embedding(self, text: str) -> List[float]:
"""
Get embedding for a text string.
Args:
text: Text to embed
Returns:
List of floats representing the embedding vector
"""
try:
result = self._client.models.embed_content(
model=self._model_name,
contents=text
)
# Extract embedding values from the response
if hasattr(result, 'embeddings') and len(result.embeddings) > 0:
embedding = result.embeddings[0]
if hasattr(embedding, 'values'):
return list(embedding.values)
raise ValueError("Unexpected response structure from Gemini embedding API")
except Exception as e:
raise RuntimeError(f"Error getting text embedding from Gemini: {str(e)}")
async def _aget_query_embedding(self, query: str) -> List[float]:
"""
Async version of _get_query_embedding.
Note: Currently uses synchronous API as Gemini SDK doesn't have async support yet.
Args:
query: Query text to embed
Returns:
List of floats representing the embedding vector
"""
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""
Async version of _get_text_embedding.
Note: Currently uses synchronous API as Gemini SDK doesn't have async support yet.
Args:
text: Text to embed
Returns:
List of floats representing the embedding vector
"""
return self._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""
Get embeddings for a list of texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
"""
embeddings = []
for text in texts:
embeddings.append(self._get_text_embedding(text))
return embeddings
|