zico-agent / src /agents /config.py
github-actions[bot]
Deploy from GitHub Actions: b307fb4145d2a061408e2bf8587cb6dd8a90be79
171371a
import os
from typing import Literal
from dotenv import load_dotenv
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from src.llm import LLMFactory, CostTrackingCallback
from src.llm.tiers import ModelTier, model_for_agent
load_dotenv()
# Type alias for providers
Provider = Literal["google", "openai", "anthropic"]
class Config:
"""Application configuration with multi-provider LLM support."""
# Default model configuration — gemini-2.5-flash everywhere
DEFAULT_MODEL = os.getenv("DEFAULT_LLM_MODEL", "gemini-2.5-flash")
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7"))
DEFAULT_PROVIDER: Provider = "google"
# Embedding configuration
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "models/gemini-embedding-001")
# Application configuration
MAX_UPLOAD_LENGTH = 16 * 1024 * 1024
MAX_CONVERSATION_LENGTH = 100
MAX_CONTEXT_MESSAGES = 10
# Agent configuration
AGENTS_CONFIG = {
"agents": [
{
"name": "crypto_data",
"description": "Handles cryptocurrency-related queries",
"type": "specialized",
"enabled": True,
"priority": 1,
},
{
"name": "general",
"description": "Handles general conversation and queries",
"type": "general",
"enabled": True,
"priority": 2,
},
]
}
# LangGraph configuration
LANGGRAPH_CONFIG = {
"max_iterations": 10,
"timeout": 30,
"memory_window": 10,
"enable_memory": True,
}
# Conversation configuration
CONVERSATION_CONFIG = {
"default_user_id": "anonymous",
"max_conversations_per_user": 50,
"conversation_timeout_hours": 24,
"enable_context_extraction": True,
}
# Instance caches
_llm_instance: BaseChatModel | None = None
_llm_fast_instance: BaseChatModel | None = None
_llm_reasoning_instance: BaseChatModel | None = None
_embeddings_instance: GoogleGenerativeAIEmbeddings | None = None
_cost_tracker: CostTrackingCallback | None = None
@classmethod
def get_llm(
cls,
model: str | None = None,
temperature: float | None = None,
with_cost_tracking: bool = True,
) -> BaseChatModel:
model = model or cls.DEFAULT_MODEL
temperature = temperature if temperature is not None else cls.DEFAULT_TEMPERATURE
use_cache = model == cls.DEFAULT_MODEL and temperature == cls.DEFAULT_TEMPERATURE
if use_cache and cls._llm_instance is not None:
return cls._llm_instance
callbacks = []
if with_cost_tracking:
callbacks.append(cls.get_cost_tracker())
llm = LLMFactory.create(
model=model,
temperature=temperature,
callbacks=callbacks if callbacks else None,
use_cache=False,
)
if use_cache:
cls._llm_instance = llm
return llm
@classmethod
def get_fast_llm(cls, with_cost_tracking: bool = True) -> BaseChatModel:
"""Return a cached FAST-tier LLM (gemini-2.5-flash)."""
if cls._llm_fast_instance is not None:
return cls._llm_fast_instance
callbacks = []
if with_cost_tracking:
callbacks.append(cls.get_cost_tracker())
llm = LLMFactory.create(
model=ModelTier.FAST,
temperature=cls.DEFAULT_TEMPERATURE,
callbacks=callbacks if callbacks else None,
use_cache=False,
)
cls._llm_fast_instance = llm
return llm
@classmethod
def get_llm_for_agent(
cls,
agent_name: str,
with_cost_tracking: bool = True,
) -> BaseChatModel:
"""Return the optimal LLM for *agent_name* based on its tier."""
model = model_for_agent(agent_name)
if model == ModelTier.FAST:
return cls.get_fast_llm(with_cost_tracking=with_cost_tracking)
return cls.get_llm(model=model, with_cost_tracking=with_cost_tracking)
@classmethod
def get_reasoning_llm(cls, with_cost_tracking: bool = True) -> BaseChatModel:
"""Return a cached REASONING-tier LLM (gemini-3-flash-preview)."""
if cls._llm_reasoning_instance is not None:
return cls._llm_reasoning_instance
callbacks = []
if with_cost_tracking:
callbacks.append(cls.get_cost_tracker())
llm = LLMFactory.create(
model=ModelTier.REASONING,
temperature=cls.DEFAULT_TEMPERATURE,
callbacks=callbacks if callbacks else None,
use_cache=False,
)
cls._llm_reasoning_instance = llm
return llm
@classmethod
def get_llm_for_mode(
cls,
mode: str = "fast",
with_cost_tracking: bool = True,
) -> BaseChatModel:
"""Return the LLM for a given response mode ('fast' or 'reasoning')."""
if mode == "reasoning":
return cls.get_reasoning_llm(with_cost_tracking=with_cost_tracking)
return cls.get_fast_llm(with_cost_tracking=with_cost_tracking)
@classmethod
def get_embeddings(cls) -> GoogleGenerativeAIEmbeddings:
"""Get or create embeddings instance (singleton)."""
if cls._embeddings_instance is None:
cls._embeddings_instance = GoogleGenerativeAIEmbeddings(
model=cls.EMBEDDING_MODEL,
google_api_key=os.getenv("GEMINI_API_KEY"),
)
return cls._embeddings_instance
@classmethod
def get_cost_tracker(cls) -> CostTrackingCallback:
"""Get or create cost tracker instance (singleton)."""
if cls._cost_tracker is None:
cls._cost_tracker = CostTrackingCallback(log_calls=True)
return cls._cost_tracker
@classmethod
def get_agent_config(cls, agent_name: str) -> dict | None:
for agent in cls.AGENTS_CONFIG["agents"]:
if agent["name"] == agent_name:
return agent
return None
@classmethod
def get_enabled_agents(cls) -> list[dict]:
return [
agent
for agent in cls.AGENTS_CONFIG["agents"]
if agent.get("enabled", True)
]
@classmethod
def list_available_models(cls) -> list[str]:
return LLMFactory.list_models()
@classmethod
def list_available_providers(cls) -> list[str]:
return LLMFactory.list_providers()
@classmethod
def validate_config(cls) -> bool:
try:
llm = cls.get_llm(with_cost_tracking=False)
embeddings = cls.get_embeddings()
return True
except Exception as e:
print(f"Configuration validation failed: {e}")
return False
@classmethod
def reset_instances(cls) -> None:
cls._llm_instance = None
cls._llm_fast_instance = None
cls._llm_reasoning_instance = None
cls._embeddings_instance = None
if cls._cost_tracker:
cls._cost_tracker.reset()
LLMFactory.clear_cache()