| from sqlalchemy.dialects.postgresql import JSONB |
| from sqlalchemy.orm import declared_attr |
| from pgvector.sqlalchemy import Vector |
| from sqlalchemy import Column |
| from datetime import datetime |
| from util import snake_case |
| import uuid as uuid_pkg |
|
|
| from sqlmodel import ( |
| UniqueConstraint, |
| create_engine, |
| Relationship, |
| SQLModel, |
| Session, |
| select, |
| Field, |
| ) |
| from typing import ( |
| Optional, |
| Union, |
| List, |
| Dict, |
| Any |
| ) |
| from config import ( |
| LLM_DEFAULT_DISTANCE_STRATEGY, |
| VECTOR_EMBEDDINGS_COUNT, |
| LLM_MAX_OUTPUT_TOKENS, |
| DISTANCE_STRATEGIES, |
| LLM_MIN_NODE_LIMIT, |
| PGVECTOR_ADD_INDEX, |
| ENTITY_STATUS, |
| CHANNEL_TYPE, |
| LLM_MODELS, |
| DB_USER, |
| SU_DSN, |
| logger, |
| ) |
|
|
|
|
| |
| |
| |
| class BaseModel(SQLModel): |
| @declared_attr |
| def __tablename__(cls) -> str: |
| return snake_case(cls.__name__) |
|
|
| @classmethod |
| def by_uuid(self, _uuid: uuid_pkg.UUID): |
| with Session(get_engine()) as session: |
| q = select(self).where(self.uuid == _uuid) |
| org = session.exec(q).first() |
| return org if org else None |
|
|
| def update(self, o: Union[SQLModel, dict] = None): |
| if not o: |
| raise ValueError("Must provide a model or dict to update values") |
| o = o if isinstance(o, dict) else o.dict(exclude_unset=True) |
| for key, value in o.items(): |
| setattr(self, key, value) |
|
|
| |
| with Session(get_engine()) as session: |
| session.add(self) |
| session.commit() |
| session.refresh(self) |
|
|
| def delete(self): |
| with Session(get_engine()) as session: |
| self.status = ENTITY_STATUS.DELETED |
| self.updated_at = datetime.utcnow() |
| session.add(self) |
| session.commit() |
| session.refresh(self) |
|
|
| @classmethod |
| def create(self, o: Union[SQLModel, dict] = None): |
| if not o: |
| raise ValueError("Must provide a model or dict to update values") |
|
|
| with Session(get_engine()) as session: |
| obj = self.from_orm(o) if isinstance(o, SQLModel) else self(**o) |
| session.add(obj) |
| session.commit() |
| session.refresh(obj) |
|
|
| return obj |
|
|
|
|
| |
| |
| |
| class Organization(BaseModel, table=True): |
| id: Optional[int] = Field(default=None, primary_key=True) |
| uuid: Optional[uuid_pkg.UUID] = Field( |
| unique=True, default_factory=uuid_pkg.uuid4 |
| ) |
| display_name: Optional[str] = Field( |
| default="Untitled Organization 😊", index=True |
| ) |
| namespace: str = Field( |
| unique=True, index=True |
| ) |
| bot_url: Optional[str] = Field(default=None) |
| status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) |
| created_at: Optional[datetime] = Field(default_factory=datetime.now) |
| updated_at: Optional[datetime] = Field(default_factory=datetime.now) |
|
|
| |
| |
| |
| projects: Optional[List["Project"]] = Relationship(back_populates="organization") |
| documents: Optional[List["Document"]] = Relationship(back_populates="organization") |
|
|
| @property |
| def project_count(self) -> int: |
| return len(self.projects) |
|
|
| @property |
| def document_count(self) -> int: |
| return len(self.documents) |
|
|
| def __repr__(self): |
| return f"<Organization id={self.id} name={self.display_name} namespace={self.namespace} uuid={self.uuid}>" |
|
|
|
|
| class OrganizationCreate(SQLModel): |
| display_name: Optional[str] |
| namespace: Optional[str] |
| bot_url: Optional[str] |
|
|
|
|
| class OrganizationRead(SQLModel): |
| id: int |
| uuid: uuid_pkg.UUID |
| display_name: str |
| namespace: Optional[str] |
| bot_url: Optional[str] |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| class OrganizationUpdate(SQLModel): |
| display_name: Optional[str] |
| namespace: Optional[str] |
| bot_url: Optional[str] |
|
|
|
|
| |
| |
| |
| class User(BaseModel, table=True): |
| id: Optional[int] = Field(default=None, primary_key=True) |
| identifier: str = Field(default=None, unique=True, index=True) |
| identifier_type: Optional[CHANNEL_TYPE] = Field(default=None) |
| uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) |
| first_name: Optional[str] = Field(default=None) |
| last_name: Optional[str] = Field(default=None) |
| email: Optional[str] = Field(default=None) |
| phone: Optional[str] = Field(default=None) |
| dob: Optional[datetime] = Field(default=None) |
| device_fingerprint: Optional[str] = Field(default=None) |
| created_at: Optional[datetime] = Field(default_factory=datetime.now) |
| updated_at: Optional[datetime] = Field(default_factory=datetime.now) |
|
|
| |
| |
| |
| chat_sessions: Optional[List["ChatSession"]] = Relationship(back_populates="user") |
|
|
| @property |
| def chat_session_count(self) -> int: |
| return len(self.chat_sessions) |
|
|
| __table_args__ = ( |
| UniqueConstraint("identifier", "identifier_type", name="unq_id_idtype"), |
| ) |
|
|
| def __repr__(self): |
| return f"<User id={self.id} uuid={self.uuid} project_id={self.project_id} device_fingerprint={self.device_fingerprint}>" |
|
|
|
|
| class UserCreate(SQLModel): |
| identifier: str |
| identifier_type: CHANNEL_TYPE |
| device_fingerprint: Optional[str] |
| first_name: Optional[str] |
| last_name: Optional[str] |
| email: Optional[str] |
| phone: Optional[str] |
| dob: Optional[datetime] |
|
|
|
|
| class UserReadList(SQLModel): |
| id: int |
| identifier: Optional[str] |
| identifier_type: Optional[CHANNEL_TYPE] |
| uuid: uuid_pkg.UUID |
| device_fingerprint: Optional[str] |
| first_name: Optional[str] |
| last_name: Optional[str] |
| email: Optional[str] |
| phone: Optional[str] |
| dob: Optional[datetime] |
| chat_session_count: int |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| class UserUpdate(SQLModel): |
| device_fingerprint: Optional[str] |
| device_fingerprint: Optional[str] |
| first_name: Optional[str] |
| last_name: Optional[str] |
| email: Optional[str] |
| phone: Optional[str] |
| dob: Optional[datetime] |
|
|
|
|
| |
| |
| |
| class Project(BaseModel, table=True): |
| id: Optional[int] = Field(default=None, primary_key=True) |
| uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) |
| organization_id: int = Field(default=None, foreign_key="organization.id") |
| display_name: str = Field(default="📝 Untitled Project") |
| status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) |
| created_at: Optional[datetime] = Field(default_factory=datetime.now) |
| updated_at: Optional[datetime] = Field(default_factory=datetime.now) |
|
|
| |
| |
| |
| organization: Optional["Organization"] = Relationship(back_populates="projects") |
| documents: Optional[List["Document"]] = Relationship(back_populates="project") |
| chat_sessions: Optional[List["ChatSession"]] = Relationship( |
| back_populates="project" |
| ) |
|
|
| @property |
| def document_count(self) -> int: |
| return len(self.documents) |
|
|
| def __repr__(self): |
| return f"<Project id={self.id} name={self.display_name} uuid={self.uuid} project_id={self.uuid}>" |
|
|
|
|
| class ProjectCreate(SQLModel): |
| display_name: Optional[str] |
|
|
|
|
| class ProjectReadListOrganization(SQLModel): |
| uuid: uuid_pkg.UUID |
| display_name: str |
| namespace: Optional[str] |
| document_count: int |
|
|
|
|
| class ProjectUpdate(SQLModel): |
| display_name: Optional[str] |
| status: Optional[ENTITY_STATUS] |
|
|
|
|
| |
| |
| |
| class Document(BaseModel, table=True): |
| id: Optional[int] = Field(default=None, primary_key=True) |
| uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) |
| organization_id: int = Field(default=None, foreign_key="organization.id") |
| project_id: int = Field(default=None, foreign_key="project.id") |
| display_name: str = Field(default="Untitled Document 😊") |
| url: str = Field(default="") |
| data: Optional[bytes] = Field(default=None) |
| hash: str = Field(default=None) |
| version: Optional[int] = Field(default=1) |
| status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) |
| created_at: Optional[datetime] = Field(default_factory=datetime.now) |
| updated_at: Optional[datetime] = Field(default_factory=datetime.now) |
|
|
| |
| |
| |
| nodes: Optional[List["Node"]] = Relationship(back_populates="document") |
| organization: Optional["Organization"] = Relationship(back_populates="documents") |
| project: Optional["Project"] = Relationship(back_populates="documents") |
|
|
| @property |
| def node_count(self) -> int: |
| return len(self.nodes) |
|
|
| __table_args__ = (UniqueConstraint("uuid", "hash", name="unq_org_document"),) |
|
|
| def __repr__(self): |
| return f"<Document id={self.id} name={self.display_name} uuid={self.uuid}>" |
|
|
|
|
| class ProjectRead(SQLModel): |
| id: int |
| uuid: uuid_pkg.UUID |
| organization: Organization |
| document_count: int |
| documents: Optional[List[Document]] = None |
| display_name: str |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| class DocumentCreate(SQLModel): |
| project: Project |
| display_name: Optional[str] |
| url: Optional[str] |
| version: Optional[str] |
| data: Optional[bytes] |
| hash: Optional[str] |
|
|
|
|
| class DocumentUpdate(SQLModel): |
| status: Optional[ENTITY_STATUS] |
|
|
|
|
| |
| |
| |
| class Node(BaseModel, table=True): |
| class Config: |
| arbitrary_types_allowed = True |
|
|
| id: Optional[int] = Field(default=None, primary_key=True) |
| document_id: int = Field(default=None, foreign_key="document.id") |
| uuid: Optional[uuid_pkg.UUID] = Field(unique=True, default_factory=uuid_pkg.uuid4) |
| embeddings: Optional[List[float]] = Field( |
| sa_column=Column(Vector(VECTOR_EMBEDDINGS_COUNT)) |
| ) |
| meta: Optional[Dict] = Field(default=None, sa_column=Column(JSONB)) |
| token_count: Optional[int] = Field(default=None) |
| text: str = Field(default=None, nullable=False) |
| status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) |
| created_at: Optional[datetime] = Field(default_factory=datetime.now) |
| updated_at: Optional[datetime] = Field(default_factory=datetime.now) |
|
|
| |
| |
| |
| document: Optional["Document"] = Relationship(back_populates="nodes") |
|
|
| def __repr__(self): |
| return f"<Node id={self.id} uuid={self.uuid} document={self.document_id}>" |
|
|
|
|
| class NodeCreate(SQLModel): |
| document: Document |
| embeddings: List[float] |
| token_count: Optional[int] |
| text: str |
| status: Optional[ENTITY_STATUS] |
|
|
|
|
| class NodeRead(SQLModel): |
| id: int |
| document: Document |
| embeddings: Optional[List[float]] |
| token_count: Optional[int] |
| text: str |
| created_at: datetime |
|
|
|
|
| class DocumentReadNodeList(SQLModel): |
| id: int |
| uuid: uuid_pkg.UUID |
| display_name: str |
| node_count: int |
|
|
|
|
| class NodeReadResult(SQLModel): |
| id: int |
| token_count: Optional[int] |
| text: str |
| meta: Optional[Dict] |
|
|
|
|
| class ProjectReadListDocumentList(SQLModel): |
| uuid: uuid_pkg.UUID |
| display_name: str |
| node_count: Optional[int] |
|
|
|
|
| class ProjectReadList(SQLModel): |
| id: int |
| |
| documents: Optional[List[DocumentReadNodeList]] |
| document_count: int |
| uuid: uuid_pkg.UUID |
| display_name: str |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| class NodeReadList(SQLModel): |
| id: int |
| document: DocumentReadNodeList |
| embeddings: Optional[List[float]] |
| token_count: Optional[int] |
| text: str |
| created_at: datetime |
|
|
|
|
| class NodeUpdate(SQLModel): |
| status: Optional[ENTITY_STATUS] = Field(default=ENTITY_STATUS.ACTIVE.value) |
|
|
|
|
| class NodeReadListDocumentRead(SQLModel): |
| uuid: uuid_pkg.UUID |
| token_count: Optional[int] |
| created_at: datetime |
|
|
|
|
| class DocumentReadList(SQLModel): |
| id: int |
| uuid: uuid_pkg.UUID |
| display_name: str |
| version: int |
| nodes: Optional[List[NodeReadListDocumentRead]] = None |
| node_count: int |
| hash: str |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| |
| |
| |
| class ChatSession(BaseModel, table=True): |
| class Config: |
| arbitrary_types_allowed = True |
|
|
| id: Optional[int] = Field(default=None, primary_key=True) |
| session_id: Optional[uuid_pkg.UUID] = Field( |
| index=True, default_factory=uuid_pkg.uuid4 |
| ) |
| user_id: int = Field(default=None, foreign_key="user.id") |
| project_id: int = Field(default=None, foreign_key="project.id") |
| channel: CHANNEL_TYPE = Field(default=CHANNEL_TYPE.TELEGRAM) |
| user_message: str = Field(default=None) |
| token_count: Optional[int] = Field(default=None) |
| embeddings: Optional[List[float]] = Field( |
| sa_column=Column(Vector(VECTOR_EMBEDDINGS_COUNT)) |
| ) |
| response: Optional[str] = Field(default=None) |
| meta: Optional[Dict] = Field(default=None, sa_column=Column(JSONB)) |
| created_at: datetime = Field(default_factory=datetime.now) |
|
|
| |
| |
| |
| user: Optional["User"] = Relationship(back_populates="chat_sessions") |
| project: Optional["Project"] = Relationship(back_populates="chat_sessions") |
|
|
| def __repr__(self): |
| return f"<ChatSession id={self.id} uuid={self.uuid} project_id={self.project_id} user_id={self.user_id} message={self.user_message}>" |
|
|
|
|
| class ChatSessionCreatePost(SQLModel): |
| project_id: Optional[str] = "" |
| organization_id: Optional[str] = "pepe" |
| channel: Optional[CHANNEL_TYPE] = CHANNEL_TYPE.TELEGRAM |
| query: Optional[str] = "What is the weather like in London right now?" |
| identifier: Optional[str] = "@username" |
| distance_strategy: Optional[str] = LLM_DEFAULT_DISTANCE_STRATEGY |
| max_output_tokens: Optional[int] = LLM_MAX_OUTPUT_TOKENS |
| node_limit: Optional[int] = LLM_MIN_NODE_LIMIT |
| model: Optional[str] = LLM_MODELS.GPT_35_TURBO |
| session_id: Optional[str] = "" |
|
|
|
|
| class ChatSessionCreate(SQLModel): |
| channel: CHANNEL_TYPE |
| token_count: Optional[int] |
| user_message: str |
| embeddings: List[float] |
| response: Optional[str] |
|
|
|
|
| class ChatSessionRead(SQLModel): |
| id: int |
| user: User |
| project: Optional[ProjectReadListDocumentList] |
| token_count: Optional[int] |
| channel: CHANNEL_TYPE |
| user_message: str |
| embeddings: List[float] |
| response: Optional[str] |
| meta: Optional[dict] |
| created_at: datetime = Field(default_factory=datetime.now) |
|
|
|
|
| class ChatSessionResponse(SQLModel): |
| meta: Optional[dict] |
| response: Optional[str] |
| user_message: Optional[str] |
|
|
|
|
| class ProjectReadChatSessionRead(SQLModel): |
| id: int |
| token_count: Optional[int] |
| channel: CHANNEL_TYPE |
| created_at: datetime = Field(default_factory=datetime.now) |
|
|
|
|
| class ChatSessionReadUserRead(SQLModel): |
| id: int |
| project: Optional[ProjectReadListDocumentList] |
| token_count: Optional[int] |
| channel: CHANNEL_TYPE |
| user_message: str |
| response: Optional[str] |
| created_at: datetime = Field(default_factory=datetime.now) |
|
|
|
|
| class UserRead(SQLModel): |
| id: int |
| identifier: Optional[str] |
| identifier_type: Optional[CHANNEL_TYPE] |
| uuid: uuid_pkg.UUID |
| language: Optional[str] |
| device_fingerprint: Optional[str] |
| first_name: Optional[str] |
| last_name: Optional[str] |
| email: Optional[str] |
| phone: Optional[str] |
| dob: Optional[datetime] |
| chat_session_count: int |
| chat_sessions: Optional[List[ChatSessionReadUserRead]] |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| class DocumentReadProjectRead(SQLModel): |
| uuid: uuid_pkg.UUID |
| display_name: str |
| namespace: Optional[str] |
| document_count: int |
|
|
|
|
| class DocumentRead(SQLModel): |
| id: int |
| uuid: uuid_pkg.UUID |
| project: DocumentReadProjectRead |
| organization: OrganizationRead |
| display_name: str |
| node_count: int |
| url: Optional[str] |
| version: int |
| data: bytes |
| hash: str |
| created_at: datetime |
| updated_at: datetime |
|
|
|
|
| class WebhookCreate(SQLModel): |
| update_id: str |
| message: Dict[str, Any] |
|
|
|
|
| class WebhookResponse(SQLModel): |
| update_id: str |
| message: Dict[str, Any] |
|
|
|
|
| |
| |
| |
| def get_engine(dsn: str = SU_DSN): |
| return create_engine(dsn) |
|
|
|
|
| def get_session(): |
| with Session(get_engine()) as session: |
| yield session |
|
|
|
|
| def create_db(): |
| logger.info("...Enabling pgvector and creating database tables") |
| enable_vector() |
| BaseModel.metadata.create_all(get_engine(dsn=SU_DSN)) |
| create_user_permissions() |
| create_vector_index() |
|
|
|
|
| def create_user_permissions(): |
| session = Session(get_engine(dsn=SU_DSN)) |
| |
| query = f"GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO {DB_USER};" |
| session.execute(query) |
| session.commit() |
| session.close() |
|
|
|
|
| def drop_db(): |
| BaseModel.metadata.drop_all(get_engine(dsn=SU_DSN)) |
|
|
|
|
| def create_vector_index(): |
| |
| |
| |
| if PGVECTOR_ADD_INDEX is True: |
| session = Session(get_engine(dsn=SU_DSN)) |
| for strategy in DISTANCE_STRATEGIES: |
| session.execute(strategy[3]) |
| session.commit() |
|
|
|
|
| def enable_vector(): |
| session = Session(get_engine(dsn=SU_DSN)) |
| query = "CREATE EXTENSION IF NOT EXISTS vector;" |
| session.execute(query) |
| session.commit() |
| add_vector_distance_fn(session) |
| session.close() |
|
|
|
|
| def add_vector_distance_fn(session: Session): |
| for strategy in DISTANCE_STRATEGIES: |
| strategy_name = strategy[1] |
| strategy_distance_str = strategy[2] |
|
|
| query = f"""create or replace function match_node_{strategy_name} ( |
| query_embeddings vector({VECTOR_EMBEDDINGS_COUNT}), |
| match_threshold float, |
| match_count int |
| ) returns table ( |
| uuid uuid, |
| text varchar, |
| similarity float |
| ) |
| language plpgsql |
| as $$ |
| begin |
| return query |
| select |
| node.uuid, |
| node.text, |
| 1 - (node.embeddings {strategy_distance_str} query_embeddings) as similarity |
| from node |
| where 1 - (node.embeddings {strategy_distance_str} query_embeddings) > match_threshold |
| order by similarity desc |
| limit match_count; |
| end; |
| $$;""" |
|
|
| session.execute(query) |
| session.commit() |
| session.close() |
|
|
|
|
| if __name__ == "__main__": |
| create_db() |
|
|