multi-agent-system / app /core /llm_factory.py
firepenguindisopanda
Refactor code structure for improved readability and maintainability
1a608b5
"""
Centralized LLM client factory for provider management.
Features:
- Centralized provider configuration
- Rate limiting to avoid quota errors
- LangSmith tracing integration
- Role-specific model configurations
- Easy provider switching
"""
import os
from functools import lru_cache
from typing import Any
from dotenv import load_dotenv
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from .observability import get_logger
from .schemas import TeamRole
load_dotenv()
logger = get_logger("llm_factory")
DEFAULT_CHAT_MODEL = "meta/llama-3.1-70b-instruct"
DEFAULT_EMBEDDING_MODEL = "nvidia/nv-embedqa-e5-v5"
AGENT_CONFIGS: dict[TeamRole, dict[str, Any]] = {
# Phase 1
TeamRole.PROJECT_REFINER: {"temperature": 0.3, "max_tokens": 2048},
TeamRole.PRODUCT_OWNER: {"temperature": 0.5, "max_tokens": 4096},
# Phase 2
TeamRole.BUSINESS_ANALYST: {"temperature": 0.3, "max_tokens": 3072},
TeamRole.SOLUTION_ARCHITECT: {"temperature": 0.4, "max_tokens": 3072},
TeamRole.DATA_ARCHITECT: {"temperature": 0.3, "max_tokens": 4096},
TeamRole.SECURITY_ANALYST: {"temperature": 0.2, "max_tokens": 2048},
# Phase 3
TeamRole.UX_DESIGNER: {"temperature": 0.8, "max_tokens": 2048},
TeamRole.API_DESIGNER: {"temperature": 0.2, "max_tokens": 4096},
TeamRole.QA_STRATEGIST: {"temperature": 0.3, "max_tokens": 4096},
TeamRole.DEVOPS_ARCHITECT: {"temperature": 0.3, "max_tokens": 2048},
# Phase 4
TeamRole.ENVIRONMENT_ENGINEER: {"temperature": 0.3, "max_tokens": 2048},
TeamRole.TECHNICAL_WRITER: {"temperature": 0.5, "max_tokens": 3072},
# Phase 5 / Judge
TeamRole.SPEC_COORDINATOR: {"temperature": 0.3, "max_tokens": 4096},
TeamRole.JUDGE: {"temperature": 0.1, "max_tokens": 2048},
}
# Default configuration for unknown roles
DEFAULT_CONFIG = {"temperature": 0.7, "max_tokens": 2048}
def get_langsmith_callbacks() -> list[BaseCallbackHandler]:
"""
Get LangSmith callbacks if tracing is enabled.
Reads from environment:
- LANGSMITH_TRACING: "true" to enable
- LANGSMITH_API_KEY: API key for LangSmith
- LANGSMITH_PROJECT: Project name for traces
"""
if os.getenv("LANGSMITH_TRACING", "").lower() != "true":
return []
if not os.getenv("LANGSMITH_API_KEY"):
logger.warning("LANGSMITH_TRACING enabled but LANGSMITH_API_KEY not set")
return []
try:
from langchain_core.tracers import LangChainTracer
tracer = LangChainTracer(
project_name=os.getenv("LANGSMITH_PROJECT", "specs-before-code"),
)
logger.info(
"LangSmith tracing enabled",
data={"project": os.getenv("LANGSMITH_PROJECT")},
)
return [tracer]
except Exception as e:
logger.warning(f"Failed to initialize LangSmith tracer: {e}")
return []
def get_chat_model(
role: TeamRole | None = None,
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
streaming: bool = False,
callbacks: list[BaseCallbackHandler] | None = None,
) -> BaseChatModel:
"""
Factory function for chat model instances with centralized config.
Args:
role: Agent role for role-specific configuration
model: Model identifier (overrides env default)
temperature: Temperature (overrides role config)
max_tokens: Max tokens (overrides role config)
streaming: Enable streaming mode
callbacks: Additional callbacks (LangSmith added automatically)
Returns:
Configured ChatNVIDIA instance
Example:
>>> llm = get_chat_model(role=TeamRole.DEVELOPER)
>>> response = await llm.ainvoke(messages)
>>> # With streaming
>>> llm = get_chat_model(role=TeamRole.ANALYST, streaming=True)
>>> async for chunk in llm.astream(messages):
... print(chunk.content, end="")
"""
# Get role-specific config or defaults
config = AGENT_CONFIGS.get(role, DEFAULT_CONFIG) if role else DEFAULT_CONFIG
# Build callback list (LangSmith + custom)
all_callbacks = get_langsmith_callbacks()
if callbacks:
all_callbacks.extend(callbacks)
# Resolve model from env or parameter
resolved_model = model or os.getenv("CHAT_MODEL", DEFAULT_CHAT_MODEL)
llm = ChatNVIDIA(
model=resolved_model,
temperature=temperature if temperature is not None else config["temperature"],
max_tokens=max_tokens if max_tokens is not None else config["max_tokens"],
streaming=streaming,
callbacks=all_callbacks if all_callbacks else None,
)
logger.debug(
"Created chat model",
data={
"model": resolved_model,
"role": role.value if role else "default",
"streaming": streaming,
"temperature": llm.temperature,
},
)
return llm
def get_judge_model(
model: str | None = None,
callbacks: list[BaseCallbackHandler] | None = None,
) -> BaseChatModel:
"""
Get a chat model configured for judge/evaluation tasks.
Uses low temperature for consistent, deterministic evaluations.
"""
return get_chat_model(
role=TeamRole.JUDGE,
model=model,
temperature=0.1,
max_tokens=1024,
streaming=False, # Judges don't need streaming
callbacks=callbacks,
)
def get_summary_model(
model: str | None = None,
callbacks: list[BaseCallbackHandler] | None = None,
) -> BaseChatModel:
"""
Get a chat model configured for summarization tasks.
Uses the SUMMARY_MODEL from environment if set.
"""
resolved_model = model or os.getenv("SUMMARY_MODEL", os.getenv("CHAT_MODEL"))
return get_chat_model(
model=resolved_model,
temperature=0.3,
max_tokens=2048,
callbacks=callbacks,
)
@lru_cache(maxsize=1)
def get_embeddings_model() -> Embeddings:
"""
Get the embeddings model (singleton, cached).
Uses NVIDIA embeddings by default with truncation enabled.
"""
model = os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
embeddings = NVIDIAEmbeddings(
model=model,
truncate="END",
)
logger.info("Initialized embeddings model", data={"model": model})
return embeddings
def get_model_info() -> dict[str, str]:
"""Get current model configuration from environment."""
return {
"chat_model": os.getenv("CHAT_MODEL", DEFAULT_CHAT_MODEL),
"embedding_model": os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL),
"summary_model": os.getenv("SUMMARY_MODEL", DEFAULT_CHAT_MODEL),
"langsmith_enabled": os.getenv("LANGSMITH_TRACING", "false"),
"langsmith_project": os.getenv("LANGSMITH_PROJECT", ""),
}
def estimate_tokens(text: str) -> int:
"""
Rough estimate of tokens in text.
Uses ~4 characters per token average (good for English text).
For precise counts, use tiktoken or model-specific tokenizer.
"""
return len(text) // 4