Spaces:
Paused
Paused
| import datetime | |
| import json | |
| from flask import request | |
| from flask_login import current_user | |
| from flask_restful import Resource, marshal_with, reqparse | |
| from werkzeug.exceptions import NotFound | |
| from controllers.console import api | |
| from controllers.console.wraps import account_initialization_required, setup_required | |
| from core.indexing_runner import IndexingRunner | |
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |
| from core.rag.extractor.notion_extractor import NotionExtractor | |
| from extensions.ext_database import db | |
| from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields | |
| from libs.login import login_required | |
| from models import DataSourceOauthBinding, Document | |
| from services.dataset_service import DatasetService, DocumentService | |
| from tasks.document_indexing_sync_task import document_indexing_sync_task | |
| class DataSourceApi(Resource): | |
| def get(self): | |
| # get workspace data source integrates | |
| data_source_integrates = ( | |
| db.session.query(DataSourceOauthBinding) | |
| .filter( | |
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |
| DataSourceOauthBinding.disabled == False, | |
| ) | |
| .all() | |
| ) | |
| base_url = request.url_root.rstrip("/") | |
| data_source_oauth_base_path = "/console/api/oauth/data-source" | |
| providers = ["notion"] | |
| integrate_data = [] | |
| for provider in providers: | |
| # existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None) | |
| existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) | |
| if existing_integrates: | |
| for existing_integrate in list(existing_integrates): | |
| integrate_data.append( | |
| { | |
| "id": existing_integrate.id, | |
| "provider": provider, | |
| "created_at": existing_integrate.created_at, | |
| "is_bound": True, | |
| "disabled": existing_integrate.disabled, | |
| "source_info": existing_integrate.source_info, | |
| "link": f"{base_url}{data_source_oauth_base_path}/{provider}", | |
| } | |
| ) | |
| else: | |
| integrate_data.append( | |
| { | |
| "id": None, | |
| "provider": provider, | |
| "created_at": None, | |
| "source_info": None, | |
| "is_bound": False, | |
| "disabled": None, | |
| "link": f"{base_url}{data_source_oauth_base_path}/{provider}", | |
| } | |
| ) | |
| return {"data": integrate_data}, 200 | |
| def patch(self, binding_id, action): | |
| binding_id = str(binding_id) | |
| action = str(action) | |
| data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() | |
| if data_source_binding is None: | |
| raise NotFound("Data source binding not found.") | |
| # enable binding | |
| if action == "enable": | |
| if data_source_binding.disabled: | |
| data_source_binding.disabled = False | |
| data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
| db.session.add(data_source_binding) | |
| db.session.commit() | |
| else: | |
| raise ValueError("Data source is not disabled.") | |
| # disable binding | |
| if action == "disable": | |
| if not data_source_binding.disabled: | |
| data_source_binding.disabled = True | |
| data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
| db.session.add(data_source_binding) | |
| db.session.commit() | |
| else: | |
| raise ValueError("Data source is disabled.") | |
| return {"result": "success"}, 200 | |
| class DataSourceNotionListApi(Resource): | |
| def get(self): | |
| dataset_id = request.args.get("dataset_id", default=None, type=str) | |
| exist_page_ids = [] | |
| # import notion in the exist dataset | |
| if dataset_id: | |
| dataset = DatasetService.get_dataset(dataset_id) | |
| if not dataset: | |
| raise NotFound("Dataset not found.") | |
| if dataset.data_source_type != "notion_import": | |
| raise ValueError("Dataset is not notion type.") | |
| documents = Document.query.filter_by( | |
| dataset_id=dataset_id, | |
| tenant_id=current_user.current_tenant_id, | |
| data_source_type="notion_import", | |
| enabled=True, | |
| ).all() | |
| if documents: | |
| for document in documents: | |
| data_source_info = json.loads(document.data_source_info) | |
| exist_page_ids.append(data_source_info["notion_page_id"]) | |
| # get all authorized pages | |
| data_source_bindings = DataSourceOauthBinding.query.filter_by( | |
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |
| ).all() | |
| if not data_source_bindings: | |
| return {"notion_info": []}, 200 | |
| pre_import_info_list = [] | |
| for data_source_binding in data_source_bindings: | |
| source_info = data_source_binding.source_info | |
| pages = source_info["pages"] | |
| # Filter out already bound pages | |
| for page in pages: | |
| if page["page_id"] in exist_page_ids: | |
| page["is_bound"] = True | |
| else: | |
| page["is_bound"] = False | |
| pre_import_info = { | |
| "workspace_name": source_info["workspace_name"], | |
| "workspace_icon": source_info["workspace_icon"], | |
| "workspace_id": source_info["workspace_id"], | |
| "pages": pages, | |
| } | |
| pre_import_info_list.append(pre_import_info) | |
| return {"notion_info": pre_import_info_list}, 200 | |
| class DataSourceNotionApi(Resource): | |
| def get(self, workspace_id, page_id, page_type): | |
| workspace_id = str(workspace_id) | |
| page_id = str(page_id) | |
| data_source_binding = DataSourceOauthBinding.query.filter( | |
| db.and_( | |
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |
| DataSourceOauthBinding.provider == "notion", | |
| DataSourceOauthBinding.disabled == False, | |
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |
| ) | |
| ).first() | |
| if not data_source_binding: | |
| raise NotFound("Data source binding not found.") | |
| extractor = NotionExtractor( | |
| notion_workspace_id=workspace_id, | |
| notion_obj_id=page_id, | |
| notion_page_type=page_type, | |
| notion_access_token=data_source_binding.access_token, | |
| tenant_id=current_user.current_tenant_id, | |
| ) | |
| text_docs = extractor.extract() | |
| return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") | |
| parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") | |
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |
| parser.add_argument( | |
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |
| ) | |
| args = parser.parse_args() | |
| # validate args | |
| DocumentService.estimate_args_validate(args) | |
| notion_info_list = args["notion_info_list"] | |
| extract_settings = [] | |
| for notion_info in notion_info_list: | |
| workspace_id = notion_info["workspace_id"] | |
| for page in notion_info["pages"]: | |
| extract_setting = ExtractSetting( | |
| datasource_type="notion_import", | |
| notion_info={ | |
| "notion_workspace_id": workspace_id, | |
| "notion_obj_id": page["page_id"], | |
| "notion_page_type": page["type"], | |
| "tenant_id": current_user.current_tenant_id, | |
| }, | |
| document_model=args["doc_form"], | |
| ) | |
| extract_settings.append(extract_setting) | |
| indexing_runner = IndexingRunner() | |
| response = indexing_runner.indexing_estimate( | |
| current_user.current_tenant_id, | |
| extract_settings, | |
| args["process_rule"], | |
| args["doc_form"], | |
| args["doc_language"], | |
| ) | |
| return response, 200 | |
| class DataSourceNotionDatasetSyncApi(Resource): | |
| def get(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| documents = DocumentService.get_document_by_dataset_id(dataset_id_str) | |
| for document in documents: | |
| document_indexing_sync_task.delay(dataset_id_str, document.id) | |
| return 200 | |
| class DataSourceNotionDocumentSyncApi(Resource): | |
| def get(self, dataset_id, document_id): | |
| dataset_id_str = str(dataset_id) | |
| document_id_str = str(document_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| document = DocumentService.get_document(dataset_id_str, document_id_str) | |
| if document is None: | |
| raise NotFound("Document not found.") | |
| document_indexing_sync_task.delay(dataset_id_str, document_id_str) | |
| return 200 | |
| api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>") | |
| api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") | |
| api.add_resource( | |
| DataSourceNotionApi, | |
| "/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview", | |
| "/datasets/notion-indexing-estimate", | |
| ) | |
| api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync") | |
| api.add_resource( | |
| DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync" | |
| ) | |