|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from google.ai.generativelanguage_v1beta.types import BatchEmbedContentsRequest |
|
|
from langchain_core.embeddings import Embeddings |
|
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings |
|
|
from langchain_google_genai._common import GoogleGenerativeAIError |
|
|
|
|
|
from langflow.custom import Component |
|
|
from langflow.io import MessageTextInput, Output, SecretStrInput |
|
|
|
|
|
|
|
|
class GoogleGenerativeAIEmbeddingsComponent(Component): |
|
|
display_name = "Google Generative AI Embeddings" |
|
|
description = ( |
|
|
"Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, " |
|
|
"found in the langchain-google-genai package." |
|
|
) |
|
|
documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/" |
|
|
icon = "Google" |
|
|
name = "Google Generative AI Embeddings" |
|
|
|
|
|
inputs = [ |
|
|
SecretStrInput(name="api_key", display_name="API Key"), |
|
|
MessageTextInput(name="model_name", display_name="Model Name", value="models/text-embedding-004"), |
|
|
] |
|
|
|
|
|
outputs = [ |
|
|
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), |
|
|
] |
|
|
|
|
|
def build_embeddings(self) -> Embeddings: |
|
|
if not self.api_key: |
|
|
msg = "API Key is required" |
|
|
raise ValueError(msg) |
|
|
|
|
|
class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings): |
|
|
def __init__(self, *args, **kwargs) -> None: |
|
|
super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs) |
|
|
|
|
|
def embed_documents( |
|
|
self, |
|
|
texts: list[str], |
|
|
*, |
|
|
batch_size: int = 100, |
|
|
task_type: str | None = None, |
|
|
titles: list[str] | None = None, |
|
|
output_dimensionality: int | None = 1536, |
|
|
) -> list[list[float]]: |
|
|
"""Embed a list of strings. |
|
|
|
|
|
Google Generative AI currently sets a max batch size of 100 strings. |
|
|
|
|
|
Args: |
|
|
texts: List[str] The list of strings to embed. |
|
|
batch_size: [int] The batch size of embeddings to send to the model |
|
|
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) |
|
|
titles: An optional list of titles for texts provided. |
|
|
Only applicable when TaskType is RETRIEVAL_DOCUMENT. |
|
|
output_dimensionality: Optional reduced dimension for the output embedding. |
|
|
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest |
|
|
Returns: |
|
|
List of embeddings, one for each text. |
|
|
""" |
|
|
embeddings: list[list[float]] = [] |
|
|
batch_start_index = 0 |
|
|
for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size): |
|
|
if titles: |
|
|
titles_batch = titles[batch_start_index : batch_start_index + len(batch)] |
|
|
batch_start_index += len(batch) |
|
|
else: |
|
|
titles_batch = [None] * len(batch) |
|
|
|
|
|
requests = [ |
|
|
self._prepare_request( |
|
|
text=text, |
|
|
task_type=task_type, |
|
|
title=title, |
|
|
output_dimensionality=output_dimensionality, |
|
|
) |
|
|
for text, title in zip(batch, titles_batch, strict=True) |
|
|
] |
|
|
|
|
|
try: |
|
|
result = self.client.batch_embed_contents( |
|
|
BatchEmbedContentsRequest(requests=requests, model=self.model) |
|
|
) |
|
|
except Exception as e: |
|
|
msg = f"Error embedding content: {e}" |
|
|
raise GoogleGenerativeAIError(msg) from e |
|
|
embeddings.extend([list(np.pad(e.values, (0, 768), "constant")) for e in result.embeddings]) |
|
|
return embeddings |
|
|
|
|
|
def embed_query( |
|
|
self, |
|
|
text: str, |
|
|
task_type: str | None = None, |
|
|
title: str | None = None, |
|
|
output_dimensionality: int | None = 1536, |
|
|
) -> list[float]: |
|
|
"""Embed a text. |
|
|
|
|
|
Args: |
|
|
text: The text to embed. |
|
|
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) |
|
|
title: An optional title for the text. |
|
|
Only applicable when TaskType is RETRIEVAL_DOCUMENT. |
|
|
output_dimensionality: Optional reduced dimension for the output embedding. |
|
|
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest |
|
|
|
|
|
Returns: |
|
|
Embedding for the text. |
|
|
""" |
|
|
task_type = task_type or "RETRIEVAL_QUERY" |
|
|
return self.embed_documents( |
|
|
[text], |
|
|
task_type=task_type, |
|
|
titles=[title] if title else None, |
|
|
output_dimensionality=output_dimensionality, |
|
|
)[0] |
|
|
|
|
|
return HotaGoogleGenerativeAIEmbeddings(model=self.model_name, google_api_key=self.api_key) |
|
|
|