from __future__ import annotations import logging import os import threading from typing import Any logger = logging.getLogger(__name__) _embedders: dict[str, Any] = {} _inference_lock = threading.Lock() def local_embedding_model() -> str: return ( os.environ.get("LOCAL_EMBEDDING_MODEL", "").strip() or os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "").strip() or os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip() or "all-MiniLM-L6-v2" ) def embedding_model_name_task_a() -> str: override = os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip() return override or local_embedding_model() def embedding_model_name_task_b() -> str: override = os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "").strip() return override or local_embedding_model() def unique_embedding_model_names() -> list[str]: names = {embedding_model_name_task_a(), embedding_model_name_task_b()} return sorted(names) def get_embedder(model_name: str) -> Any: key = model_name.strip() if key not in _embedders: try: from sentence_transformers import SentenceTransformer # type: ignore[import-untyped] except ImportError as e: raise RuntimeError("sentence-transformers required") from e logger.info("Loading shared embedding model %s", key) _embedders[key] = SentenceTransformer(key) return _embedders[key] def inference_lock() -> threading.Lock: return _inference_lock def warm_shared_weights() -> None: for name in unique_embedding_model_names(): get_embedder(name) logger.info( "Shared weights ready (%d embedder(s); generation via Gemini API)", len(_embedders), )