| | import asyncio |
| | import os |
| | from dotenv import load_dotenv |
| | from sqlalchemy.ext.asyncio import create_async_engine |
| | from sqlalchemy import text |
| |
|
| | load_dotenv() |
| |
|
| | async def setup_rag_db(): |
| | database_url = os.getenv("DATABASE_URL") |
| | if not database_url: |
| | print("DATABASE_URL not found") |
| | return |
| |
|
| | |
| | if "postgresql://" in database_url and "asyncpg" not in database_url: |
| | database_url = database_url.replace("postgresql://", "postgresql+asyncpg://") |
| |
|
| | engine = create_async_engine( |
| | database_url, |
| | echo=True, |
| | connect_args={ |
| | "statement_cache_size": 0, |
| | "server_settings": { |
| | "jit": "off", |
| | } |
| | } |
| | ) |
| |
|
| | async with engine.begin() as conn: |
| | print("Enabling pgvector extension...") |
| | await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) |
| | await conn.execute(text("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")) |
| |
|
| | print("Checking/Creating documents table...") |
| | |
| | await conn.execute(text("DROP TABLE IF EXISTS documents CASCADE")) |
| | |
| | await conn.execute(text(""" |
| | CREATE TABLE IF NOT EXISTS documents ( |
| | id uuid PRIMARY KEY DEFAULT uuid_generate_v4(), |
| | content text, |
| | metadata jsonb, |
| | embedding vector(3072) -- Ensure 3072 dimensions |
| | ) |
| | """)) |
| | |
| | print("Creating match_documents function...") |
| | await conn.execute(text(""" |
| | CREATE OR REPLACE FUNCTION match_documents ( |
| | query_embedding vector(3072), |
| | match_threshold float, |
| | match_count int |
| | ) |
| | RETURNS TABLE ( |
| | id uuid, |
| | content text, |
| | metadata jsonb, |
| | similarity float |
| | ) |
| | LANGUAGE plpgsql |
| | AS $$ |
| | BEGIN |
| | RETURN QUERY |
| | SELECT |
| | documents.id, |
| | documents.content, |
| | documents.metadata, |
| | 1 - (documents.embedding <=> query_embedding) AS similarity |
| | FROM documents |
| | WHERE 1 - (documents.embedding <=> query_embedding) > match_threshold |
| | ORDER BY similarity DESC |
| | LIMIT match_count; |
| | END; |
| | $$; |
| | """)) |
| | |
| | print("Done setup.") |
| |
|
| | if __name__ == "__main__": |
| | asyncio.run(setup_rag_db()) |
| |
|