jibsn commited on
Commit
e581c57
·
verified ·
1 Parent(s): 8ed3aaa

Update utils/r2_utils.py

Browse files
Files changed (1) hide show
  1. utils/r2_utils.py +197 -197
utils/r2_utils.py CHANGED
@@ -1,197 +1,197 @@
1
- import boto3
2
- import io
3
- import json
4
- import asyncio
5
- import pandas as pd
6
-
7
- from docx import Document
8
- from loguru import logger
9
- from entities.task import Task, task_factory
10
-
11
-
12
- BUCKET_NAME = "ai-scientist"
13
-
14
- r2_endpoint = "https://468d92a3c903c841bc2de3b413e45072.r2.cloudflarestorage.com/ai-scientist"
15
-
16
- TOKEN = "KhGGD1ZJI_YTlLaZ0nSMfBJSLnOhgYN6cwq1De7G"
17
- R2_ACCESS_KEY_ID = "b9bc4becece838742ae1dc161be92de3"
18
- R2_SECRET_ACCESS_KEY = "f68eb82bd1c00528f26c6ac9b57d737fe0e4729ac7c429030fbc22a17dc8f105"
19
-
20
- def get_client():
21
- return boto3.client(
22
- "s3",
23
- endpoint_url=r2_endpoint,
24
- aws_access_key_id=R2_ACCESS_KEY_ID,
25
- aws_secret_access_key=R2_SECRET_ACCESS_KEY,
26
- region_name="auto" # R2 需要设置为 auto
27
- )
28
-
29
-
30
- async def get_task_from_minio(
31
- uuid: str,
32
- customer_name: str,
33
- client=None
34
- ) -> Task:
35
- if client is None:
36
- client = get_client()
37
-
38
- response = await asyncio.to_thread(
39
- lambda: client.list_objects_v2(
40
- Bucket=BUCKET_NAME,
41
- Prefix=f"{customer_name}/"
42
- )
43
- )
44
-
45
- objects = response.get("Contents", [])
46
- if not objects:
47
- raise FileNotFoundError(f"No task found for customer {customer_name}")
48
-
49
- object_names = [obj["Key"].split("/")[1] for obj in objects]
50
- if uuid not in object_names:
51
- raise FileNotFoundError(f"No task found for customer {customer_name} with uuid {uuid}")
52
-
53
- json_file = await get_file_from_minio(
54
- bucket_name=BUCKET_NAME,
55
- object_name=f"{customer_name}/{uuid}/task.json",
56
- client=client
57
- )
58
-
59
- json_data = json_file.decode("utf-8")
60
- json_data = json.loads(json_data)
61
- return task_factory[json_data["task_type"]].load_from_json(json_data)
62
-
63
-
64
- async def get_all_tasks_from_minio(
65
- customer_name: str,
66
- client=None
67
- ) -> list[Task]:
68
- if client is None:
69
- client = get_client()
70
-
71
- response = await asyncio.to_thread(
72
- lambda: client.list_objects_v2(
73
- Bucket=BUCKET_NAME,
74
- Prefix=f"{customer_name}/"
75
- )
76
- )
77
- objects = response.get("Contents", [])
78
- if not objects:
79
- return []
80
-
81
- task_ids = list(set([obj["Key"].split("/")[1] for obj in objects]))
82
- task_jsons = await asyncio.gather(
83
- *(get_task_from_minio(uuid=task_id, customer_name=customer_name, client=client) for task_id in task_ids)
84
- )
85
- return task_jsons
86
-
87
-
88
- async def upload_task_json_to_minio(task: Task, client=None) -> Task:
89
- if client is None:
90
- client = get_client()
91
-
92
- json_data = task.save_to_json()
93
- byte_data = io.BytesIO(json_data.encode("utf-8"))
94
-
95
- await asyncio.to_thread(
96
- lambda: client.put_object(
97
- Bucket=BUCKET_NAME,
98
- Key=f"{task.customer_name}/{task.uuid}/task.json",
99
- Body=byte_data,
100
- ContentType="application/json"
101
- )
102
- )
103
- return task
104
-
105
-
106
- async def upload_text_to_minio(
107
- bucket_name: str,
108
- object_name: str,
109
- file_content: str,
110
- client=None,
111
- ):
112
- if client is None:
113
- client = get_client()
114
-
115
- file_data = io.BytesIO(file_content.encode("utf-8"))
116
-
117
- await asyncio.to_thread(
118
- lambda: client.put_object(
119
- Bucket=bucket_name,
120
- Key=object_name,
121
- Body=file_data
122
- )
123
- )
124
-
125
-
126
- async def upload_dataframe_to_minio(
127
- bucket_name: str,
128
- object_name: str,
129
- df: pd.DataFrame,
130
- client=None,
131
- ):
132
- buffer = io.BytesIO()
133
- df.to_csv(buffer, index=False)
134
- await upload_text_to_minio(
135
- bucket_name=bucket_name,
136
- object_name=object_name,
137
- file_content=buffer.getvalue().decode("utf-8"),
138
- client=client
139
- )
140
-
141
-
142
- async def upload_document_to_minio(
143
- bucket_name: str,
144
- object_name: str,
145
- document: Document,
146
- client=None,
147
- ):
148
- if client is None:
149
- client = get_client()
150
-
151
- buffer = io.BytesIO()
152
- document.save(buffer)
153
- buffer.seek(0)
154
-
155
- await asyncio.to_thread(
156
- lambda: client.put_object(
157
- Bucket=bucket_name,
158
- Key=object_name,
159
- Body=buffer,
160
- ContentType="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
161
- )
162
- )
163
-
164
-
165
- async def get_file_from_minio(
166
- bucket_name: str,
167
- object_name: str,
168
- client=None,
169
- ):
170
- if client is None:
171
- client = get_client()
172
-
173
- try:
174
- response = await asyncio.to_thread(
175
- lambda: client.get_object(Bucket=bucket_name, Key=object_name)
176
- )
177
- return response["Body"].read()
178
- except Exception as e:
179
- raise Exception(f"Error getting file from minio: {e}")
180
-
181
-
182
- async def get_dataframe_from_minio(
183
- bucket_name: str,
184
- object_name: str,
185
- client=None,
186
- ):
187
- file_data = await get_file_from_minio(
188
- bucket_name=bucket_name,
189
- object_name=object_name,
190
- client=client
191
- )
192
-
193
- if object_name.endswith(".csv"):
194
- df = pd.read_csv(io.BytesIO(file_data))
195
- elif object_name.endswith(".xlsx") or object_name.endswith(".xls"):
196
- df = pd.read_excel(io.BytesIO(file_data))
197
- return df
 
