|
|
from urllib.parse import urlparse |
|
|
|
|
|
import requests |
|
|
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings |
|
|
from pydantic.v1.types import SecretStr |
|
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
|
|
|
from langflow.base.embeddings.model import LCEmbeddingsModel |
|
|
from langflow.field_typing import Embeddings |
|
|
from langflow.io import MessageTextInput, Output, SecretStrInput |
|
|
|
|
|
|
|
|
class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel): |
|
|
display_name = "HuggingFace Embeddings Inference" |
|
|
description = "Generate embeddings using HuggingFace Text Embeddings Inference (TEI)" |
|
|
documentation = "https://huggingface.co/docs/text-embeddings-inference/index" |
|
|
icon = "HuggingFace" |
|
|
name = "HuggingFaceInferenceAPIEmbeddings" |
|
|
|
|
|
inputs = [ |
|
|
SecretStrInput( |
|
|
name="api_key", |
|
|
display_name="API Key", |
|
|
advanced=True, |
|
|
info="Required for non-local inference endpoints. Local inference does not require an API Key.", |
|
|
), |
|
|
MessageTextInput( |
|
|
name="inference_endpoint", |
|
|
display_name="Inference Endpoint", |
|
|
required=True, |
|
|
value="https://api-inference.huggingface.co/models/", |
|
|
info="Custom inference endpoint URL.", |
|
|
), |
|
|
MessageTextInput( |
|
|
name="model_name", |
|
|
display_name="Model Name", |
|
|
value="BAAI/bge-large-en-v1.5", |
|
|
info="The name of the model to use for text embeddings.", |
|
|
), |
|
|
] |
|
|
|
|
|
outputs = [ |
|
|
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), |
|
|
] |
|
|
|
|
|
def validate_inference_endpoint(self, inference_endpoint: str) -> bool: |
|
|
parsed_url = urlparse(inference_endpoint) |
|
|
if not all([parsed_url.scheme, parsed_url.netloc]): |
|
|
msg = ( |
|
|
f"Invalid inference endpoint format: '{self.inference_endpoint}'. " |
|
|
"Please ensure the URL includes both a scheme (e.g., 'http://' or 'https://') and a domain name. " |
|
|
"Example: 'http://localhost:8080' or 'https://api.example.com'" |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
try: |
|
|
response = requests.get(f"{inference_endpoint}/health", timeout=5) |
|
|
except requests.RequestException as e: |
|
|
msg = ( |
|
|
f"Inference endpoint '{inference_endpoint}' is not responding. " |
|
|
"Please ensure the URL is correct and the service is running." |
|
|
) |
|
|
raise ValueError(msg) from e |
|
|
|
|
|
if response.status_code != requests.codes.ok: |
|
|
msg = f"HuggingFace health check failed: {response.status_code}" |
|
|
raise ValueError(msg) |
|
|
|
|
|
return True |
|
|
|
|
|
def get_api_url(self) -> str: |
|
|
if "huggingface" in self.inference_endpoint.lower(): |
|
|
return f"{self.inference_endpoint}{self.model_name}" |
|
|
return self.inference_endpoint |
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) |
|
|
def create_huggingface_embeddings( |
|
|
self, api_key: SecretStr, api_url: str, model_name: str |
|
|
) -> HuggingFaceInferenceAPIEmbeddings: |
|
|
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=api_url, model_name=model_name) |
|
|
|
|
|
def build_embeddings(self) -> Embeddings: |
|
|
api_url = self.get_api_url() |
|
|
|
|
|
is_local_url = api_url.startswith(("http://localhost", "http://127.0.0.1")) |
|
|
|
|
|
if not self.api_key and is_local_url: |
|
|
self.validate_inference_endpoint(api_url) |
|
|
api_key = SecretStr("DummyAPIKeyForLocalDeployment") |
|
|
elif not self.api_key: |
|
|
msg = "API Key is required for non-local inference endpoints" |
|
|
raise ValueError(msg) |
|
|
else: |
|
|
api_key = SecretStr(self.api_key).get_secret_value() |
|
|
|
|
|
try: |
|
|
return self.create_huggingface_embeddings(api_key, api_url, self.model_name) |
|
|
except Exception as e: |
|
|
msg = "Could not connect to HuggingFace Inference API." |
|
|
raise ValueError(msg) from e |
|
|
|