Spaces:
Running
Running
| import os | |
| from typing import Literal | |
| from dotenv import load_dotenv | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from src.llm import LLMFactory, CostTrackingCallback | |
| from src.llm.tiers import ModelTier, model_for_agent | |
| load_dotenv() | |
| # Type alias for providers | |
| Provider = Literal["google", "openai", "anthropic"] | |
| class Config: | |
| """Application configuration with multi-provider LLM support.""" | |
| # Default model configuration — gemini-2.5-flash everywhere | |
| DEFAULT_MODEL = os.getenv("DEFAULT_LLM_MODEL", "gemini-2.5-flash") | |
| DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_LLM_TEMPERATURE", "0.7")) | |
| DEFAULT_PROVIDER: Provider = "google" | |
| # Embedding configuration | |
| EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "models/gemini-embedding-001") | |
| # Application configuration | |
| MAX_UPLOAD_LENGTH = 16 * 1024 * 1024 | |
| MAX_CONVERSATION_LENGTH = 100 | |
| MAX_CONTEXT_MESSAGES = 10 | |
| # Agent configuration | |
| AGENTS_CONFIG = { | |
| "agents": [ | |
| { | |
| "name": "crypto_data", | |
| "description": "Handles cryptocurrency-related queries", | |
| "type": "specialized", | |
| "enabled": True, | |
| "priority": 1, | |
| }, | |
| { | |
| "name": "general", | |
| "description": "Handles general conversation and queries", | |
| "type": "general", | |
| "enabled": True, | |
| "priority": 2, | |
| }, | |
| ] | |
| } | |
| # LangGraph configuration | |
| LANGGRAPH_CONFIG = { | |
| "max_iterations": 10, | |
| "timeout": 30, | |
| "memory_window": 10, | |
| "enable_memory": True, | |
| } | |
| # Conversation configuration | |
| CONVERSATION_CONFIG = { | |
| "default_user_id": "anonymous", | |
| "max_conversations_per_user": 50, | |
| "conversation_timeout_hours": 24, | |
| "enable_context_extraction": True, | |
| } | |
| # Instance caches | |
| _llm_instance: BaseChatModel | None = None | |
| _llm_fast_instance: BaseChatModel | None = None | |
| _llm_reasoning_instance: BaseChatModel | None = None | |
| _embeddings_instance: GoogleGenerativeAIEmbeddings | None = None | |
| _cost_tracker: CostTrackingCallback | None = None | |
| def get_llm( | |
| cls, | |
| model: str | None = None, | |
| temperature: float | None = None, | |
| with_cost_tracking: bool = True, | |
| ) -> BaseChatModel: | |
| model = model or cls.DEFAULT_MODEL | |
| temperature = temperature if temperature is not None else cls.DEFAULT_TEMPERATURE | |
| use_cache = model == cls.DEFAULT_MODEL and temperature == cls.DEFAULT_TEMPERATURE | |
| if use_cache and cls._llm_instance is not None: | |
| return cls._llm_instance | |
| callbacks = [] | |
| if with_cost_tracking: | |
| callbacks.append(cls.get_cost_tracker()) | |
| llm = LLMFactory.create( | |
| model=model, | |
| temperature=temperature, | |
| callbacks=callbacks if callbacks else None, | |
| use_cache=False, | |
| ) | |
| if use_cache: | |
| cls._llm_instance = llm | |
| return llm | |
| def get_fast_llm(cls, with_cost_tracking: bool = True) -> BaseChatModel: | |
| """Return a cached FAST-tier LLM (gemini-2.5-flash).""" | |
| if cls._llm_fast_instance is not None: | |
| return cls._llm_fast_instance | |
| callbacks = [] | |
| if with_cost_tracking: | |
| callbacks.append(cls.get_cost_tracker()) | |
| llm = LLMFactory.create( | |
| model=ModelTier.FAST, | |
| temperature=cls.DEFAULT_TEMPERATURE, | |
| callbacks=callbacks if callbacks else None, | |
| use_cache=False, | |
| ) | |
| cls._llm_fast_instance = llm | |
| return llm | |
| def get_llm_for_agent( | |
| cls, | |
| agent_name: str, | |
| with_cost_tracking: bool = True, | |
| ) -> BaseChatModel: | |
| """Return the optimal LLM for *agent_name* based on its tier.""" | |
| model = model_for_agent(agent_name) | |
| if model == ModelTier.FAST: | |
| return cls.get_fast_llm(with_cost_tracking=with_cost_tracking) | |
| return cls.get_llm(model=model, with_cost_tracking=with_cost_tracking) | |
| def get_reasoning_llm(cls, with_cost_tracking: bool = True) -> BaseChatModel: | |
| """Return a cached REASONING-tier LLM (gemini-3-flash-preview).""" | |
| if cls._llm_reasoning_instance is not None: | |
| return cls._llm_reasoning_instance | |
| callbacks = [] | |
| if with_cost_tracking: | |
| callbacks.append(cls.get_cost_tracker()) | |
| llm = LLMFactory.create( | |
| model=ModelTier.REASONING, | |
| temperature=cls.DEFAULT_TEMPERATURE, | |
| callbacks=callbacks if callbacks else None, | |
| use_cache=False, | |
| ) | |
| cls._llm_reasoning_instance = llm | |
| return llm | |
| def get_llm_for_mode( | |
| cls, | |
| mode: str = "fast", | |
| with_cost_tracking: bool = True, | |
| ) -> BaseChatModel: | |
| """Return the LLM for a given response mode ('fast' or 'reasoning').""" | |
| if mode == "reasoning": | |
| return cls.get_reasoning_llm(with_cost_tracking=with_cost_tracking) | |
| return cls.get_fast_llm(with_cost_tracking=with_cost_tracking) | |
| def get_embeddings(cls) -> GoogleGenerativeAIEmbeddings: | |
| """Get or create embeddings instance (singleton).""" | |
| if cls._embeddings_instance is None: | |
| cls._embeddings_instance = GoogleGenerativeAIEmbeddings( | |
| model=cls.EMBEDDING_MODEL, | |
| google_api_key=os.getenv("GEMINI_API_KEY"), | |
| ) | |
| return cls._embeddings_instance | |
| def get_cost_tracker(cls) -> CostTrackingCallback: | |
| """Get or create cost tracker instance (singleton).""" | |
| if cls._cost_tracker is None: | |
| cls._cost_tracker = CostTrackingCallback(log_calls=True) | |
| return cls._cost_tracker | |
| def get_agent_config(cls, agent_name: str) -> dict | None: | |
| for agent in cls.AGENTS_CONFIG["agents"]: | |
| if agent["name"] == agent_name: | |
| return agent | |
| return None | |
| def get_enabled_agents(cls) -> list[dict]: | |
| return [ | |
| agent | |
| for agent in cls.AGENTS_CONFIG["agents"] | |
| if agent.get("enabled", True) | |
| ] | |
| def list_available_models(cls) -> list[str]: | |
| return LLMFactory.list_models() | |
| def list_available_providers(cls) -> list[str]: | |
| return LLMFactory.list_providers() | |
| def validate_config(cls) -> bool: | |
| try: | |
| llm = cls.get_llm(with_cost_tracking=False) | |
| embeddings = cls.get_embeddings() | |
| return True | |
| except Exception as e: | |
| print(f"Configuration validation failed: {e}") | |
| return False | |
| def reset_instances(cls) -> None: | |
| cls._llm_instance = None | |
| cls._llm_fast_instance = None | |
| cls._llm_reasoning_instance = None | |
| cls._embeddings_instance = None | |
| if cls._cost_tracker: | |
| cls._cost_tracker.reset() | |
| LLMFactory.clear_cache() | |