1
+ import boto3
2
+ import io
3
+ import json
4
+ import asyncio
5
+ import pandas as pd
6
+
7
+ from docx import Document
8
+ from loguru import logger
9
+ from entities.task import Task, task_factory
10
+
11
+
12
+ BUCKET_NAME = "ai-scientist"
13
+
14
+ r2_endpoint = "https://468d92a3c903c841bc2de3b413e45072.r2.cloudflarestorage.com/"
15
+
16
+ TOKEN = "KhGGD1ZJI_YTlLaZ0nSMfBJSLnOhgYN6cwq1De7G"
17
+ R2_ACCESS_KEY_ID = "b9bc4becece838742ae1dc161be92de3"
18
+ R2_SECRET_ACCESS_KEY = "f68eb82bd1c00528f26c6ac9b57d737fe0e4729ac7c429030fbc22a17dc8f105"
19
+
20
+ def get_client():
21
+ return boto3.client(
22
+ "s3",
23
+ endpoint_url=r2_endpoint,
24
+ aws_access_key_id=R2_ACCESS_KEY_ID,
25
+ aws_secret_access_key=R2_SECRET_ACCESS_KEY,
26
+ region_name="auto" # R2 需要设置为 auto
27
+ )
28
+
29
+
30
+ async def get_task_from_minio(
31
+ uuid: str,
32
+ customer_name: str,
33
+ client=None
34
+ ) -> Task:
35
+ if client is None:
36
+ client = get_client()
37
+
38
+ response = await asyncio.to_thread(
39
+ lambda: client.list_objects_v2(
40
+ Bucket=BUCKET_NAME,
41
+ Prefix=f"{customer_name}/"
42
+ )
43
+ )
44
+
45
+ objects = response.get("Contents", [])
46
+ if not objects:
47
+ raise FileNotFoundError(f"No task found for customer {customer_name}")
48
+
49
+ object_names = [obj["Key"].split("/")[1] for obj in objects]
50
+ if uuid not in object_names:
51
+ raise FileNotFoundError(f"No task found for customer {customer_name} with uuid {uuid}")
52
+
53
+ json_file = await get_file_from_minio(
54
+ bucket_name=BUCKET_NAME,
55
+ object_name=f"{customer_name}/{uuid}/task.json",
56
+ client=client
57
+ )
58
+
59
+ json_data = json_file.decode("utf-8")
60
+ json_data = json.loads(json_data)
61
+ return task_factory[json_data["task_type"]].load_from_json(json_data)
62
+
63
+
64
+ async def get_all_tasks_from_minio(
65
+ customer_name: str,
66
+ client=None
67
+ ) -> list[Task]:
68
+ if client is None:
69
+ client = get_client()
70
+
71
+ response = await asyncio.to_thread(
72
+ lambda: client.list_objects_v2(
73
+ Bucket=BUCKET_NAME,
74
+ Prefix=f"{customer_name}/"
75
+ )
76
+ )
77
+ objects = response.get("Contents", [])
78
+ if not objects:
79
+ return []
80
+
81
+ task_ids = list(set([obj["Key"].split("/")[1] for obj in objects]))
82
+ task_jsons = await asyncio.gather(
83
+ *(get_task_from_minio(uuid=task_id, customer_name=customer_name, client=client) for task_id in task_ids)
84
+ )
85
+ return task_jsons
86
+
87
+
88
+ async def upload_task_json_to_minio(task: Task, client=None) -> Task:
89
+ if client is None:
90
+ client = get_client()
91
+
92
+ json_data = task.save_to_json()
93
+ byte_data = io.BytesIO(json_data.encode("utf-8"))
94
+
95
+ await asyncio.to_thread(
96
+ lambda: client.put_object(
97
+ Bucket=BUCKET_NAME,
98
+ Key=f"{task.customer_name}/{task.uuid}/task.json",
99
+ Body=byte_data,
100
+ ContentType="application/json"
101
+ )
102
+ )
103
+ return task
104
+
105
+
106
+ async def upload_text_to_minio(
107
+ bucket_name: str,
108
+ object_name: str,
109
+ file_content: str,
110
+ client=None,
111
+ ):
112
+ if client is None:
113
+ client = get_client()
114
+
115
+ file_data = io.BytesIO(file_content.encode("utf-8"))
116
+
117
+ await asyncio.to_thread(
118
+ lambda: client.put_object(
119
+ Bucket=bucket_name,
120
+ Key=object_name,
121
+ Body=file_data
122
+ )
123
+ )
124
+
125
+
126
+ async def upload_dataframe_to_minio(
127
+ bucket_name: str,
128
+ object_name: str,
129
+ df: pd.DataFrame,
130
+ client=None,
131
+ ):
132
+ buffer = io.BytesIO()
133
+ df.to_csv(buffer, index=False)
134
+ await upload_text_to_minio(
135
+ bucket_name=bucket_name,
136
+ object_name=object_name,
137
+ file_content=buffer.getvalue().decode("utf-8"),
138
+ client=client
139
+ )
140
+
141
+
142
+ async def upload_document_to_minio(
143
+ bucket_name: str,
144
+ object_name: str,
145
+ document: Document,
146
+ client=None,
147
+ ):
148
+ if client is None:
149
+ client = get_client()
150
+
151
+ buffer = io.BytesIO()
152
+ document.save(buffer)
153
+ buffer.seek(0)
154
+
155
+ await asyncio.to_thread(
156
+ lambda: client.put_object(
157
+ Bucket=bucket_name,
158
+ Key=object_name,
159
+ Body=buffer,
160
+ ContentType="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
161
+ )
162
+ )
163
+
164
+
165
+ async def get_file_from_minio(
166
+ bucket_name: str,
167
+ object_name: str,
168
+ client=None,
169
+ ):
170
+ if client is None:
171
+ client = get_client()
172
+
173
+ try:
174
+ response = await asyncio.to_thread(
175
+ lambda: client.get_object(Bucket=bucket_name, Key=object_name)
176
+ )
177
+ return response["Body"].read()
178
+ except Exception as e:
179
+ raise Exception(f"Error getting file from minio: {e}")
180
+
181
+
182
+ async def get_dataframe_from_minio(
183
+ bucket_name: str,
184
+ object_name: str,
185
+ client=None,
186
+ ):
187
+ file_data = await get_file_from_minio(
188
+ bucket_name=bucket_name,
189
+ object_name=object_name,
190
+ client=client
191
+ )
192
+
193
+ if object_name.endswith(".csv"):
194
+ df = pd.read_csv(io.BytesIO(file_data))
195
+ elif object_name.endswith(".xlsx") or object_name.endswith(".xls"):
196
+ df = pd.read_excel(io.BytesIO(file_data))
197
+ return df