| | import logging |
| | import os |
| | import time |
| |
|
| | import docker |
| | import pytest |
| | from docker import DockerClient |
| | from pytest_docker.plugin import get_docker_ip |
| | from fastapi.testclient import TestClient |
| | from sqlalchemy import text, create_engine |
| |
|
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | def get_fast_api_client(): |
| | from main import app |
| |
|
| | with TestClient(app) as c: |
| | return c |
| |
|
| |
|
| | class AbstractIntegrationTest: |
| | BASE_PATH = None |
| |
|
| | def create_url(self, path="", query_params=None): |
| | if self.BASE_PATH is None: |
| | raise Exception("BASE_PATH is not set") |
| | parts = self.BASE_PATH.split("/") |
| | parts = [part.strip() for part in parts if part.strip() != ""] |
| | path_parts = path.split("/") |
| | path_parts = [part.strip() for part in path_parts if part.strip() != ""] |
| | query_parts = "" |
| | if query_params: |
| | query_parts = "&".join( |
| | [f"{key}={value}" for key, value in query_params.items()] |
| | ) |
| | query_parts = f"?{query_parts}" |
| | return "/".join(parts + path_parts) + query_parts |
| |
|
| | @classmethod |
| | def setup_class(cls): |
| | pass |
| |
|
| | def setup_method(self): |
| | pass |
| |
|
| | @classmethod |
| | def teardown_class(cls): |
| | pass |
| |
|
| | def teardown_method(self): |
| | pass |
| |
|
| |
|
| | class AbstractPostgresTest(AbstractIntegrationTest): |
| | DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" |
| | docker_client: DockerClient |
| |
|
| | @classmethod |
| | def _create_db_url(cls, env_vars_postgres: dict) -> str: |
| | host = get_docker_ip() |
| | user = env_vars_postgres["POSTGRES_USER"] |
| | pw = env_vars_postgres["POSTGRES_PASSWORD"] |
| | port = 8081 |
| | db = env_vars_postgres["POSTGRES_DB"] |
| | return f"postgresql://{user}:{pw}@{host}:{port}/{db}" |
| |
|
| | @classmethod |
| | def setup_class(cls): |
| | super().setup_class() |
| | try: |
| | env_vars_postgres = { |
| | "POSTGRES_USER": "user", |
| | "POSTGRES_PASSWORD": "example", |
| | "POSTGRES_DB": "openwebui", |
| | } |
| | cls.docker_client = docker.from_env() |
| | cls.docker_client.containers.run( |
| | "postgres:16.2", |
| | detach=True, |
| | environment=env_vars_postgres, |
| | name=cls.DOCKER_CONTAINER_NAME, |
| | ports={5432: ("0.0.0.0", 8081)}, |
| | command="postgres -c log_statement=all", |
| | ) |
| | time.sleep(0.5) |
| |
|
| | database_url = cls._create_db_url(env_vars_postgres) |
| | os.environ["DATABASE_URL"] = database_url |
| | retries = 10 |
| | db = None |
| | while retries > 0: |
| | try: |
| | from open_webui.config import OPEN_WEBUI_DIR |
| |
|
| | db = create_engine(database_url, pool_pre_ping=True) |
| | db = db.connect() |
| | log.info("postgres is ready!") |
| | break |
| | except Exception as e: |
| | log.warning(e) |
| | time.sleep(3) |
| | retries -= 1 |
| |
|
| | if db: |
| | |
| | cls.fast_api_client = get_fast_api_client() |
| | db.close() |
| | else: |
| | raise Exception("Could not connect to Postgres") |
| | except Exception as ex: |
| | log.error(ex) |
| | cls.teardown_class() |
| | pytest.fail(f"Could not setup test environment: {ex}") |
| |
|
| | def _check_db_connection(self): |
| | from open_webui.internal.db import Session |
| |
|
| | retries = 10 |
| | while retries > 0: |
| | try: |
| | Session.execute(text("SELECT 1")) |
| | Session.commit() |
| | break |
| | except Exception as e: |
| | Session.rollback() |
| | log.warning(e) |
| | time.sleep(3) |
| | retries -= 1 |
| |
|
| | def setup_method(self): |
| | super().setup_method() |
| | self._check_db_connection() |
| |
|
| | @classmethod |
| | def teardown_class(cls) -> None: |
| | super().teardown_class() |
| | cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) |
| |
|
| | def teardown_method(self): |
| | from open_webui.internal.db import Session |
| |
|
| | |
| | Session.commit() |
| |
|
| | |
| | tables = [ |
| | "auth", |
| | "chat", |
| | "chatidtag", |
| | "document", |
| | "memory", |
| | "model", |
| | "prompt", |
| | "tag", |
| | '"user"', |
| | ] |
| | for table in tables: |
| | Session.execute(text(f"TRUNCATE TABLE {table}")) |
| | Session.commit() |
| |
|