Spaces:
Paused
Paused
| import base64 | |
| import enum | |
| import hashlib | |
| import hmac | |
| import json | |
| import logging | |
| import os | |
| import pickle | |
| import re | |
| import time | |
| from json import JSONDecodeError | |
| from sqlalchemy import func | |
| from sqlalchemy.dialects.postgresql import JSONB | |
| from configs import dify_config | |
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |
| from extensions.ext_database import db | |
| from extensions.ext_storage import storage | |
| from .account import Account | |
| from .model import App, Tag, TagBinding, UploadFile | |
| from .types import StringUUID | |
| class DatasetPermissionEnum(str, enum.Enum): | |
| ONLY_ME = "only_me" | |
| ALL_TEAM = "all_team_members" | |
| PARTIAL_TEAM = "partial_members" | |
| class Dataset(db.Model): | |
| __tablename__ = "datasets" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="dataset_pkey"), | |
| db.Index("dataset_tenant_idx", "tenant_id"), | |
| db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), | |
| ) | |
| INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] | |
| PROVIDER_LIST = ["vendor", "external", None] | |
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |
| tenant_id = db.Column(StringUUID, nullable=False) | |
| name = db.Column(db.String(255), nullable=False) | |
| description = db.Column(db.Text, nullable=True) | |
| provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) | |
| permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) | |
| data_source_type = db.Column(db.String(255)) | |
| indexing_technique = db.Column(db.String(255), nullable=True) | |
| index_struct = db.Column(db.Text, nullable=True) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| updated_by = db.Column(StringUUID, nullable=True) | |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| embedding_model = db.Column(db.String(255), nullable=True) | |
| embedding_model_provider = db.Column(db.String(255), nullable=True) | |
| collection_binding_id = db.Column(StringUUID, nullable=True) | |
| retrieval_model = db.Column(JSONB, nullable=True) | |
| def dataset_keyword_table(self): | |
| dataset_keyword_table = ( | |
| db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() | |
| ) | |
| if dataset_keyword_table: | |
| return dataset_keyword_table | |
| return None | |
| def index_struct_dict(self): | |
| return json.loads(self.index_struct) if self.index_struct else None | |
| def external_retrieval_model(self): | |
| default_retrieval_model = { | |
| "top_k": 2, | |
| "score_threshold": 0.0, | |
| } | |
| return self.retrieval_model or default_retrieval_model | |
| def created_by_account(self): | |
| return db.session.get(Account, self.created_by) | |
| def latest_process_rule(self): | |
| return ( | |
| DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) | |
| .order_by(DatasetProcessRule.created_at.desc()) | |
| .first() | |
| ) | |
| def app_count(self): | |
| return ( | |
| db.session.query(func.count(AppDatasetJoin.id)) | |
| .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) | |
| .scalar() | |
| ) | |
| def document_count(self): | |
| return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() | |
| def available_document_count(self): | |
| return ( | |
| db.session.query(func.count(Document.id)) | |
| .filter( | |
| Document.dataset_id == self.id, | |
| Document.indexing_status == "completed", | |
| Document.enabled == True, | |
| Document.archived == False, | |
| ) | |
| .scalar() | |
| ) | |
| def available_segment_count(self): | |
| return ( | |
| db.session.query(func.count(DocumentSegment.id)) | |
| .filter( | |
| DocumentSegment.dataset_id == self.id, | |
| DocumentSegment.status == "completed", | |
| DocumentSegment.enabled == True, | |
| ) | |
| .scalar() | |
| ) | |
| def word_count(self): | |
| return ( | |
| Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) | |
| .filter(Document.dataset_id == self.id) | |
| .scalar() | |
| ) | |
| def doc_form(self): | |
| document = db.session.query(Document).filter(Document.dataset_id == self.id).first() | |
| if document: | |
| return document.doc_form | |
| return None | |
| def retrieval_model_dict(self): | |
| default_retrieval_model = { | |
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |
| "reranking_enable": False, | |
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |
| "top_k": 2, | |
| "score_threshold_enabled": False, | |
| } | |
| return self.retrieval_model or default_retrieval_model | |
| def tags(self): | |
| tags = ( | |
| db.session.query(Tag) | |
| .join(TagBinding, Tag.id == TagBinding.tag_id) | |
| .filter( | |
| TagBinding.target_id == self.id, | |
| TagBinding.tenant_id == self.tenant_id, | |
| Tag.tenant_id == self.tenant_id, | |
| Tag.type == "knowledge", | |
| ) | |
| .all() | |
| ) | |
| return tags or [] | |
| def external_knowledge_info(self): | |
| if self.provider != "external": | |
| return None | |
| external_knowledge_binding = ( | |
| db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() | |
| ) | |
| if not external_knowledge_binding: | |
| return None | |
| external_knowledge_api = ( | |
| db.session.query(ExternalKnowledgeApis) | |
| .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) | |
| .first() | |
| ) | |
| if not external_knowledge_api: | |
| return None | |
| return { | |
| "external_knowledge_id": external_knowledge_binding.external_knowledge_id, | |
| "external_knowledge_api_id": external_knowledge_api.id, | |
| "external_knowledge_api_name": external_knowledge_api.name, | |
| "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | |
| } | |
| def gen_collection_name_by_id(dataset_id: str) -> str: | |
| normalized_dataset_id = dataset_id.replace("-", "_") | |
| return f"Vector_index_{normalized_dataset_id}_Node" | |
| class DatasetProcessRule(db.Model): | |
| __tablename__ = "dataset_process_rules" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), | |
| db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), | |
| ) | |
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) | |
| rules = db.Column(db.Text, nullable=True) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| MODES = ["automatic", "custom"] | |
| PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] | |
| AUTOMATIC_RULES = { | |
| "pre_processing_rules": [ | |
| {"id": "remove_extra_spaces", "enabled": True}, | |
| {"id": "remove_urls_emails", "enabled": False}, | |
| ], | |
| "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | |
| } | |
| def to_dict(self): | |
| return { | |
| "id": self.id, | |
| "dataset_id": self.dataset_id, | |
| "mode": self.mode, | |
| "rules": self.rules_dict, | |
| "created_by": self.created_by, | |
| "created_at": self.created_at, | |
| } | |
| def rules_dict(self): | |
| try: | |
| return json.loads(self.rules) if self.rules else None | |
| except JSONDecodeError: | |
| return None | |
| class Document(db.Model): | |
| __tablename__ = "documents" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="document_pkey"), | |
| db.Index("document_dataset_id_idx", "dataset_id"), | |
| db.Index("document_is_paused_idx", "is_paused"), | |
| db.Index("document_tenant_idx", "tenant_id"), | |
| ) | |
| # initial fields | |
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| tenant_id = db.Column(StringUUID, nullable=False) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| position = db.Column(db.Integer, nullable=False) | |
| data_source_type = db.Column(db.String(255), nullable=False) | |
| data_source_info = db.Column(db.Text, nullable=True) | |
| dataset_process_rule_id = db.Column(StringUUID, nullable=True) | |
| batch = db.Column(db.String(255), nullable=False) | |
| name = db.Column(db.String(255), nullable=False) | |
| created_from = db.Column(db.String(255), nullable=False) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_api_request_id = db.Column(StringUUID, nullable=True) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| # start processing | |
| processing_started_at = db.Column(db.DateTime, nullable=True) | |
| # parsing | |
| file_id = db.Column(db.Text, nullable=True) | |
| word_count = db.Column(db.Integer, nullable=True) | |
| parsing_completed_at = db.Column(db.DateTime, nullable=True) | |
| # cleaning | |
| cleaning_completed_at = db.Column(db.DateTime, nullable=True) | |
| # split | |
| splitting_completed_at = db.Column(db.DateTime, nullable=True) | |
| # indexing | |
| tokens = db.Column(db.Integer, nullable=True) | |
| indexing_latency = db.Column(db.Float, nullable=True) | |
| completed_at = db.Column(db.DateTime, nullable=True) | |
| # pause | |
| is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) | |
| paused_by = db.Column(StringUUID, nullable=True) | |
| paused_at = db.Column(db.DateTime, nullable=True) | |
| # error | |
| error = db.Column(db.Text, nullable=True) | |
| stopped_at = db.Column(db.DateTime, nullable=True) | |
| # basic fields | |
| indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) | |
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |
| disabled_at = db.Column(db.DateTime, nullable=True) | |
| disabled_by = db.Column(StringUUID, nullable=True) | |
| archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |
| archived_reason = db.Column(db.String(255), nullable=True) | |
| archived_by = db.Column(StringUUID, nullable=True) | |
| archived_at = db.Column(db.DateTime, nullable=True) | |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| doc_type = db.Column(db.String(40), nullable=True) | |
| doc_metadata = db.Column(db.JSON, nullable=True) | |
| doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) | |
| doc_language = db.Column(db.String(255), nullable=True) | |
| DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] | |
| def display_status(self): | |
| status = None | |
| if self.indexing_status == "waiting": | |
| status = "queuing" | |
| elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: | |
| status = "paused" | |
| elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: | |
| status = "indexing" | |
| elif self.indexing_status == "error": | |
| status = "error" | |
| elif self.indexing_status == "completed" and not self.archived and self.enabled: | |
| status = "available" | |
| elif self.indexing_status == "completed" and not self.archived and not self.enabled: | |
| status = "disabled" | |
| elif self.indexing_status == "completed" and self.archived: | |
| status = "archived" | |
| return status | |
| def data_source_info_dict(self): | |
| if self.data_source_info: | |
| try: | |
| data_source_info_dict = json.loads(self.data_source_info) | |
| except JSONDecodeError: | |
| data_source_info_dict = {} | |
| return data_source_info_dict | |
| return None | |
| def data_source_detail_dict(self): | |
| if self.data_source_info: | |
| if self.data_source_type == "upload_file": | |
| data_source_info_dict = json.loads(self.data_source_info) | |
| file_detail = ( | |
| db.session.query(UploadFile) | |
| .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) | |
| .one_or_none() | |
| ) | |
| if file_detail: | |
| return { | |
| "upload_file": { | |
| "id": file_detail.id, | |
| "name": file_detail.name, | |
| "size": file_detail.size, | |
| "extension": file_detail.extension, | |
| "mime_type": file_detail.mime_type, | |
| "created_by": file_detail.created_by, | |
| "created_at": file_detail.created_at.timestamp(), | |
| } | |
| } | |
| elif self.data_source_type in {"notion_import", "website_crawl"}: | |
| return json.loads(self.data_source_info) | |
| return {} | |
| def average_segment_length(self): | |
| if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0: | |
| return self.word_count // self.segment_count | |
| return 0 | |
| def dataset_process_rule(self): | |
| if self.dataset_process_rule_id: | |
| return db.session.get(DatasetProcessRule, self.dataset_process_rule_id) | |
| return None | |
| def dataset(self): | |
| return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() | |
| def segment_count(self): | |
| return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() | |
| def hit_count(self): | |
| return ( | |
| DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) | |
| .filter(DocumentSegment.document_id == self.id) | |
| .scalar() | |
| ) | |
| def to_dict(self): | |
| return { | |
| "id": self.id, | |
| "tenant_id": self.tenant_id, | |
| "dataset_id": self.dataset_id, | |
| "position": self.position, | |
| "data_source_type": self.data_source_type, | |
| "data_source_info": self.data_source_info, | |
| "dataset_process_rule_id": self.dataset_process_rule_id, | |
| "batch": self.batch, | |
| "name": self.name, | |
| "created_from": self.created_from, | |
| "created_by": self.created_by, | |
| "created_api_request_id": self.created_api_request_id, | |
| "created_at": self.created_at, | |
| "processing_started_at": self.processing_started_at, | |
| "file_id": self.file_id, | |
| "word_count": self.word_count, | |
| "parsing_completed_at": self.parsing_completed_at, | |
| "cleaning_completed_at": self.cleaning_completed_at, | |
| "splitting_completed_at": self.splitting_completed_at, | |
| "tokens": self.tokens, | |
| "indexing_latency": self.indexing_latency, | |
| "completed_at": self.completed_at, | |
| "is_paused": self.is_paused, | |
| "paused_by": self.paused_by, | |
| "paused_at": self.paused_at, | |
| "error": self.error, | |
| "stopped_at": self.stopped_at, | |
| "indexing_status": self.indexing_status, | |
| "enabled": self.enabled, | |
| "disabled_at": self.disabled_at, | |
| "disabled_by": self.disabled_by, | |
| "archived": self.archived, | |
| "archived_reason": self.archived_reason, | |
| "archived_by": self.archived_by, | |
| "archived_at": self.archived_at, | |
| "updated_at": self.updated_at, | |
| "doc_type": self.doc_type, | |
| "doc_metadata": self.doc_metadata, | |
| "doc_form": self.doc_form, | |
| "doc_language": self.doc_language, | |
| "display_status": self.display_status, | |
| "data_source_info_dict": self.data_source_info_dict, | |
| "average_segment_length": self.average_segment_length, | |
| "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | |
| "dataset": self.dataset.to_dict() if self.dataset else None, | |
| "segment_count": self.segment_count, | |
| "hit_count": self.hit_count, | |
| } | |
| def from_dict(cls, data: dict): | |
| return cls( | |
| id=data.get("id"), | |
| tenant_id=data.get("tenant_id"), | |
| dataset_id=data.get("dataset_id"), | |
| position=data.get("position"), | |
| data_source_type=data.get("data_source_type"), | |
| data_source_info=data.get("data_source_info"), | |
| dataset_process_rule_id=data.get("dataset_process_rule_id"), | |
| batch=data.get("batch"), | |
| name=data.get("name"), | |
| created_from=data.get("created_from"), | |
| created_by=data.get("created_by"), | |
| created_api_request_id=data.get("created_api_request_id"), | |
| created_at=data.get("created_at"), | |
| processing_started_at=data.get("processing_started_at"), | |
| file_id=data.get("file_id"), | |
| word_count=data.get("word_count"), | |
| parsing_completed_at=data.get("parsing_completed_at"), | |
| cleaning_completed_at=data.get("cleaning_completed_at"), | |
| splitting_completed_at=data.get("splitting_completed_at"), | |
| tokens=data.get("tokens"), | |
| indexing_latency=data.get("indexing_latency"), | |
| completed_at=data.get("completed_at"), | |
| is_paused=data.get("is_paused"), | |
| paused_by=data.get("paused_by"), | |
| paused_at=data.get("paused_at"), | |
| error=data.get("error"), | |
| stopped_at=data.get("stopped_at"), | |
| indexing_status=data.get("indexing_status"), | |
| enabled=data.get("enabled"), | |
| disabled_at=data.get("disabled_at"), | |
| disabled_by=data.get("disabled_by"), | |
| archived=data.get("archived"), | |
| archived_reason=data.get("archived_reason"), | |
| archived_by=data.get("archived_by"), | |
| archived_at=data.get("archived_at"), | |
| updated_at=data.get("updated_at"), | |
| doc_type=data.get("doc_type"), | |
| doc_metadata=data.get("doc_metadata"), | |
| doc_form=data.get("doc_form"), | |
| doc_language=data.get("doc_language"), | |
| ) | |
| class DocumentSegment(db.Model): | |
| __tablename__ = "document_segments" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="document_segment_pkey"), | |
| db.Index("document_segment_dataset_id_idx", "dataset_id"), | |
| db.Index("document_segment_document_id_idx", "document_id"), | |
| db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), | |
| db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), | |
| db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), | |
| db.Index("document_segment_tenant_idx", "tenant_id"), | |
| ) | |
| # initial fields | |
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| tenant_id = db.Column(StringUUID, nullable=False) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| document_id = db.Column(StringUUID, nullable=False) | |
| position = db.Column(db.Integer, nullable=False) | |
| content = db.Column(db.Text, nullable=False) | |
| answer = db.Column(db.Text, nullable=True) | |
| word_count = db.Column(db.Integer, nullable=False) | |
| tokens = db.Column(db.Integer, nullable=False) | |
| # indexing fields | |
| keywords = db.Column(db.JSON, nullable=True) | |
| index_node_id = db.Column(db.String(255), nullable=True) | |
| index_node_hash = db.Column(db.String(255), nullable=True) | |
| # basic fields | |
| hit_count = db.Column(db.Integer, nullable=False, default=0) | |
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |
| disabled_at = db.Column(db.DateTime, nullable=True) | |
| disabled_by = db.Column(StringUUID, nullable=True) | |
| status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| updated_by = db.Column(StringUUID, nullable=True) | |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| indexing_at = db.Column(db.DateTime, nullable=True) | |
| completed_at = db.Column(db.DateTime, nullable=True) | |
| error = db.Column(db.Text, nullable=True) | |
| stopped_at = db.Column(db.DateTime, nullable=True) | |
| def dataset(self): | |
| return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() | |
| def document(self): | |
| return db.session.query(Document).filter(Document.id == self.document_id).first() | |
| def previous_segment(self): | |
| return ( | |
| db.session.query(DocumentSegment) | |
| .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) | |
| .first() | |
| ) | |
| def next_segment(self): | |
| return ( | |
| db.session.query(DocumentSegment) | |
| .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) | |
| .first() | |
| ) | |
| def get_sign_content(self): | |
| signed_urls = [] | |
| text = self.content | |
| # For data before v0.10.0 | |
| pattern = r"/files/([a-f0-9\-]+)/image-preview" | |
| matches = re.finditer(pattern, text) | |
| for match in matches: | |
| upload_file_id = match.group(1) | |
| nonce = os.urandom(16).hex() | |
| timestamp = str(int(time.time())) | |
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |
| secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |
| params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |
| signed_url = f"{match.group(0)}?{params}" | |
| signed_urls.append((match.start(), match.end(), signed_url)) | |
| # For data after v0.10.0 | |
| pattern = r"/files/([a-f0-9\-]+)/file-preview" | |
| matches = re.finditer(pattern, text) | |
| for match in matches: | |
| upload_file_id = match.group(1) | |
| nonce = os.urandom(16).hex() | |
| timestamp = str(int(time.time())) | |
| data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" | |
| secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |
| params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |
| signed_url = f"{match.group(0)}?{params}" | |
| signed_urls.append((match.start(), match.end(), signed_url)) | |
| # Reconstruct the text with signed URLs | |
| offset = 0 | |
| for start, end, signed_url in signed_urls: | |
| text = text[: start + offset] + signed_url + text[end + offset :] | |
| offset += len(signed_url) - (end - start) | |
| return text | |
| class AppDatasetJoin(db.Model): | |
| __tablename__ = "app_dataset_joins" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), | |
| db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| app_id = db.Column(StringUUID, nullable=False) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |
| def app(self): | |
| return db.session.get(App, self.app_id) | |
| class DatasetQuery(db.Model): | |
| __tablename__ = "dataset_queries" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), | |
| db.Index("dataset_query_dataset_id_idx", "dataset_id"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| content = db.Column(db.Text, nullable=False) | |
| source = db.Column(db.String(255), nullable=False) | |
| source_app_id = db.Column(StringUUID, nullable=True) | |
| created_by_role = db.Column(db.String, nullable=False) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |
| class DatasetKeywordTable(db.Model): | |
| __tablename__ = "dataset_keyword_tables" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), | |
| db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
| dataset_id = db.Column(StringUUID, nullable=False, unique=True) | |
| keyword_table = db.Column(db.Text, nullable=False) | |
| data_source_type = db.Column( | |
| db.String(255), nullable=False, server_default=db.text("'database'::character varying") | |
| ) | |
| def keyword_table_dict(self): | |
| class SetDecoder(json.JSONDecoder): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |
| def object_hook(self, dct): | |
| if isinstance(dct, dict): | |
| for keyword, node_idxs in dct.items(): | |
| if isinstance(node_idxs, list): | |
| dct[keyword] = set(node_idxs) | |
| return dct | |
| # get dataset | |
| dataset = Dataset.query.filter_by(id=self.dataset_id).first() | |
| if not dataset: | |
| return None | |
| if self.data_source_type == "database": | |
| return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |
| else: | |
| file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" | |
| try: | |
| keyword_table_text = storage.load_once(file_key) | |
| if keyword_table_text: | |
| return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) | |
| return None | |
| except Exception as e: | |
| logging.exception(str(e)) | |
| return None | |
| class Embedding(db.Model): | |
| __tablename__ = "embeddings" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="embedding_pkey"), | |
| db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), | |
| db.Index("created_at_idx", "created_at"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
| model_name = db.Column( | |
| db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") | |
| ) | |
| hash = db.Column(db.String(64), nullable=False) | |
| embedding = db.Column(db.LargeBinary, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) | |
| def set_embedding(self, embedding_data: list[float]): | |
| self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) | |
| def get_embedding(self) -> list[float]: | |
| return pickle.loads(self.embedding) | |
| class DatasetCollectionBinding(db.Model): | |
| __tablename__ = "dataset_collection_bindings" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), | |
| db.Index("provider_model_name_idx", "provider_name", "model_name"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
| provider_name = db.Column(db.String(40), nullable=False) | |
| model_name = db.Column(db.String(255), nullable=False) | |
| type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) | |
| collection_name = db.Column(db.String(64), nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| class TidbAuthBinding(db.Model): | |
| __tablename__ = "tidb_auth_bindings" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), | |
| db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), | |
| db.Index("tidb_auth_bindings_active_idx", "active"), | |
| db.Index("tidb_auth_bindings_created_at_idx", "created_at"), | |
| db.Index("tidb_auth_bindings_status_idx", "status"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
| tenant_id = db.Column(StringUUID, nullable=True) | |
| cluster_id = db.Column(db.String(255), nullable=False) | |
| cluster_name = db.Column(db.String(255), nullable=False) | |
| active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |
| status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) | |
| account = db.Column(db.String(255), nullable=False) | |
| password = db.Column(db.String(255), nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| class Whitelist(db.Model): | |
| __tablename__ = "whitelists" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="whitelists_pkey"), | |
| db.Index("whitelists_tenant_idx", "tenant_id"), | |
| ) | |
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
| tenant_id = db.Column(StringUUID, nullable=True) | |
| category = db.Column(db.String(255), nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| class DatasetPermission(db.Model): | |
| __tablename__ = "dataset_permissions" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), | |
| db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), | |
| db.Index("idx_dataset_permissions_account_id", "account_id"), | |
| db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), | |
| ) | |
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| account_id = db.Column(StringUUID, nullable=False) | |
| tenant_id = db.Column(StringUUID, nullable=False) | |
| has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| class ExternalKnowledgeApis(db.Model): | |
| __tablename__ = "external_knowledge_apis" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), | |
| db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), | |
| db.Index("external_knowledge_apis_name_idx", "name"), | |
| ) | |
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| name = db.Column(db.String(255), nullable=False) | |
| description = db.Column(db.String(255), nullable=False) | |
| tenant_id = db.Column(StringUUID, nullable=False) | |
| settings = db.Column(db.Text, nullable=True) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| updated_by = db.Column(StringUUID, nullable=True) | |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| def to_dict(self): | |
| return { | |
| "id": self.id, | |
| "tenant_id": self.tenant_id, | |
| "name": self.name, | |
| "description": self.description, | |
| "settings": self.settings_dict, | |
| "dataset_bindings": self.dataset_bindings, | |
| "created_by": self.created_by, | |
| "created_at": self.created_at.isoformat(), | |
| } | |
| def settings_dict(self): | |
| try: | |
| return json.loads(self.settings) if self.settings else None | |
| except JSONDecodeError: | |
| return None | |
| def dataset_bindings(self): | |
| external_knowledge_bindings = ( | |
| db.session.query(ExternalKnowledgeBindings) | |
| .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |
| .all() | |
| ) | |
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | |
| datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() | |
| dataset_bindings = [] | |
| for dataset in datasets: | |
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | |
| return dataset_bindings | |
| class ExternalKnowledgeBindings(db.Model): | |
| __tablename__ = "external_knowledge_bindings" | |
| __table_args__ = ( | |
| db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), | |
| db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), | |
| db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), | |
| db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), | |
| db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), | |
| ) | |
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
| tenant_id = db.Column(StringUUID, nullable=False) | |
| external_knowledge_api_id = db.Column(StringUUID, nullable=False) | |
| dataset_id = db.Column(StringUUID, nullable=False) | |
| external_knowledge_id = db.Column(db.Text, nullable=False) | |
| created_by = db.Column(StringUUID, nullable=False) | |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
| updated_by = db.Column(StringUUID, nullable=True) | |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |