document-qa-dev / document_qa /custom_embeddings.py
lfoppiano's picture
Upload folder using huggingface_hub
6f06d5d verified
Raw
History Blame Contribute Delete
3.25 kB
"""Custom LangChain-compatible embedding client.
Provides :class:`ModalEmbeddings`, a drop-in ``Embeddings`` implementation
that calls any service exposing an ``/embeddings`` endpoint (OpenAI,
vLLM, Modal, LM Studio, etc.).
"""
from typing import List
import requests
from langchain_core.embeddings import Embeddings
class ModalEmbeddings(Embeddings):
"""LangChain ``Embeddings`` backed by an OpenAI-compatible HTTP API.
The service must expose a ``POST /embeddings`` endpoint that accepts
``{"model": "…", "input": ["…"]}`` and returns the standard OpenAI
response shape.
Args:
url: Base URL of the embedding service (e.g. ``"http://localhost:1234/v1"``).
model_name: Model identifier(e.g. ``"intfloat/multilingual-e5-large-instruct"``).
api_key: Optional bearer token for authenticated endpoints.
"""
def __init__(self, url: str, model_name: str, api_key: str = None):
self.url = url
self.model_name = model_name
self.api_key = api_key
def embed(self, text: List[str]) -> List[List[float]]:
"""Embed a list of texts via the configured API.
Newlines are replaced with spaces before sending, since most
embedding models treat them as noise.
Args:
text: Strings to embed.
Returns:
list[list[float]]: One embedding vector per input string.
Raises:
requests.HTTPError: If the API returns a non-2xx status.
"""
# Newlines degrade embedding quality for most models
cleaned_text = [t.replace("\n", " ") for t in text]
payload = {"text": "\n".join(cleaned_text)}
headers = {}
if self.api_key:
headers = {"x-api-key": self.api_key}
response = requests.post(self.url, data=payload, files=[], headers=headers)
response.raise_for_status()
# print(response.text)
return response.json()
def embed_documents(self, text: List[str]) -> List[List[str]]:
"""Embed multiple documents (LangChain interface).
Args:
text: Document strings to embed.
Returns:
list[list[float]]: One embedding vector per document.
"""
return self.embed(text)
def embed_query(self, text: str) -> List[List[str]]:
"""Embed a single query string (LangChain interface).
Args:
text: The query string.
Returns:
list[float]: The embedding vector for *text*.
"""
return self.embed([text])[0]
def get_model_name(self) -> str:
"""Return the model identifier used for embedding requests."""
return self.model_name
if __name__ == "__main__":
# Smoke test against a deployed Modal embedding endpoint. The endpoint requires
# the x-api-key header, so set EMBEDS_URL and EMBEDS_API_KEY in the environment
# (see document_qa/deployment/README.md).
import os
embeds = ModalEmbeddings(
url=os.environ["EMBEDS_URL"],
model_name="intfloat/multilingual-e5-large-instruct",
api_key=os.environ.get("EMBEDS_API_KEY"),
)
print(embeds.embed(["We are surrounded by stupid kids", "We are interested in the future of AI"]))