Spaces:
Runtime error
Runtime error
File size: 2,082 Bytes
129cd69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | from typing import Optional, Tuple
import sqlalchemy
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import Session, relationship
from langchain.vectorstores.pgvector import BaseModel
class CollectionStore(BaseModel):
"""Collection store."""
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel):
"""Embedding store."""
__tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(None))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|