import requests from typing import List, Union from config import settings class BGEM3Embedder: """Wraps the Modal-deployed BGE-M3 embedding endpoint.""" BASE_URL = settings.BGE_M3_URL def __init__(self, timeout: int = 120): self.timeout = timeout self.session = requests.Session() def embed(self, text: str, normalize: bool = True, max_length: int = 8192) -> List[float]: """Embed a single text.""" payload = { "input": [text], "normalize_embeddings": normalize, "max_length": max_length } response = self.session.post( f"{self.BASE_URL}/embed", json=payload, timeout=self.timeout ) response.raise_for_status() data = response.json() return data["embeddings"][0] def embed_many( self, texts: List[str], normalize: bool = True, max_length: int = 8192, batch_size: int = 16 # Optional: split large lists to avoid timeout ) -> List[List[float]]: """Embed multiple text strings with optional batching.""" all_embeddings = [] # Process in batches to stay within timeout limits for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] payload = { "input": batch, "normalize_embeddings": normalize, "max_length": max_length } response = self.session.post( f"{self.BASE_URL}/embed", json=payload, timeout=self.timeout ) response.raise_for_status() data = response.json() all_embeddings.extend(data["embeddings"]) return all_embeddings def health_check(self) -> bool: """Check if the Modal service is healthy.""" try: response = self.session.get(f"{self.BASE_URL}/health", timeout=30) return response.status_code == 200 except requests.RequestException: return False def get_model_info(self) -> dict: """Fetch model metadata from the service.""" response = self.session.get(f"{self.BASE_URL}/model_info", timeout=30) response.raise_for_status() return response.json()