|
|
from enum import StrEnum |
|
|
from json import loads |
|
|
from typing import Annotated, Any |
|
|
|
|
|
from dotenv import find_dotenv |
|
|
from pydantic import ( |
|
|
BeforeValidator, |
|
|
Field, |
|
|
HttpUrl, |
|
|
SecretStr, |
|
|
TypeAdapter, |
|
|
computed_field, |
|
|
) |
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict |
|
|
|
|
|
from schema.models import ( |
|
|
AllModelEnum, |
|
|
AnthropicModelName, |
|
|
AWSModelName, |
|
|
AzureOpenAIModelName, |
|
|
DeepseekModelName, |
|
|
FakeModelName, |
|
|
GoogleModelName, |
|
|
GroqModelName, |
|
|
OllamaModelName, |
|
|
OpenAICompatibleName, |
|
|
OpenAIModelName, |
|
|
OpenRouterModelName, |
|
|
Provider, |
|
|
VertexAIModelName, |
|
|
AllEmbeddingModelEnum, |
|
|
OpenAIEmbeddingModelName, |
|
|
GoogleEmbeddingModelName, |
|
|
OllamaEmbeddingModelName, |
|
|
) |
|
|
|
|
|
|
|
|
class DatabaseType(StrEnum): |
|
|
SQLITE = "sqlite" |
|
|
POSTGRES = "postgres" |
|
|
MONGO = "mongo" |
|
|
|
|
|
|
|
|
class LogLevel(StrEnum): |
|
|
DEBUG = "DEBUG" |
|
|
INFO = "INFO" |
|
|
WARNING = "WARNING" |
|
|
ERROR = "ERROR" |
|
|
CRITICAL = "CRITICAL" |
|
|
|
|
|
def to_logging_level(self) -> int: |
|
|
"""Convert to Python logging level constant.""" |
|
|
import logging |
|
|
|
|
|
mapping = { |
|
|
LogLevel.DEBUG: logging.DEBUG, |
|
|
LogLevel.INFO: logging.INFO, |
|
|
LogLevel.WARNING: logging.WARNING, |
|
|
LogLevel.ERROR: logging.ERROR, |
|
|
LogLevel.CRITICAL: logging.CRITICAL, |
|
|
} |
|
|
return mapping[self] |
|
|
|
|
|
|
|
|
def check_str_is_http(x: str) -> str: |
|
|
http_url_adapter = TypeAdapter(HttpUrl) |
|
|
return str(http_url_adapter.validate_python(x)) |
|
|
|
|
|
|
|
|
class Settings(BaseSettings): |
|
|
model_config = SettingsConfigDict( |
|
|
env_file=find_dotenv(), |
|
|
env_file_encoding="utf-8", |
|
|
env_ignore_empty=True, |
|
|
extra="ignore", |
|
|
validate_default=False, |
|
|
) |
|
|
MODE: str | None = None |
|
|
|
|
|
HOST: str = "0.0.0.0" |
|
|
PORT: int = 7860 |
|
|
GRACEFUL_SHUTDOWN_TIMEOUT: int = 30 |
|
|
LOG_LEVEL: LogLevel = LogLevel.WARNING |
|
|
|
|
|
AUTH_SECRET: SecretStr | None = None |
|
|
CORS_ORIGINS: Annotated[Any, BeforeValidator(lambda x: x.split(",") if isinstance(x, str) else x)] = [ |
|
|
"http://localhost:3000", |
|
|
"http://localhost:8081", |
|
|
"http://localhost:5173", |
|
|
] |
|
|
|
|
|
OPENAI_API_KEY: SecretStr | None = None |
|
|
DEEPSEEK_API_KEY: SecretStr | None = None |
|
|
ANTHROPIC_API_KEY: SecretStr | None = None |
|
|
GOOGLE_API_KEY: SecretStr | None = None |
|
|
GOOGLE_APPLICATION_CREDENTIALS: SecretStr | None = None |
|
|
GROQ_API_KEY: SecretStr | None = None |
|
|
USE_AWS_BEDROCK: bool = False |
|
|
OLLAMA_MODEL: str | None = None |
|
|
OLLAMA_BASE_URL: str | None = None |
|
|
USE_FAKE_MODEL: bool = False |
|
|
OPENROUTER_API_KEY: str | None = None |
|
|
|
|
|
|
|
|
DEFAULT_MODEL: AllModelEnum | None = None |
|
|
AVAILABLE_MODELS: set[AllModelEnum] = set() |
|
|
|
|
|
|
|
|
DEFAULT_EMBEDDING_MODEL: AllEmbeddingModelEnum | None = None |
|
|
AVAILABLE_EMBEDDING_MODELS: set[AllEmbeddingModelEnum] = set() |
|
|
OLLAMA_EMBEDDING_MODEL: str | None = None |
|
|
|
|
|
|
|
|
COMPATIBLE_MODEL: str | None = None |
|
|
COMPATIBLE_API_KEY: SecretStr | None = None |
|
|
COMPATIBLE_BASE_URL: str | None = None |
|
|
|
|
|
OPENWEATHERMAP_API_KEY: SecretStr | None = None |
|
|
|
|
|
|
|
|
GITHUB_PAT: SecretStr | None = None |
|
|
MCP_GITHUB_SERVER_URL: str = "https://api.githubcopilot.com/mcp/" |
|
|
|
|
|
LANGCHAIN_TRACING_V2: bool = False |
|
|
LANGCHAIN_PROJECT: str = "default" |
|
|
LANGCHAIN_ENDPOINT: Annotated[str, BeforeValidator(check_str_is_http)] = ( |
|
|
"https://api.smith.langchain.com" |
|
|
) |
|
|
LANGCHAIN_API_KEY: SecretStr | None = None |
|
|
|
|
|
LANGFUSE_TRACING: bool = False |
|
|
LANGFUSE_HOST: Annotated[str, BeforeValidator(check_str_is_http)] = "https://cloud.langfuse.com" |
|
|
LANGFUSE_PUBLIC_KEY: SecretStr | None = None |
|
|
LANGFUSE_SECRET_KEY: SecretStr | None = None |
|
|
|
|
|
|
|
|
DATABASE_TYPE: DatabaseType = ( |
|
|
DatabaseType.SQLITE |
|
|
) |
|
|
SQLITE_DB_PATH: str = "checkpoints.db" |
|
|
|
|
|
|
|
|
POSTGRES_URL: SecretStr | None = None |
|
|
POSTGRES_USER: str | None = None |
|
|
POSTGRES_PASSWORD: SecretStr | None = None |
|
|
POSTGRES_HOST: str | None = None |
|
|
POSTGRES_PORT: int | None = None |
|
|
POSTGRES_DB: str | None = None |
|
|
POSTGRES_APPLICATION_NAME: str = "agent-service-toolkit" |
|
|
POSTGRES_MIN_CONNECTIONS_PER_POOL: int = 1 |
|
|
POSTGRES_MAX_CONNECTIONS_PER_POOL: int = 1 |
|
|
|
|
|
|
|
|
MONGO_HOST: str | None = None |
|
|
MONGO_PORT: int | None = None |
|
|
MONGO_DB: str | None = None |
|
|
MONGO_USER: str | None = None |
|
|
MONGO_PASSWORD: SecretStr | None = None |
|
|
MONGO_AUTH_SOURCE: str | None = None |
|
|
|
|
|
|
|
|
AZURE_OPENAI_API_KEY: SecretStr | None = None |
|
|
AZURE_OPENAI_ENDPOINT: str | None = None |
|
|
AZURE_OPENAI_API_VERSION: str = "2024-02-15-preview" |
|
|
AZURE_OPENAI_DEPLOYMENT_MAP: dict[str, str] = Field( |
|
|
default_factory=dict, description="Map of model names to Azure deployment IDs" |
|
|
) |
|
|
|
|
|
def model_post_init(self, __context: Any) -> None: |
|
|
api_keys = { |
|
|
Provider.OPENAI: self.OPENAI_API_KEY, |
|
|
Provider.OPENAI_COMPATIBLE: self.COMPATIBLE_BASE_URL and self.COMPATIBLE_MODEL, |
|
|
Provider.DEEPSEEK: self.DEEPSEEK_API_KEY, |
|
|
Provider.ANTHROPIC: self.ANTHROPIC_API_KEY, |
|
|
Provider.GOOGLE: self.GOOGLE_API_KEY, |
|
|
Provider.VERTEXAI: self.GOOGLE_APPLICATION_CREDENTIALS, |
|
|
Provider.GROQ: self.GROQ_API_KEY, |
|
|
Provider.AWS: self.USE_AWS_BEDROCK, |
|
|
Provider.OLLAMA: self.OLLAMA_MODEL, |
|
|
Provider.FAKE: self.USE_FAKE_MODEL, |
|
|
Provider.AZURE_OPENAI: self.AZURE_OPENAI_API_KEY, |
|
|
Provider.OPENROUTER: self.OPENROUTER_API_KEY, |
|
|
} |
|
|
active_keys = [k for k, v in api_keys.items() if v] |
|
|
if not active_keys: |
|
|
raise ValueError("At least one LLM API key must be provided.") |
|
|
|
|
|
for provider in active_keys: |
|
|
match provider: |
|
|
case Provider.OPENAI: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = OpenAIModelName.GPT_5_NANO |
|
|
self.AVAILABLE_MODELS.update(set(OpenAIModelName)) |
|
|
case Provider.OPENAI_COMPATIBLE: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = OpenAICompatibleName.OPENAI_COMPATIBLE |
|
|
self.AVAILABLE_MODELS.update(set(OpenAICompatibleName)) |
|
|
case Provider.DEEPSEEK: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = DeepseekModelName.DEEPSEEK_CHAT |
|
|
self.AVAILABLE_MODELS.update(set(DeepseekModelName)) |
|
|
case Provider.ANTHROPIC: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = AnthropicModelName.HAIKU_45 |
|
|
self.AVAILABLE_MODELS.update(set(AnthropicModelName)) |
|
|
case Provider.GOOGLE: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = GoogleModelName.GEMINI_20_FLASH |
|
|
self.AVAILABLE_MODELS.update(set(GoogleModelName)) |
|
|
case Provider.VERTEXAI: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = VertexAIModelName.GEMINI_20_FLASH |
|
|
self.AVAILABLE_MODELS.update(set(VertexAIModelName)) |
|
|
case Provider.GROQ: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = GroqModelName.LLAMA_31_8B |
|
|
self.AVAILABLE_MODELS.update(set(GroqModelName)) |
|
|
case Provider.AWS: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = AWSModelName.BEDROCK_HAIKU |
|
|
self.AVAILABLE_MODELS.update(set(AWSModelName)) |
|
|
case Provider.OLLAMA: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = OllamaModelName.OLLAMA_GENERIC |
|
|
self.AVAILABLE_MODELS.update(set(OllamaModelName)) |
|
|
case Provider.OPENROUTER: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = OpenRouterModelName.GEMINI_25_FLASH |
|
|
self.AVAILABLE_MODELS.update(set(OpenRouterModelName)) |
|
|
case Provider.FAKE: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = FakeModelName.FAKE |
|
|
self.AVAILABLE_MODELS.update(set(FakeModelName)) |
|
|
case Provider.AZURE_OPENAI: |
|
|
if self.DEFAULT_MODEL is None: |
|
|
self.DEFAULT_MODEL = AzureOpenAIModelName.AZURE_GPT_4O_MINI |
|
|
self.AVAILABLE_MODELS.update(set(AzureOpenAIModelName)) |
|
|
|
|
|
if not self.AZURE_OPENAI_API_KEY: |
|
|
raise ValueError("AZURE_OPENAI_API_KEY must be set") |
|
|
if not self.AZURE_OPENAI_ENDPOINT: |
|
|
raise ValueError("AZURE_OPENAI_ENDPOINT must be set") |
|
|
if not self.AZURE_OPENAI_DEPLOYMENT_MAP: |
|
|
raise ValueError("AZURE_OPENAI_DEPLOYMENT_MAP must be set") |
|
|
|
|
|
|
|
|
if isinstance(self.AZURE_OPENAI_DEPLOYMENT_MAP, str): |
|
|
try: |
|
|
self.AZURE_OPENAI_DEPLOYMENT_MAP = loads( |
|
|
self.AZURE_OPENAI_DEPLOYMENT_MAP |
|
|
) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Invalid AZURE_OPENAI_DEPLOYMENT_MAP JSON: {e}") |
|
|
|
|
|
|
|
|
required_models = {"gpt-4o", "gpt-4o-mini"} |
|
|
missing_models = required_models - set(self.AZURE_OPENAI_DEPLOYMENT_MAP.keys()) |
|
|
if missing_models: |
|
|
raise ValueError(f"Missing required Azure deployments: {missing_models}") |
|
|
case _: |
|
|
raise ValueError(f"Unknown provider: {provider}") |
|
|
|
|
|
for provider in active_keys: |
|
|
match provider: |
|
|
case Provider.OPENAI: |
|
|
if self.DEFAULT_EMBEDDING_MODEL is None: |
|
|
self.DEFAULT_EMBEDDING_MODEL = OpenAIEmbeddingModelName.TEXT_EMBEDDING_3_SMALL |
|
|
self.AVAILABLE_EMBEDDING_MODELS.update(set(OpenAIEmbeddingModelName)) |
|
|
case Provider.GOOGLE: |
|
|
if self.DEFAULT_EMBEDDING_MODEL is None: |
|
|
self.DEFAULT_EMBEDDING_MODEL = GoogleEmbeddingModelName.TEXT_EMBEDDING_004 |
|
|
self.AVAILABLE_EMBEDDING_MODELS.update(set(GoogleEmbeddingModelName)) |
|
|
case Provider.OLLAMA: |
|
|
if self.DEFAULT_EMBEDDING_MODEL is None: |
|
|
self.DEFAULT_EMBEDDING_MODEL = OllamaEmbeddingModelName.NOMIC_EMBED_TEXT |
|
|
self.AVAILABLE_EMBEDDING_MODELS.update(set(OllamaEmbeddingModelName)) |
|
|
if not self.OLLAMA_EMBEDDING_MODEL: |
|
|
self.OLLAMA_EMBEDDING_MODEL = OllamaEmbeddingModelName.NOMIC_EMBED_TEXT |
|
|
|
|
|
@computed_field |
|
|
@property |
|
|
def BASE_URL(self) -> str: |
|
|
return f"http://{self.HOST}:{self.PORT}" |
|
|
|
|
|
def is_dev(self) -> bool: |
|
|
return self.MODE == "dev" |
|
|
|
|
|
|
|
|
settings = Settings() |
|
|
|