DSN / app /shared_models.py
nexusbert's picture
Refactor agent workflow and update documentation for Gemini-first implementation
d47b370
Raw
History Blame Contribute Delete
1.73 kB
from __future__ import annotations
import logging
import os
import threading
from typing import Any
logger = logging.getLogger(__name__)
_embedders: dict[str, Any] = {}
_inference_lock = threading.Lock()
def local_embedding_model() -> str:
return (
os.environ.get("LOCAL_EMBEDDING_MODEL", "").strip()
or os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "").strip()
or os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip()
or "all-MiniLM-L6-v2"
)
def embedding_model_name_task_a() -> str:
override = os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip()
return override or local_embedding_model()
def embedding_model_name_task_b() -> str:
override = os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "").strip()
return override or local_embedding_model()
def unique_embedding_model_names() -> list[str]:
names = {embedding_model_name_task_a(), embedding_model_name_task_b()}
return sorted(names)
def get_embedder(model_name: str) -> Any:
key = model_name.strip()
if key not in _embedders:
try:
from sentence_transformers import SentenceTransformer # type: ignore[import-untyped]
except ImportError as e:
raise RuntimeError("sentence-transformers required") from e
logger.info("Loading shared embedding model %s", key)
_embedders[key] = SentenceTransformer(key)
return _embedders[key]
def inference_lock() -> threading.Lock:
return _inference_lock
def warm_shared_weights() -> None:
for name in unique_embedding_model_names():
get_embedder(name)
logger.info(
"Shared weights ready (%d embedder(s); generation via Gemini API)",
len(_embedders),
)