Spaces:
Paused
Paused
| import os | |
| import sys | |
| import logging | |
| import importlib.metadata | |
| import pkgutil | |
| import chromadb | |
| from chromadb import Settings | |
| from bs4 import BeautifulSoup | |
| from typing import TypeVar, Generic | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from pathlib import Path | |
| import json | |
| import yaml | |
| import markdown | |
| import requests | |
| import shutil | |
| from constants import ERROR_MESSAGES | |
| #################################### | |
| # Load .env file | |
| #################################### | |
| BACKEND_DIR = Path(__file__).parent # the path containing this file | |
| BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ | |
| print(BASE_DIR) | |
| try: | |
| from dotenv import load_dotenv, find_dotenv | |
| load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) | |
| except ImportError: | |
| print("dotenv not installed, skipping...") | |
| #################################### | |
| # LOGGING | |
| #################################### | |
| log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] | |
| GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() | |
| if GLOBAL_LOG_LEVEL in log_levels: | |
| logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) | |
| else: | |
| GLOBAL_LOG_LEVEL = "INFO" | |
| log = logging.getLogger(__name__) | |
| log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") | |
| log_sources = [ | |
| "AUDIO", | |
| "COMFYUI", | |
| "CONFIG", | |
| "DB", | |
| "IMAGES", | |
| "MAIN", | |
| "MODELS", | |
| "OLLAMA", | |
| "OPENAI", | |
| "RAG", | |
| "WEBHOOK", | |
| ] | |
| SRC_LOG_LEVELS = {} | |
| for source in log_sources: | |
| log_env_var = source + "_LOG_LEVEL" | |
| SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() | |
| if SRC_LOG_LEVELS[source] not in log_levels: | |
| SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL | |
| log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") | |
| log.setLevel(SRC_LOG_LEVELS["CONFIG"]) | |
| class EndpointFilter(logging.Filter): | |
| def filter(self, record: logging.LogRecord) -> bool: | |
| return record.getMessage().find("/health") == -1 | |
| # Filter out /endpoint | |
| logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) | |
| WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") | |
| if WEBUI_NAME != "Open WebUI": | |
| WEBUI_NAME += " (Open WebUI)" | |
| WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000") | |
| WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" | |
| #################################### | |
| # ENV (dev,test,prod) | |
| #################################### | |
| ENV = os.environ.get("ENV", "dev") | |
| try: | |
| PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text()) | |
| except: | |
| try: | |
| PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")} | |
| except importlib.metadata.PackageNotFoundError: | |
| PACKAGE_DATA = {"version": "0.0.0"} | |
| VERSION = PACKAGE_DATA["version"] | |
| # Function to parse each section | |
| def parse_section(section): | |
| items = [] | |
| for li in section.find_all("li"): | |
| # Extract raw HTML string | |
| raw_html = str(li) | |
| # Extract text without HTML tags | |
| text = li.get_text(separator=" ", strip=True) | |
| # Split into title and content | |
| parts = text.split(": ", 1) | |
| title = parts[0].strip() if len(parts) > 1 else "" | |
| content = parts[1].strip() if len(parts) > 1 else text | |
| items.append({"title": title, "content": content, "raw": raw_html}) | |
| return items | |
| try: | |
| changelog_path = BASE_DIR / "CHANGELOG.md" | |
| with open(str(changelog_path.absolute()), "r", encoding="utf8") as file: | |
| changelog_content = file.read() | |
| except: | |
| changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode() | |
| # Convert markdown content to HTML | |
| html_content = markdown.markdown(changelog_content) | |
| # Parse the HTML content | |
| soup = BeautifulSoup(html_content, "html.parser") | |
| # Initialize JSON structure | |
| changelog_json = {} | |
| # Iterate over each version | |
| for version in soup.find_all("h2"): | |
| version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets | |
| date = version.get_text().strip().split(" - ")[1] | |
| version_data = {"date": date} | |
| # Find the next sibling that is a h3 tag (section title) | |
| current = version.find_next_sibling() | |
| while current and current.name != "h2": | |
| if current.name == "h3": | |
| section_title = current.get_text().lower() # e.g., "added", "fixed" | |
| section_items = parse_section(current.find_next_sibling("ul")) | |
| version_data[section_title] = section_items | |
| # Move to the next element | |
| current = current.find_next_sibling() | |
| changelog_json[version_number] = version_data | |
| CHANGELOG = changelog_json | |
| #################################### | |
| # SAFE_MODE | |
| #################################### | |
| SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" | |
| #################################### | |
| # WEBUI_BUILD_HASH | |
| #################################### | |
| WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build") | |
| #################################### | |
| # DATA/FRONTEND BUILD DIR | |
| #################################### | |
| DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() | |
| FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() | |
| RESET_CONFIG_ON_START = ( | |
| os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" | |
| ) | |
| if RESET_CONFIG_ON_START: | |
| try: | |
| os.remove(f"{DATA_DIR}/config.json") | |
| with open(f"{DATA_DIR}/config.json", "w") as f: | |
| f.write("{}") | |
| except: | |
| pass | |
| try: | |
| CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text()) | |
| except: | |
| CONFIG_DATA = {} | |
| #################################### | |
| # Config helpers | |
| #################################### | |
| def save_config(): | |
| try: | |
| with open(f"{DATA_DIR}/config.json", "w") as f: | |
| json.dump(CONFIG_DATA, f, indent="\t") | |
| except Exception as e: | |
| log.exception(e) | |
| def get_config_value(config_path: str): | |
| path_parts = config_path.split(".") | |
| cur_config = CONFIG_DATA | |
| for key in path_parts: | |
| if key in cur_config: | |
| cur_config = cur_config[key] | |
| else: | |
| return None | |
| return cur_config | |
| T = TypeVar("T") | |
| class PersistentConfig(Generic[T]): | |
| def __init__(self, env_name: str, config_path: str, env_value: T): | |
| self.env_name = env_name | |
| self.config_path = config_path | |
| self.env_value = env_value | |
| self.config_value = get_config_value(config_path) | |
| if self.config_value is not None: | |
| log.info(f"'{env_name}' loaded from config.json") | |
| self.value = self.config_value | |
| else: | |
| self.value = env_value | |
| def __str__(self): | |
| return str(self.value) | |
| def __dict__(self): | |
| raise TypeError( | |
| "PersistentConfig object cannot be converted to dict, use config_get or .value instead." | |
| ) | |
| def __getattribute__(self, item): | |
| if item == "__dict__": | |
| raise TypeError( | |
| "PersistentConfig object cannot be converted to dict, use config_get or .value instead." | |
| ) | |
| return super().__getattribute__(item) | |
| def save(self): | |
| # Don't save if the value is the same as the env value and the config value | |
| if self.env_value == self.value: | |
| if self.config_value == self.value: | |
| return | |
| log.info(f"Saving '{self.env_name}' to config.json") | |
| path_parts = self.config_path.split(".") | |
| config = CONFIG_DATA | |
| for key in path_parts[:-1]: | |
| if key not in config: | |
| config[key] = {} | |
| config = config[key] | |
| config[path_parts[-1]] = self.value | |
| save_config() | |
| self.config_value = self.value | |
| class AppConfig: | |
| _state: dict[str, PersistentConfig] | |
| def __init__(self): | |
| super().__setattr__("_state", {}) | |
| def __setattr__(self, key, value): | |
| if isinstance(value, PersistentConfig): | |
| self._state[key] = value | |
| else: | |
| self._state[key].value = value | |
| self._state[key].save() | |
| def __getattr__(self, key): | |
| return self._state[key].value | |
| #################################### | |
| # WEBUI_AUTH (Required for security) | |
| #################################### | |
| WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" | |
| WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( | |
| "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None | |
| ) | |
| WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) | |
| JWT_EXPIRES_IN = PersistentConfig( | |
| "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") | |
| ) | |
| #################################### | |
| # OAuth config | |
| #################################### | |
| ENABLE_OAUTH_SIGNUP = PersistentConfig( | |
| "ENABLE_OAUTH_SIGNUP", | |
| "oauth.enable_signup", | |
| os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", | |
| ) | |
| OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( | |
| "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", | |
| "oauth.merge_accounts_by_email", | |
| os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true", | |
| ) | |
| OAUTH_PROVIDERS = {} | |
| GOOGLE_CLIENT_ID = PersistentConfig( | |
| "GOOGLE_CLIENT_ID", | |
| "oauth.google.client_id", | |
| os.environ.get("GOOGLE_CLIENT_ID", ""), | |
| ) | |
| GOOGLE_CLIENT_SECRET = PersistentConfig( | |
| "GOOGLE_CLIENT_SECRET", | |
| "oauth.google.client_secret", | |
| os.environ.get("GOOGLE_CLIENT_SECRET", ""), | |
| ) | |
| GOOGLE_OAUTH_SCOPE = PersistentConfig( | |
| "GOOGLE_OAUTH_SCOPE", | |
| "oauth.google.scope", | |
| os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), | |
| ) | |
| GOOGLE_REDIRECT_URI = PersistentConfig( | |
| "GOOGLE_REDIRECT_URI", | |
| "oauth.google.redirect_uri", | |
| os.environ.get("GOOGLE_REDIRECT_URI", ""), | |
| ) | |
| MICROSOFT_CLIENT_ID = PersistentConfig( | |
| "MICROSOFT_CLIENT_ID", | |
| "oauth.microsoft.client_id", | |
| os.environ.get("MICROSOFT_CLIENT_ID", ""), | |
| ) | |
| MICROSOFT_CLIENT_SECRET = PersistentConfig( | |
| "MICROSOFT_CLIENT_SECRET", | |
| "oauth.microsoft.client_secret", | |
| os.environ.get("MICROSOFT_CLIENT_SECRET", ""), | |
| ) | |
| MICROSOFT_CLIENT_TENANT_ID = PersistentConfig( | |
| "MICROSOFT_CLIENT_TENANT_ID", | |
| "oauth.microsoft.tenant_id", | |
| os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""), | |
| ) | |
| MICROSOFT_OAUTH_SCOPE = PersistentConfig( | |
| "MICROSOFT_OAUTH_SCOPE", | |
| "oauth.microsoft.scope", | |
| os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), | |
| ) | |
| MICROSOFT_REDIRECT_URI = PersistentConfig( | |
| "MICROSOFT_REDIRECT_URI", | |
| "oauth.microsoft.redirect_uri", | |
| os.environ.get("MICROSOFT_REDIRECT_URI", ""), | |
| ) | |
| OAUTH_CLIENT_ID = PersistentConfig( | |
| "OAUTH_CLIENT_ID", | |
| "oauth.oidc.client_id", | |
| os.environ.get("OAUTH_CLIENT_ID", ""), | |
| ) | |
| OAUTH_CLIENT_SECRET = PersistentConfig( | |
| "OAUTH_CLIENT_SECRET", | |
| "oauth.oidc.client_secret", | |
| os.environ.get("OAUTH_CLIENT_SECRET", ""), | |
| ) | |
| OPENID_PROVIDER_URL = PersistentConfig( | |
| "OPENID_PROVIDER_URL", | |
| "oauth.oidc.provider_url", | |
| os.environ.get("OPENID_PROVIDER_URL", ""), | |
| ) | |
| OPENID_REDIRECT_URI = PersistentConfig( | |
| "OPENID_REDIRECT_URI", | |
| "oauth.oidc.redirect_uri", | |
| os.environ.get("OPENID_REDIRECT_URI", ""), | |
| ) | |
| OAUTH_SCOPES = PersistentConfig( | |
| "OAUTH_SCOPES", | |
| "oauth.oidc.scopes", | |
| os.environ.get("OAUTH_SCOPES", "openid email profile"), | |
| ) | |
| OAUTH_PROVIDER_NAME = PersistentConfig( | |
| "OAUTH_PROVIDER_NAME", | |
| "oauth.oidc.provider_name", | |
| os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), | |
| ) | |
| OAUTH_USERNAME_CLAIM = PersistentConfig( | |
| "OAUTH_USERNAME_CLAIM", | |
| "oauth.oidc.username_claim", | |
| os.environ.get("OAUTH_USERNAME_CLAIM", "name"), | |
| ) | |
| OAUTH_PICTURE_CLAIM = PersistentConfig( | |
| "OAUTH_USERNAME_CLAIM", | |
| "oauth.oidc.avatar_claim", | |
| os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), | |
| ) | |
| def load_oauth_providers(): | |
| OAUTH_PROVIDERS.clear() | |
| if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: | |
| OAUTH_PROVIDERS["google"] = { | |
| "client_id": GOOGLE_CLIENT_ID.value, | |
| "client_secret": GOOGLE_CLIENT_SECRET.value, | |
| "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", | |
| "scope": GOOGLE_OAUTH_SCOPE.value, | |
| "redirect_uri": GOOGLE_REDIRECT_URI.value, | |
| } | |
| if ( | |
| MICROSOFT_CLIENT_ID.value | |
| and MICROSOFT_CLIENT_SECRET.value | |
| and MICROSOFT_CLIENT_TENANT_ID.value | |
| ): | |
| OAUTH_PROVIDERS["microsoft"] = { | |
| "client_id": MICROSOFT_CLIENT_ID.value, | |
| "client_secret": MICROSOFT_CLIENT_SECRET.value, | |
| "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", | |
| "scope": MICROSOFT_OAUTH_SCOPE.value, | |
| "redirect_uri": MICROSOFT_REDIRECT_URI.value, | |
| } | |
| if ( | |
| OAUTH_CLIENT_ID.value | |
| and OAUTH_CLIENT_SECRET.value | |
| and OPENID_PROVIDER_URL.value | |
| ): | |
| OAUTH_PROVIDERS["oidc"] = { | |
| "client_id": OAUTH_CLIENT_ID.value, | |
| "client_secret": OAUTH_CLIENT_SECRET.value, | |
| "server_metadata_url": OPENID_PROVIDER_URL.value, | |
| "scope": OAUTH_SCOPES.value, | |
| "name": OAUTH_PROVIDER_NAME.value, | |
| "redirect_uri": OPENID_REDIRECT_URI.value, | |
| } | |
| load_oauth_providers() | |
| #################################### | |
| # Static DIR | |
| #################################### | |
| STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() | |
| frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png" | |
| if frontend_favicon.exists(): | |
| try: | |
| shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| else: | |
| logging.warning(f"Frontend favicon not found at {frontend_favicon}") | |
| frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png" | |
| if frontend_splash.exists(): | |
| try: | |
| shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| else: | |
| logging.warning(f"Frontend splash not found at {frontend_splash}") | |
| #################################### | |
| # CUSTOM_NAME | |
| #################################### | |
| CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "") | |
| if CUSTOM_NAME: | |
| try: | |
| r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}") | |
| data = r.json() | |
| if r.ok: | |
| if "logo" in data: | |
| WEBUI_FAVICON_URL = url = ( | |
| f"https://api.openwebui.com{data['logo']}" | |
| if data["logo"][0] == "/" | |
| else data["logo"] | |
| ) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| with open(f"{STATIC_DIR}/favicon.png", "wb") as f: | |
| r.raw.decode_content = True | |
| shutil.copyfileobj(r.raw, f) | |
| if "splash" in data: | |
| url = ( | |
| f"https://api.openwebui.com{data['splash']}" | |
| if data["splash"][0] == "/" | |
| else data["splash"] | |
| ) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| with open(f"{STATIC_DIR}/splash.png", "wb") as f: | |
| r.raw.decode_content = True | |
| shutil.copyfileobj(r.raw, f) | |
| WEBUI_NAME = data["name"] | |
| except Exception as e: | |
| log.exception(e) | |
| pass | |
| #################################### | |
| # File Upload DIR | |
| #################################### | |
| UPLOAD_DIR = f"{DATA_DIR}/uploads" | |
| Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Cache DIR | |
| #################################### | |
| CACHE_DIR = f"{DATA_DIR}/cache" | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Docs DIR | |
| #################################### | |
| DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") | |
| Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Tools DIR | |
| #################################### | |
| TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") | |
| Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Functions DIR | |
| #################################### | |
| FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") | |
| Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # LITELLM_CONFIG | |
| #################################### | |
| def create_config_file(file_path): | |
| directory = os.path.dirname(file_path) | |
| # Check if directory exists, if not, create it | |
| if not os.path.exists(directory): | |
| os.makedirs(directory) | |
| # Data to write into the YAML file | |
| config_data = { | |
| "general_settings": {}, | |
| "litellm_settings": {}, | |
| "model_list": [], | |
| "router_settings": {}, | |
| } | |
| # Write data to YAML file | |
| with open(file_path, "w") as file: | |
| yaml.dump(config_data, file) | |
| LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" | |
| # if not os.path.exists(LITELLM_CONFIG_PATH): | |
| # log.info("Config file doesn't exist. Creating...") | |
| # create_config_file(LITELLM_CONFIG_PATH) | |
| # log.info("Config file created successfully.") | |
| #################################### | |
| # OLLAMA_BASE_URL | |
| #################################### | |
| ENABLE_OLLAMA_API = PersistentConfig( | |
| "ENABLE_OLLAMA_API", | |
| "ollama.enable", | |
| os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true", | |
| ) | |
| OLLAMA_API_BASE_URL = os.environ.get( | |
| "OLLAMA_API_BASE_URL", "http://localhost:11434/api" | |
| ) | |
| OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") | |
| AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") | |
| if AIOHTTP_CLIENT_TIMEOUT == "": | |
| AIOHTTP_CLIENT_TIMEOUT = None | |
| else: | |
| try: | |
| AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) | |
| except: | |
| AIOHTTP_CLIENT_TIMEOUT = 300 | |
| K8S_FLAG = os.environ.get("K8S_FLAG", "") | |
| USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") | |
| if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": | |
| OLLAMA_BASE_URL = ( | |
| OLLAMA_API_BASE_URL[:-4] | |
| if OLLAMA_API_BASE_URL.endswith("/api") | |
| else OLLAMA_API_BASE_URL | |
| ) | |
| if ENV == "prod": | |
| if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG: | |
| if USE_OLLAMA_DOCKER.lower() == "true": | |
| # if you use all-in-one docker container (Open WebUI + Ollama) | |
| # with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434 | |
| OLLAMA_BASE_URL = "http://localhost:11434" | |
| else: | |
| OLLAMA_BASE_URL = "http://host.docker.internal:11434" | |
| elif K8S_FLAG: | |
| OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" | |
| OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") | |
| OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL | |
| OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] | |
| OLLAMA_BASE_URLS = PersistentConfig( | |
| "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS | |
| ) | |
| #################################### | |
| # OPENAI_API | |
| #################################### | |
| ENABLE_OPENAI_API = PersistentConfig( | |
| "ENABLE_OPENAI_API", | |
| "openai.enable", | |
| os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true", | |
| ) | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
| OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") | |
| if OPENAI_API_BASE_URL == "": | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1" | |
| OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") | |
| OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY | |
| OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] | |
| OPENAI_API_KEYS = PersistentConfig( | |
| "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS | |
| ) | |
| OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") | |
| OPENAI_API_BASE_URLS = ( | |
| OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL | |
| ) | |
| OPENAI_API_BASE_URLS = [ | |
| url.strip() if url != "" else "https://api.openai.com/v1" | |
| for url in OPENAI_API_BASE_URLS.split(";") | |
| ] | |
| OPENAI_API_BASE_URLS = PersistentConfig( | |
| "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS | |
| ) | |
| OPENAI_API_KEY = "" | |
| try: | |
| OPENAI_API_KEY = OPENAI_API_KEYS.value[ | |
| OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") | |
| ] | |
| except: | |
| pass | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1" | |
| #################################### | |
| # WEBUI | |
| #################################### | |
| ENABLE_SIGNUP = PersistentConfig( | |
| "ENABLE_SIGNUP", | |
| "ui.enable_signup", | |
| ( | |
| False | |
| if not WEBUI_AUTH | |
| else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" | |
| ), | |
| ) | |
| ENABLE_LOGIN_FORM = PersistentConfig( | |
| "ENABLE_LOGIN_FORM", | |
| "ui.ENABLE_LOGIN_FORM", | |
| os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", | |
| ) | |
| DEFAULT_LOCALE = PersistentConfig( | |
| "DEFAULT_LOCALE", | |
| "ui.default_locale", | |
| os.environ.get("DEFAULT_LOCALE", ""), | |
| ) | |
| DEFAULT_MODELS = PersistentConfig( | |
| "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) | |
| ) | |
| DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( | |
| "DEFAULT_PROMPT_SUGGESTIONS", | |
| "ui.prompt_suggestions", | |
| [ | |
| { | |
| "title": ["Help me study", "vocabulary for a college entrance exam"], | |
| "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", | |
| }, | |
| { | |
| "title": ["Give me ideas", "for what to do with my kids' art"], | |
| "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", | |
| }, | |
| { | |
| "title": ["Tell me a fun fact", "about the Roman Empire"], | |
| "content": "Tell me a random fun fact about the Roman Empire", | |
| }, | |
| { | |
| "title": ["Show me a code snippet", "of a website's sticky header"], | |
| "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", | |
| }, | |
| { | |
| "title": [ | |
| "Explain options trading", | |
| "if I'm familiar with buying and selling stocks", | |
| ], | |
| "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", | |
| }, | |
| { | |
| "title": ["Overcome procrastination", "give me tips"], | |
| "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", | |
| }, | |
| ], | |
| ) | |
| DEFAULT_USER_ROLE = PersistentConfig( | |
| "DEFAULT_USER_ROLE", | |
| "ui.default_user_role", | |
| os.getenv("DEFAULT_USER_ROLE", "pending"), | |
| ) | |
| USER_PERMISSIONS_CHAT_DELETION = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS = PersistentConfig( | |
| "USER_PERMISSIONS", | |
| "ui.user_permissions", | |
| {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}, | |
| ) | |
| ENABLE_MODEL_FILTER = PersistentConfig( | |
| "ENABLE_MODEL_FILTER", | |
| "model_filter.enable", | |
| os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", | |
| ) | |
| MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") | |
| MODEL_FILTER_LIST = PersistentConfig( | |
| "MODEL_FILTER_LIST", | |
| "model_filter.list", | |
| [model.strip() for model in MODEL_FILTER_LIST.split(";")], | |
| ) | |
| WEBHOOK_URL = PersistentConfig( | |
| "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") | |
| ) | |
| ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" | |
| ENABLE_COMMUNITY_SHARING = PersistentConfig( | |
| "ENABLE_COMMUNITY_SHARING", | |
| "ui.enable_community_sharing", | |
| os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", | |
| ) | |
| class BannerModel(BaseModel): | |
| id: str | |
| type: str | |
| title: Optional[str] = None | |
| content: str | |
| dismissible: bool | |
| timestamp: int | |
| try: | |
| banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) | |
| banners = [BannerModel(**banner) for banner in banners] | |
| except Exception as e: | |
| print(f"Error loading WEBUI_BANNERS: {e}") | |
| banners = [] | |
| WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners) | |
| SHOW_ADMIN_DETAILS = PersistentConfig( | |
| "SHOW_ADMIN_DETAILS", | |
| "auth.admin.show", | |
| os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true", | |
| ) | |
| ADMIN_EMAIL = PersistentConfig( | |
| "ADMIN_EMAIL", | |
| "auth.admin.email", | |
| os.environ.get("ADMIN_EMAIL", None), | |
| ) | |
| #################################### | |
| # TASKS | |
| #################################### | |
| TASK_MODEL = PersistentConfig( | |
| "TASK_MODEL", | |
| "task.model.default", | |
| os.environ.get("TASK_MODEL", ""), | |
| ) | |
| TASK_MODEL_EXTERNAL = PersistentConfig( | |
| "TASK_MODEL_EXTERNAL", | |
| "task.model.external", | |
| os.environ.get("TASK_MODEL_EXTERNAL", ""), | |
| ) | |
| TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "TITLE_GENERATION_PROMPT_TEMPLATE", | |
| "task.title.prompt_template", | |
| os.environ.get( | |
| "TITLE_GENERATION_PROMPT_TEMPLATE", | |
| """Here is the query: | |
| {{prompt:middletruncate:8000}} | |
| Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. | |
| Examples of titles: | |
| 📉 Stock Market Trends | |
| 🍪 Perfect Chocolate Chip Recipe | |
| Evolution of Music Streaming | |
| Remote Work Productivity Tips | |
| Artificial Intelligence in Healthcare | |
| 🎮 Video Game Development Insights""", | |
| ), | |
| ) | |
| SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", | |
| "task.search.prompt_template", | |
| os.environ.get( | |
| "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", | |
| """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}. | |
| Question: | |
| {{prompt:end:4000}}""", | |
| ), | |
| ) | |
| SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( | |
| "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", | |
| "task.search.prompt_length_threshold", | |
| int( | |
| os.environ.get( | |
| "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", | |
| 100, | |
| ) | |
| ), | |
| ) | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( | |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", | |
| "task.tools.prompt_template", | |
| os.environ.get( | |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", | |
| """Tools: {{TOOLS}} | |
| If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", | |
| ), | |
| ) | |
| #################################### | |
| # WEBUI_SECRET_KEY | |
| #################################### | |
| WEBUI_SECRET_KEY = os.environ.get( | |
| "WEBUI_SECRET_KEY", | |
| os.environ.get( | |
| "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" | |
| ), # DEPRECATED: remove at next major version | |
| ) | |
| WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get( | |
| "WEBUI_SESSION_COOKIE_SAME_SITE", | |
| os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"), | |
| ) | |
| WEBUI_SESSION_COOKIE_SECURE = os.environ.get( | |
| "WEBUI_SESSION_COOKIE_SECURE", | |
| os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true", | |
| ) | |
| if WEBUI_AUTH and WEBUI_SECRET_KEY == "": | |
| raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) | |
| #################################### | |
| # RAG document content extraction | |
| #################################### | |
| CONTENT_EXTRACTION_ENGINE = PersistentConfig( | |
| "CONTENT_EXTRACTION_ENGINE", | |
| "rag.CONTENT_EXTRACTION_ENGINE", | |
| os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), | |
| ) | |
| TIKA_SERVER_URL = PersistentConfig( | |
| "TIKA_SERVER_URL", | |
| "rag.tika_server_url", | |
| os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment | |
| ) | |
| #################################### | |
| # RAG | |
| #################################### | |
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | |
| CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) | |
| CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) | |
| CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") | |
| CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) | |
| # Comma-separated list of header=value pairs | |
| CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") | |
| if CHROMA_HTTP_HEADERS: | |
| CHROMA_HTTP_HEADERS = dict( | |
| [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")] | |
| ) | |
| else: | |
| CHROMA_HTTP_HEADERS = None | |
| CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" | |
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) | |
| RAG_TOP_K = PersistentConfig( | |
| "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) | |
| ) | |
| RAG_RELEVANCE_THRESHOLD = PersistentConfig( | |
| "RAG_RELEVANCE_THRESHOLD", | |
| "rag.relevance_threshold", | |
| float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), | |
| ) | |
| ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( | |
| "ENABLE_RAG_HYBRID_SEARCH", | |
| "rag.enable_hybrid_search", | |
| os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", | |
| ) | |
| ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( | |
| "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", | |
| "rag.enable_web_loader_ssl_verification", | |
| os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", | |
| ) | |
| RAG_EMBEDDING_ENGINE = PersistentConfig( | |
| "RAG_EMBEDDING_ENGINE", | |
| "rag.embedding_engine", | |
| os.environ.get("RAG_EMBEDDING_ENGINE", ""), | |
| ) | |
| PDF_EXTRACT_IMAGES = PersistentConfig( | |
| "PDF_EXTRACT_IMAGES", | |
| "rag.pdf_extract_images", | |
| os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", | |
| ) | |
| RAG_EMBEDDING_MODEL = PersistentConfig( | |
| "RAG_EMBEDDING_MODEL", | |
| "rag.embedding_model", | |
| os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), | |
| ) | |
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"), | |
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | |
| os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | |
| ) | |
| RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( | |
| "RAG_EMBEDDING_OPENAI_BATCH_SIZE", | |
| "rag.embedding_openai_batch_size", | |
| os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1), | |
| ) | |
| RAG_RERANKING_MODEL = PersistentConfig( | |
| "RAG_RERANKING_MODEL", | |
| "rag.reranking_model", | |
| os.environ.get("RAG_RERANKING_MODEL", ""), | |
| ) | |
| if RAG_RERANKING_MODEL.value != "": | |
| log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"), | |
| RAG_RERANKING_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | |
| os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | |
| ) | |
| if CHROMA_HTTP_HOST != "": | |
| CHROMA_CLIENT = chromadb.HttpClient( | |
| host=CHROMA_HTTP_HOST, | |
| port=CHROMA_HTTP_PORT, | |
| headers=CHROMA_HTTP_HEADERS, | |
| ssl=CHROMA_HTTP_SSL, | |
| tenant=CHROMA_TENANT, | |
| database=CHROMA_DATABASE, | |
| settings=Settings(allow_reset=True, anonymized_telemetry=False), | |
| ) | |
| else: | |
| CHROMA_CLIENT = chromadb.PersistentClient( | |
| path=CHROMA_DATA_PATH, | |
| settings=Settings(allow_reset=True, anonymized_telemetry=False), | |
| tenant=CHROMA_TENANT, | |
| database=CHROMA_DATABASE, | |
| ) | |
| # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance | |
| USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") | |
| if USE_CUDA.lower() == "true": | |
| DEVICE_TYPE = "cuda" | |
| else: | |
| DEVICE_TYPE = "cpu" | |
| CHUNK_SIZE = PersistentConfig( | |
| "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) | |
| ) | |
| CHUNK_OVERLAP = PersistentConfig( | |
| "CHUNK_OVERLAP", | |
| "rag.chunk_overlap", | |
| int(os.environ.get("CHUNK_OVERLAP", "100")), | |
| ) | |
| DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. | |
| <context> | |
| [context] | |
| </context> | |
| When answer to user: | |
| - If you don't know, just say that you don't know. | |
| - If you don't know when you are not sure, ask for clarification. | |
| Avoid mentioning that you obtained the information from the context. | |
| And answer according to the language of the user's question. | |
| Given the context information, answer the query. | |
| Query: [query]""" | |
| RAG_TEMPLATE = PersistentConfig( | |
| "RAG_TEMPLATE", | |
| "rag.template", | |
| os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), | |
| ) | |
| RAG_OPENAI_API_BASE_URL = PersistentConfig( | |
| "RAG_OPENAI_API_BASE_URL", | |
| "rag.openai_api_base_url", | |
| os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| RAG_OPENAI_API_KEY = PersistentConfig( | |
| "RAG_OPENAI_API_KEY", | |
| "rag.openai_api_key", | |
| os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| ENABLE_RAG_LOCAL_WEB_FETCH = ( | |
| os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" | |
| ) | |
| YOUTUBE_LOADER_LANGUAGE = PersistentConfig( | |
| "YOUTUBE_LOADER_LANGUAGE", | |
| "rag.youtube_loader_language", | |
| os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), | |
| ) | |
| ENABLE_RAG_WEB_SEARCH = PersistentConfig( | |
| "ENABLE_RAG_WEB_SEARCH", | |
| "rag.web.search.enable", | |
| os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true", | |
| ) | |
| RAG_WEB_SEARCH_ENGINE = PersistentConfig( | |
| "RAG_WEB_SEARCH_ENGINE", | |
| "rag.web.search.engine", | |
| os.getenv("RAG_WEB_SEARCH_ENGINE", ""), | |
| ) | |
| # You can provide a list of your own websites to filter after performing a web search. | |
| # This ensures the highest level of safety and reliability of the information sources. | |
| RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( | |
| "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST", | |
| "rag.rag.web.search.domain.filter_list", | |
| [ | |
| # "wikipedia.com", | |
| # "wikimedia.org", | |
| # "wikidata.org", | |
| ], | |
| ) | |
| SEARXNG_QUERY_URL = PersistentConfig( | |
| "SEARXNG_QUERY_URL", | |
| "rag.web.search.searxng_query_url", | |
| os.getenv("SEARXNG_QUERY_URL", ""), | |
| ) | |
| GOOGLE_PSE_API_KEY = PersistentConfig( | |
| "GOOGLE_PSE_API_KEY", | |
| "rag.web.search.google_pse_api_key", | |
| os.getenv("GOOGLE_PSE_API_KEY", ""), | |
| ) | |
| GOOGLE_PSE_ENGINE_ID = PersistentConfig( | |
| "GOOGLE_PSE_ENGINE_ID", | |
| "rag.web.search.google_pse_engine_id", | |
| os.getenv("GOOGLE_PSE_ENGINE_ID", ""), | |
| ) | |
| BRAVE_SEARCH_API_KEY = PersistentConfig( | |
| "BRAVE_SEARCH_API_KEY", | |
| "rag.web.search.brave_search_api_key", | |
| os.getenv("BRAVE_SEARCH_API_KEY", ""), | |
| ) | |
| SERPSTACK_API_KEY = PersistentConfig( | |
| "SERPSTACK_API_KEY", | |
| "rag.web.search.serpstack_api_key", | |
| os.getenv("SERPSTACK_API_KEY", ""), | |
| ) | |
| SERPSTACK_HTTPS = PersistentConfig( | |
| "SERPSTACK_HTTPS", | |
| "rag.web.search.serpstack_https", | |
| os.getenv("SERPSTACK_HTTPS", "True").lower() == "true", | |
| ) | |
| SERPER_API_KEY = PersistentConfig( | |
| "SERPER_API_KEY", | |
| "rag.web.search.serper_api_key", | |
| os.getenv("SERPER_API_KEY", ""), | |
| ) | |
| SERPLY_API_KEY = PersistentConfig( | |
| "SERPLY_API_KEY", | |
| "rag.web.search.serply_api_key", | |
| os.getenv("SERPLY_API_KEY", ""), | |
| ) | |
| TAVILY_API_KEY = PersistentConfig( | |
| "TAVILY_API_KEY", | |
| "rag.web.search.tavily_api_key", | |
| os.getenv("TAVILY_API_KEY", ""), | |
| ) | |
| RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( | |
| "RAG_WEB_SEARCH_RESULT_COUNT", | |
| "rag.web.search.result_count", | |
| int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")), | |
| ) | |
| RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( | |
| "RAG_WEB_SEARCH_CONCURRENT_REQUESTS", | |
| "rag.web.search.concurrent_requests", | |
| int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), | |
| ) | |
| #################################### | |
| # Transcribe | |
| #################################### | |
| WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") | |
| WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") | |
| WHISPER_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| #################################### | |
| # Images | |
| #################################### | |
| IMAGE_GENERATION_ENGINE = PersistentConfig( | |
| "IMAGE_GENERATION_ENGINE", | |
| "image_generation.engine", | |
| os.getenv("IMAGE_GENERATION_ENGINE", ""), | |
| ) | |
| ENABLE_IMAGE_GENERATION = PersistentConfig( | |
| "ENABLE_IMAGE_GENERATION", | |
| "image_generation.enable", | |
| os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", | |
| ) | |
| AUTOMATIC1111_BASE_URL = PersistentConfig( | |
| "AUTOMATIC1111_BASE_URL", | |
| "image_generation.automatic1111.base_url", | |
| os.getenv("AUTOMATIC1111_BASE_URL", ""), | |
| ) | |
| AUTOMATIC1111_API_AUTH = PersistentConfig( | |
| "AUTOMATIC1111_API_AUTH", | |
| "image_generation.automatic1111.api_auth", | |
| os.getenv("AUTOMATIC1111_API_AUTH", ""), | |
| ) | |
| COMFYUI_BASE_URL = PersistentConfig( | |
| "COMFYUI_BASE_URL", | |
| "image_generation.comfyui.base_url", | |
| os.getenv("COMFYUI_BASE_URL", ""), | |
| ) | |
| COMFYUI_CFG_SCALE = PersistentConfig( | |
| "COMFYUI_CFG_SCALE", | |
| "image_generation.comfyui.cfg_scale", | |
| os.getenv("COMFYUI_CFG_SCALE", ""), | |
| ) | |
| COMFYUI_SAMPLER = PersistentConfig( | |
| "COMFYUI_SAMPLER", | |
| "image_generation.comfyui.sampler", | |
| os.getenv("COMFYUI_SAMPLER", ""), | |
| ) | |
| COMFYUI_SCHEDULER = PersistentConfig( | |
| "COMFYUI_SCHEDULER", | |
| "image_generation.comfyui.scheduler", | |
| os.getenv("COMFYUI_SCHEDULER", ""), | |
| ) | |
| COMFYUI_SD3 = PersistentConfig( | |
| "COMFYUI_SD3", | |
| "image_generation.comfyui.sd3", | |
| os.environ.get("COMFYUI_SD3", "").lower() == "true", | |
| ) | |
| IMAGES_OPENAI_API_BASE_URL = PersistentConfig( | |
| "IMAGES_OPENAI_API_BASE_URL", | |
| "image_generation.openai.api_base_url", | |
| os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| IMAGES_OPENAI_API_KEY = PersistentConfig( | |
| "IMAGES_OPENAI_API_KEY", | |
| "image_generation.openai.api_key", | |
| os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| IMAGE_SIZE = PersistentConfig( | |
| "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") | |
| ) | |
| IMAGE_STEPS = PersistentConfig( | |
| "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) | |
| ) | |
| IMAGE_GENERATION_MODEL = PersistentConfig( | |
| "IMAGE_GENERATION_MODEL", | |
| "image_generation.model", | |
| os.getenv("IMAGE_GENERATION_MODEL", ""), | |
| ) | |
| #################################### | |
| # Audio | |
| #################################### | |
| AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( | |
| "AUDIO_STT_OPENAI_API_BASE_URL", | |
| "audio.stt.openai.api_base_url", | |
| os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| AUDIO_STT_OPENAI_API_KEY = PersistentConfig( | |
| "AUDIO_STT_OPENAI_API_KEY", | |
| "audio.stt.openai.api_key", | |
| os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| AUDIO_STT_ENGINE = PersistentConfig( | |
| "AUDIO_STT_ENGINE", | |
| "audio.stt.engine", | |
| os.getenv("AUDIO_STT_ENGINE", ""), | |
| ) | |
| AUDIO_STT_MODEL = PersistentConfig( | |
| "AUDIO_STT_MODEL", | |
| "audio.stt.model", | |
| os.getenv("AUDIO_STT_MODEL", "whisper-1"), | |
| ) | |
| AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( | |
| "AUDIO_TTS_OPENAI_API_BASE_URL", | |
| "audio.tts.openai.api_base_url", | |
| os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( | |
| "AUDIO_TTS_OPENAI_API_KEY", | |
| "audio.tts.openai.api_key", | |
| os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| AUDIO_TTS_API_KEY = PersistentConfig( | |
| "AUDIO_TTS_API_KEY", | |
| "audio.tts.api_key", | |
| os.getenv("AUDIO_TTS_API_KEY", ""), | |
| ) | |
| AUDIO_TTS_ENGINE = PersistentConfig( | |
| "AUDIO_TTS_ENGINE", | |
| "audio.tts.engine", | |
| os.getenv("AUDIO_TTS_ENGINE", ""), | |
| ) | |
| AUDIO_TTS_MODEL = PersistentConfig( | |
| "AUDIO_TTS_MODEL", | |
| "audio.tts.model", | |
| os.getenv("AUDIO_TTS_MODEL", "tts-1"), | |
| ) | |
| AUDIO_TTS_VOICE = PersistentConfig( | |
| "AUDIO_TTS_VOICE", | |
| "audio.tts.voice", | |
| os.getenv("AUDIO_TTS_VOICE", "alloy"), | |
| ) | |
| #################################### | |
| # Database | |
| #################################### | |
| DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") | |
| # Replace the postgres:// with postgresql:// | |
| if "postgres://" in DATABASE_URL: | |
| DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") | |