| from __future__ import annotations |
|
|
| import logging |
| import os |
| import threading |
| from typing import Any |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _embedders: dict[str, Any] = {} |
| _inference_lock = threading.Lock() |
|
|
|
|
| def local_embedding_model() -> str: |
| return ( |
| os.environ.get("LOCAL_EMBEDDING_MODEL", "").strip() |
| or os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "").strip() |
| or os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip() |
| or "all-MiniLM-L6-v2" |
| ) |
|
|
|
|
| def embedding_model_name_task_a() -> str: |
| override = os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip() |
| return override or local_embedding_model() |
|
|
|
|
| def embedding_model_name_task_b() -> str: |
| override = os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "").strip() |
| return override or local_embedding_model() |
|
|
|
|
| def unique_embedding_model_names() -> list[str]: |
| names = {embedding_model_name_task_a(), embedding_model_name_task_b()} |
| return sorted(names) |
|
|
|
|
| def get_embedder(model_name: str) -> Any: |
| key = model_name.strip() |
| if key not in _embedders: |
| try: |
| from sentence_transformers import SentenceTransformer |
| except ImportError as e: |
| raise RuntimeError("sentence-transformers required") from e |
| logger.info("Loading shared embedding model %s", key) |
| _embedders[key] = SentenceTransformer(key) |
| return _embedders[key] |
|
|
|
|
| def inference_lock() -> threading.Lock: |
| return _inference_lock |
|
|
|
|
| def warm_shared_weights() -> None: |
| for name in unique_embedding_model_names(): |
| get_embedder(name) |
| logger.info( |
| "Shared weights ready (%d embedder(s); generation via Gemini API)", |
| len(_embedders), |
| ) |
|
|