cecilia-uu
commited on
Commit
·
090b2e7
1
Parent(s):
2ef3748
create list_dataset api and tests (#1138)
Browse files### What problem does this PR solve?
This PR have completed both HTTP API and Python SDK for 'list_dataset".
In addition, there are tests for it.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/dataset_api.py +20 -6
- api/db/services/knowledgebase_service.py +23 -0
- sdk/python/ragflow/__init__.py +2 -0
- sdk/python/ragflow/dataset.py +1 -1
- sdk/python/ragflow/ragflow.py +36 -10
- sdk/python/test/common.py +1 -1
- sdk/python/test/test_dataset.py +87 -7
api/apps/dataset_api.py
CHANGED
|
@@ -46,7 +46,7 @@ from api.contants import NAME_LENGTH_LIMIT
|
|
| 46 |
|
| 47 |
# ------------------------------ create a dataset ---------------------------------------
|
| 48 |
@manager.route('/', methods=['POST'])
|
| 49 |
-
@login_required
|
| 50 |
@validate_request("name") # check name key
|
| 51 |
def create_dataset():
|
| 52 |
# Check if Authorization header is present
|
|
@@ -111,10 +111,27 @@ def create_dataset():
|
|
| 111 |
if not KnowledgebaseService.save(**request_body):
|
| 112 |
# failed to create new dataset
|
| 113 |
return construct_result()
|
| 114 |
-
return construct_json_result(data={"
|
| 115 |
except Exception as e:
|
| 116 |
return construct_error_response(e)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
@manager.route('/<dataset_id>', methods=['DELETE'])
|
| 120 |
@login_required
|
|
@@ -135,8 +152,5 @@ def get_dataset(dataset_id):
|
|
| 135 |
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
|
| 136 |
|
| 137 |
|
| 138 |
-
|
| 139 |
-
@login_required
|
| 140 |
-
def list_datasets():
|
| 141 |
-
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
|
| 142 |
|
|
|
|
| 46 |
|
| 47 |
# ------------------------------ create a dataset ---------------------------------------
|
| 48 |
@manager.route('/', methods=['POST'])
|
| 49 |
+
@login_required # use login
|
| 50 |
@validate_request("name") # check name key
|
| 51 |
def create_dataset():
|
| 52 |
# Check if Authorization header is present
|
|
|
|
| 111 |
if not KnowledgebaseService.save(**request_body):
|
| 112 |
# failed to create new dataset
|
| 113 |
return construct_result()
|
| 114 |
+
return construct_json_result(data={"dataset_name": request_body["name"]})
|
| 115 |
except Exception as e:
|
| 116 |
return construct_error_response(e)
|
| 117 |
|
| 118 |
+
# -----------------------------list datasets-------------------------------------------------------
|
| 119 |
+
@manager.route('/', methods=['GET'])
|
| 120 |
+
@login_required
|
| 121 |
+
def list_datasets():
|
| 122 |
+
offset = request.args.get("offset", 0)
|
| 123 |
+
count = request.args.get("count", -1)
|
| 124 |
+
orderby = request.args.get("orderby", "create_time")
|
| 125 |
+
desc = request.args.get("desc", True)
|
| 126 |
+
try:
|
| 127 |
+
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
| 128 |
+
kbs = KnowledgebaseService.get_by_tenant_ids(
|
| 129 |
+
[m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc)
|
| 130 |
+
return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return construct_error_response(e)
|
| 133 |
+
|
| 134 |
+
# ---------------------------------delete a dataset ----------------------------
|
| 135 |
|
| 136 |
@manager.route('/<dataset_id>', methods=['DELETE'])
|
| 137 |
@login_required
|
|
|
|
| 152 |
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
|
| 153 |
|
| 154 |
|
| 155 |
+
|
|
|
|
|
|
|
|
|
|
| 156 |
|
api/db/services/knowledgebase_service.py
CHANGED
|
@@ -40,6 +40,29 @@ class KnowledgebaseService(CommonService):
|
|
| 40 |
|
| 41 |
return list(kbs.dicts())
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
@classmethod
|
| 44 |
@DB.connection_context()
|
| 45 |
def get_detail(cls, kb_id):
|
|
|
|
| 40 |
|
| 41 |
return list(kbs.dicts())
|
| 42 |
|
| 43 |
+
@classmethod
|
| 44 |
+
@DB.connection_context()
|
| 45 |
+
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
| 46 |
+
offset, count, orderby, desc):
|
| 47 |
+
kbs = cls.model.select().where(
|
| 48 |
+
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
| 49 |
+
TenantPermission.TEAM.value)) | (
|
| 50 |
+
cls.model.tenant_id == user_id))
|
| 51 |
+
& (cls.model.status == StatusEnum.VALID.value)
|
| 52 |
+
)
|
| 53 |
+
if desc:
|
| 54 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
| 55 |
+
else:
|
| 56 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
| 57 |
+
|
| 58 |
+
kbs = list(kbs.dicts())
|
| 59 |
+
|
| 60 |
+
kbs_length = len(kbs)
|
| 61 |
+
if offset < 0 or offset > kbs_length:
|
| 62 |
+
raise IndexError("Offset is out of the valid range.")
|
| 63 |
+
|
| 64 |
+
return kbs[offset:offset+count]
|
| 65 |
+
|
| 66 |
@classmethod
|
| 67 |
@DB.connection_context()
|
| 68 |
def get_detail(cls, kb_id):
|
sdk/python/ragflow/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
import importlib.metadata
|
| 2 |
|
| 3 |
__version__ = importlib.metadata.version("ragflow")
|
|
|
|
|
|
|
|
|
| 1 |
import importlib.metadata
|
| 2 |
|
| 3 |
__version__ = importlib.metadata.version("ragflow")
|
| 4 |
+
|
| 5 |
+
from .ragflow import RAGFlow
|
sdk/python/ragflow/dataset.py
CHANGED
|
@@ -18,4 +18,4 @@ class DataSet:
|
|
| 18 |
self.user_key = user_key
|
| 19 |
self.dataset_url = dataset_url
|
| 20 |
self.uuid = uuid
|
| 21 |
-
self.name = name
|
|
|
|
| 18 |
self.user_key = user_key
|
| 19 |
self.dataset_url = dataset_url
|
| 20 |
self.uuid = uuid
|
| 21 |
+
self.name = name
|
sdk/python/ragflow/ragflow.py
CHANGED
|
@@ -17,7 +17,10 @@ import os
|
|
| 17 |
import requests
|
| 18 |
import json
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
def __init__(self, user_key, base_url, version = 'v1'):
|
| 22 |
'''
|
| 23 |
api_url: http://<host_address>/api/v1
|
|
@@ -36,16 +39,39 @@ class RAGFLow:
|
|
| 36 |
result_dict = json.loads(res.text)
|
| 37 |
return result_dict
|
| 38 |
|
| 39 |
-
def delete_dataset(self, dataset_name
|
| 40 |
return dataset_name
|
| 41 |
|
| 42 |
-
def list_dataset(self):
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def get_dataset(self, dataset_id):
|
| 51 |
endpoint = f"{self.dataset_url}/{dataset_id}"
|
|
@@ -61,4 +87,4 @@ class RAGFLow:
|
|
| 61 |
if response.status_code == 200:
|
| 62 |
return True
|
| 63 |
else:
|
| 64 |
-
return False
|
|
|
|
| 17 |
import requests
|
| 18 |
import json
|
| 19 |
|
| 20 |
+
from httpx import HTTPError
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RAGFlow:
|
| 24 |
def __init__(self, user_key, base_url, version = 'v1'):
|
| 25 |
'''
|
| 26 |
api_url: http://<host_address>/api/v1
|
|
|
|
| 39 |
result_dict = json.loads(res.text)
|
| 40 |
return result_dict
|
| 41 |
|
| 42 |
+
def delete_dataset(self, dataset_name=None, dataset_id=None):
|
| 43 |
return dataset_name
|
| 44 |
|
| 45 |
+
def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True):
|
| 46 |
+
params = {
|
| 47 |
+
"offset": offset,
|
| 48 |
+
"count": count,
|
| 49 |
+
"orderby": orderby,
|
| 50 |
+
"desc": desc
|
| 51 |
+
}
|
| 52 |
+
try:
|
| 53 |
+
response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header)
|
| 54 |
+
response.raise_for_status() # if it is not 200
|
| 55 |
+
original_data = response.json()
|
| 56 |
+
# TODO: format the data
|
| 57 |
+
# print(original_data)
|
| 58 |
+
# # Process the original data into the desired format
|
| 59 |
+
# formatted_data = {
|
| 60 |
+
# "datasets": [
|
| 61 |
+
# {
|
| 62 |
+
# "id": dataset["id"],
|
| 63 |
+
# "created": dataset["create_time"], # Adjust the key based on the actual response
|
| 64 |
+
# "fileCount": dataset["doc_num"], # Adjust the key based on the actual response
|
| 65 |
+
# "name": dataset["name"]
|
| 66 |
+
# }
|
| 67 |
+
# for dataset in original_data
|
| 68 |
+
# ]
|
| 69 |
+
# }
|
| 70 |
+
return response.status_code, original_data
|
| 71 |
+
except HTTPError as http_err:
|
| 72 |
+
print(f"HTTP error occurred: {http_err}")
|
| 73 |
+
except Exception as err:
|
| 74 |
+
print(f"An error occurred: {err}")
|
| 75 |
|
| 76 |
def get_dataset(self, dataset_id):
|
| 77 |
endpoint = f"{self.dataset_url}/{dataset_id}"
|
|
|
|
| 87 |
if response.status_code == 200:
|
| 88 |
return True
|
| 89 |
else:
|
| 90 |
+
return False
|
sdk/python/test/common.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
|
| 2 |
|
| 3 |
-
API_KEY = '
|
| 4 |
HOST_ADDRESS = 'http://127.0.0.1:9380'
|
|
|
|
| 1 |
|
| 2 |
|
| 3 |
+
API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E'
|
| 4 |
HOST_ADDRESS = 'http://127.0.0.1:9380'
|
sdk/python/test/test_dataset.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
from test_sdkbase import TestSdk
|
| 2 |
-
import
|
| 3 |
-
from ragflow.ragflow import RAGFLow
|
| 4 |
import pytest
|
| 5 |
-
from unittest.mock import MagicMock
|
| 6 |
from common import API_KEY, HOST_ADDRESS
|
| 7 |
|
|
|
|
|
|
|
| 8 |
class TestDataset(TestSdk):
|
| 9 |
|
| 10 |
def test_create_dataset(self):
|
|
@@ -15,12 +15,92 @@ class TestDataset(TestSdk):
|
|
| 15 |
4. update the kb
|
| 16 |
5. delete the kb
|
| 17 |
'''
|
| 18 |
-
ragflow = RAGFLow(API_KEY, HOST_ADDRESS)
|
| 19 |
|
|
|
|
| 20 |
# create a kb
|
| 21 |
res = ragflow.create_dataset("kb1")
|
| 22 |
assert res['code'] == 0 and res['message'] == 'success'
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
# TODO: list the kb
|
|
|
|
| 1 |
from test_sdkbase import TestSdk
|
| 2 |
+
from ragflow import RAGFlow
|
|
|
|
| 3 |
import pytest
|
|
|
|
| 4 |
from common import API_KEY, HOST_ADDRESS
|
| 5 |
|
| 6 |
+
|
| 7 |
+
|
| 8 |
class TestDataset(TestSdk):
|
| 9 |
|
| 10 |
def test_create_dataset(self):
|
|
|
|
| 15 |
4. update the kb
|
| 16 |
5. delete the kb
|
| 17 |
'''
|
|
|
|
| 18 |
|
| 19 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 20 |
# create a kb
|
| 21 |
res = ragflow.create_dataset("kb1")
|
| 22 |
assert res['code'] == 0 and res['message'] == 'success'
|
| 23 |
+
dataset_name = res['data']['dataset_name']
|
| 24 |
+
|
| 25 |
+
def test_list_dataset_success(self):
|
| 26 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 27 |
+
# Call the list_datasets method
|
| 28 |
+
response = ragflow.list_dataset()
|
| 29 |
+
|
| 30 |
+
code, datasets = response
|
| 31 |
+
|
| 32 |
+
assert code == 200
|
| 33 |
+
|
| 34 |
+
def test_list_dataset_with_checking_size_and_name(self):
|
| 35 |
+
datasets_to_create = ["dataset1", "dataset2", "dataset3"]
|
| 36 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 37 |
+
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
| 38 |
+
|
| 39 |
+
real_name_to_create = set()
|
| 40 |
+
for response in created_response:
|
| 41 |
+
assert 'data' in response, "Response is missing 'data' key"
|
| 42 |
+
dataset_name = response['data']['dataset_name']
|
| 43 |
+
real_name_to_create.add(dataset_name)
|
| 44 |
+
|
| 45 |
+
status_code, listed_data = ragflow.list_dataset(0, 3)
|
| 46 |
+
listed_data = listed_data['data']
|
| 47 |
+
|
| 48 |
+
listed_names = {d['name'] for d in listed_data}
|
| 49 |
+
assert listed_names == real_name_to_create
|
| 50 |
+
assert status_code == 200
|
| 51 |
+
assert len(listed_data) == len(datasets_to_create)
|
| 52 |
+
|
| 53 |
+
def test_list_dataset_with_getting_empty_result(self):
|
| 54 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 55 |
+
datasets_to_create = []
|
| 56 |
+
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
| 57 |
+
|
| 58 |
+
real_name_to_create = set()
|
| 59 |
+
for response in created_response:
|
| 60 |
+
assert 'data' in response, "Response is missing 'data' key"
|
| 61 |
+
dataset_name = response['data']['dataset_name']
|
| 62 |
+
real_name_to_create.add(dataset_name)
|
| 63 |
+
|
| 64 |
+
status_code, listed_data = ragflow.list_dataset(0, 0)
|
| 65 |
+
listed_data = listed_data['data']
|
| 66 |
+
|
| 67 |
+
listed_names = {d['name'] for d in listed_data}
|
| 68 |
+
assert listed_names == real_name_to_create
|
| 69 |
+
assert status_code == 200
|
| 70 |
+
assert len(listed_data) == 0
|
| 71 |
+
|
| 72 |
+
def test_list_dataset_with_creating_100_knowledge_bases(self):
|
| 73 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 74 |
+
datasets_to_create = ["dataset1"] * 100
|
| 75 |
+
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
| 76 |
+
|
| 77 |
+
real_name_to_create = set()
|
| 78 |
+
for response in created_response:
|
| 79 |
+
assert 'data' in response, "Response is missing 'data' key"
|
| 80 |
+
dataset_name = response['data']['dataset_name']
|
| 81 |
+
real_name_to_create.add(dataset_name)
|
| 82 |
+
|
| 83 |
+
status_code, listed_data = ragflow.list_dataset(0, 100)
|
| 84 |
+
listed_data = listed_data['data']
|
| 85 |
+
|
| 86 |
+
listed_names = {d['name'] for d in listed_data}
|
| 87 |
+
assert listed_names == real_name_to_create
|
| 88 |
+
assert status_code == 200
|
| 89 |
+
assert len(listed_data) == 100
|
| 90 |
+
|
| 91 |
+
def test_list_dataset_with_showing_one_dataset(self):
|
| 92 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 93 |
+
response = ragflow.list_dataset(0, 1)
|
| 94 |
+
code, response = response
|
| 95 |
+
datasets = response['data']
|
| 96 |
+
assert len(datasets) == 1
|
| 97 |
+
|
| 98 |
+
def test_list_dataset_failure(self):
|
| 99 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
| 100 |
+
response = ragflow.list_dataset(-1, -1)
|
| 101 |
+
_, res = response
|
| 102 |
+
assert "IndexError" in res['message']
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
|
|
|