api-embedding / vectordb /embedding_client.py
fahmiaziz98
add example using pinecone vectordb, add pubmed data sample, eval retrieval (mrr, hit rate, precission@k)
6d882b2
from httpx import AsyncClient, HTTPError, TimeoutException
from typing import List, Dict
import asyncio
class EmbeddingAPIClient:
"""
A client for interacting with an embedding API.
Attributes:
base_url (str): The base URL of the embedding API.
timeout (int): The timeout duration for requests in seconds.
max_retries (int): The maximum number of retry attempts for failed requests.
client (AsyncClient): An instance of AsyncClient for making HTTP requests.
"""
def __init__(self, base_url: str, timeout: int = 60, max_retries: int = 3) -> None:
"""
Initializes the EmbeddingAPIClient with the specified parameters.
"""
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.client = AsyncClient(base_url=base_url, timeout=timeout)
async def _make_request_with_retry(
self, endpoint: str, payload: Dict, retry_count: int = 0
):
"""
Helper method to make a POST request with retry logic.
Args:
endpoint (str): The endpoint URL to which the request is sent.
payload (Dict): The JSON data to be sent in the request.
retry_count (int, optional): The current retry attempt count. Defaults to 0.
Returns:
Dict: The JSON response from the API.
Raises:
Exception: If the request fails after the maximum number of retries.
"""
try:
response = await self.client.post(endpoint, json=payload)
response.raise_for_status()
return response.json()
except (HTTPError, TimeoutException) as e:
if retry_count < self.max_retries:
wait_time = 2**retry_count
print(
f"⚠️ Request failed, retrying in {wait_time}s... (attempt {retry_count + 1}/{self.max_retries})"
)
await asyncio.sleep(wait_time)
return await self._make_request_with_retry(
endpoint, payload, retry_count + 1
)
else:
raise Exception(f"❌ Failed after {self.max_retries} retries: {str(e)}")
async def get_dense_embeddings(
self, texts: List[str], model: str = "qwen3-0.6b"
) -> List[List[float]]:
"""
Retrieve dense embeddings from the API.
Args:
texts (List[str]): A list of texts for which to retrieve embeddings.
model (str): The model to use for generating embeddings. Defaults to "qwen3-0.6b".
Returns:
List[List[float]]: A list of dense embeddings corresponding to the input texts.
"""
data = await self._make_request_with_retry(
"/embeddings", {"input": texts, "model": model}
)
return [item["embedding"] for item in data["data"]]
async def get_sparse_embeddings(
self, texts: List[str], model: str = "splade-large-query"
) -> List[Dict[str, List]]:
"""
Retrieve sparse embeddings from the API.
Args:
texts (List[str]): A list of texts for which to retrieve embeddings.
model (str): The model to use for generating embeddings. Defaults to "splade-large-query".
Returns:
List[Dict[str, List]]: A list of sparse embeddings corresponding to the input texts.
"""
data = await self._make_request_with_retry(
"/embed_sparse", {"input": texts, "model": model}
)
return data["embeddings"]
async def rerank_documents(
self, query: str, documents: List[str], top_k: int = 5, model: str = "bge-v2-m3"
) -> List[Dict]:
"""
Rerank a list of documents based on a query.
Args:
query (str): The query string used for reranking.
documents (List[str]): A list of documents to be reranked.
top_k (int): The number of top documents to return. Defaults to 5.
model (str): The model to use for reranking. Defaults to "bge-v2-m3".
Returns:
List[Dict]: A list of reranked documents with their scores.
"""
data = await self._make_request_with_retry(
"/rerank",
{"query": query, "documents": documents, "top_k": top_k, "model": model},
)
return data["results"]
async def close(self):
"""
Close the HTTP client.
This method closes the AsyncClient instance to free up resources.
"""
await self.client.aclose()