Spaces:
Paused
Paused
| import json | |
| import logging | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Generic, Optional, TypeVar | |
| from urllib.parse import urlparse | |
| import chromadb | |
| import requests | |
| import yaml | |
| from open_webui.apps.webui.internal.db import Base, get_db | |
| from open_webui.env import ( | |
| OPEN_WEBUI_DIR, | |
| DATA_DIR, | |
| ENV, | |
| FRONTEND_BUILD_DIR, | |
| WEBUI_AUTH, | |
| WEBUI_FAVICON_URL, | |
| WEBUI_NAME, | |
| log, | |
| DATABASE_URL, | |
| ) | |
| from pydantic import BaseModel | |
| from sqlalchemy import JSON, Column, DateTime, Integer, func | |
| class EndpointFilter(logging.Filter): | |
| def filter(self, record: logging.LogRecord) -> bool: | |
| return record.getMessage().find("/health") == -1 | |
| # Filter out /endpoint | |
| logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) | |
| #################################### | |
| # Config helpers | |
| #################################### | |
| # Function to run the alembic migrations | |
| def run_migrations(): | |
| print("Running migrations") | |
| try: | |
| from alembic import command | |
| from alembic.config import Config | |
| alembic_cfg = Config(OPEN_WEBUI_DIR / "alembic.ini") | |
| # Set the script location dynamically | |
| migrations_path = OPEN_WEBUI_DIR / "migrations" | |
| alembic_cfg.set_main_option("script_location", str(migrations_path)) | |
| command.upgrade(alembic_cfg, "head") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| run_migrations() | |
| class Config(Base): | |
| __tablename__ = "config" | |
| id = Column(Integer, primary_key=True) | |
| data = Column(JSON, nullable=False) | |
| version = Column(Integer, nullable=False, default=0) | |
| created_at = Column(DateTime, nullable=False, server_default=func.now()) | |
| updated_at = Column(DateTime, nullable=True, onupdate=func.now()) | |
| def load_json_config(): | |
| with open(f"{DATA_DIR}/config.json", "r") as file: | |
| return json.load(file) | |
| def save_to_db(data): | |
| with get_db() as db: | |
| existing_config = db.query(Config).first() | |
| if not existing_config: | |
| new_config = Config(data=data, version=0) | |
| db.add(new_config) | |
| else: | |
| existing_config.data = data | |
| existing_config.updated_at = datetime.now() | |
| db.add(existing_config) | |
| db.commit() | |
| def reset_config(): | |
| with get_db() as db: | |
| db.query(Config).delete() | |
| db.commit() | |
| # When initializing, check if config.json exists and migrate it to the database | |
| if os.path.exists(f"{DATA_DIR}/config.json"): | |
| data = load_json_config() | |
| save_to_db(data) | |
| os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json") | |
| DEFAULT_CONFIG = { | |
| "version": 0, | |
| "ui": { | |
| "default_locale": "", | |
| "prompt_suggestions": [ | |
| { | |
| "title": [ | |
| "Help me study", | |
| "vocabulary for a college entrance exam", | |
| ], | |
| "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", | |
| }, | |
| { | |
| "title": [ | |
| "Give me ideas", | |
| "for what to do with my kids' art", | |
| ], | |
| "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", | |
| }, | |
| { | |
| "title": ["Tell me a fun fact", "about the Roman Empire"], | |
| "content": "Tell me a random fun fact about the Roman Empire", | |
| }, | |
| { | |
| "title": [ | |
| "Show me a code snippet", | |
| "of a website's sticky header", | |
| ], | |
| "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", | |
| }, | |
| { | |
| "title": [ | |
| "Explain options trading", | |
| "if I'm familiar with buying and selling stocks", | |
| ], | |
| "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", | |
| }, | |
| { | |
| "title": ["Overcome procrastination", "give me tips"], | |
| "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", | |
| }, | |
| { | |
| "title": [ | |
| "Grammar check", | |
| "rewrite it for better readability ", | |
| ], | |
| "content": 'Check the following sentence for grammar and clarity: "[sentence]". Rewrite it for better readability while maintaining its original meaning.', | |
| }, | |
| ], | |
| }, | |
| } | |
| def get_config(): | |
| with get_db() as db: | |
| config_entry = db.query(Config).order_by(Config.id.desc()).first() | |
| return config_entry.data if config_entry else DEFAULT_CONFIG | |
| CONFIG_DATA = get_config() | |
| def get_config_value(config_path: str): | |
| path_parts = config_path.split(".") | |
| cur_config = CONFIG_DATA | |
| for key in path_parts: | |
| if key in cur_config: | |
| cur_config = cur_config[key] | |
| else: | |
| return None | |
| return cur_config | |
| PERSISTENT_CONFIG_REGISTRY = [] | |
| def save_config(config): | |
| global CONFIG_DATA | |
| global PERSISTENT_CONFIG_REGISTRY | |
| try: | |
| save_to_db(config) | |
| CONFIG_DATA = config | |
| # Trigger updates on all registered PersistentConfig entries | |
| for config_item in PERSISTENT_CONFIG_REGISTRY: | |
| config_item.update() | |
| except Exception as e: | |
| log.exception(e) | |
| return False | |
| return True | |
| T = TypeVar("T") | |
| class PersistentConfig(Generic[T]): | |
| def __init__(self, env_name: str, config_path: str, env_value: T): | |
| self.env_name = env_name | |
| self.config_path = config_path | |
| self.env_value = env_value | |
| self.config_value = get_config_value(config_path) | |
| if self.config_value is not None: | |
| log.info(f"'{env_name}' loaded from the latest database entry") | |
| self.value = self.config_value | |
| else: | |
| self.value = env_value | |
| PERSISTENT_CONFIG_REGISTRY.append(self) | |
| def __str__(self): | |
| return str(self.value) | |
| def __dict__(self): | |
| raise TypeError( | |
| "PersistentConfig object cannot be converted to dict, use config_get or .value instead." | |
| ) | |
| def __getattribute__(self, item): | |
| if item == "__dict__": | |
| raise TypeError( | |
| "PersistentConfig object cannot be converted to dict, use config_get or .value instead." | |
| ) | |
| return super().__getattribute__(item) | |
| def update(self): | |
| new_value = get_config_value(self.config_path) | |
| if new_value is not None: | |
| self.value = new_value | |
| log.info(f"Updated {self.env_name} to new value {self.value}") | |
| def save(self): | |
| log.info(f"Saving '{self.env_name}' to the database") | |
| path_parts = self.config_path.split(".") | |
| sub_config = CONFIG_DATA | |
| for key in path_parts[:-1]: | |
| if key not in sub_config: | |
| sub_config[key] = {} | |
| sub_config = sub_config[key] | |
| sub_config[path_parts[-1]] = self.value | |
| save_to_db(CONFIG_DATA) | |
| self.config_value = self.value | |
| class AppConfig: | |
| _state: dict[str, PersistentConfig] | |
| def __init__(self): | |
| super().__setattr__("_state", {}) | |
| def __setattr__(self, key, value): | |
| if isinstance(value, PersistentConfig): | |
| self._state[key] = value | |
| else: | |
| self._state[key].value = value | |
| self._state[key].save() | |
| def __getattr__(self, key): | |
| return self._state[key].value | |
| #################################### | |
| # WEBUI_AUTH (Required for security) | |
| #################################### | |
| ENABLE_API_KEY = PersistentConfig( | |
| "ENABLE_API_KEY", | |
| "auth.api_key.enable", | |
| os.environ.get("ENABLE_API_KEY", "True").lower() == "true", | |
| ) | |
| JWT_EXPIRES_IN = PersistentConfig( | |
| "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") | |
| ) | |
| #################################### | |
| # OAuth config | |
| #################################### | |
| ENABLE_OAUTH_SIGNUP = PersistentConfig( | |
| "ENABLE_OAUTH_SIGNUP", | |
| "oauth.enable_signup", | |
| os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", | |
| ) | |
| OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( | |
| "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", | |
| "oauth.merge_accounts_by_email", | |
| os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true", | |
| ) | |
| OAUTH_PROVIDERS = {} | |
| GOOGLE_CLIENT_ID = PersistentConfig( | |
| "GOOGLE_CLIENT_ID", | |
| "oauth.google.client_id", | |
| os.environ.get("GOOGLE_CLIENT_ID", ""), | |
| ) | |
| GOOGLE_CLIENT_SECRET = PersistentConfig( | |
| "GOOGLE_CLIENT_SECRET", | |
| "oauth.google.client_secret", | |
| os.environ.get("GOOGLE_CLIENT_SECRET", ""), | |
| ) | |
| GOOGLE_OAUTH_SCOPE = PersistentConfig( | |
| "GOOGLE_OAUTH_SCOPE", | |
| "oauth.google.scope", | |
| os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), | |
| ) | |
| GOOGLE_REDIRECT_URI = PersistentConfig( | |
| "GOOGLE_REDIRECT_URI", | |
| "oauth.google.redirect_uri", | |
| os.environ.get("GOOGLE_REDIRECT_URI", ""), | |
| ) | |
| MICROSOFT_CLIENT_ID = PersistentConfig( | |
| "MICROSOFT_CLIENT_ID", | |
| "oauth.microsoft.client_id", | |
| os.environ.get("MICROSOFT_CLIENT_ID", ""), | |
| ) | |
| MICROSOFT_CLIENT_SECRET = PersistentConfig( | |
| "MICROSOFT_CLIENT_SECRET", | |
| "oauth.microsoft.client_secret", | |
| os.environ.get("MICROSOFT_CLIENT_SECRET", ""), | |
| ) | |
| MICROSOFT_CLIENT_TENANT_ID = PersistentConfig( | |
| "MICROSOFT_CLIENT_TENANT_ID", | |
| "oauth.microsoft.tenant_id", | |
| os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""), | |
| ) | |
| MICROSOFT_OAUTH_SCOPE = PersistentConfig( | |
| "MICROSOFT_OAUTH_SCOPE", | |
| "oauth.microsoft.scope", | |
| os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), | |
| ) | |
| MICROSOFT_REDIRECT_URI = PersistentConfig( | |
| "MICROSOFT_REDIRECT_URI", | |
| "oauth.microsoft.redirect_uri", | |
| os.environ.get("MICROSOFT_REDIRECT_URI", ""), | |
| ) | |
| OAUTH_CLIENT_ID = PersistentConfig( | |
| "OAUTH_CLIENT_ID", | |
| "oauth.oidc.client_id", | |
| os.environ.get("OAUTH_CLIENT_ID", ""), | |
| ) | |
| OAUTH_CLIENT_SECRET = PersistentConfig( | |
| "OAUTH_CLIENT_SECRET", | |
| "oauth.oidc.client_secret", | |
| os.environ.get("OAUTH_CLIENT_SECRET", ""), | |
| ) | |
| OPENID_PROVIDER_URL = PersistentConfig( | |
| "OPENID_PROVIDER_URL", | |
| "oauth.oidc.provider_url", | |
| os.environ.get("OPENID_PROVIDER_URL", ""), | |
| ) | |
| OPENID_REDIRECT_URI = PersistentConfig( | |
| "OPENID_REDIRECT_URI", | |
| "oauth.oidc.redirect_uri", | |
| os.environ.get("OPENID_REDIRECT_URI", ""), | |
| ) | |
| OAUTH_SCOPES = PersistentConfig( | |
| "OAUTH_SCOPES", | |
| "oauth.oidc.scopes", | |
| os.environ.get("OAUTH_SCOPES", "openid email profile"), | |
| ) | |
| OAUTH_PROVIDER_NAME = PersistentConfig( | |
| "OAUTH_PROVIDER_NAME", | |
| "oauth.oidc.provider_name", | |
| os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), | |
| ) | |
| OAUTH_USERNAME_CLAIM = PersistentConfig( | |
| "OAUTH_USERNAME_CLAIM", | |
| "oauth.oidc.username_claim", | |
| os.environ.get("OAUTH_USERNAME_CLAIM", "name"), | |
| ) | |
| OAUTH_PICTURE_CLAIM = PersistentConfig( | |
| "OAUTH_PICTURE_CLAIM", | |
| "oauth.oidc.avatar_claim", | |
| os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), | |
| ) | |
| OAUTH_EMAIL_CLAIM = PersistentConfig( | |
| "OAUTH_EMAIL_CLAIM", | |
| "oauth.oidc.email_claim", | |
| os.environ.get("OAUTH_EMAIL_CLAIM", "email"), | |
| ) | |
| ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( | |
| "ENABLE_OAUTH_ROLE_MANAGEMENT", | |
| "oauth.enable_role_mapping", | |
| os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", | |
| ) | |
| OAUTH_ROLES_CLAIM = PersistentConfig( | |
| "OAUTH_ROLES_CLAIM", | |
| "oauth.roles_claim", | |
| os.environ.get("OAUTH_ROLES_CLAIM", "roles"), | |
| ) | |
| OAUTH_ALLOWED_ROLES = PersistentConfig( | |
| "OAUTH_ALLOWED_ROLES", | |
| "oauth.allowed_roles", | |
| [ | |
| role.strip() | |
| for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",") | |
| ], | |
| ) | |
| OAUTH_ADMIN_ROLES = PersistentConfig( | |
| "OAUTH_ADMIN_ROLES", | |
| "oauth.admin_roles", | |
| [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], | |
| ) | |
| def load_oauth_providers(): | |
| OAUTH_PROVIDERS.clear() | |
| if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: | |
| OAUTH_PROVIDERS["google"] = { | |
| "client_id": GOOGLE_CLIENT_ID.value, | |
| "client_secret": GOOGLE_CLIENT_SECRET.value, | |
| "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", | |
| "scope": GOOGLE_OAUTH_SCOPE.value, | |
| "redirect_uri": GOOGLE_REDIRECT_URI.value, | |
| } | |
| if ( | |
| MICROSOFT_CLIENT_ID.value | |
| and MICROSOFT_CLIENT_SECRET.value | |
| and MICROSOFT_CLIENT_TENANT_ID.value | |
| ): | |
| OAUTH_PROVIDERS["microsoft"] = { | |
| "client_id": MICROSOFT_CLIENT_ID.value, | |
| "client_secret": MICROSOFT_CLIENT_SECRET.value, | |
| "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", | |
| "scope": MICROSOFT_OAUTH_SCOPE.value, | |
| "redirect_uri": MICROSOFT_REDIRECT_URI.value, | |
| } | |
| if ( | |
| OAUTH_CLIENT_ID.value | |
| and OAUTH_CLIENT_SECRET.value | |
| and OPENID_PROVIDER_URL.value | |
| ): | |
| OAUTH_PROVIDERS["oidc"] = { | |
| "client_id": OAUTH_CLIENT_ID.value, | |
| "client_secret": OAUTH_CLIENT_SECRET.value, | |
| "server_metadata_url": OPENID_PROVIDER_URL.value, | |
| "scope": OAUTH_SCOPES.value, | |
| "name": OAUTH_PROVIDER_NAME.value, | |
| "redirect_uri": OPENID_REDIRECT_URI.value, | |
| } | |
| load_oauth_providers() | |
| #################################### | |
| # Static DIR | |
| #################################### | |
| STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve() | |
| frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png" | |
| if frontend_favicon.exists(): | |
| try: | |
| shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| else: | |
| logging.warning(f"Frontend favicon not found at {frontend_favicon}") | |
| frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png" | |
| if frontend_splash.exists(): | |
| try: | |
| shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| else: | |
| logging.warning(f"Frontend splash not found at {frontend_splash}") | |
| #################################### | |
| # CUSTOM_NAME | |
| #################################### | |
| CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "") | |
| if CUSTOM_NAME: | |
| try: | |
| r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}") | |
| data = r.json() | |
| if r.ok: | |
| if "logo" in data: | |
| WEBUI_FAVICON_URL = url = ( | |
| f"https://api.openwebui.com{data['logo']}" | |
| if data["logo"][0] == "/" | |
| else data["logo"] | |
| ) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| with open(f"{STATIC_DIR}/favicon.png", "wb") as f: | |
| r.raw.decode_content = True | |
| shutil.copyfileobj(r.raw, f) | |
| if "splash" in data: | |
| url = ( | |
| f"https://api.openwebui.com{data['splash']}" | |
| if data["splash"][0] == "/" | |
| else data["splash"] | |
| ) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| with open(f"{STATIC_DIR}/splash.png", "wb") as f: | |
| r.raw.decode_content = True | |
| shutil.copyfileobj(r.raw, f) | |
| WEBUI_NAME = data["name"] | |
| except Exception as e: | |
| log.exception(e) | |
| pass | |
| #################################### | |
| # STORAGE PROVIDER | |
| #################################### | |
| STORAGE_PROVIDER = os.environ.get("STORAGE_PROVIDER", "") # defaults to local, s3 | |
| S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None) | |
| S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None) | |
| S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) | |
| S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) | |
| S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) | |
| #################################### | |
| # File Upload DIR | |
| #################################### | |
| UPLOAD_DIR = f"{DATA_DIR}/uploads" | |
| Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Cache DIR | |
| #################################### | |
| CACHE_DIR = f"{DATA_DIR}/cache" | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # OLLAMA_BASE_URL | |
| #################################### | |
| ENABLE_OLLAMA_API = PersistentConfig( | |
| "ENABLE_OLLAMA_API", | |
| "ollama.enable", | |
| os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true", | |
| ) | |
| OLLAMA_API_BASE_URL = os.environ.get( | |
| "OLLAMA_API_BASE_URL", "http://localhost:11434/api" | |
| ) | |
| OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") | |
| K8S_FLAG = os.environ.get("K8S_FLAG", "") | |
| USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") | |
| if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": | |
| OLLAMA_BASE_URL = ( | |
| OLLAMA_API_BASE_URL[:-4] | |
| if OLLAMA_API_BASE_URL.endswith("/api") | |
| else OLLAMA_API_BASE_URL | |
| ) | |
| if ENV == "prod": | |
| if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG: | |
| if USE_OLLAMA_DOCKER.lower() == "true": | |
| # if you use all-in-one docker container (Open WebUI + Ollama) | |
| # with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434 | |
| OLLAMA_BASE_URL = "http://localhost:11434" | |
| else: | |
| OLLAMA_BASE_URL = "http://host.docker.internal:11434" | |
| elif K8S_FLAG: | |
| OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" | |
| OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") | |
| OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL | |
| OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] | |
| OLLAMA_BASE_URLS = PersistentConfig( | |
| "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS | |
| ) | |
| OLLAMA_API_CONFIGS = PersistentConfig( | |
| "OLLAMA_API_CONFIGS", | |
| "ollama.api_configs", | |
| {}, | |
| ) | |
| #################################### | |
| # OPENAI_API | |
| #################################### | |
| ENABLE_OPENAI_API = PersistentConfig( | |
| "ENABLE_OPENAI_API", | |
| "openai.enable", | |
| os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true", | |
| ) | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
| OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") | |
| if OPENAI_API_BASE_URL == "": | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1" | |
| OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") | |
| OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY | |
| OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] | |
| OPENAI_API_KEYS = PersistentConfig( | |
| "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS | |
| ) | |
| OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") | |
| OPENAI_API_BASE_URLS = ( | |
| OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL | |
| ) | |
| OPENAI_API_BASE_URLS = [ | |
| url.strip() if url != "" else "https://api.openai.com/v1" | |
| for url in OPENAI_API_BASE_URLS.split(";") | |
| ] | |
| OPENAI_API_BASE_URLS = PersistentConfig( | |
| "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS | |
| ) | |
| OPENAI_API_CONFIGS = PersistentConfig( | |
| "OPENAI_API_CONFIGS", | |
| "openai.api_configs", | |
| {}, | |
| ) | |
| # Get the actual OpenAI API key based on the base URL | |
| OPENAI_API_KEY = "" | |
| try: | |
| OPENAI_API_KEY = OPENAI_API_KEYS.value[ | |
| OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") | |
| ] | |
| except Exception: | |
| pass | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1" | |
| #################################### | |
| # WEBUI | |
| #################################### | |
| ENABLE_SIGNUP = PersistentConfig( | |
| "ENABLE_SIGNUP", | |
| "ui.enable_signup", | |
| ( | |
| False | |
| if not WEBUI_AUTH | |
| else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" | |
| ), | |
| ) | |
| ENABLE_LOGIN_FORM = PersistentConfig( | |
| "ENABLE_LOGIN_FORM", | |
| "ui.ENABLE_LOGIN_FORM", | |
| os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", | |
| ) | |
| DEFAULT_LOCALE = PersistentConfig( | |
| "DEFAULT_LOCALE", | |
| "ui.default_locale", | |
| os.environ.get("DEFAULT_LOCALE", ""), | |
| ) | |
| DEFAULT_MODELS = PersistentConfig( | |
| "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) | |
| ) | |
| DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( | |
| "DEFAULT_PROMPT_SUGGESTIONS", | |
| "ui.prompt_suggestions", | |
| [ | |
| { | |
| "title": ["Help me study", "vocabulary for a college entrance exam"], | |
| "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", | |
| }, | |
| { | |
| "title": ["Give me ideas", "for what to do with my kids' art"], | |
| "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", | |
| }, | |
| { | |
| "title": ["Tell me a fun fact", "about the Roman Empire"], | |
| "content": "Tell me a random fun fact about the Roman Empire", | |
| }, | |
| { | |
| "title": ["Show me a code snippet", "of a website's sticky header"], | |
| "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", | |
| }, | |
| { | |
| "title": [ | |
| "Explain options trading", | |
| "if I'm familiar with buying and selling stocks", | |
| ], | |
| "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", | |
| }, | |
| { | |
| "title": ["Overcome procrastination", "give me tips"], | |
| "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", | |
| }, | |
| ], | |
| ) | |
| DEFAULT_USER_ROLE = PersistentConfig( | |
| "DEFAULT_USER_ROLE", | |
| "ui.default_user_role", | |
| os.getenv("DEFAULT_USER_ROLE", "pending"), | |
| ) | |
| USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( | |
| os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() | |
| == "true" | |
| ) | |
| USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = ( | |
| os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower() | |
| == "true" | |
| ) | |
| USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = ( | |
| os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower() | |
| == "true" | |
| ) | |
| USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = ( | |
| os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" | |
| ) | |
| USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS_CHAT_DELETE = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS_CHAT_EDIT = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS_CHAT_TEMPORARY = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS = PersistentConfig( | |
| "USER_PERMISSIONS", | |
| "user.permissions", | |
| { | |
| "workspace": { | |
| "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, | |
| "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, | |
| "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, | |
| "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, | |
| }, | |
| "chat": { | |
| "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, | |
| "delete": USER_PERMISSIONS_CHAT_DELETE, | |
| "edit": USER_PERMISSIONS_CHAT_EDIT, | |
| "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, | |
| }, | |
| }, | |
| ) | |
| ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig( | |
| "ENABLE_EVALUATION_ARENA_MODELS", | |
| "evaluation.arena.enable", | |
| os.environ.get("ENABLE_EVALUATION_ARENA_MODELS", "True").lower() == "true", | |
| ) | |
| EVALUATION_ARENA_MODELS = PersistentConfig( | |
| "EVALUATION_ARENA_MODELS", | |
| "evaluation.arena.models", | |
| [], | |
| ) | |
| DEFAULT_ARENA_MODEL = { | |
| "id": "arena-model", | |
| "name": "Arena Model", | |
| "meta": { | |
| "profile_image_url": "/favicon.png", | |
| "description": "Submit your questions to anonymous AI chatbots and vote on the best response.", | |
| "model_ids": None, | |
| }, | |
| } | |
| WEBHOOK_URL = PersistentConfig( | |
| "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") | |
| ) | |
| ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" | |
| ENABLE_ADMIN_CHAT_ACCESS = ( | |
| os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true" | |
| ) | |
| ENABLE_COMMUNITY_SHARING = PersistentConfig( | |
| "ENABLE_COMMUNITY_SHARING", | |
| "ui.enable_community_sharing", | |
| os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", | |
| ) | |
| ENABLE_MESSAGE_RATING = PersistentConfig( | |
| "ENABLE_MESSAGE_RATING", | |
| "ui.enable_message_rating", | |
| os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", | |
| ) | |
| def validate_cors_origins(origins): | |
| for origin in origins: | |
| if origin != "*": | |
| validate_cors_origin(origin) | |
| def validate_cors_origin(origin): | |
| parsed_url = urlparse(origin) | |
| # Check if the scheme is either http or https | |
| if parsed_url.scheme not in ["http", "https"]: | |
| raise ValueError( | |
| f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed." | |
| ) | |
| # Ensure that the netloc (domain + port) is present, indicating it's a valid URL | |
| if not parsed_url.netloc: | |
| raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.") | |
| # For production, you should only need one host as | |
| # fastapi serves the svelte-kit built frontend and backend from the same host and port. | |
| # To test CORS_ALLOW_ORIGIN locally, you can set something like | |
| # CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 | |
| # in your .env file depending on your frontend port, 5173 in this case. | |
| CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") | |
| if "*" in CORS_ALLOW_ORIGIN: | |
| log.warning( | |
| "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n" | |
| ) | |
| validate_cors_origins(CORS_ALLOW_ORIGIN) | |
| class BannerModel(BaseModel): | |
| id: str | |
| type: str | |
| title: Optional[str] = None | |
| content: str | |
| dismissible: bool | |
| timestamp: int | |
| try: | |
| banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) | |
| banners = [BannerModel(**banner) for banner in banners] | |
| except Exception as e: | |
| print(f"Error loading WEBUI_BANNERS: {e}") | |
| banners = [] | |
| WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners) | |
| SHOW_ADMIN_DETAILS = PersistentConfig( | |
| "SHOW_ADMIN_DETAILS", | |
| "auth.admin.show", | |
| os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true", | |
| ) | |
| ADMIN_EMAIL = PersistentConfig( | |
| "ADMIN_EMAIL", | |
| "auth.admin.email", | |
| os.environ.get("ADMIN_EMAIL", None), | |
| ) | |
| #################################### | |
| # TASKS | |
| #################################### | |
| TASK_MODEL = PersistentConfig( | |
| "TASK_MODEL", | |
| "task.model.default", | |
| os.environ.get("TASK_MODEL", ""), | |
| ) | |
| TASK_MODEL_EXTERNAL = PersistentConfig( | |
| "TASK_MODEL_EXTERNAL", | |
| "task.model.external", | |
| os.environ.get("TASK_MODEL_EXTERNAL", ""), | |
| ) | |
| TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "TITLE_GENERATION_PROMPT_TEMPLATE", | |
| "task.title.prompt_template", | |
| os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), | |
| ) | |
| TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "TAGS_GENERATION_PROMPT_TEMPLATE", | |
| "task.tags.prompt_template", | |
| os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""), | |
| ) | |
| ENABLE_TAGS_GENERATION = PersistentConfig( | |
| "ENABLE_TAGS_GENERATION", | |
| "task.tags.enable", | |
| os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", | |
| ) | |
| ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig( | |
| "ENABLE_SEARCH_QUERY_GENERATION", | |
| "task.query.search.enable", | |
| os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true", | |
| ) | |
| ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig( | |
| "ENABLE_RETRIEVAL_QUERY_GENERATION", | |
| "task.query.retrieval.enable", | |
| os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true", | |
| ) | |
| QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "QUERY_GENERATION_PROMPT_TEMPLATE", | |
| "task.query.prompt_template", | |
| os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""), | |
| ) | |
| DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: | |
| Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list. | |
| ### Guidelines: | |
| - Respond exclusively with a JSON object. | |
| - If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise. | |
| - If no search query is necessary, output should be: { "queries": [] } | |
| - Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required. | |
| - Be concise, focusing strictly on composing search queries with no additional commentary or text. | |
| - When in doubt, prefer to suggest a search for comprehensiveness. | |
| - Today's date is: {{CURRENT_DATE}} | |
| ### Output: | |
| JSON format: { | |
| "queries": ["query1", "query2"] | |
| } | |
| ### Chat History: | |
| <chat_history> | |
| {{MESSAGES:END:6}} | |
| </chat_history> | |
| """ | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( | |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", | |
| "task.tools.prompt_template", | |
| os.environ.get("TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", ""), | |
| ) | |
| #################################### | |
| # Vector Database | |
| #################################### | |
| VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") | |
| # Chroma | |
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | |
| CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) | |
| CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) | |
| CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") | |
| CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) | |
| CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "") | |
| CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "") | |
| # Comma-separated list of header=value pairs | |
| CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") | |
| if CHROMA_HTTP_HEADERS: | |
| CHROMA_HTTP_HEADERS = dict( | |
| [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")] | |
| ) | |
| else: | |
| CHROMA_HTTP_HEADERS = None | |
| CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" | |
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) | |
| # Milvus | |
| MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") | |
| # Qdrant | |
| QDRANT_URI = os.environ.get("QDRANT_URI", None) | |
| QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) | |
| # OpenSearch | |
| OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") | |
| OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True) | |
| OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) | |
| OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) | |
| OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) | |
| # Pgvector | |
| PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL) | |
| if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"): | |
| raise ValueError( | |
| "Pgvector requires setting PGVECTOR_DB_URL or using Postgres with vector extension as the primary database." | |
| ) | |
| #################################### | |
| # Information Retrieval (RAG) | |
| #################################### | |
| # RAG Content Extraction | |
| CONTENT_EXTRACTION_ENGINE = PersistentConfig( | |
| "CONTENT_EXTRACTION_ENGINE", | |
| "rag.CONTENT_EXTRACTION_ENGINE", | |
| os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), | |
| ) | |
| TIKA_SERVER_URL = PersistentConfig( | |
| "TIKA_SERVER_URL", | |
| "rag.tika_server_url", | |
| os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment | |
| ) | |
| RAG_TOP_K = PersistentConfig( | |
| "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) | |
| ) | |
| RAG_RELEVANCE_THRESHOLD = PersistentConfig( | |
| "RAG_RELEVANCE_THRESHOLD", | |
| "rag.relevance_threshold", | |
| float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), | |
| ) | |
| ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( | |
| "ENABLE_RAG_HYBRID_SEARCH", | |
| "rag.enable_hybrid_search", | |
| os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", | |
| ) | |
| RAG_FILE_MAX_COUNT = PersistentConfig( | |
| "RAG_FILE_MAX_COUNT", | |
| "rag.file.max_count", | |
| ( | |
| int(os.environ.get("RAG_FILE_MAX_COUNT")) | |
| if os.environ.get("RAG_FILE_MAX_COUNT") | |
| else None | |
| ), | |
| ) | |
| RAG_FILE_MAX_SIZE = PersistentConfig( | |
| "RAG_FILE_MAX_SIZE", | |
| "rag.file.max_size", | |
| ( | |
| int(os.environ.get("RAG_FILE_MAX_SIZE")) | |
| if os.environ.get("RAG_FILE_MAX_SIZE") | |
| else None | |
| ), | |
| ) | |
| ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( | |
| "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", | |
| "rag.enable_web_loader_ssl_verification", | |
| os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", | |
| ) | |
| RAG_EMBEDDING_ENGINE = PersistentConfig( | |
| "RAG_EMBEDDING_ENGINE", | |
| "rag.embedding_engine", | |
| os.environ.get("RAG_EMBEDDING_ENGINE", ""), | |
| ) | |
| PDF_EXTRACT_IMAGES = PersistentConfig( | |
| "PDF_EXTRACT_IMAGES", | |
| "rag.pdf_extract_images", | |
| os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", | |
| ) | |
| RAG_EMBEDDING_MODEL = PersistentConfig( | |
| "RAG_EMBEDDING_MODEL", | |
| "rag.embedding_model", | |
| os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), | |
| ) | |
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") | |
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true" | |
| ) | |
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | |
| os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" | |
| ) | |
| RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( | |
| "RAG_EMBEDDING_BATCH_SIZE", | |
| "rag.embedding_batch_size", | |
| int( | |
| os.environ.get("RAG_EMBEDDING_BATCH_SIZE") | |
| or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1") | |
| ), | |
| ) | |
| RAG_RERANKING_MODEL = PersistentConfig( | |
| "RAG_RERANKING_MODEL", | |
| "rag.reranking_model", | |
| os.environ.get("RAG_RERANKING_MODEL", ""), | |
| ) | |
| if RAG_RERANKING_MODEL.value != "": | |
| log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") | |
| RAG_RERANKING_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" | |
| ) | |
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | |
| os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" | |
| ) | |
| RAG_TEXT_SPLITTER = PersistentConfig( | |
| "RAG_TEXT_SPLITTER", | |
| "rag.text_splitter", | |
| os.environ.get("RAG_TEXT_SPLITTER", ""), | |
| ) | |
| TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") | |
| TIKTOKEN_ENCODING_NAME = PersistentConfig( | |
| "TIKTOKEN_ENCODING_NAME", | |
| "rag.tiktoken_encoding_name", | |
| os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"), | |
| ) | |
| CHUNK_SIZE = PersistentConfig( | |
| "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) | |
| ) | |
| CHUNK_OVERLAP = PersistentConfig( | |
| "CHUNK_OVERLAP", | |
| "rag.chunk_overlap", | |
| int(os.environ.get("CHUNK_OVERLAP", "100")), | |
| ) | |
| DEFAULT_RAG_TEMPLATE = """### Task: | |
| Respond to the user query using the provided context, incorporating inline citations in the format [source_id] **only when the <source_id> tag is explicitly provided** in the context. | |
| ### Guidelines: | |
| - If you don't know the answer, clearly state that. | |
| - If uncertain, ask the user for clarification. | |
| - Respond in the same language as the user's query. | |
| - If the context is unreadable or of poor quality, inform the user and provide the best possible answer. | |
| - If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding. | |
| - **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.** | |
| - Do not cite if the <source_id> tag is not provided in the context. | |
| - Do not use XML tags in your response. | |
| - Ensure citations are concise and directly related to the information provided. | |
| ### Example of Citation: | |
| If the user asks about a specific topic and the information is found in "whitepaper.pdf" with a provided <source_id>, the response should include the citation like so: | |
| * "According to the study, the proposed method increases efficiency by 20% [whitepaper.pdf]." | |
| If no <source_id> is present, the response should omit the citation. | |
| ### Output: | |
| Provide a clear and direct response to the user's query, including inline citations in the format [source_id] only when the <source_id> tag is present in the context. | |
| <context> | |
| {{CONTEXT}} | |
| </context> | |
| <user_query> | |
| {{QUERY}} | |
| </user_query> | |
| """ | |
| RAG_TEMPLATE = PersistentConfig( | |
| "RAG_TEMPLATE", | |
| "rag.template", | |
| os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), | |
| ) | |
| RAG_OPENAI_API_BASE_URL = PersistentConfig( | |
| "RAG_OPENAI_API_BASE_URL", | |
| "rag.openai_api_base_url", | |
| os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| RAG_OPENAI_API_KEY = PersistentConfig( | |
| "RAG_OPENAI_API_KEY", | |
| "rag.openai_api_key", | |
| os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| RAG_OLLAMA_BASE_URL = PersistentConfig( | |
| "RAG_OLLAMA_BASE_URL", | |
| "rag.ollama.url", | |
| os.getenv("RAG_OLLAMA_BASE_URL", OLLAMA_BASE_URL), | |
| ) | |
| RAG_OLLAMA_API_KEY = PersistentConfig( | |
| "RAG_OLLAMA_API_KEY", | |
| "rag.ollama.key", | |
| os.getenv("RAG_OLLAMA_API_KEY", ""), | |
| ) | |
| ENABLE_RAG_LOCAL_WEB_FETCH = ( | |
| os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" | |
| ) | |
| YOUTUBE_LOADER_LANGUAGE = PersistentConfig( | |
| "YOUTUBE_LOADER_LANGUAGE", | |
| "rag.youtube_loader_language", | |
| os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), | |
| ) | |
| ENABLE_RAG_WEB_SEARCH = PersistentConfig( | |
| "ENABLE_RAG_WEB_SEARCH", | |
| "rag.web.search.enable", | |
| os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true", | |
| ) | |
| RAG_WEB_SEARCH_ENGINE = PersistentConfig( | |
| "RAG_WEB_SEARCH_ENGINE", | |
| "rag.web.search.engine", | |
| os.getenv("RAG_WEB_SEARCH_ENGINE", ""), | |
| ) | |
| # You can provide a list of your own websites to filter after performing a web search. | |
| # This ensures the highest level of safety and reliability of the information sources. | |
| RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( | |
| "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST", | |
| "rag.rag.web.search.domain.filter_list", | |
| [ | |
| # "wikipedia.com", | |
| # "wikimedia.org", | |
| # "wikidata.org", | |
| ], | |
| ) | |
| SEARXNG_QUERY_URL = PersistentConfig( | |
| "SEARXNG_QUERY_URL", | |
| "rag.web.search.searxng_query_url", | |
| os.getenv("SEARXNG_QUERY_URL", ""), | |
| ) | |
| GOOGLE_PSE_API_KEY = PersistentConfig( | |
| "GOOGLE_PSE_API_KEY", | |
| "rag.web.search.google_pse_api_key", | |
| os.getenv("GOOGLE_PSE_API_KEY", ""), | |
| ) | |
| GOOGLE_PSE_ENGINE_ID = PersistentConfig( | |
| "GOOGLE_PSE_ENGINE_ID", | |
| "rag.web.search.google_pse_engine_id", | |
| os.getenv("GOOGLE_PSE_ENGINE_ID", ""), | |
| ) | |
| BRAVE_SEARCH_API_KEY = PersistentConfig( | |
| "BRAVE_SEARCH_API_KEY", | |
| "rag.web.search.brave_search_api_key", | |
| os.getenv("BRAVE_SEARCH_API_KEY", ""), | |
| ) | |
| MOJEEK_SEARCH_API_KEY = PersistentConfig( | |
| "MOJEEK_SEARCH_API_KEY", | |
| "rag.web.search.mojeek_search_api_key", | |
| os.getenv("MOJEEK_SEARCH_API_KEY", ""), | |
| ) | |
| SERPSTACK_API_KEY = PersistentConfig( | |
| "SERPSTACK_API_KEY", | |
| "rag.web.search.serpstack_api_key", | |
| os.getenv("SERPSTACK_API_KEY", ""), | |
| ) | |
| SERPSTACK_HTTPS = PersistentConfig( | |
| "SERPSTACK_HTTPS", | |
| "rag.web.search.serpstack_https", | |
| os.getenv("SERPSTACK_HTTPS", "True").lower() == "true", | |
| ) | |
| SERPER_API_KEY = PersistentConfig( | |
| "SERPER_API_KEY", | |
| "rag.web.search.serper_api_key", | |
| os.getenv("SERPER_API_KEY", ""), | |
| ) | |
| SERPLY_API_KEY = PersistentConfig( | |
| "SERPLY_API_KEY", | |
| "rag.web.search.serply_api_key", | |
| os.getenv("SERPLY_API_KEY", ""), | |
| ) | |
| TAVILY_API_KEY = PersistentConfig( | |
| "TAVILY_API_KEY", | |
| "rag.web.search.tavily_api_key", | |
| os.getenv("TAVILY_API_KEY", ""), | |
| ) | |
| JINA_API_KEY = PersistentConfig( | |
| "JINA_API_KEY", | |
| "rag.web.search.jina_api_key", | |
| os.getenv("JINA_API_KEY", ""), | |
| ) | |
| SEARCHAPI_API_KEY = PersistentConfig( | |
| "SEARCHAPI_API_KEY", | |
| "rag.web.search.searchapi_api_key", | |
| os.getenv("SEARCHAPI_API_KEY", ""), | |
| ) | |
| SEARCHAPI_ENGINE = PersistentConfig( | |
| "SEARCHAPI_ENGINE", | |
| "rag.web.search.searchapi_engine", | |
| os.getenv("SEARCHAPI_ENGINE", ""), | |
| ) | |
| BING_SEARCH_V7_ENDPOINT = PersistentConfig( | |
| "BING_SEARCH_V7_ENDPOINT", | |
| "rag.web.search.bing_search_v7_endpoint", | |
| os.environ.get( | |
| "BING_SEARCH_V7_ENDPOINT", "https://api.bing.microsoft.com/v7.0/search" | |
| ), | |
| ) | |
| BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig( | |
| "BING_SEARCH_V7_SUBSCRIPTION_KEY", | |
| "rag.web.search.bing_search_v7_subscription_key", | |
| os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""), | |
| ) | |
| RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( | |
| "RAG_WEB_SEARCH_RESULT_COUNT", | |
| "rag.web.search.result_count", | |
| int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")), | |
| ) | |
| RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( | |
| "RAG_WEB_SEARCH_CONCURRENT_REQUESTS", | |
| "rag.web.search.concurrent_requests", | |
| int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), | |
| ) | |
| #################################### | |
| # Images | |
| #################################### | |
| IMAGE_GENERATION_ENGINE = PersistentConfig( | |
| "IMAGE_GENERATION_ENGINE", | |
| "image_generation.engine", | |
| os.getenv("IMAGE_GENERATION_ENGINE", "openai"), | |
| ) | |
| ENABLE_IMAGE_GENERATION = PersistentConfig( | |
| "ENABLE_IMAGE_GENERATION", | |
| "image_generation.enable", | |
| os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", | |
| ) | |
| AUTOMATIC1111_BASE_URL = PersistentConfig( | |
| "AUTOMATIC1111_BASE_URL", | |
| "image_generation.automatic1111.base_url", | |
| os.getenv("AUTOMATIC1111_BASE_URL", ""), | |
| ) | |
| AUTOMATIC1111_API_AUTH = PersistentConfig( | |
| "AUTOMATIC1111_API_AUTH", | |
| "image_generation.automatic1111.api_auth", | |
| os.getenv("AUTOMATIC1111_API_AUTH", ""), | |
| ) | |
| AUTOMATIC1111_CFG_SCALE = PersistentConfig( | |
| "AUTOMATIC1111_CFG_SCALE", | |
| "image_generation.automatic1111.cfg_scale", | |
| ( | |
| float(os.environ.get("AUTOMATIC1111_CFG_SCALE")) | |
| if os.environ.get("AUTOMATIC1111_CFG_SCALE") | |
| else None | |
| ), | |
| ) | |
| AUTOMATIC1111_SAMPLER = PersistentConfig( | |
| "AUTOMATIC1111_SAMPLER", | |
| "image_generation.automatic1111.sampler", | |
| ( | |
| os.environ.get("AUTOMATIC1111_SAMPLER") | |
| if os.environ.get("AUTOMATIC1111_SAMPLER") | |
| else None | |
| ), | |
| ) | |
| AUTOMATIC1111_SCHEDULER = PersistentConfig( | |
| "AUTOMATIC1111_SCHEDULER", | |
| "image_generation.automatic1111.scheduler", | |
| ( | |
| os.environ.get("AUTOMATIC1111_SCHEDULER") | |
| if os.environ.get("AUTOMATIC1111_SCHEDULER") | |
| else None | |
| ), | |
| ) | |
| COMFYUI_BASE_URL = PersistentConfig( | |
| "COMFYUI_BASE_URL", | |
| "image_generation.comfyui.base_url", | |
| os.getenv("COMFYUI_BASE_URL", ""), | |
| ) | |
| COMFYUI_DEFAULT_WORKFLOW = """ | |
| { | |
| "3": { | |
| "inputs": { | |
| "seed": 0, | |
| "steps": 20, | |
| "cfg": 8, | |
| "sampler_name": "euler", | |
| "scheduler": "normal", | |
| "denoise": 1, | |
| "model": [ | |
| "4", | |
| 0 | |
| ], | |
| "positive": [ | |
| "6", | |
| 0 | |
| ], | |
| "negative": [ | |
| "7", | |
| 0 | |
| ], | |
| "latent_image": [ | |
| "5", | |
| 0 | |
| ] | |
| }, | |
| "class_type": "KSampler", | |
| "_meta": { | |
| "title": "KSampler" | |
| } | |
| }, | |
| "4": { | |
| "inputs": { | |
| "ckpt_name": "model.safetensors" | |
| }, | |
| "class_type": "CheckpointLoaderSimple", | |
| "_meta": { | |
| "title": "Load Checkpoint" | |
| } | |
| }, | |
| "5": { | |
| "inputs": { | |
| "width": 512, | |
| "height": 512, | |
| "batch_size": 1 | |
| }, | |
| "class_type": "EmptyLatentImage", | |
| "_meta": { | |
| "title": "Empty Latent Image" | |
| } | |
| }, | |
| "6": { | |
| "inputs": { | |
| "text": "Prompt", | |
| "clip": [ | |
| "4", | |
| 1 | |
| ] | |
| }, | |
| "class_type": "CLIPTextEncode", | |
| "_meta": { | |
| "title": "CLIP Text Encode (Prompt)" | |
| } | |
| }, | |
| "7": { | |
| "inputs": { | |
| "text": "", | |
| "clip": [ | |
| "4", | |
| 1 | |
| ] | |
| }, | |
| "class_type": "CLIPTextEncode", | |
| "_meta": { | |
| "title": "CLIP Text Encode (Prompt)" | |
| } | |
| }, | |
| "8": { | |
| "inputs": { | |
| "samples": [ | |
| "3", | |
| 0 | |
| ], | |
| "vae": [ | |
| "4", | |
| 2 | |
| ] | |
| }, | |
| "class_type": "VAEDecode", | |
| "_meta": { | |
| "title": "VAE Decode" | |
| } | |
| }, | |
| "9": { | |
| "inputs": { | |
| "filename_prefix": "ComfyUI", | |
| "images": [ | |
| "8", | |
| 0 | |
| ] | |
| }, | |
| "class_type": "SaveImage", | |
| "_meta": { | |
| "title": "Save Image" | |
| } | |
| } | |
| } | |
| """ | |
| COMFYUI_WORKFLOW = PersistentConfig( | |
| "COMFYUI_WORKFLOW", | |
| "image_generation.comfyui.workflow", | |
| os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW), | |
| ) | |
| COMFYUI_WORKFLOW_NODES = PersistentConfig( | |
| "COMFYUI_WORKFLOW", | |
| "image_generation.comfyui.nodes", | |
| [], | |
| ) | |
| IMAGES_OPENAI_API_BASE_URL = PersistentConfig( | |
| "IMAGES_OPENAI_API_BASE_URL", | |
| "image_generation.openai.api_base_url", | |
| os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| IMAGES_OPENAI_API_KEY = PersistentConfig( | |
| "IMAGES_OPENAI_API_KEY", | |
| "image_generation.openai.api_key", | |
| os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| IMAGE_SIZE = PersistentConfig( | |
| "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") | |
| ) | |
| IMAGE_STEPS = PersistentConfig( | |
| "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) | |
| ) | |
| IMAGE_GENERATION_MODEL = PersistentConfig( | |
| "IMAGE_GENERATION_MODEL", | |
| "image_generation.model", | |
| os.getenv("IMAGE_GENERATION_MODEL", ""), | |
| ) | |
| #################################### | |
| # Audio | |
| #################################### | |
| # Transcription | |
| WHISPER_MODEL = PersistentConfig( | |
| "WHISPER_MODEL", | |
| "audio.stt.whisper_model", | |
| os.getenv("WHISPER_MODEL", "base"), | |
| ) | |
| WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") | |
| WHISPER_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( | |
| "AUDIO_STT_OPENAI_API_BASE_URL", | |
| "audio.stt.openai.api_base_url", | |
| os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| AUDIO_STT_OPENAI_API_KEY = PersistentConfig( | |
| "AUDIO_STT_OPENAI_API_KEY", | |
| "audio.stt.openai.api_key", | |
| os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| AUDIO_STT_ENGINE = PersistentConfig( | |
| "AUDIO_STT_ENGINE", | |
| "audio.stt.engine", | |
| os.getenv("AUDIO_STT_ENGINE", ""), | |
| ) | |
| AUDIO_STT_MODEL = PersistentConfig( | |
| "AUDIO_STT_MODEL", | |
| "audio.stt.model", | |
| os.getenv("AUDIO_STT_MODEL", ""), | |
| ) | |
| AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( | |
| "AUDIO_TTS_OPENAI_API_BASE_URL", | |
| "audio.tts.openai.api_base_url", | |
| os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( | |
| "AUDIO_TTS_OPENAI_API_KEY", | |
| "audio.tts.openai.api_key", | |
| os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| AUDIO_TTS_API_KEY = PersistentConfig( | |
| "AUDIO_TTS_API_KEY", | |
| "audio.tts.api_key", | |
| os.getenv("AUDIO_TTS_API_KEY", ""), | |
| ) | |
| AUDIO_TTS_ENGINE = PersistentConfig( | |
| "AUDIO_TTS_ENGINE", | |
| "audio.tts.engine", | |
| os.getenv("AUDIO_TTS_ENGINE", ""), | |
| ) | |
| AUDIO_TTS_MODEL = PersistentConfig( | |
| "AUDIO_TTS_MODEL", | |
| "audio.tts.model", | |
| os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model | |
| ) | |
| AUDIO_TTS_VOICE = PersistentConfig( | |
| "AUDIO_TTS_VOICE", | |
| "audio.tts.voice", | |
| os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice | |
| ) | |
| AUDIO_TTS_SPLIT_ON = PersistentConfig( | |
| "AUDIO_TTS_SPLIT_ON", | |
| "audio.tts.split_on", | |
| os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"), | |
| ) | |
| AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig( | |
| "AUDIO_TTS_AZURE_SPEECH_REGION", | |
| "audio.tts.azure.speech_region", | |
| os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", "eastus"), | |
| ) | |
| AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( | |
| "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", | |
| "audio.tts.azure.speech_output_format", | |
| os.getenv( | |
| "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3" | |
| ), | |
| ) | |
| #################################### | |
| # LDAP | |
| #################################### | |
| ENABLE_LDAP = PersistentConfig( | |
| "ENABLE_LDAP", | |
| "ldap.enable", | |
| os.environ.get("ENABLE_LDAP", "false").lower() == "true", | |
| ) | |
| LDAP_SERVER_LABEL = PersistentConfig( | |
| "LDAP_SERVER_LABEL", | |
| "ldap.server.label", | |
| os.environ.get("LDAP_SERVER_LABEL", "LDAP Server"), | |
| ) | |
| LDAP_SERVER_HOST = PersistentConfig( | |
| "LDAP_SERVER_HOST", | |
| "ldap.server.host", | |
| os.environ.get("LDAP_SERVER_HOST", "localhost"), | |
| ) | |
| LDAP_SERVER_PORT = PersistentConfig( | |
| "LDAP_SERVER_PORT", | |
| "ldap.server.port", | |
| int(os.environ.get("LDAP_SERVER_PORT", "389")), | |
| ) | |
| LDAP_ATTRIBUTE_FOR_USERNAME = PersistentConfig( | |
| "LDAP_ATTRIBUTE_FOR_USERNAME", | |
| "ldap.server.attribute_for_username", | |
| os.environ.get("LDAP_ATTRIBUTE_FOR_USERNAME", "uid"), | |
| ) | |
| LDAP_APP_DN = PersistentConfig( | |
| "LDAP_APP_DN", "ldap.server.app_dn", os.environ.get("LDAP_APP_DN", "") | |
| ) | |
| LDAP_APP_PASSWORD = PersistentConfig( | |
| "LDAP_APP_PASSWORD", | |
| "ldap.server.app_password", | |
| os.environ.get("LDAP_APP_PASSWORD", ""), | |
| ) | |
| LDAP_SEARCH_BASE = PersistentConfig( | |
| "LDAP_SEARCH_BASE", "ldap.server.users_dn", os.environ.get("LDAP_SEARCH_BASE", "") | |
| ) | |
| LDAP_SEARCH_FILTERS = PersistentConfig( | |
| "LDAP_SEARCH_FILTER", | |
| "ldap.server.search_filter", | |
| os.environ.get("LDAP_SEARCH_FILTER", ""), | |
| ) | |
| LDAP_USE_TLS = PersistentConfig( | |
| "LDAP_USE_TLS", | |
| "ldap.server.use_tls", | |
| os.environ.get("LDAP_USE_TLS", "True").lower() == "true", | |
| ) | |
| LDAP_CA_CERT_FILE = PersistentConfig( | |
| "LDAP_CA_CERT_FILE", | |
| "ldap.server.ca_cert_file", | |
| os.environ.get("LDAP_CA_CERT_FILE", ""), | |
| ) | |
| LDAP_CIPHERS = PersistentConfig( | |
| "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL") | |
| ) | |