jibsn commited on
Commit
163107e
·
verified ·
1 Parent(s): d522381

Upload 20 files

Browse files
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (155 Bytes). View file
 
utils/__pycache__/api_utils.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
utils/__pycache__/arxiv_utils.cpython-311.pyc ADDED
Binary file (27.9 kB). View file
 
utils/__pycache__/common_utils.cpython-311.pyc ADDED
Binary file (2.2 kB). View file
 
utils/__pycache__/minio_utils.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
utils/__pycache__/paper_plus_utils.cpython-311.pyc ADDED
Binary file (60.7 kB). View file
 
utils/__pycache__/paper_utils.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
utils/__pycache__/pubmed_plus_utils.cpython-311.pyc ADDED
Binary file (22.3 kB). View file
 
utils/__pycache__/pubmed_utils.cpython-311.pyc ADDED
Binary file (42.7 kB). View file
 
utils/__pycache__/r2_utils.cpython-311.pyc ADDED
Binary file (9.77 kB). View file
 
utils/api_utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from loguru import logger
4
+ from openai import OpenAI
5
+ from functools import partial
6
+ from typing import Callable
7
+
8
+ CLIENTS = {
9
+ "glm-4-plus": {
10
+ "api_key": "3a3c9f497e34a0514da974a4ccb886e.urkW20Nz3aklp3Mk",
11
+ "base_url": "https://open.bigmodel.cn/api/paas/v4",
12
+ },
13
+ "glm-4": {
14
+ "api_key": "3a3c9f497e34a0514da974a4ccb886e.urkW20Nz3aklp3Mk",
15
+ "base_url": "https://open.bigmodel.cn/api/paas/v4",
16
+ },
17
+ "glm-4-airx": {
18
+ "api_key": "3a3c9f497e34a0514da974a4ccb886e.urkW20Nz3aklp3Mk",
19
+ "base_url": "https://open.bigmodel.cn/api/paas/v4",
20
+ },
21
+ "glm-4-flash": {
22
+ "api_key": "4541d6f6421cb131eba8c8390d956237.V1W9TkfzupwCQmOU",
23
+ "base_url": "https://open.bigmodel.cn/api/paas/v4",
24
+ },
25
+ "gpt-4o-mini": {
26
+ "api_key": "sk-RqmH8qL4MUxDlvJtE6045a9931474629B11015Df08D3C915",
27
+ "base_url": "https://api.qqslyx.com/v1",
28
+ },
29
+ "deepseek-chat": {
30
+ "api_key": "sk-253d4686221e41618f239b064ada3d21",
31
+ "base_url": "https://api.deepseek.com/v1"
32
+ },
33
+ "deepseek/deepseek-chat-v3-0324:free": {
34
+ "api_key": "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",
35
+ "base_url": "https://openrouter.ai/api/v1"
36
+ }
37
+ }
38
+
39
+
40
+ def get_chat_func(model_names: list[str]):
41
+ """
42
+ Get a list of chat functions for the specified model names.
43
+
44
+ Args:
45
+ model_names (list[str]): A list of model names.
46
+
47
+ Returns:
48
+ list[Callable]: A list of chat functions.
49
+ """
50
+ chat_funcs = []
51
+ for model_name in model_names:
52
+ if model_name not in list(CLIENTS.keys()):
53
+ continue
54
+ chat_funcs.append(partial(chat_completion, model_name=model_name))
55
+ return chat_funcs
56
+
57
+
58
+ async def chat_completion(prompt: str, model_name: str) -> str:
59
+ """
60
+ Perform a chat completion using the specified model.
61
+
62
+ Args:
63
+ prompt (str): The prompt to send to the model.
64
+ model_name (str): The name of the model to use.
65
+ client (OpenAI, optional): The OpenAI client to use. Defaults to None.
66
+
67
+ Returns:
68
+ str: The response from the model.
69
+
70
+ """
71
+ assert model_name in list(CLIENTS.keys()), f"Model {model_name} not found"
72
+
73
+ API_KEY = CLIENTS[model_name]["api_key"]
74
+ BASE_URL = CLIENTS[model_name]["base_url"]
75
+
76
+ client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
77
+ completion = client.chat.completions.create(
78
+ model=model_name,
79
+ messages=[
80
+ {"role": "user", "content": prompt}
81
+ ]
82
+ )
83
+ return completion
84
+
85
+
86
+ async def retry_operation(func, task, max_retries=5, delay=0.5, *args, **kwargs):
87
+ """
88
+ Retry an operation asynchronously with exponential backoff.
89
+
90
+ Args:
91
+ func (Callable): The function to be retried.
92
+ task (Task): The task object to update the status.
93
+ max_retries (int, optional): The maximum number of retries. Defaults to 5.
94
+ delay (float, optional): The initial delay between retries. Defaults to 0.5.
95
+ *args: Additional positional arguments to pass to the function.
96
+ **kwargs: Additional keyword arguments to pass to the function.
97
+
98
+ Returns:
99
+ Any: The result of the operation.
100
+
101
+ """
102
+ retries = 0
103
+ exceptions = []
104
+ while retries < max_retries:
105
+ # return await func(*args, **kwargs)
106
+ try:
107
+ return await func(*args, **kwargs), None
108
+ except Exception as e:
109
+ exceptions.append(f"retry {retries}: {e}")
110
+ retries += 1
111
+ logger.error(e)
112
+ await asyncio.sleep(delay * retries)
113
+ continue
114
+ return None, "\n".join(exceptions)
115
+
116
+
117
+ async def chat_completion_multiple_models(
118
+ prompt: str,
119
+ model_names: list[str] = [],
120
+ chat_funcs: list[Callable] = []
121
+ ):
122
+ """
123
+ Perform a chat completion using multiple models asynchronously.
124
+
125
+ Args:
126
+ prompt (str): The prompt to send to the models.
127
+ model_names (list[str], optional): A list of model names. Defaults to [].
128
+ chat_funcs (list[Callable], optional): A list of chat functions. Defaults to [].
129
+
130
+ Returns:
131
+ list[Any]: A list of results from the chat completions.
132
+
133
+ """
134
+ if not chat_funcs or len(chat_funcs) == 0:
135
+ chat_funcs = get_chat_func(model_names)
136
+ return await asyncio.gather(
137
+ *(chat_func(prompt=prompt)
138
+ for chat_func in chat_funcs)
139
+ )
140
+
141
+
142
+ async def func_wrap_multiple_models(
143
+ wrap_func: Callable,
144
+ model_names: list[str] = [],
145
+ chat_funcs: list[Callable] = [],
146
+ model_weights: list[float] = [],
147
+ *args,
148
+ ):
149
+ """
150
+ Wrap a function to be executed asynchronously with multiple models.
151
+
152
+ Args:
153
+ func (Callable): The function to be wrapped.
154
+ model_names (list[str], optional): A list of model names. Defaults to [].
155
+ chat_funcs (list[Callable], optional): A list of chat functions. Defaults to [].
156
+ model_weights (list[float], optional): A list of model weights. Defaults to [].
157
+ *args: Additional positional arguments to pass to the function.
158
+
159
+ Returns:
160
+ list[Any]: A list of results from the function.
161
+
162
+ """
163
+ if not chat_funcs or len(chat_funcs) == 0:
164
+ chat_funcs = get_chat_func(model_names)
165
+ if not model_weights:
166
+ model_weights = [1.0 for _ in range(len(chat_funcs))]
167
+ assert len(chat_funcs) == len(model_weights), \
168
+ "model_weights must be same length as chat_funcs"
169
+
170
+ return await asyncio.gather(
171
+ *(wrap_func(*args, chat_func=chat_func)
172
+ for chat_func in chat_funcs)
173
+ )
174
+
175
+
176
+ async def compare_chat_chocies(
177
+ contents: list[str],
178
+ model_names: list[Callable] = [],
179
+ chat_funcs: list[Callable] = [],
180
+ model_weights: list[float] = []
181
+ ):
182
+ if not chat_funcs or len(chat_funcs) == 0:
183
+ chat_funcs = get_chat_func(model_names)
184
+ if not model_weights:
185
+ model_weights = [1.0 for _ in range(len(chat_funcs))]
186
+ assert len(chat_funcs) == len(model_weights), \
187
+ "model_weights must be same length as chat_funcs"
188
+
189
+ prompts = []
190
+ eval_chat_funcs = []
191
+ for i in range(len(contents)):
192
+ prompt = f"""
193
+ You are provided with {len(contents)-1} choices, and you are asked to rank them based on the quality and relevance.
194
+ Rank 1 is the best.
195
+ Just Output Index and Corresponding Rank in format Index:Rank.
196
+ Just Number, no text. For example: "0:1" is correct, "Index 0:1" and "0: 1" are wrong.
197
+ One Line for Each Rank.
198
+ Just output like "Index:Rank\nIndex:Rank\nIndex:Rank\n"
199
+ No other output is allowed.
200
+
201
+ """
202
+ for j, content in enumerate(contents):
203
+ if i == j: # skip self evaluation
204
+ continue
205
+ else:
206
+ prompt += f"""
207
+ Index {j}:
208
+ {content}
209
+ ----------
210
+
211
+ """
212
+ prompts.append(prompt)
213
+ eval_chat_funcs.append(chat_funcs[i])
214
+ compares = await asyncio.gather(
215
+ *(chat_func(prompt=prompt)
216
+ for prompt, chat_func in zip(prompts, eval_chat_funcs))
217
+ )
218
+
219
+ rank_scores = {i: 0 for i in range(len(contents))}
220
+ for i, comp in enumerate(compares):
221
+ for rank in comp.choices[0].message.content.strip().split("\n"):
222
+ index, rank = rank.split(":")
223
+ rank_scores[int(index)] += int(rank) * model_weights[i]
224
+ return rank_scores
225
+
utils/arxiv_utils.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import asyncio
3
+ import aiohttp
4
+
5
+ from minio import Minio
6
+ from loguru import logger
7
+ from lxml import etree
8
+
9
+ from utils.api_utils import (
10
+ retry_operation,
11
+ get_chat_func,
12
+ compare_chat_chocies
13
+ )
14
+ from utils.r2_utils import (
15
+ get_client,
16
+ get_file_from_minio,
17
+ get_dataframe_from_minio,
18
+ upload_text_to_minio,
19
+ upload_task_json_to_minio,
20
+ )
21
+ from utils.common_utils import escape_csv_field
22
+ from utils.paper_utils import (
23
+ process_papers,
24
+ generate_subheadings,
25
+ assign_subheadings_to_summaries,
26
+ create_paragraphs_by_subheading,
27
+ enhance_language_readability,
28
+ translate_to_chinese_before_references
29
+ )
30
+ from entities.task import ArxivTask
31
+
32
+
33
+ BUCKET_NAME = "ai-scientist"
34
+
35
+
36
+ # =================================
37
+ # Function Groups: Pipeline for Arxiv
38
+ #
39
+ # 1. pipeline
40
+ # 2. single model chat
41
+ # =================================
42
+
43
+ async def arxiv_pipeline(
44
+ task: ArxivTask,
45
+ client: Minio = None,
46
+ max_retries: int = 5,
47
+ delay: float = 0.5
48
+ ):
49
+ """
50
+ Arxiv pipeline
51
+
52
+ Args:
53
+ task: ArxivTask, the task object
54
+ client: Minio client, the Minio client object
55
+ max_retries: int, the maximum number of retries
56
+ delay: float, the delay between retries
57
+
58
+ Returns:
59
+ None
60
+
61
+ """
62
+ if client is None:
63
+ client = get_client()
64
+
65
+ customer_name = task.customer_name
66
+ uuid = task.uuid
67
+ model_names = task.model_names
68
+
69
+ task.status_string["overall"] = "processing"
70
+
71
+ await asyncio.gather(
72
+ *(process_arxiv_single_chat(
73
+ task, model_name, client, max_retries, delay
74
+ ) for model_name in model_names)
75
+ )
76
+
77
+ # if compare between models
78
+ # at least 3 models should be selected
79
+ logger.info("Check Compare...")
80
+ if task.do_compare and len(task.model_names) >= 3:
81
+ if task.status.get("compare", 0) == 0:
82
+ contents = await asyncio.gather(
83
+ *(get_file_from_minio(
84
+ bucket_name=BUCKET_NAME,
85
+ object_name=f"{customer_name}/{uuid}/{model_name}/review_paper.txt",
86
+ ) for model_name in model_names)
87
+ )
88
+ contents = [c.data.decode("utf-8") for c in contents]
89
+ task.status_string["overall"] = "Start Compare"
90
+
91
+ rank_scores = await compare_chat_chocies(
92
+ contents=contents,
93
+ model_names=model_names
94
+ )
95
+ best_content = contents[min(rank_scores, key=rank_scores.get)]
96
+ await upload_text_to_minio(
97
+ bucket_name=BUCKET_NAME,
98
+ object_name=f"{customer_name}/{uuid}/compared_review_paper.txt",
99
+ file_content=best_content
100
+ )
101
+ task.status_string["overall"] = "Finished"
102
+ await upload_task_json_to_minio(task, client)
103
+ else:
104
+ task.status_string["overall"] = "Finished"
105
+ await upload_task_json_to_minio(task, client)
106
+ else:
107
+ task.status_string["overall"] = "Finished"
108
+ await upload_task_json_to_minio(task, client)
109
+
110
+ async def process_arxiv_single_chat(
111
+ task: ArxivTask,
112
+ model_name: str,
113
+ client: Minio = None,
114
+ max_retries: int = 5,
115
+ delay: float = 0.5
116
+ ):
117
+ """
118
+ Process Arxiv Task
119
+
120
+ Args:
121
+ task: ArxivTask, the task object
122
+ model_name: str, the model name
123
+ client: Minio client, the Minio client object
124
+ max_retries: int, the maximum number of retries
125
+ delay: float, the delay between retries
126
+
127
+ Returns:
128
+ None
129
+
130
+ """
131
+
132
+ # get minio client
133
+ if client is None:
134
+ client = get_client()
135
+
136
+ # add status for <model_name>
137
+ if model_name not in task.status.keys():
138
+ task.status[model_name] = 0
139
+
140
+ # set task status string
141
+ task.status_string["overall"] = "processing"
142
+
143
+ process_steps = {
144
+ 0: process_arxiv_fetch_arxiv_data,
145
+ 1: process_arxiv_process_papers,
146
+ 2: process_arxiv_generate_subheadings,
147
+ 3: process_arxiv_assign_subheadings_to_summaries,
148
+ 4: process_arxiv_create_paragraphs_by_subheading,
149
+ 5: process_arxiv_enhance_language_readability,
150
+ 6: process_arxiv_translate
151
+ }
152
+
153
+ state_description = {
154
+ 0: "Finished fetching data.",
155
+ 1: "Finished paper processing.",
156
+ 2: "Finished subheading generation.",
157
+ 3: "Finished subheading assignment.",
158
+ 4: "Finished paragraph generation.",
159
+ 5: "Finished review language readability enhancement.",
160
+ 6: "Finished review translation.",
161
+ }
162
+
163
+ # Execute Phase
164
+ current_state = task.status[model_name]
165
+ for state in range(current_state, len(process_steps.keys())):
166
+ await process_steps[state](
167
+ task=task,
168
+ model_name=model_name,
169
+ save_name=model_name,
170
+ prev_name=model_name,
171
+ client=client,
172
+ max_retries=max_retries, delay=delay
173
+ )
174
+ task.status_string[model_name] = state_description[state]
175
+ task.status[model_name] = state + 1
176
+ await upload_task_json_to_minio(task, client)
177
+
178
+ task.status_string[model_name] = "Finished."
179
+ await upload_task_json_to_minio(task, client)
180
+
181
+
182
+ # =================================
183
+ # Function Groups: process_arxiv_*
184
+ # 1. _fetch_arxiv_data
185
+ # 2. _process_papers
186
+ # 3. _create_review_paper
187
+ # =================================
188
+
189
+ async def process_arxiv_fetch_arxiv_data(
190
+ task: ArxivTask,
191
+ model_name: str,
192
+ save_name: str,
193
+ prev_name: str = None,
194
+ client: Minio = None,
195
+ max_retries: int = 5,
196
+ delay: float = 0.5
197
+ ):
198
+ """
199
+ Fetch Arxiv Data
200
+
201
+ Args:
202
+ task: ArxivTask, the task object
203
+ model_name: str, the model name
204
+ save_name: str, the save name
205
+ prev_name: str, the previous name
206
+ client: Minio client, the Minio client object
207
+ max_retries: int, the maximum number of retries
208
+ delay: float, the delay between retries
209
+
210
+ Returns:
211
+ None
212
+
213
+ """
214
+
215
+ if client is None:
216
+ client = get_client()
217
+
218
+ if prev_name is not None:
219
+ logger.warning("For first step, prev_model_name is not used.")
220
+
221
+ query = task.query
222
+ customer_name = task.customer_name
223
+ uuid = task.uuid
224
+ start_date = task.start_date
225
+ end_date = task.end_date
226
+ total_page = task.size / 50
227
+
228
+ results, exceptions = await retry_operation(
229
+ get_arxiv_df, task,
230
+ model_name=save_name,
231
+ start_date=start_date, end_date=end_date,
232
+ initial_query=query, total_page=total_page,
233
+ uuid=uuid, customer_name=customer_name,
234
+ max_retries=max_retries, delay=delay
235
+ )
236
+ if results is None: # no valid result after max retries
237
+ task.status_string[model_name] = exceptions # store exception strings in status
238
+ await upload_task_json_to_minio(task, client)
239
+ raise RuntimeError("Arxiv Paper Crawl Failed.") # exit
240
+
241
+
242
+ async def process_arxiv_process_papers(
243
+ task: ArxivTask,
244
+ model_name: str,
245
+ save_name: str,
246
+ prev_name: str = None,
247
+ client: Minio = None,
248
+ max_retries: int = 5,
249
+ delay: float = 0.5
250
+ ):
251
+ """
252
+ Process Arxiv Process Papers
253
+
254
+ Args:
255
+ task: ArxivTask, the task object
256
+ model_name: str, the model name
257
+ save_name: str, the save name
258
+ prev_name: str, the previous name
259
+ client: Minio client, the Minio client object
260
+ max_retries: int, the maximum number of retries
261
+ delay: float, the delay between retries
262
+
263
+ Returns:
264
+ None
265
+
266
+ """
267
+ if client is None:
268
+ client = get_client()
269
+
270
+ query = task.query
271
+ direction = task.query
272
+ customer_name = task.customer_name
273
+ uuid = task.uuid
274
+
275
+ chat_func = get_chat_func(model_names=[model_name])[0]
276
+
277
+ review_arxiv_df = await get_dataframe_from_minio(
278
+ bucket_name=BUCKET_NAME,
279
+ object_name=f"{customer_name}/{uuid}/{prev_name}/arxiv_results.csv",
280
+ client=client
281
+ )
282
+ results, exceptions = await retry_operation(
283
+ process_papers, task,
284
+ dataframe=review_arxiv_df,
285
+ topic=query, direction=direction,
286
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
287
+ max_retries=max_retries, delay=delay,
288
+ chat_func=chat_func
289
+ )
290
+ if results is None: # no valid result after max retries
291
+ task.status_string[model_name] = exceptions # store exception strings in status
292
+ await upload_task_json_to_minio(task, client)
293
+ raise RuntimeError("Arxiv Paper Crawl Failed.") # exit
294
+
295
+
296
+ async def process_arxiv_generate_subheadings(
297
+ task: ArxivTask,
298
+ model_name: str,
299
+ save_name: str,
300
+ prev_name: str = None,
301
+ client: Minio = None,
302
+ max_retries: int = 5,
303
+ delay: float = 0.5
304
+ ):
305
+ """
306
+ Generate Subheadings
307
+
308
+ Args:
309
+ task: ArxivTask, the task object
310
+ model_name: str, the model name
311
+ save_name: str, the save name
312
+ prev_name: str, the previous name
313
+ client: Minio client, the Minio client object
314
+ max_retries: int, the maximum number of retries
315
+ delay: float, the delay between retries
316
+
317
+ Returns:
318
+ None
319
+ """
320
+ if client is None:
321
+ client = get_client()
322
+
323
+ customer_name = task.customer_name
324
+ uuid = task.uuid
325
+
326
+ chat_func = get_chat_func(model_names=[model_name])[0]
327
+
328
+ review_arxiv_df = await get_dataframe_from_minio(
329
+ bucket_name=BUCKET_NAME,
330
+ object_name=f"{customer_name}/{uuid}/{prev_name}/arxiv_results.csv",
331
+ client=client
332
+ )
333
+
334
+ results, exceptions = await retry_operation(
335
+ generate_subheadings, task,
336
+ dataframe=review_arxiv_df,
337
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
338
+ max_retries=max_retries, delay=delay,
339
+ chat_func=chat_func
340
+ )
341
+ if results is None: # no valid result after max retries
342
+ task.status_string[model_name] = exceptions # store exception strings in status
343
+ await upload_task_json_to_minio(task, client)
344
+ raise RuntimeError("Arxiv Generate Subheadings Failed.") # exit
345
+
346
+
347
+ async def process_arxiv_assign_subheadings_to_summaries(
348
+ task: ArxivTask,
349
+ model_name: str,
350
+ save_name: str,
351
+ prev_name: str = None,
352
+ client: Minio = None,
353
+ max_retries: int = 5,
354
+ delay: float = 0.5
355
+ ):
356
+ """
357
+ Assign Subheadings to Summaries
358
+ Args:
359
+ task: ArxivTask, the task object
360
+ model_name: str, the model name
361
+ save_name: str, the save name
362
+ prev_name: str, the previous name
363
+ client: Minio client, the Minio client object
364
+ max_retries: int, the maximum number of retries
365
+ delay: float, the delay between retries
366
+
367
+ Returns:
368
+ None
369
+ """
370
+ if client is None:
371
+ client = get_client()
372
+
373
+ customer_name = task.customer_name
374
+ uuid = task.uuid
375
+
376
+ chat_func = get_chat_func(model_names=[model_name])[0]
377
+
378
+ subheadings = await get_file_from_minio(
379
+ bucket_name=BUCKET_NAME,
380
+ object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
381
+ client=client
382
+ )
383
+ subheadings = subheadings.data.decode("utf-8").split("\n")
384
+
385
+ review_arxiv_df = await get_dataframe_from_minio(
386
+ bucket_name=BUCKET_NAME,
387
+ object_name=f"{customer_name}/{uuid}/{prev_name}/arxiv_results.csv",
388
+ client=client
389
+ )
390
+
391
+ results, exceptions = await retry_operation(
392
+ assign_subheadings_to_summaries, task,
393
+ subheadings=subheadings,
394
+ relevant_papers_df=review_arxiv_df,
395
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
396
+ max_retries=max_retries, delay=delay,
397
+ chat_func=chat_func
398
+ )
399
+ if results is None: # no valid result after max retries
400
+ task.status_string[model_name] = exceptions # store exception strings in status
401
+ await upload_task_json_to_minio(task, client)
402
+ raise RuntimeError("Arxiv Assign Subheadings Failed.") # exit
403
+
404
+
405
+ async def process_arxiv_create_paragraphs_by_subheading(
406
+ task: ArxivTask,
407
+ model_name: str,
408
+ save_name: str,
409
+ prev_name: str = None,
410
+ client: Minio = None,
411
+ max_retries: int = 5,
412
+ delay: float = 0.5
413
+ ):
414
+ """
415
+ Create Paragraphs by Subheading
416
+
417
+ Args:
418
+ task: ArxivTask, the task object
419
+ model_name: str, the model name
420
+ save_name: str, the save name
421
+ prev_name: str, the previous name
422
+ client: Minio client, the Minio client object
423
+ max_retries: int, the maximum number of retries
424
+ delay: float, the delay between retries
425
+
426
+ Returns:
427
+ None
428
+ """
429
+ if client is None:
430
+ client = get_client()
431
+
432
+ query = task.query
433
+ customer_name = task.customer_name
434
+ uuid = task.uuid
435
+
436
+ chat_func = get_chat_func(model_names=[model_name])[0]
437
+
438
+ subheadings = await get_file_from_minio(
439
+ bucket_name=BUCKET_NAME,
440
+ object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
441
+ client=client
442
+ )
443
+ subheadings = subheadings.data.decode("utf-8").split("\n")
444
+
445
+ review_arxiv_df = await get_dataframe_from_minio(
446
+ bucket_name=BUCKET_NAME,
447
+ object_name=f"{customer_name}/{uuid}/{prev_name}/arxiv_results.csv",
448
+ client=client
449
+ )
450
+
451
+ results, exceptions = await retry_operation(
452
+ create_paragraphs_by_subheading, task,
453
+ subheadings=subheadings, main_topic=query,
454
+ relevant_papers_df=review_arxiv_df,
455
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
456
+ max_retries=max_retries, delay=delay,
457
+ chat_func=chat_func
458
+ )
459
+ if results is None: # no valid result after max retries
460
+ task.status_string[model_name] = exceptions # store exception strings in status
461
+ await upload_task_json_to_minio(task, client)
462
+ raise RuntimeError("Arxiv Create Paragraphs Failed.") # exit
463
+
464
+
465
+ async def process_arxiv_enhance_language_readability(
466
+ task: ArxivTask,
467
+ model_name: str,
468
+ save_name: str,
469
+ prev_name: str = None,
470
+ client: Minio = None,
471
+ max_retries: int = 5,
472
+ delay: float = 0.5
473
+ ):
474
+ """
475
+ Enhance Language Readability
476
+ Args:
477
+ task: ArxivTask, the task object
478
+ prev_name: str, the previous name
479
+ model_name: str, the model name
480
+ save_name: str, the save name
481
+ client: Minio client, the Minio client object
482
+ max_retries: int, the maximum number of retries
483
+ delay: float, the delay between retries
484
+
485
+ Returns:
486
+ None
487
+ """
488
+ if client is None:
489
+ client = get_client()
490
+
491
+ customer_name = task.customer_name
492
+ uuid = task.uuid
493
+
494
+ chat_func = get_chat_func(model_names=[model_name])[0]
495
+
496
+ review_content = await get_file_from_minio(
497
+ bucket_name=BUCKET_NAME,
498
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_non_refined.txt",
499
+ client=client
500
+ )
501
+ review_content = review_content.data.decode("utf-8")
502
+
503
+ results, exceptions = await retry_operation(
504
+ enhance_language_readability, task,
505
+ content=review_content,
506
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
507
+ max_retries=max_retries, delay=delay,
508
+ chat_func=chat_func
509
+ )
510
+ if results is None: # no valid result after max retries
511
+ task.status_string[model_name] = exceptions # store exception strings in status
512
+ await upload_task_json_to_minio(task, client)
513
+ raise RuntimeError("Arxiv Enhance Language Failed.") # exit
514
+
515
+
516
+ async def process_arxiv_translate(
517
+ task: ArxivTask,
518
+ model_name: str,
519
+ save_name: str,
520
+ prev_name: str = None,
521
+ client: Minio = None,
522
+ max_retries: int = 5,
523
+ delay: float = 0.5
524
+ ):
525
+ """
526
+ Translate
527
+ Args:
528
+ task: ArxivTask, the task object
529
+ prev_name: str, the previous name
530
+ model_name: str, the model name
531
+ save_name: str, the save name
532
+ client: Minio client, the Minio client object
533
+ max_retries: int, the maximum number of retries
534
+ delay: float, the delay between retries
535
+
536
+ Returns:
537
+ None
538
+ """
539
+ if client is None:
540
+ client = get_client()
541
+
542
+ customer_name = task.customer_name
543
+ uuid = task.uuid
544
+
545
+ chat_func = get_chat_func(model_names=[model_name])[0]
546
+
547
+ review_content = await get_file_from_minio(
548
+ bucket_name=BUCKET_NAME,
549
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_paper.txt",
550
+ client=client
551
+ )
552
+ review_content = review_content.data.decode("utf-8")
553
+
554
+ results, exceptions = await retry_operation(
555
+ translate_to_chinese_before_references, task,
556
+ text=review_content,
557
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
558
+ max_retries=max_retries, delay=delay,
559
+ chat_func=chat_func
560
+ )
561
+ if results is None: # no valid result after max retries
562
+ task.status_string[model_name] = exceptions # store exception strings in status
563
+ await upload_task_json_to_minio(task, client)
564
+ raise RuntimeError("Arxiv Translate Failed.") # exit
565
+
566
+
567
+ # =================================
568
+ # Function Groups: Arxiv Task
569
+ #
570
+ # functions specific for arxiv task
571
+ # =================================
572
+
573
+ async def get_arxiv_df(
574
+ start_date, end_date,
575
+ initial_query, total_page,
576
+ uuid, customer_name, model_name
577
+ ):
578
+ cookie = \
579
+ 'browser=117.174.233.206.1731117480203659; arxiv_labs={%22sameSite%22:%22strict%22%2C%22expires%22:365%2C%22last_tab%22:%22tabone%22}; arxiv-search-parameters="{\"order\": \"-announced_date_first\"\054 \"size\": \"50\"\054 \"abstracts\": \"show\"}"'
580
+ headers = {
581
+ "authority": "arxiv.org",
582
+ "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
583
+ "accept-language": "zh-CN,zh;q=0.9",
584
+ "cache-control": "no-cache",
585
+ "pragma": "no-cache",
586
+ "referer": "https://arxiv.org/search/physics?query=^%^28^%^28^%^27deep+learning^%^27^%^29+OR+^%^28^%^27machine+learning^%^27^%^29^%^29+AND+^%^28^%^27antibody^%^27^%^29+earch+v0.5.6+released+2020+^%^28^%^28^%^27deep+learning^%^27^%^29+OR+^%^28^%^27machine+learning^%^27^%^29^%^29+AND^%^28^%^27antibody^%^27^%^29&searchtype=all&abstracts=show&order=-announced_date_first&size=50",
587
+ "sec-ch-ua": "^\\^Chromium^^;v=^\\^104^^, ^\\^",
588
+ "sec-ch-ua-mobile": "?0",
589
+ "sec-ch-ua-platform": "^\\^Windows^^",
590
+ "sec-fetch-dest": "document",
591
+ "sec-fetch-mode": "navigate",
592
+ "sec-fetch-site": "same-origin",
593
+ "sec-fetch-user": "?1",
594
+ "upgrade-insecure-requests": "1",
595
+ "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.0.0 Safari/537.36",
596
+ 'cookie': cookie
597
+ }
598
+
599
+ url = "https://arxiv.org/search/advanced"
600
+ csv_filename = f"{customer_name}/{uuid}/{model_name}/arxiv_results.csv"
601
+
602
+ texts = ""
603
+
604
+ fieldnames = ['JT', 'DCOM', 'PMID', 'TI',
605
+ 'FAU', 'FAU-frist', 'AB', 'Full_Text_Links']
606
+
607
+ texts += ",".join(fieldnames) + "\n"
608
+ res_count = 0
609
+ for page in range(0, int(total_page)+1):
610
+ offset = page * 50
611
+ params = {
612
+ "advanced": "",
613
+ "terms-0-operator": "AND",
614
+ "terms-0-term": initial_query,
615
+ "terms-0-field": "all",
616
+ "classification-physics_archives": "all",
617
+ "classification-include_cross_list": "include",
618
+ "date-year": "",
619
+ "date-filter_by": "date_range",
620
+ "date-from_date": start_date,
621
+ "date-to_date": end_date,
622
+ "date-date_type": "submitted_date",
623
+ "abstracts": "show",
624
+ "size": "50",
625
+ 'start': offset,
626
+ "order": "-announced_date_first"
627
+ }
628
+ async with aiohttp.ClientSession() as session:
629
+ async with session.get(url, headers=headers, params=params) as resp:
630
+ if resp.status != 200:
631
+ logger.error("Failed to retrieve data from arxiv")
632
+ raise ConnectionError("Failed to retrieve data from arxiv")
633
+ res = await resp.text()
634
+
635
+ if "produced no results" in res:
636
+ logger.warning("No results found")
637
+ break
638
+ else:
639
+ tree = etree.HTML(res.encode('utf-8'))
640
+ li_list = tree.xpath('//*[@class="breathe-horizontal"]/li')
641
+ if len(li_list) > 0:
642
+ for aa in li_list:
643
+ # 提取论文信息
644
+ tid = ''.join(
645
+ aa.xpath(
646
+ './/*[@class="list-title is-inline-block"]/a/text()'
647
+ )
648
+ ).strip()
649
+ authors = ''.join(
650
+ aa.xpath('.//*[@class="authors"]/a/text()')
651
+ ).strip()
652
+ first_authors = aa.xpath(
653
+ './/*[@class="authors"]/a/text()'
654
+ )[0] if len(aa.xpath(
655
+ './/*[@class="authors"]/a/text()')
656
+ ) > 0 else ''
657
+ title = ''.join(
658
+ aa.xpath(
659
+ './/*[@class="title is-5 mathjax"]//text()')
660
+ ).strip()
661
+ abstract = ','.join(aa.xpath(
662
+ './/*[@class="abstract-full has-text-grey-dark mathjax"]//text()')
663
+ ).strip()
664
+ # pdate = aa.xpath(".//p[@class='is-size-7']/text()")[0] if len(aa.xpath(".//p[@class='is-size-7']/text()")) > 0 else ''
665
+ pdate = aa.xpath(
666
+ ".//p[@class='is-size-7']/text()"
667
+ )[0].strip() if len(
668
+ aa.xpath(".//p[@class='is-size-7']/text()")
669
+ ) > 0 else ''
670
+ pdate = re.sub(r'\s*;.*$', '', pdate)
671
+ purl = ''.join(
672
+ aa.xpath('.//*[@class="list-title is-inline-block"]/a/@href')).strip()
673
+ subjects = await get_more_detail(purl) # 获取更多细节
674
+ texts += ",".join([
675
+ escape_csv_field(x) for x in [
676
+ subjects, pdate, tid, title, authors,
677
+ first_authors, abstract, purl
678
+ ]
679
+ ]) + "\n"
680
+ res_count += len(li_list)
681
+ else:
682
+ break
683
+ await upload_text_to_minio(
684
+ bucket_name=BUCKET_NAME,
685
+ object_name=csv_filename,
686
+ file_content=texts,
687
+ )
688
+
689
+ logger.info(f'已成功保存至{csv_filename}, 共获取到结果:{res_count}个')
690
+ return csv_filename
691
+
692
+
693
+ async def get_more_detail(url):
694
+ """
695
+ 获取论文的更多细节信息,如主题。
696
+
697
+ :param url: 论文的链接
698
+ :return: 主题字符串
699
+ """
700
+ headers = {
701
+ "authority": "arxiv.org",
702
+ "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
703
+ "accept-language": "zh-CN,zh;q=0.9",
704
+ "cache-control": "no-cache",
705
+ "pragma": "no-cache",
706
+ "referer": "https://arxiv.org/search/advanced",
707
+ "sec-ch-ua": "^\\^Chromium^^;v=^\\^104^^, ^\\^",
708
+ "sec-ch-ua-mobile": "?0",
709
+ "sec-ch-ua-platform": "^\\^Windows^^",
710
+ "sec-fetch-dest": "document",
711
+ "sec-fetch-mode": "navigate",
712
+ "sec-fetch-site": "same-origin",
713
+ "sec-fetch-user": "?1",
714
+ "upgrade-insecure-requests": "1",
715
+ "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.0.0 Safari/537.36"
716
+ }
717
+ cookies = {
718
+ "browser": "117.174.233.206.1731117480203659",
719
+ "arxiv-search-parameters": "^\\^^{^^^^order^\\^\\^\\^:"
720
+ }
721
+
722
+ # 发送请求获取论文详情
723
+ async with aiohttp.ClientSession() as session:
724
+ async with session.get(url, headers=headers, cookies=cookies) as resp:
725
+ if resp.status != 200:
726
+ logger.error("Failed to get detail from arxiv")
727
+ raise ConnectionError("Failed to get detail from arxiv")
728
+ res = await resp.text()
729
+
730
+ tree = etree.HTML(res.encode("utf-8"))
731
+
732
+ # 提取主题信息
733
+ subjects_list = tree.xpath('//*[@class="tablecell subjects"]//text()')
734
+ subjects = ''
735
+ if subjects_list:
736
+ subjects = ''.join([tags.strip() for tags in subjects_list if tags])
737
+
738
+ return subjects
utils/common_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import asyncio
3
+
4
+ from functools import wraps
5
+
6
+
7
+ def execution_time(func):
8
+ @wraps(func)
9
+ def sync_wrapper(*args, **kwargs):
10
+ start_time = time.time()
11
+ result = func(*args, **kwargs)
12
+ end_time = time.time()
13
+ elapsed_time = end_time - start_time
14
+ print(f"Execution time of {func.__name__}: {elapsed_time:.4f} seconds")
15
+ return result
16
+
17
+ @wraps(func)
18
+ async def async_wrapper(*args, **kwargs):
19
+ start_time = time.time()
20
+ result = await func(*args, **kwargs)
21
+ end_time = time.time()
22
+ elapsed_time = end_time - start_time
23
+ print(f"Execution time of {func.__name__}: {elapsed_time:.4f} seconds")
24
+ return result
25
+
26
+ if asyncio.iscoroutinefunction(func):
27
+ return async_wrapper
28
+ else:
29
+ return sync_wrapper
30
+
31
+
32
+ def escape_csv_field(field):
33
+ """
34
+ Escapes fields to ensure proper CSV formatting.
35
+ - Wraps the field in double quotes if it contains a comma, double quote, or newline.
36
+ - Escapes double quotes inside the field by doubling them.
37
+ """
38
+ field_str = str(field) # Convert the field to a string
39
+ if ',' in field_str or '"' in field_str or '\n' in field_str:
40
+ field_str = '"' + field_str.replace('"', '""') + '"'
41
+ return field_str
utils/minio_utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ import asyncio
4
+ import pandas as pd
5
+
6
+ from docx import Document
7
+ from loguru import logger
8
+ from minio import Minio
9
+ from entities.task import Task, task_factory
10
+
11
+
12
+ BUCKET_NAME = "ai-scientist"
13
+
14
+
15
+ # def get_client():
16
+ # return Minio(
17
+ # endpoint="0.0.0.0:9000",
18
+ # access_key="9o5pg5tBkjZlsvp4tegE",
19
+ # secret_key="YQqCRAlBygHSy7Rh3qZ2kTLqo9WcTQiqttHLQaPE",
20
+ # secure=False
21
+ # )
22
+
23
+ def get_client():
24
+ return Minio(
25
+ endpoint="0.0.0.0:9000",
26
+ access_key="minioadmin",
27
+ secret_key="minioadmin",
28
+ secure=False
29
+ )
30
+
31
+ async def get_task_from_minio(
32
+ uuid: str,
33
+ customer_name: str,
34
+ client: Minio = None
35
+ ) -> Task:
36
+ """
37
+ Asynchronously retrieve a task from MinIO.
38
+
39
+ Args:
40
+ uuid (str): Task UUID.
41
+ customer_name (str): Customer name.
42
+ client (Minio, optional): MinIO client instance.
43
+
44
+ Returns:
45
+ Task: The task object.
46
+
47
+ Raises:
48
+ FileNotFoundError: If the task or customer data is not found.
49
+ """
50
+ if client is None:
51
+ client = get_client()
52
+
53
+ objects = await asyncio.to_thread(
54
+ lambda: list(client.list_objects(
55
+ bucket_name=BUCKET_NAME,
56
+ prefix=f"{customer_name}/"
57
+ ))
58
+ )
59
+
60
+ logger.info(objects)
61
+
62
+ # Check if customer exists
63
+ if len(objects) <= 0:
64
+ raise FileNotFoundError(f"No task found for customer {customer_name}")
65
+
66
+ # Check if task exists
67
+ object_names = [obj.object_name.split("/")[1] for obj in objects]
68
+ if uuid not in object_names:
69
+ raise FileNotFoundError(f"No task found for customer {customer_name} with uuid {uuid}")
70
+
71
+ # If task found
72
+ json_file = await get_file_from_minio(
73
+ bucket_name=BUCKET_NAME,
74
+ object_name=f"{customer_name}/{uuid}/task.json",
75
+ client=client
76
+ )
77
+
78
+ json_data = json_file.data.decode("utf-8")
79
+ json_data = json.loads(json_data)
80
+ print(json_data)
81
+ return task_factory[json_data["task_type"]].load_from_json(json_data)
82
+
83
+
84
+ async def get_all_tasks_from_minio(
85
+ customer_name: str,
86
+ client: Minio = None
87
+ ) -> list[Task]:
88
+ """
89
+
90
+ Asynchronously retrieve all tasks for a customer from MinIO.
91
+
92
+ Args:
93
+ customer_name (str): Customer name.
94
+ client (Minio, optional): MinIO client instance.
95
+
96
+ Returns:
97
+ list[Task]: List of task objects.
98
+ """
99
+
100
+ if client is None:
101
+ client = get_client()
102
+
103
+ objects = await asyncio.to_thread(
104
+ lambda: list(client.list_objects(
105
+ bucket_name=BUCKET_NAME,
106
+ prefix=f"{customer_name}/"
107
+ ))
108
+ )
109
+
110
+ # Check if customer exists
111
+ if len(objects) <= 0:
112
+ # raise FileNotFoundError(f"No task found for customer {customer_name}")
113
+ return []
114
+
115
+ task_ids = [obj.object_name.split("/")[1] for obj in objects]
116
+ task_jsons = await asyncio.gather(
117
+ *(get_task_from_minio(
118
+ uuid=task_id, customer_name=customer_name
119
+ ) for task_id in task_ids)
120
+ )
121
+ return task_jsons
122
+
123
+
124
+ async def upload_task_json_to_minio(task: Task, client: Minio = None) -> Task:
125
+ """
126
+ Asynchronously upload a task's JSON representation to MinIO.
127
+
128
+ Args:
129
+ task (Task): The task object to upload.
130
+ client (Minio, optional): MinIO client instance.
131
+
132
+ Returns:
133
+ Task: The uploaded task object.
134
+ """
135
+ if client is None:
136
+ client = get_client()
137
+
138
+ json_data = task.save_to_json()
139
+ byte_data = io.BytesIO(json_data.encode("utf-8"))
140
+
141
+ await asyncio.to_thread(
142
+ lambda: client.put_object(
143
+ bucket_name=BUCKET_NAME,
144
+ object_name=f"{task.customer_name}/{task.uuid}/task.json",
145
+ data=byte_data,
146
+ length=len(byte_data.getvalue()),
147
+ content_type="application/json"
148
+ )
149
+ )
150
+ return task
151
+
152
+
153
+ async def upload_text_to_minio(
154
+ bucket_name: str,
155
+ object_name: str,
156
+ file_content: str,
157
+ client: Minio = None,
158
+ ):
159
+ if client is None:
160
+ client = get_client()
161
+
162
+ file_data = io.BytesIO(file_content.encode("utf-8"))
163
+
164
+ try:
165
+ await asyncio.to_thread(
166
+ client.put_object,
167
+ bucket_name=bucket_name,
168
+ object_name=object_name,
169
+ data=file_data,
170
+ length=len(file_data.getvalue()),
171
+ )
172
+ except Exception as e:
173
+ raise Exception(f"Error uploading file to MinIO: {e}")
174
+
175
+
176
+ async def upload_dataframe_to_minio(
177
+ bucket_name: str,
178
+ object_name: str,
179
+ df: pd.DataFrame,
180
+ client: Minio = None,
181
+ ):
182
+ if client is None:
183
+ client = get_client()
184
+
185
+ buffer = io.BytesIO()
186
+ df.to_csv(buffer, index=False)
187
+
188
+ await upload_text_to_minio(
189
+ bucket_name=bucket_name,
190
+ object_name=object_name,
191
+ file_content=buffer.getvalue().decode("utf-8")
192
+ )
193
+
194
+
195
+ async def upload_document_to_minio(
196
+ bucket_name: str,
197
+ object_name: str,
198
+ document: Document,
199
+ client: Minio = None,
200
+ ):
201
+ if client is None:
202
+ client = get_client()
203
+
204
+ buffer = io.BytesIO()
205
+ document.save(buffer)
206
+ buffer.seek(0)
207
+
208
+ await asyncio.to_thread(
209
+ lambda: client.put_object(
210
+ bucket_name=bucket_name,
211
+ object_name=object_name,
212
+ data=buffer,
213
+ length=buffer.getbuffer().nbytes,
214
+ content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
215
+ )
216
+ )
217
+
218
+
219
+ async def get_file_from_minio(
220
+ bucket_name: str,
221
+ object_name: str,
222
+ client: Minio = None,
223
+ ):
224
+ if client is None:
225
+ client = get_client()
226
+
227
+ try:
228
+ file_data = await asyncio.to_thread(
229
+ client.get_object,
230
+ bucket_name=bucket_name,
231
+ object_name=object_name
232
+ )
233
+ return file_data
234
+ except Exception as e:
235
+ raise Exception(f"Error getting file from MinIO: {e}")
236
+
237
+
238
+ async def get_dataframe_from_minio(
239
+ bucket_name: str,
240
+ object_name: str,
241
+ client: Minio = None,
242
+ ):
243
+ if client is None:
244
+ client = get_client()
245
+
246
+ file_data = await get_file_from_minio(
247
+ bucket_name=bucket_name,
248
+ object_name=object_name,
249
+ client=client
250
+ )
251
+
252
+ if object_name.endswith(".csv"):
253
+ df = pd.read_csv(io.BytesIO(file_data.data))
254
+ elif object_name.endswith(".xlsx") or object_name.endswith("xls"):
255
+ df = pd.read_excel(io.BytesIO(file_data.data))
256
+ return df
utils/paper_plus_utils.py ADDED
@@ -0,0 +1,1265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import re
4
+ import math
5
+ import random
6
+ import asyncio
7
+ import textwrap
8
+ import pandas as pd
9
+
10
+ from docx import Document
11
+ from loguru import logger
12
+
13
+ from .minio_utils import (
14
+ upload_text_to_minio,
15
+ upload_dataframe_to_minio,
16
+ upload_document_to_minio,
17
+ get_file_from_minio
18
+ )
19
+ from .common_utils import escape_csv_field
20
+
21
+
22
+ BUCKET_NAME = "ai-scientist"
23
+
24
+
25
+ # Function to check relevance and obtain keywords as reason
26
+ async def is_relevant(title, abstract, topic, direction, chat_func):
27
+ """
28
+ Check if a paper is relevant to a topic and obtain keywords as reason.
29
+
30
+ Args:
31
+ title (str): Title of the paper.
32
+ abstract (str): Abstract of the paper.
33
+ topic (str): Topic to check relevance against.
34
+ direction (str): Direction to check relevance against.
35
+ chat_func (function): Function to call the chat model.
36
+
37
+ Returns:
38
+ bool: True if the paper is relevant, False otherwise.
39
+ str: Keywords that indicate relevance.
40
+
41
+ """
42
+ relevance_prompt = (
43
+ f"You are an academic expert specializing in the field of {topic}. Your task is to determine if the following paper is relevant to the research direction described as '{direction}'.\n\n"
44
+ "Please follow this reasoning process:\n"
45
+ "1. Carefully read the paper's title and abstract.\n"
46
+ "2. Identify the core research area, methodology, results, or focal points presented in the paper.\n"
47
+ "3. Compare these core elements to the given research direction. Consider whether the paper directly addresses, contributes to, or is closely aligned with the stated direction.\n"
48
+ "4. If the paper aligns conceptually, methodologically, or thematically with the direction, then it is relevant. If it is only tangential or unrelated, it is not relevant.\n"
49
+ "5. From the text, select the main keywords that strongly indicate relevance (if relevant). These keywords should be key concepts, terms, or phrases that link the paper’s content to the given research direction.\n"
50
+ "6. If not relevant, you can provide no keywords or give a brief note indicating no strong linkage.\n\n"
51
+ "You must provide the answer in the following exact format:\n"
52
+ "Relevance: True or False\n"
53
+ "Keywords: [Comma-separated keywords]\n\n"
54
+ f"Title: {title}\n"
55
+ f"Abstract: {abstract}\n"
56
+ )
57
+ response = await chat_func(relevance_prompt)
58
+ if response is None:
59
+ return False, "Relevance check unavailable due to server error."
60
+
61
+ try:
62
+ response_text = response.choices[0].message.content
63
+ relevance = "True" in response_text
64
+ keywords = response_text.split(
65
+ "Keywords:")[-1].strip() if "Keywords:" in response_text else ""
66
+ return relevance, keywords
67
+ except AttributeError:
68
+ logger.error("Error in chat_func response format:", response)
69
+ return False, "Relevance check failed"
70
+
71
+
72
+ # Modified summarize_abstract function with error handling for failed completion requests
73
+ async def summarize_abstract(title, abstract, first_author, chat_func):
74
+ """
75
+ Summarize the abstract of a research paper.
76
+
77
+ Args:
78
+ title (str): Title of the paper.
79
+ abstract (str): Abstract of the paper.
80
+ first_author (str): Name of the first author.
81
+ chat_func (function): Function to call the chat model.
82
+
83
+ Returns:
84
+ str: Summary of the abstract.
85
+
86
+ """
87
+ formatted_author = reformat_author_name(first_author)
88
+
89
+ # decision_prompt仍然维持原有逻辑,用于判断摘要类型
90
+ decision_prompt = (
91
+ f"Your task is to decide the type of summary needed based on the abstract.\n\n"
92
+ f"Instructions:\n"
93
+ f"- If the study primarily introduces, describes, or refines a method, technique, model, or computational approach, "
94
+ f"with its main contribution being methodological rather than a discovery about a phenomenon, then output:\n"
95
+ f"Output: full\n\n"
96
+ f"- If the study primarily reports a new discovery, finding, result, or empirical outcome about a certain phenomenon, "
97
+ f"biological entity, material property, or theoretical insight, then output:\n"
98
+ f"Output: concise\n\n"
99
+ f"Make your decision strictly based on the abstract content. Do not provide explanations or reasoning, "
100
+ f"only the exact output word as instructed.\n\n"
101
+ f"Title: {title}\nAbstract: {abstract}\n"
102
+ )
103
+
104
+ # full_summary_prompt不再要求使用第一作者信息,只需要两句话总结主要发现
105
+ full_summary_prompt = (
106
+ "In exactly two sentences, provide a high-level summary of the study’s key findings, "
107
+ "while maintaining concrete technical terms, methodologies, and specific entities. "
108
+ # "Do not use 'this study', 'the authors', or similar phrases as the subject; instead, use a proper noun or specific entity mentioned or implied in the abstract as the subject. "
109
+ "Use clear and advanced language without generalizing or replacing specific methods with vague terms.\n\n"
110
+ f"The summary should use clear, advanced language and mention the first author {formatted_author} followed by 'et al.':\n\n"
111
+ f"Title: {title}\nAbstract: {abstract}\n\n"
112
+ f"Summary by {formatted_author} et al.:"
113
+ )
114
+
115
+ # concise_summary_prompt不再要求使用第一作者信息,只需要一句话总结主要发现
116
+ concise_summary_prompt = (
117
+ "In two sentence, provide a precise statement of the study’s main finding without generalizing and without making the study itself the subject. "
118
+ "Do not use 'this study', 'the authors', or similar phrases as the subject; instead, use a proper noun or specific entity mentioned or implied in the abstract as the subject of the sentence. "
119
+ "Directly present the finding as the sentence’s focus, using advanced and specific language.\n\n"
120
+ f"Title: {title}\nAbstract: {abstract}\n\n"
121
+ )
122
+
123
+ response_decision = await chat_func(decision_prompt)
124
+ response_decision = response_decision.choices[0].message.content.strip().lower()
125
+
126
+ if response_decision and "full" in response_decision:
127
+ prompt_summary = full_summary_prompt
128
+ else:
129
+ prompt_summary = concise_summary_prompt
130
+
131
+ response = await chat_func(prompt_summary)
132
+
133
+ if response is None:
134
+ return "Summary unavailable due to server error."
135
+
136
+ try:
137
+ result = response.choices[0].message.content.strip()
138
+ result_words = result.split()
139
+ summary = " ".join(result_words)
140
+ return summary
141
+ except AttributeError:
142
+ logger.error("Error in chat_func response format:", response)
143
+ return "Summary unavailable"
144
+
145
+
146
+ # Function to reformat first author name
147
+ def reformat_author_name(author_name):
148
+ """
149
+ Reformat the first author name by removing commas.
150
+
151
+ Args:
152
+ author_name (str): Name of the first author.
153
+
154
+ Returns:
155
+ str: Reformatted name of the first author.
156
+
157
+ """
158
+ try:
159
+ return author_name.replace(",", "")
160
+ except AttributeError:
161
+ return "Unknown Author"
162
+
163
+
164
+ # Function to generate 3-5 hierarchical subheadings related to the main topic
165
+ async def generate_subheadings(
166
+ relevant_papers_df, main_topic,
167
+ uuid, customer_name, model_name,
168
+ chat_func
169
+ ):
170
+ """
171
+ Generate 3-5 hierarchical subheadings related to the main topic based on the summaries of relevant papers.
172
+
173
+ Args:
174
+ relevant_papers_df: DataFrame containing relevant papers.
175
+ main_topic: Main topic of the research.
176
+ chat_func: Function to send chat messages to the chatbot.
177
+
178
+ Returns:
179
+ List[str]: List of generated subheadings.
180
+
181
+ """
182
+ # Determine the number of subheadings based on the number of rows
183
+ num_papers = len(relevant_papers_df)
184
+ if num_papers < 10:
185
+ num_subheadings = 1
186
+ elif num_papers <= 20:
187
+ num_subheadings = 2
188
+ elif num_papers <= 40:
189
+ num_subheadings = 3
190
+ elif num_papers <= 60:
191
+ num_subheadings = 4
192
+ elif num_papers <= 100:
193
+ num_subheadings = 5
194
+ else:
195
+ num_subheadings = 6
196
+
197
+ # Generate the summaries for the prompt
198
+ summaries = " ".join(relevant_papers_df['Summary'].tolist())
199
+
200
+ # Create the improved prompt
201
+ prompt = (
202
+ f"Consider the following main topic: '{main_topic}'. You are given a set of summaries extracted from relevant research papers related to this topic. Your goal is to generate {num_subheadings} hierarchical subheadings that clearly reflect and logically organize the key concepts and themes found in these summaries.\n\n"
203
+ "Instructions:\n"
204
+ "1. Carefully read and analyze the provided summaries.\n"
205
+ "2. Identify broad thematic categories directly mentioned or strongly implied by the summaries. These should serve as the starting points for the subheadings.\n"
206
+ "3. Arrange the subheadings in a hierarchical manner: start with the most general or foundational aspects of the main topic, then move progressively towards more specific, nuanced, or advanced themes.\n"
207
+ "4. Ensure that each subheading is distinct and does not overlap in scope or content with the others. Every subheading should be directly supported by information present in the summaries.\n"
208
+ "5. Do not introduce concepts that are not reflected in the summaries. All subheadings must be grounded in the text provided.\n"
209
+ "6. The final output should be a simple list of subheadings, each preceded by a hyphen, without additional explanation or commentary.\n\n"
210
+ f"Summaries:\n{summaries}\n\n"
211
+ "Output format:\n- Subheading 1\n- Subheading 2\n- Subheading 3\n..."
212
+ )
213
+
214
+ response = await chat_func(prompt)
215
+ subheadings = response.choices[0].message.content.strip().splitlines()
216
+ subheadings = [subheading.replace(r"[-*']", '').strip() for subheading in subheadings]
217
+ subheadings = [subheading.replace(r"- ", '').strip() for subheading in subheadings]
218
+ subheadings = [re.sub(r"^[^\w]+|[^\w]+$", '', subheading).strip()
219
+ for subheading in subheadings]
220
+ subheadings = subheadings[:num_subheadings]
221
+ logger.info("Generated Subheadings:\n" + "\n".join(subheadings))
222
+
223
+ output_filename = f"{customer_name}/{uuid}/{model_name}/generated_subheadings.txt"
224
+ await upload_text_to_minio(
225
+ bucket_name=BUCKET_NAME,
226
+ object_name=output_filename,
227
+ file_content="\n".join(subheadings)
228
+ )
229
+ logger.info(f"Subheadings saved to {output_filename}")
230
+ return subheadings
231
+
232
+
233
+ # Function to assign summaries to subheadings with minimum allocation of references per subheading
234
+ async def assign_subheadings_to_summaries(
235
+ relevant_papers_df,
236
+ subheadings,
237
+ uuid, customer_name, model_name,
238
+ chat_func
239
+ ):
240
+ """
241
+ Assign summaries to subheadings with minimum allocation of references per subheading.
242
+
243
+ Args:
244
+ relevant_papers_df: DataFrame containing relevant papers.
245
+ subheadings: List of subheadings.
246
+ uuid: Unique identifier for the task.
247
+ customer_name: Name of the customer.
248
+ chat_func: Function to send chat messages to the chatbot.
249
+
250
+ Returns:
251
+ DataFrame with assigned subheadings.
252
+
253
+ """
254
+ total_papers = len(relevant_papers_df)
255
+ min_papers_per_subheading = math.ceil(total_papers / (len(subheadings) + 1))
256
+
257
+ assigned_subheadings = []
258
+ prompts = []
259
+ for summary in relevant_papers_df['Summary']:
260
+ prompt = (
261
+ # 对模型的指令明确化
262
+ f"Given the following subheadings and a research paper summary, identify the single most appropriate subheading for the provided summary. "
263
+ f"You must carefully analyze the semantic content, thematic focus, and logical structure within the summary. "
264
+ f"Ensure that the chosen subheading closely matches the core topic, key findings, research objectives, or main arguments of the paper summary. "
265
+ f"Do not select a subheading that only partially fits; the chosen subheading should represent a strong and direct thematic alignment with the summary's central ideas. "
266
+ f"Each subheading covers a distinct aspect or theme. Avoid overlaps by choosing the one that best captures the essence of the summary. "
267
+ f"If a subheading does not logically or semantically align with the main theme or content described in the summary, it should not be chosen.\n\n"
268
+
269
+ # 提供小标题列表
270
+ f"Subheadings:\n{subheadings}\n\n"
271
+
272
+ # 提供文献摘要
273
+ f"Summary:\n{summary}\n\n"
274
+
275
+ # 请求结果格式
276
+ "Output format:\nSubheading: [Chosen subheading]"
277
+ )
278
+ prompts.append(prompt)
279
+ responses = await asyncio.gather(
280
+ *(chat_func(prompt) for prompt in prompts)
281
+ )
282
+ for response in responses:
283
+ assigned_subheading = response.choices[0].message.content.split(": ", 1)[1]
284
+ assigned_subheadings.append(assigned_subheading)
285
+
286
+ relevant_papers_df['Assigned Subheading'] = assigned_subheadings
287
+
288
+ # Ensure minimum papers per subheading
289
+ counts = relevant_papers_df['Assigned Subheading'].value_counts().to_dict()
290
+ for subheading in subheadings:
291
+ if counts.get(subheading, 0) < min_papers_per_subheading:
292
+ extra_summaries = relevant_papers_df[relevant_papers_df['Assigned Subheading'] != subheading].sample(
293
+ min_papers_per_subheading - counts.get(subheading, 0)
294
+ )
295
+ relevant_papers_df.loc[extra_summaries.index,
296
+ 'Assigned Subheading'] = subheading
297
+
298
+ relevant_papers_df['Assigned Subheading'] = (
299
+ relevant_papers_df['Assigned Subheading']
300
+ .str.replace(r"^[^\w]+|[^\w]+$", '', regex=True) # 去除开头和结尾的非字母数字字符
301
+ .str.strip() # 去除字符串两端的空格
302
+ )
303
+
304
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
305
+ output_dir = prefix
306
+
307
+ csv_filename = os.path.join(output_dir, f"assigned_subheadings.csv")
308
+
309
+ # relevant_papers_df.to_csv(csv_filename, index=False, encoding='utf-8')
310
+ await upload_dataframe_to_minio(
311
+ bucket_name=BUCKET_NAME,
312
+ object_name=csv_filename,
313
+ df=relevant_papers_df,
314
+ )
315
+
316
+ logger.info(f"Assigned subheadings saved to {csv_filename}")
317
+ logger.info(f"Found {len(relevant_papers_df)} related papers")
318
+
319
+ return relevant_papers_df
320
+
321
+
322
+ async def get_sorting_suggestions(subheading, sub_df, chat_func):
323
+ # Add original index column to sub_df to retain original paper number
324
+ sub_df = sub_df.copy() # Avoid SettingWithCopyWarning
325
+ sub_df.reset_index(drop=True, inplace=True)
326
+ sub_df.index = sub_df.index + 1
327
+ sub_df['Original Index'] = sub_df.index
328
+
329
+ paper_num = sub_df.shape[0]
330
+ logger.info(paper_num)
331
+
332
+ if paper_num > 1:
333
+ # Combine summaries into one string, appending author information
334
+ summaries_text = '\n'.join(
335
+ [f"Paper {row['Original Index']} by {row['First Author']}:\nSummary: {row['Summary']}\nRelevance Keywords: {row['Relevance Keywords']}"
336
+ for _, row in sub_df.iterrows()]
337
+ )
338
+ logger.info(summaries_text)
339
+
340
+ prompt = (
341
+ f"You are an experienced scientist tasked with organizing a collection of {paper_num} papers under the subheading '{subheading}' for a scientific review article.\n\n"
342
+
343
+ "You have the following input:\n"
344
+ "1. A set of papers, each with a summary and relevance keywords.\n"
345
+ "2. A need to arrange these papers in a coherent and logical order that supports a narrative flow in a review article.\n\n"
346
+
347
+ "Please address the following tasks:\n\n"
348
+ "1. **Identify Key Themes and Group Papers:**\n"
349
+ "- First, thoroughly read the summaries and relevance keywords of all the provided papers.\n"
350
+ "- Determine distinct thematic groups or categories. A thematic group can be based on shared methodology, a common theoretical framework, a particular type of material, organism, phenomenon, or a progressive line of inquiry.\n"
351
+ "- The grouping should reflect logical subdivisions that a reader of a review article could follow. For instance:\n"
352
+ " - Start with foundational or broadly relevant studies that introduce key concepts, contexts, or basic methods.\n"
353
+ " - Follow with papers that build upon these foundations, introducing more advanced techniques, deeper investigations, specialized findings, or novel approaches.\n"
354
+ " - Conclude with cutting-edge, most specialized, or recently introduced concepts that push the boundaries of the field.\n"
355
+ "- If certain papers align well as a stepping stone from one theme to another, position them accordingly to create a smooth thematic transition.\n\n"
356
+
357
+ "2. **Determine the Logical Order Within Each Group:**\n"
358
+ "- Within each thematic group, arrange the papers in an order that naturally builds understanding. Consider:\n"
359
+ " - Present foundational or earlier conceptual frameworks before more advanced or derivative studies.\n"
360
+ " - Highlight any chronological clues (if provided) or logical sequences, such as a method introduced in one paper being applied or expanded in a later paper.\n"
361
+ " - Move from general to specific, from simpler methodologies to more complex analyses, or from well-established concepts to more tentative or innovative ones.\n\n"
362
+
363
+ "3. **Combine Groups into a Cohesive Narrative:**\n"
364
+ "- After organizing papers within their groups, merge the groups into a single final list.\n"
365
+ "- The final list should read like a storyline: start with a broad, conceptual or methodological foundation, then move through intermediate studies that expand and refine these ideas, and end with the most advanced, specialized, or novel findings.\n"
366
+ "- Ensure that transitions between groups make sense, helping a reader follow a narrative where each section logically paves the way for the next.\n\n"
367
+
368
+ "4. **Provide the Final Ordered List:**\n"
369
+ "- Present the final ordered list as a numbered list from 1 to {paper_num}.\n"
370
+ "- Each entry should include the original paper number and the first author's name in the following format:\n"
371
+ " <Final Position>. <Original Paper Number>. (<First Author's Last Name>)\n\n"
372
+ "For example:\n"
373
+ "1. 3. (Smith)\n"
374
+ "2. 1. (Johnson)\n"
375
+ "3. 5. (Williams)\n\n"
376
+ "All papers must appear once, and each final position should be unique. Do not omit any papers.\n\n"
377
+
378
+ "Below are the papers:\n\n"
379
+ f"{summaries_text}\n\n"
380
+
381
+ "Please reflect on the thematic connections and carefully arrange the papers according to the instructions above."
382
+ )
383
+
384
+ # Retry mechanism to handle mismatches
385
+ sorting_order = []
386
+ sorting_response = await chat_func(prompt) # Replace with your chat model interface
387
+ sorting_suggestion = sorting_response.choices[0].message.content.strip()
388
+ logger.info(f'Sorting suggestion:{sorting_suggestion}')
389
+ matches = re.findall(r'(\d+)\.\s*(\d+)\.\s*\((.*?)\)', sorting_suggestion)
390
+
391
+ # Debugging: print out raw matches to verify correctness
392
+ logger.info(f"Matches found: {matches}")
393
+
394
+ for match in matches:
395
+ original_num = int(match[0]) # Original number
396
+ new_num = int(match[1]) # Recommended number
397
+ author = match[2].strip() # Author name
398
+ sorting_order.append((original_num, new_num, author))
399
+ else:
400
+ author = sub_df["Fisrt Author"].values[0]
401
+ sorting_order.append((1, 1, author))
402
+
403
+ # Ensure no duplicate new numbers and correct count
404
+ new_nums = [x[1] for x in sorting_order]
405
+ if len(sorting_order) == paper_num and len(set(new_nums)) == paper_num:
406
+ pass # Sorting succeeded, break the loop
407
+ elif abs(len(sorting_order) - paper_num) <= 2:
408
+ logger.info(f"Warning: Sorting order mismatch, difference of {abs(len(sorting_order) - paper_num)}. Assigning missing positions.")
409
+ existing_sorted_numbers = [x[1] for x in sorting_order]
410
+ missing_numbers = set(range(1, paper_num + 1)) - set(existing_sorted_numbers)
411
+
412
+ for idx, original_num in enumerate(range(1, paper_num + 1)):
413
+ if original_num not in existing_sorted_numbers:
414
+ random_new_num = random.choice(list(missing_numbers))
415
+ sorting_order.append((original_num, random_new_num, "Unknown Author")) # Placeholder author
416
+ missing_numbers.remove(random_new_num)
417
+
418
+ # Sort by recommended number
419
+ sorting_order.sort(key=lambda x: x[1]) # Sort by new number
420
+
421
+ # Extract sorted original indices
422
+ final_sorted_order = [item[0] for item in sorting_order]
423
+
424
+ logger.info(f"Final sorted order: {final_sorted_order}")
425
+
426
+ # Reorder sub_df based on the sorted order
427
+ try:
428
+ sorted_indices = [sub_df[sub_df['Original Index'] == idx].index[0] for idx in final_sorted_order]
429
+ sorted_sub_df = sub_df.loc[sorted_indices].reset_index(drop=True)
430
+ except Exception as e:
431
+ logger.error(f"Error in sorting DataFrame: {e}")
432
+ raise ValueError("Reordering of DataFrame failed.")
433
+
434
+ return sorted_sub_df
435
+
436
+
437
+ # Function to create expanded paragraphs with required reference count and consistent reference indexing
438
+ async def create_paragraphs_by_subheading(
439
+ relevant_papers_df, subheadings, main_topic,
440
+ uuid, customer_name, model_name,
441
+ chat_func
442
+ ):
443
+ """
444
+ Create expanded paragraphs by subheading with required reference count and consistent reference indexing.
445
+
446
+ Args:
447
+ relevant_papers_df (pd.DataFrame): DataFrame containing relevant papers and their summaries.
448
+ subheadings (list): List of subheadings for the review paper.
449
+ main_topic (str): Main topic of the review paper.
450
+ uuid (str): UUID of the task.
451
+ customer_name (str): Name of the customer.
452
+ chat_func (function): Function to send chat messages to the chatbot.
453
+
454
+ Returns:
455
+ list: List of paragraphs with subheadings and consistent reference indexing.
456
+
457
+ """
458
+ paragraphs = []
459
+
460
+ # Reorder relevant_papers_df based on the subheadings order
461
+ subheading_order = {subheading: idx for idx, subheading in enumerate(subheadings)}
462
+ relevant_papers_df['Subheading Order'] = \
463
+ relevant_papers_df['Assigned Subheading'].map(subheading_order)
464
+
465
+ # Remove rows where 'Subheading Order' is NA
466
+ relevant_papers_df = relevant_papers_df.dropna(subset=['Subheading Order'])
467
+
468
+ relevant_papers_df = relevant_papers_df.sort_values(by='Subheading Order')
469
+
470
+ relevant_papers_df.reset_index(drop=True, inplace=True)
471
+ await upload_dataframe_to_minio(
472
+ bucket_name=BUCKET_NAME,
473
+ object_name=f"{customer_name}/{uuid}/{model_name}/relevant_papers_sort.csv",
474
+ df=relevant_papers_df,
475
+ )
476
+
477
+ # Split relevant_papers_df by 'Assigned Subheading' into separate sub-dataframes
478
+ subheading_groups = relevant_papers_df.groupby('Assigned Subheading')
479
+
480
+ sub_dfs = []
481
+ sorted_sub_dataframes = []
482
+ for subheading in subheadings:
483
+ # Check if subheading exists in subheading_groups
484
+ if subheading in subheading_groups.groups:
485
+ sub_df = subheading_groups.get_group(subheading)
486
+ sub_dfs.append(sub_df)
487
+
488
+ sorted_sub_dataframes = await asyncio.gather(
489
+ *(get_sorting_suggestions(subheading, sub_df, chat_func)
490
+ for sub_df in sub_dfs)
491
+ )
492
+
493
+ sorted_sub_dataframes = [x for x in sorted_sub_dataframes if not x.empty]
494
+
495
+ # Concatenate all sorted sub-dataframes and reset index
496
+ if sorted_sub_dataframes:
497
+ final_relevant_papers_df = pd.concat(sorted_sub_dataframes).reset_index(drop=True)
498
+ final_relevant_papers_df.index = final_relevant_papers_df.index + 1 # Start from index 1
499
+ final_relevant_papers_df['ref_index'] = final_relevant_papers_df.index # Add ref_index column
500
+ else:
501
+ logger.error("Error: No valid sub-dataframes to concatenate.")
502
+ final_relevant_papers_df = pd.DataFrame() # Create an empty DataFrame in case of error
503
+
504
+ final_relevant_papers_df = final_relevant_papers_df.drop_duplicates()
505
+ logger.info(final_relevant_papers_df.head())
506
+
507
+ # Introduction
508
+ intro_prompt = (
509
+ f"Write a concise and advanced introductory paragraph for a scientific review paper on '{main_topic}'. "
510
+ "Introduce the topic, its importance, and the scope of the review. The introduction should provide a logical "
511
+ "setup for the following subheadings.\n\n"
512
+ "Output format:\n[Write introduction here]"
513
+ )
514
+ intro_response = await chat_func(intro_prompt)
515
+ intro_paragraph = intro_response.choices[0].message.content.strip()
516
+ paragraphs.append(f"**Introduction**\n{intro_paragraph}\n")
517
+
518
+ used_titles = set()
519
+ summaries_text_by_subheading = {subheading: [] for subheading in subheadings}
520
+ ref_index_map = {}
521
+
522
+ for subheading in subheadings:
523
+ relevant_summaries = final_relevant_papers_df[
524
+ final_relevant_papers_df['Assigned Subheading'] == subheading
525
+ ]
526
+
527
+ for idx, (summary, title, author, pub_date, ref_index) in relevant_summaries[
528
+ ['Summary', 'Title', 'First Author', 'Publication Date', 'ref_index']
529
+ ].iterrows():
530
+ if title in used_titles:
531
+ continue
532
+ used_titles.add(title)
533
+ ref_index_map[title] = ref_index
534
+ summaries_text_by_subheading[subheading].append(
535
+ f"{summary} [Ref: {ref_index}]"
536
+ )
537
+
538
+ logger.info(summaries_text_by_subheading)
539
+ paragraph_prompts = []
540
+ for subheading in subheadings:
541
+ summaries_text = summaries_text_by_subheading[subheading]
542
+
543
+ # Adjust word_size based on the number of summaries
544
+ num_summaries = len(summaries_text)
545
+ if num_summaries < 10:
546
+ word_size = num_summaries * 200 + 200 # If fewer than 10 summaries
547
+ elif num_summaries > 30:
548
+ word_size = num_summaries * 400 + 800 # If more than 20 summaries
549
+ elif num_summaries > 20:
550
+ word_size = num_summaries * 350 + 500 # If more than 20 summaries
551
+ else:
552
+ word_size = num_summaries * 250 + 300 # Otherwise, the default case
553
+
554
+ # Generate the detailed paragraph for the subheading
555
+ paragraph_prompt = (
556
+ # f"Write a {word_size}-word thematically focused and critical paragraph under the subheading '{subheading}' for a scientific review on '{subheading}'. "
557
+ f"Write a {word_size}-word thematically focused and critical paragraph for a scientific review on '{subheading}'. "
558
+ "please do the following:\n"
559
+ "1.Begin the paragraph with 100-word sentences that summarize the main findings and objectives of the following studies, providing a clear context for the discussion.You may supplement this introduction with additional relevant knowledge to enhance understanding."
560
+ "2.Before introducing each piece of literature, you need to come up with a sentence or conjunction that connects the context"
561
+ "3.For each study, provide a overview, analyzing its objectives, methodologies, findings, and broader significance. "
562
+ "Ensure that the analysis of each study is presented in sequence, without skipping any, and maintain a logical flow."
563
+ "4.Relevant literature should be critically discussed, highlighting how it contributes to the field and emphasizing its strengths and limitations. "
564
+ "5.After discussing all studies, provide a concluding paragraph that offers a deep analysis of the collective progress represented by the studies, "
565
+ "identifying overarching trends, advancements, and gaps. Conclude with insightful suggestions for future directions and research areas that need further exploration. "
566
+ "please Meet the following requirements:\n"
567
+ "1.Maintain clear academic language in the style of *Nature*, with a focus on the relationships between studies and their contributions to the subheading's topic. "
568
+ "2.Ensure in-text citations are included in the format [Ref: number], avoid repetition, and provide a critical, objective comparison where relevant. "
569
+ "3.The entire paragraph should be coherent, without empty lines between studies, and flow logically from one point to the next. Each study must be fully represented,with no omission or skipping.\n "
570
+ "4.To prevent the simple stacking of literature, you need to think about how to make the article more readable, logical, and professional."
571
+ # f"Summaries:{' '.join(summaries_text)}"
572
+ f"Summaries:{' '.join(s.strip() for s in summaries_text)}"
573
+ "Output format:[Write paragraph here]"
574
+ )
575
+ paragraph_prompts.append(paragraph_prompt)
576
+
577
+ paragraph_responses = await asyncio.gather(
578
+ *(chat_func(para_prompt)
579
+ for para_prompt in paragraph_prompts)
580
+ )
581
+ for subheading, paragraph_response in \
582
+ zip(subheadings, paragraph_responses):
583
+ paragraph_text = paragraph_response.choices[0].message.content.strip()
584
+ paragraph_text = re.sub(r'\(Ref:\s*(\d+)\)', r'[Ref: \1]', paragraph_text)
585
+ paragraph_text = re.sub(r'\n\s*\n', '\n', paragraph_text)
586
+ paragraph_text = paragraph_text.replace('\n', ' ')
587
+ paragraph = f"**{subheading}**\n{paragraph_text}\n"
588
+ paragraphs.append(paragraph)
589
+
590
+ # Conclusion
591
+ conclusion_prompt = (
592
+ f"Write a concluding paragraph for a scientific review on '{main_topic}'. Summarize the main points discussed in the previous sections, "
593
+ "highlight the significance of the research, and suggest possible future directions or applications.\n\n"
594
+ "Output format:\n[Write conclusion here]"
595
+ )
596
+ conclusion_response = await chat_func(conclusion_prompt)
597
+ conclusion_paragraph = conclusion_response.choices[0].message.content.strip()
598
+ paragraphs.append(f"**Conclusion**\n{conclusion_paragraph}\n")
599
+
600
+ used_references = final_relevant_papers_df[
601
+ ['Title', 'First Author', 'Journal Title','Publication Date', 'ref_index']
602
+ ].sort_values(by='ref_index')
603
+
604
+ # References section (only used references)
605
+ references = "\n".join([
606
+ f"[Ref:{idx}]. {author} et al. {title}{Journal_Title}({pub_date})."
607
+ for idx, (author, title, Journal_Title, pub_date, ref_index)
608
+ in enumerate(used_references[
609
+ ['First Author','Title', 'Journal Title', 'Publication Date', 'ref_index']
610
+ ].values, 1
611
+ )
612
+ ])
613
+ paragraphs.append(f"**References**\n{references}")
614
+
615
+ # Compile paragraphs into final content
616
+ final_content = "\n".join(paragraphs)
617
+
618
+ # Save grouped summaries to CSV with customer_name and current date
619
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
620
+ output_dir = prefix
621
+
622
+ review_file = os.path.join(output_dir, f"review_non_refined.txt")
623
+
624
+ await upload_text_to_minio(
625
+ bucket_name=BUCKET_NAME,
626
+ object_name=review_file,
627
+ file_content=final_content
628
+ )
629
+
630
+ logger.info(f"Non-refined review saved to {review_file}")
631
+ return final_content
632
+
633
+
634
+ # Function to enhance language and readability to meet Nature journal style
635
+ async def enhance_language_readability(content, chat_func):
636
+ """
637
+ Enhance the language and readability of the given content to meet the style of the *Nature* journal.
638
+
639
+ Args:
640
+ content (str): The content to enhance.
641
+ chat_func (function): The function to use for the chat completion.
642
+
643
+ Returns:
644
+ str: The enhanced content.
645
+
646
+ """
647
+ # Separate sections based on paragraph breaks
648
+ sections = content.split("\n\n")
649
+ enhanced_sections = []
650
+ prompts = []
651
+ for section in sections:
652
+ prompt = (
653
+ "Enhance the following text to align with the writing style of *Nature* journal. Refine language to be sophisticated and objective, "
654
+ "using advanced vocabulary and a factual tone. Ensure a high level of lexical diversity and rhythm, with alternating sentence lengths "
655
+ "and varied structures for readability. Avoid emotional, speculative, or conversational language, focusing on objective analysis.\n\n"
656
+ f"Text:\n{section}\n\n"
657
+ "Output format:\n[Enhanced text here]"
658
+ )
659
+ prompts.append(prompt)
660
+
661
+ responses = await asyncio.gather(
662
+ *(chat_func(prompt) for prompt in prompts)
663
+ )
664
+ for response in responses:
665
+ enhanced_section = response.choices[0].message.content.strip()
666
+ enhanced_sections.append(enhanced_section)
667
+
668
+ return "\n\n".join(enhanced_sections)
669
+
670
+
671
+ async def split_by_section(content):
672
+ """
673
+ Split the given content into sections based on paragraph breaks.
674
+
675
+ Args:
676
+ content (str): The content to split.
677
+
678
+ Returns:
679
+ list: The list of sections.
680
+
681
+ """
682
+ # Split the content into sections based on paragraph breaks
683
+ subheading_pattern = r"(?m)^\*\*(.*?)\*\*$"
684
+ matches = list(re.finditer(subheading_pattern, content))
685
+
686
+ sections = []
687
+ references_found = False
688
+ for i, match in enumerate(matches):
689
+ subheading = match.group(1).strip() # Get the subheading text
690
+ if subheading.lower() == "references":
691
+ references_found = True
692
+
693
+ start = match.end() # End of the subheading line
694
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(content)
695
+ paragraph_text = content[start:end].strip()
696
+
697
+ if references_found: # Add everything under "References" as is
698
+ sections.append((subheading, paragraph_text))
699
+ break # Stop further processing
700
+
701
+ sections.append((subheading, paragraph_text))
702
+
703
+ return sections
704
+
705
+
706
+ async def process_sections(sections, chat_func):
707
+ """
708
+ Processes each section (subheading and corresponding text) through the AI model.
709
+ Skips processing the "Introduction", "Conclusion", and "References" sections.
710
+ """
711
+ refined_sections = []
712
+ seen_subheadings = set()
713
+ skip_subheadings = {"introduction", "conclusion", "references"} # Sections to skip
714
+
715
+ prompts = []
716
+ for idx, (subheading, text) in enumerate(sections):
717
+ subheading_clean = subheading.strip("*").strip()
718
+ logger.info(f"Processing section {idx + 1} of {len(sections)}: {subheading_clean}")
719
+
720
+ if subheading_clean.lower() in skip_subheadings:
721
+ logger.info(f"Skipping '{subheading_clean}' section.")
722
+ # refined_sections.append((subheading, text)) # Keep these sections as is
723
+ continue
724
+
725
+ if subheading_clean in seen_subheadings:
726
+ logger.info(f"Duplicate subheading detected: {subheading_clean}. Skipping.")
727
+ continue
728
+
729
+ seen_subheadings.add(subheading_clean)
730
+ if text.strip(): # Skip empty sections
731
+ # Remove extra newlines and ensure no empty lines in the text
732
+ text = re.sub(r'\n\s*\n', ' ', text) # Replace multiple newlines with a single space
733
+ text = text.replace('\n', ' ') # Replace remaining newlines with spaces
734
+ text = re.sub(r'\s+', ' ', text).strip() # Ensure no extra spaces
735
+
736
+ # Updated prompt for higher review quality
737
+ prompt = textwrap.dedent(f"""
738
+ Your task is to refine the following academic section for clarity, depth, and suitability for publication in a high-impact journal.
739
+
740
+ Please adhere to these guidelines:
741
+
742
+ **1. Structure and Organization:**
743
+ - Identify and emphasize key themes or topics within the section.
744
+ - Group related studies together to enhance coherence and logical flow.
745
+ - Reorganize the content to ensure a clear progression of ideas.
746
+ - Use smooth transitions to connect paragraphs and concepts without relying on explicit subheadings.
747
+
748
+ **2. Integration and Analysis of Literature:**
749
+ - Synthesize findings from cited studies, highlighting connections, similarities, and differences.
750
+ - Avoid merely listing studies; focus on comparative analysis and critical evaluation.
751
+ - Highlight significant contributions, novel findings, or implications of each study.
752
+ - Discuss any controversies, differing perspectives, or gaps in the current research.
753
+
754
+ **3. Depth and Critical Insight:**
755
+ - Deepen analytical insights by going beyond surface-level summarization.
756
+ - Provide critical evaluations, discussing strengths, limitations, and areas needing further exploration.
757
+ - Highlight the significance of trends or shifts in the field.
758
+
759
+ **4. Language and Clarity:**
760
+ - Use precise and concise language appropriate for an academic audience.
761
+ - Vary sentence structures to enhance readability and engagement.
762
+ - Eliminate redundant or repetitive statements to streamline the content.
763
+ - Maintain a formal academic tone while ensuring the text is accessible.
764
+
765
+ **5. Consistency and Terminology:**
766
+ - Ensure consistency in terminology, style, and formatting throughout the section.
767
+ - Use technical terms accurately and define specialized terms if necessary.
768
+ - Avoid unnecessary acronyms unless commonly understood in the field.
769
+
770
+ **6. Accuracy and Detail:**
771
+ - Verify that descriptions of studies are accurate and that key findings are correctly represented.
772
+ - Emphasize the most relevant and impactful information from each study.
773
+ - Provide context where needed to aid understanding for a multidisciplinary audience.
774
+
775
+ **7. Conclusion and Future Directions:**
776
+ - Summarize main points and discuss how findings align or diverge from prior work.
777
+ - Suggest areas for future research based on identified gaps or limitations.
778
+ - Discuss practical implications or potential applications if relevant.
779
+
780
+ **8. Citation and Formatting:**
781
+ - Ensure citations are formatted accurately (e.g., [Ref: number]) and integrated smoothly into the text.
782
+ - Do not alter the "References" section or the citation order.
783
+ - Maintain the existing citation positions within the text.
784
+
785
+ **Section to refine:**
786
+ {text}
787
+ """)
788
+
789
+ prompts.append(prompt)
790
+
791
+ # Call the AI model with the updated prompt
792
+ index = 0
793
+ refined_texts = await asyncio.gather(
794
+ *(chat_func(prompt) for prompt in prompts)
795
+ )
796
+
797
+ logger.info(len(refined_texts))
798
+ logger.info(len(prompts))
799
+
800
+ seen_subheadings = set()
801
+ for idx, (subheading, text) in enumerate(sections):
802
+ subheading_clean = subheading.strip("*").strip()
803
+ logger.info(f"Processing section {idx + 1} of {len(sections)}: {subheading_clean}")
804
+
805
+ if subheading_clean.lower() in skip_subheadings:
806
+ refined_sections.append((subheading, text))
807
+ continue
808
+
809
+ if subheading_clean in seen_subheadings:
810
+ logger.info(f"Duplicate subheading detected: {subheading_clean}. Skipping.")
811
+ continue
812
+
813
+ seen_subheadings.add(subheading_clean)
814
+ if text.strip():
815
+ refined_text = refined_texts[index].choices[0].message.content.strip()
816
+ refined_text = re.sub(r'\n\s*\n', ' ', refined_text) # Replace extra newlines with a single space
817
+ refined_text = refined_text.replace('\n', ' ') # Replace remaining newlines with spaces
818
+ refined_text = re.sub(r'\s+', ' ', refined_text).strip() # Ensure no extra spaces
819
+ refined_sections.append((subheading, refined_text))
820
+ index += 1
821
+
822
+ return refined_sections
823
+
824
+
825
+ async def process_papers(
826
+ dataframe, topic, direction,
827
+ uuid, customer_name, model_name,
828
+ chat_func
829
+ ):
830
+ """
831
+ Process the given papers to extract relevant information and save it to a CSV file.
832
+
833
+ Args:
834
+ dataframe (pandas.DataFrame): The DataFrame containing the papers.
835
+ topic (str): The topic to filter the papers by.
836
+ direction (str): The direction to filter the papers by.
837
+ uuid (str): The UUID of the task.
838
+ customer_name (str): The name of the customer.
839
+ chat_func (function): The function to use for the chat completion.
840
+
841
+ Returns:
842
+ pandas.DataFrame: The DataFrame containing the relevant papers.
843
+
844
+ """
845
+ # Duplicate, no need
846
+ # relevant_rows = [] # List to collect relevant rows for DataFrame creation
847
+
848
+ # Set up the output directory and CSV file
849
+ # output_dir = os.path.join(customer_name)
850
+ # os.makedirs(output_dir, exist_ok=True)
851
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
852
+ output_dir = prefix
853
+
854
+ output_path = os.path.join(output_dir, "relevant_papers.csv")
855
+
856
+ # Create or clear the output file at the beginning
857
+ # with open(output_path, 'w', newline='', encoding='utf-8') as f:
858
+ # writer = csv.writer(f, quoting=csv.QUOTE_ALL)
859
+ # writer.writerow(["Journal Title", "Publication Date", "Title", "First Author", "Summary", "Is Relevant", "Relevance Keywords"]) # Writing header
860
+ texts = ""
861
+ fieldnames = ["Journal Title", "Publication Date", "Title",
862
+ "First Author", "Summary", "Is Relevant", "Relevance Keywords"]
863
+ texts += ",".join([escape_csv_field(x) for x in fieldnames]) + "\n"
864
+
865
+ titles = []
866
+ abstracts = []
867
+ journal_titles = []
868
+ pubd_dates = []
869
+ first_authors = []
870
+ summaries = []
871
+ for idx, row in dataframe.iterrows():
872
+ title = row["TI"]
873
+ abstract = row["AB"]
874
+ journal_title = row["JT"]
875
+ pub_date = row["DCOM"]
876
+ first_author = row["FAU-frist"]
877
+
878
+ titles.append(title)
879
+ abstracts.append(abstract)
880
+ journal_titles.append(journal_title)
881
+ pubd_dates.append(pub_date)
882
+ first_authors.append(first_author)
883
+
884
+ relevants = await asyncio.gather(
885
+ *(is_relevant(
886
+ title, abstract, topic, direction, chat_func
887
+ ) for title, abstract in zip(titles, abstracts))
888
+ )
889
+
890
+ is_relevant_flags = [relevant[0] for relevant in relevants]
891
+ relevance_keywords = [relevant[1] for relevant in relevants]
892
+
893
+ rtitles = []
894
+ rabstracts = []
895
+ rjournal_titles = []
896
+ rpubd_dates = []
897
+ rfirst_authors = []
898
+ rflags = []
899
+ rkeywords = []
900
+
901
+ for (
902
+ rflag, rkeyword, title, abstarct, first_author, journal_title, pub_date
903
+ ) in zip(
904
+ is_relevant_flags, relevance_keywords,
905
+ titles, abstracts, first_authors, journal_titles, pubd_dates
906
+ ):
907
+ if rflag:
908
+ rtitles.append(title)
909
+ rabstracts.append(abstarct)
910
+ rfirst_authors.append(first_author)
911
+ rjournal_titles.append(journal_title)
912
+ rpubd_dates.append(pub_date)
913
+ rflags.append(rflag)
914
+ rkeywords.append(rkeyword)
915
+
916
+ summaries = await asyncio.gather(
917
+ *(summarize_abstract(
918
+ title, abstract, first_author, chat_func
919
+ ) for title, abstract, first_author in
920
+ zip(rtitles, rabstracts, rfirst_authors)
921
+ )
922
+ )
923
+
924
+ for (
925
+ summary,
926
+ journal_title, pub_date, title, first_author,
927
+ rflag, rkeyword
928
+ ) in zip(
929
+ summaries,
930
+ rjournal_titles, rpubd_dates, rtitles, rfirst_authors,
931
+ rflags, rkeywords
932
+ ):
933
+ journal_title = escape_csv_field(journal_title)
934
+ pub_date = escape_csv_field(pub_date)
935
+ title = escape_csv_field(title)
936
+ first_author = escape_csv_field(first_author)
937
+ summary = escape_csv_field(summary)
938
+ rkeyword = escape_csv_field(rkeyword)
939
+
940
+ texts += ",".join([
941
+ str(x) for x in [
942
+ journal_title, pub_date, title, first_author,
943
+ summary, rflag, rkeyword
944
+ ]
945
+ ]) + "\n"
946
+
947
+ # Print the added summary and keywords
948
+ logger.info(f"Added summary: {summary}")
949
+ logger.info(f"Relevance Keywords: {rkeyword}")
950
+
951
+ # Create the relevant DataFrame to return
952
+ # relevant_df = pd.DataFrame(relevant_rows)
953
+ # return relevant_df
954
+ await upload_text_to_minio(
955
+ bucket_name=BUCKET_NAME,
956
+ object_name=output_path,
957
+ file_content=texts
958
+ )
959
+
960
+ return output_path
961
+
962
+
963
+ async def translate_to_chinese_before_references(
964
+ text,
965
+ uuid, customer_name, model_name,
966
+ chat_func
967
+ ):
968
+ """
969
+ Translates the content of a text file to Chinese, keeping the '**References**' section in English.
970
+
971
+ Args:
972
+ text (str): The content of the text file.
973
+ output_filename (str): The name of the output file.
974
+ chat_func (function): The function to use for translation.
975
+
976
+ Returns:
977
+ str: The translated content.
978
+
979
+ """
980
+ lines = text.split("\n")
981
+
982
+ # Step 3: 找到 '**References**' 行的索引
983
+ references_index = None
984
+ for i, line in enumerate(lines):
985
+ if line.strip() == "**References**":
986
+ references_index = i
987
+ break
988
+
989
+ # Step 4: 根据找到的索引分割内容
990
+ if references_index is not None:
991
+ main_content_lines = lines[:references_index]
992
+ references_content_lines = lines[references_index:]
993
+ else:
994
+ # 如果没有找到 '**References**',则认为整个内容为正文
995
+ main_content_lines = lines
996
+ references_content_lines = []
997
+
998
+ # 将正文内容拼接为一个字符串
999
+ main_content = "\n".join(main_content_lines)
1000
+
1001
+ # Step 5: 分段处理正文内容进行翻译
1002
+ sections = main_content.split("\n\n")
1003
+ translated_sections = []
1004
+
1005
+ prompts = []
1006
+
1007
+ for section in sections:
1008
+ # 简化 prompt,只要求翻译正文内容
1009
+ prompt = (
1010
+ "Translate the following text to academic Chinese:\n\n"
1011
+ f"Text:\n{section}\n\n"
1012
+ "Output format:\n[Translated Chinese text here]"
1013
+ )
1014
+ prompts.append(prompt)
1015
+
1016
+ responses = await asyncio.gather(
1017
+ *(chat_func(prompt) for prompt in prompts)
1018
+ )
1019
+ for response in responses:
1020
+ translated_section = response.choices[0].message.content.strip()
1021
+ translated_sections.append(translated_section)
1022
+
1023
+ # Step 6: 将翻译后的正文拼接
1024
+ translated_content = "\n\n".join(translated_sections)
1025
+
1026
+ # Step 7: 合并翻译后的正文和 References 部分
1027
+ if references_content_lines:
1028
+ references_content = "\n".join(references_content_lines)
1029
+ final_content = translated_content + "\n\n" + references_content
1030
+ else:
1031
+ final_content = translated_content
1032
+
1033
+ # Step 8: 保存结果到新的文件
1034
+ output_filename = f"{customer_name}/{uuid}/{model_name}/review_non_refined_translated.txt"
1035
+ await upload_text_to_minio(
1036
+ bucket_name=BUCKET_NAME,
1037
+ object_name=output_filename,
1038
+ file_content=final_content
1039
+ )
1040
+
1041
+ logger.info(f"\nTranslated content saved to {output_filename}")
1042
+
1043
+
1044
+ async def translate_refined_review_to_chinese(
1045
+ refined_review_content,
1046
+ uuid, customer_name, model_name,
1047
+ chat_func
1048
+ ):
1049
+
1050
+ # Read the Word document
1051
+ doc = Document(refined_review_content)
1052
+
1053
+ # Prepare to create a new document for the translated content
1054
+ translated_doc = Document()
1055
+
1056
+ # Set of subheadings to skip translation
1057
+ skip_subheadings = {"references"}
1058
+
1059
+ # Keep track of the current section heading
1060
+ current_heading = None
1061
+ in_references_section = False
1062
+
1063
+ prompts = []
1064
+ for para in doc.paragraphs:
1065
+ # Check if the paragraph is a heading
1066
+ if para.style.name.startswith('Heading'):
1067
+ # Get the heading text
1068
+ current_heading = para.text.strip()
1069
+ # Get the heading level
1070
+ heading_level_match = re.findall(r'\d+', para.style.name)
1071
+ heading_level = int(heading_level_match[0]) if heading_level_match else 1
1072
+
1073
+ # Check if the heading text is in skip_subheadings
1074
+ if current_heading.lower() in skip_subheadings:
1075
+ in_references_section = True
1076
+ # Add the heading as is
1077
+ # translated_doc.add_heading(current_heading, level=heading_level)
1078
+ else:
1079
+ in_references_section = False
1080
+ # Translate the heading
1081
+ prompt = f"Translate the following heading to Chinese:\n\n{current_heading}"
1082
+ prompts.append(prompt)
1083
+ # translated_heading = chat_func(prompt)
1084
+ # Add the translated heading
1085
+ # translated_doc.add_heading(translated_heading, level=heading_level)
1086
+ else:
1087
+ if in_references_section:
1088
+ # Add the paragraph as is
1089
+ # translated_doc.add_paragraph(para.text)
1090
+ pass
1091
+ else:
1092
+ # Translate the paragraph text to Chinese, preserving in-text citations
1093
+ text_to_translate = para.text
1094
+ if text_to_translate.strip() == '':
1095
+ # If the paragraph is empty, skip translation
1096
+ translated_doc.add_paragraph('')
1097
+ else:
1098
+ # We need to preserve in-text citations, e.g., [Ref: 38]
1099
+ # Instruct the AI to keep the in-text citations in English
1100
+ prompt = f"""
1101
+ Translate the following text to academic Chinese. Keep any in-text citations (e.g., [Ref: number]) in English.
1102
+
1103
+ Text:
1104
+ {text_to_translate}
1105
+ """
1106
+ prompts.append(prompt)
1107
+
1108
+ translated_texts = await asyncio.gather(
1109
+ *(chat_func(prompt) for prompt in prompts)
1110
+ )
1111
+ translated_texts = [
1112
+ t.choices[0].message.content.strip() for t in translated_texts
1113
+ ]
1114
+
1115
+ index = 0
1116
+ for para in doc.paragraphs:
1117
+ # Check if the paragraph is a heading
1118
+ if para.style.name.startswith('Heading'):
1119
+ # Get the heading text
1120
+ current_heading = para.text.strip()
1121
+ # Get the heading level
1122
+ heading_level_match = re.findall(r'\d+', para.style.name)
1123
+ heading_level = int(heading_level_match[0]) if heading_level_match else 1
1124
+
1125
+ # Check if the heading text is in skip_subheadings
1126
+ if current_heading.lower() in skip_subheadings:
1127
+ in_references_section = True
1128
+ # Add the heading as is
1129
+ translated_doc.add_heading(current_heading, level=heading_level)
1130
+ else:
1131
+ in_references_section = False
1132
+ translated_doc.add_heading(translated_texts[index], level=heading_level)
1133
+ index += 1
1134
+ else:
1135
+ if in_references_section:
1136
+ # Add the paragraph as is
1137
+ translated_doc.add_paragraph(para.text)
1138
+ else:
1139
+ # Translate the paragraph text to Chinese, preserving in-text citations
1140
+ text_to_translate = para.text
1141
+ if text_to_translate.strip() == '':
1142
+ # If the paragraph is empty, skip translation
1143
+ translated_doc.add_paragraph('')
1144
+ else:
1145
+ translated_text = translated_texts[index]
1146
+ translated_doc.add_paragraph(translated_text)
1147
+ index += 1
1148
+
1149
+ output_file_path = f"{customer_name}/{uuid}/{model_name}/review_paper_refined_translated.docx"
1150
+ await upload_document_to_minio(
1151
+ bucket_name=BUCKET_NAME,
1152
+ object_name=output_file_path,
1153
+ document=translated_doc
1154
+ )
1155
+ return output_file_path
1156
+
1157
+
1158
+ async def refine_review_content(
1159
+ non_refine_content,
1160
+ uuid, customer_name, model_name,
1161
+ chat_func
1162
+ ):
1163
+ sections = await split_by_section(non_refine_content)
1164
+ refined_sections = await process_sections(sections, chat_func)
1165
+
1166
+ prompt_title = f"""
1167
+ Based on the following literature review, generate an appropriate and concise title:
1168
+ {non_refine_content}
1169
+ """
1170
+ title = await chat_func(prompt_title)
1171
+ title = title.choices[0].message.content.strip()
1172
+ logger.info(f"Generated Title: {title}")
1173
+
1174
+ doc = Document()
1175
+ doc.add_heading(title, level=1)
1176
+
1177
+ for subheading, content in refined_sections:
1178
+ doc.add_heading(subheading, level=2)
1179
+ doc.add_paragraph(content)
1180
+
1181
+ output_file = f"{customer_name}/{uuid}/{model_name}/review_paper_refined.docx"
1182
+ await upload_document_to_minio(
1183
+ bucket_name=BUCKET_NAME,
1184
+ object_name=output_file,
1185
+ document=doc
1186
+ )
1187
+ return output_file
1188
+
1189
+
1190
+ # Main function to automate the review paper creation process with language enhancement step
1191
+ async def create_review_paper(
1192
+ relevant_papers_df,
1193
+ main_topic,
1194
+ uuid, customer_name, model_name,
1195
+ chat_func,
1196
+ translate_to_cn=False,
1197
+ do_refine=False,
1198
+ ):
1199
+ """
1200
+ Main function to automate the review paper creation process with language enhancement step.
1201
+
1202
+ Args:
1203
+ relevant_papers_df (pd.DataFrame): DataFrame containing relevant papers.
1204
+ main_topic (str): Main topic of the review paper.
1205
+ uuid (str): Unique identifier for the review paper.
1206
+ customer_name (str): Name of the customer.
1207
+ chat_func (function): Function to handle chat interactions.
1208
+ translate_to_cn (bool): Flag to indicate if translation to Chinese is required.
1209
+
1210
+ Returns:
1211
+ None
1212
+
1213
+ """
1214
+
1215
+ # Step 1: Generate subheadings related to the main topic
1216
+ subheadings = await generate_subheadings(
1217
+ relevant_papers_df, main_topic,
1218
+ chat_func
1219
+ )
1220
+
1221
+ # Step 2: Assign each summary to a subheading
1222
+ relevant_papers_df = await assign_subheadings_to_summaries(
1223
+ relevant_papers_df, subheadings,
1224
+ uuid, customer_name, model_name,
1225
+ chat_func
1226
+ )
1227
+
1228
+ # Step 3: Create paragraphs by subheading, with introductory and concluding sections, and references
1229
+ review_content = await create_paragraphs_by_subheading(
1230
+ relevant_papers_df, subheadings, main_topic,
1231
+ uuid, customer_name, model_name,
1232
+ chat_func
1233
+ )
1234
+
1235
+ output_filename = f"{customer_name}/{uuid}/{model_name}/review_non_refined.txt"
1236
+
1237
+ if do_refine:
1238
+ # Step 4: Refine Review Content to a Word Document
1239
+ await refine_review_content(
1240
+ review_content,
1241
+ uuid, customer_name, model_name,
1242
+ chat_func
1243
+ )
1244
+ refined_review_content = await get_file_from_minio(
1245
+ bucket_name=BUCKET_NAME,
1246
+ object_name=f"{customer_name}/{uuid}/{model_name}/review_paper_refined.docx",
1247
+ )
1248
+ refined_review_content = io.BytesIO(refined_review_content.data)
1249
+
1250
+ if translate_to_cn:
1251
+ if do_refine:
1252
+ await translate_refined_review_to_chinese(
1253
+ refined_review_content,
1254
+ uuid, customer_name, model_name,
1255
+ chat_func
1256
+ )
1257
+ output_filename = f"{customer_name}/{uuid}/{model_name}/review_paper_refined_translated.txt"
1258
+ else:
1259
+ await translate_to_chinese_before_references(
1260
+ review_content,
1261
+ uuid, customer_name, model_name,
1262
+ chat_func
1263
+ )
1264
+ output_filename = f"{customer_name}/{uuid}/{model_name}/review_non_refined_translated.txt"
1265
+ return output_filename
utils/paper_utils.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import asyncio
4
+
5
+ from loguru import logger
6
+
7
+ from .minio_utils import (
8
+ upload_text_to_minio,
9
+ upload_dataframe_to_minio,
10
+ )
11
+ from .common_utils import escape_csv_field
12
+
13
+
14
+ BUCKET_NAME = "ai-scientist"
15
+
16
+
17
+ # Function to check relevance and obtain keywords as reason
18
+ async def is_relevant(title, abstract, topic, direction, chat_func):
19
+ """
20
+ Check if a paper is relevant to a topic and obtain keywords as reason.
21
+
22
+ Args:
23
+ title (str): Title of the paper.
24
+ abstract (str): Abstract of the paper.
25
+ topic (str): Topic to check relevance against.
26
+ direction (str): Direction to check relevance against.
27
+ chat_func (function): Function to call the chat model.
28
+
29
+ Returns:
30
+ bool: True if the paper is relevant, False otherwise.
31
+ str: Keywords that indicate relevance.
32
+
33
+ """
34
+ relevance_prompt = (
35
+ f"You are an academic expert in {topic}. Identify if the following paper is "
36
+ f"related to '{direction}' and list only the main keywords that indicate relevance:\n\n"
37
+ f"Title: {title}\nAbstract: {abstract}\n\n"
38
+ "Answer format:\n"
39
+ "Relevance: True or False\n"
40
+ "Keywords: [Comma-separated keywords]"
41
+ )
42
+ response = await chat_func(relevance_prompt)
43
+ if response is None:
44
+ return False, "Relevance check unavailable due to server error."
45
+
46
+ try:
47
+ response_text = response.choices[0].message.content
48
+ relevance = "True" in response_text
49
+ keywords = response_text.split(
50
+ "Keywords:")[-1].strip() if "Keywords:" in response_text else ""
51
+ return relevance, keywords
52
+ except AttributeError:
53
+ logger.error("Error in chat_func response format:", response)
54
+ return False, "Relevance check failed"
55
+
56
+
57
+ # Modified summarize_abstract function with error handling for failed completion requests
58
+ async def summarize_abstract(title, abstract, first_author, chat_func):
59
+ """
60
+ Summarize the abstract of a research paper.
61
+
62
+ Args:
63
+ title (str): Title of the paper.
64
+ abstract (str): Abstract of the paper.
65
+ first_author (str): Name of the first author.
66
+ chat_func (function): Function to call the chat model.
67
+
68
+ Returns:
69
+ str: Summary of the abstract.
70
+
71
+ """
72
+ formatted_author = reformat_author_name(first_author)
73
+ summary_prompt = (
74
+ f"Write a concise, high-level summary in 2-3 sentences, highlighting the study's "
75
+ f"purpose, specific methodology, main findings, and significance. Avoid generalizing "
76
+ f"or replacing specific method names or entities with vague language. Retain concrete terms "
77
+ f"and clear descriptions of methodology and findings.\n\n"
78
+ f"Title: {title}\nAbstract: {abstract}\n\n"
79
+ f"Summary by {formatted_author} et al.:"
80
+ )
81
+
82
+ response = await chat_func(summary_prompt)
83
+ if response is None:
84
+ return "Summary unavailable due to server error."
85
+
86
+ try:
87
+ result = response.choices[0].message.content
88
+ result_words = result.split()
89
+ summary = " ".join(result_words)
90
+ return summary
91
+ except AttributeError:
92
+ logger.error("Error in chat_func response format:", response)
93
+ return "Summary unavailable"
94
+
95
+
96
+ # Function to reformat first author name
97
+ def reformat_author_name(author_name):
98
+ """
99
+ Reformat the first author name by removing commas.
100
+
101
+ Args:
102
+ author_name (str): Name of the first author.
103
+
104
+ Returns:
105
+ str: Reformatted name of the first author.
106
+
107
+ """
108
+ try:
109
+ return author_name.replace(",", "")
110
+ except AttributeError:
111
+ return "Unknown Author"
112
+
113
+
114
+ # Function to generate 3-5 hierarchical subheadings related to the main topic
115
+ async def generate_subheadings(
116
+ relevant_papers_df, main_topic,
117
+ uuid, customer_name, model_name,
118
+ chat_func
119
+ ):
120
+ """
121
+ Generate 3-5 hierarchical subheadings related to the main topic based on the summaries of relevant papers.
122
+
123
+ Args:
124
+ relevant_papers_df: DataFrame containing relevant papers.
125
+ main_topic: Main topic of the research.
126
+ chat_func: Function to send chat messages to the chatbot.
127
+
128
+ Returns:
129
+ List[str]: List of generated subheadings.
130
+
131
+ """
132
+ summaries = " ".join(relevant_papers_df['Summary'].tolist())
133
+ prompt = (
134
+ f"The main topic is '{main_topic}'. Based on this topic and the following summaries from relevant research papers, "
135
+ "generate 3-5 hierarchical subheadings that progressively explore the topic. Begin with broader subheadings and "
136
+ "move towards more specific themes, avoiding overlap in scope or content. Subheadings should be distinct and arranged "
137
+ "in a logical order suitable for a structured review.\n\n"
138
+ f"Summaries:\n{summaries}\n\n"
139
+ "Output format:\n- Subheading 1\n- Subheading 2\n- Subheading 3\n..."
140
+ )
141
+ response = await chat_func(prompt)
142
+ subheadings = response.choices[0].message.content.strip().splitlines()
143
+ logger.info("Generated Subheadings:\n" + "\n".join(subheadings))
144
+
145
+ output_filename = f"{customer_name}/{uuid}/{model_name}/generated_subheadings.txt"
146
+ await upload_text_to_minio(
147
+ bucket_name=BUCKET_NAME,
148
+ object_name=output_filename,
149
+ file_content="\n".join(subheadings)
150
+ )
151
+ logger.info(f"Subheadings saved to {output_filename}")
152
+ return subheadings
153
+
154
+
155
+ # Function to assign summaries to subheadings with minimum allocation of references per subheading
156
+ async def assign_subheadings_to_summaries(
157
+ relevant_papers_df,
158
+ subheadings,
159
+ uuid, customer_name, model_name,
160
+ chat_func
161
+ ):
162
+ """
163
+ Assign summaries to subheadings with minimum allocation of references per subheading.
164
+
165
+ Args:
166
+ relevant_papers_df: DataFrame containing relevant papers.
167
+ subheadings: List of subheadings.
168
+ uuid: Unique identifier for the task.
169
+ customer_name: Name of the customer.
170
+ chat_func: Function to send chat messages to the chatbot.
171
+
172
+ Returns:
173
+ DataFrame with assigned subheadings.
174
+
175
+ """
176
+ total_papers = len(relevant_papers_df)
177
+ min_papers_per_subheading = math.ceil(
178
+ total_papers / (len(subheadings) + 1))
179
+
180
+ assigned_subheadings = []
181
+ prompts = []
182
+ for summary in relevant_papers_df['Summary']:
183
+ prompt = (
184
+ "Given the following subheadings and a research paper summary, determine the most appropriate subheading "
185
+ "for this summary. Each subheading should cover a unique aspect of the main topic without overlap. "
186
+ "Select the best-fitting subheading based on thematic relevance and coherence with similar studies.\n\n"
187
+ f"Subheadings:\n{subheadings}\n\n"
188
+ f"Summary:\n{summary}\n\n"
189
+ "Output format:\nSubheading: [Chosen subheading]"
190
+ )
191
+ prompts.append(prompt)
192
+ responses = await asyncio.gather(
193
+ *(chat_func(prompt) for prompt in prompts)
194
+ )
195
+ for response in responses:
196
+ assigned_subheading = response.choices[0].message.content.split(": ")[1]
197
+ assigned_subheadings.append(assigned_subheading)
198
+
199
+ relevant_papers_df['Assigned Subheading'] = assigned_subheadings
200
+
201
+ # Ensure minimum papers per subheading
202
+ counts = relevant_papers_df['Assigned Subheading'].value_counts().to_dict()
203
+ for subheading in subheadings:
204
+ if counts.get(subheading, 0) < min_papers_per_subheading:
205
+ extra_summaries = relevant_papers_df[relevant_papers_df['Assigned Subheading'] != subheading].sample(
206
+ min_papers_per_subheading - counts.get(subheading, 0)
207
+ )
208
+ relevant_papers_df.loc[extra_summaries.index,
209
+ 'Assigned Subheading'] = subheading
210
+
211
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
212
+ output_dir = prefix
213
+
214
+ csv_filename = os.path.join(output_dir, f"assigned_subheadings.csv")
215
+
216
+ # relevant_papers_df.to_csv(csv_filename, index=False, encoding='utf-8')
217
+ await upload_dataframe_to_minio(
218
+ bucket_name=BUCKET_NAME,
219
+ object_name=csv_filename,
220
+ df=relevant_papers_df,
221
+ )
222
+
223
+ logger.info(f"Assigned subheadings saved to {csv_filename}")
224
+ logger.info(f"Found {len(relevant_papers_df)} related papers")
225
+
226
+ return relevant_papers_df
227
+
228
+
229
+ # Function to create expanded paragraphs with required reference count and consistent reference indexing
230
+ async def create_paragraphs_by_subheading(
231
+ relevant_papers_df, subheadings, main_topic,
232
+ uuid, customer_name, model_name,
233
+ chat_func
234
+ ):
235
+ """
236
+ Create expanded paragraphs by subheading with required reference count and consistent reference indexing.
237
+
238
+ Args:
239
+ relevant_papers_df (pd.DataFrame): DataFrame containing relevant papers and their summaries.
240
+ subheadings (list): List of subheadings for the review paper.
241
+ main_topic (str): Main topic of the review paper.
242
+ uuid (str): UUID of the task.
243
+ customer_name (str): Name of the customer.
244
+ chat_func (function): Function to send chat messages to the chatbot.
245
+
246
+ Returns:
247
+ list: List of paragraphs with subheadings and consistent reference indexing.
248
+
249
+ """
250
+ paragraphs = []
251
+
252
+ # Introduction
253
+ intro_prompt = (
254
+ f"Write a concise and advanced introductory paragraph for a scientific review paper on '{main_topic}'. "
255
+ "Introduce the topic, its importance, and the scope of the review. The introduction should provide a logical "
256
+ "setup for the following subheadings.\n\n"
257
+ "Output format:\n[Write introduction here]"
258
+ )
259
+ intro_response = await chat_func(intro_prompt)
260
+ intro_paragraph = intro_response.choices[0].message.content.strip()
261
+ paragraphs.append(f"**Introduction**\n{intro_paragraph}\n")
262
+
263
+ # Body paragraphs based on subheadings with consistent reference numbering
264
+ reference_map = {}
265
+ used_references = []
266
+ total_papers = len(relevant_papers_df)
267
+ min_papers_per_subheading = math.ceil(
268
+ total_papers / (len(subheadings) + 1))
269
+ ref_counter = 1
270
+
271
+ paragraph_prompts = []
272
+ for subheading in subheadings:
273
+ relevant_summaries = relevant_papers_df[relevant_papers_df['Assigned Subheading'] == subheading]
274
+
275
+ new_references = []
276
+ summaries_text = []
277
+ for idx, (summary, title, author, pub_date) in relevant_summaries[['Summary', 'Title', 'First Author', 'Publication Date']].iterrows():
278
+ if title not in reference_map:
279
+ reference_map[title] = ref_counter
280
+ ref_counter += 1
281
+ ref_index = reference_map[title]
282
+ summaries_text.append(f"{summary} [Ref: {ref_index}]")
283
+ new_references.append((title, author, pub_date))
284
+
285
+ # Compose prompt to generate an extended paragraph with at least 800 words
286
+ paragraph_prompt = (
287
+ f"Write an 800-word thematic and critical paragraph under the subheading '{subheading}' for a scientific review on '{main_topic}'. "
288
+ f"Combine the following summaries into a coherent, well-structured paragraph discussing the studies’ objectives, findings, "
289
+ "and methodologies. Use advanced academic language, include in-text citations in the format [Ref: number], and avoid repeating "
290
+ "content from previous sections. Provide critical insights and comparative analysis where relevant.\n\n"
291
+ f"Summaries:\n{' '.join(summaries_text)}\n\n"
292
+ "Output format:\n[Write paragraph here]"
293
+ )
294
+
295
+ paragraph_prompts.append(paragraph_prompt)
296
+ used_references.extend(new_references)
297
+
298
+ paragraph_responses = await asyncio.gather(
299
+ *(chat_func(para_prompt)
300
+ for para_prompt in paragraph_prompts)
301
+ )
302
+ for subheading, paragraph_response in \
303
+ zip(subheadings, paragraph_responses):
304
+ paragraph = f"**{subheading}**\n{paragraph_response.choices[0].message.content.strip()}\n"
305
+ paragraphs.append(paragraph)
306
+
307
+ # Conclusion
308
+ conclusion_prompt = (
309
+ f"Write a concluding paragraph for a scientific review on '{main_topic}'. Summarize the main points discussed in the previous sections, "
310
+ "highlight the significance of the research, and suggest possible future directions or applications.\n\n"
311
+ "Output format:\n[Write conclusion here]"
312
+ )
313
+ conclusion_response = await chat_func(conclusion_prompt)
314
+ conclusion_paragraph = conclusion_response.choices[0].message.content.strip()
315
+ paragraphs.append(f"**Conclusion**\n{conclusion_paragraph}\n")
316
+
317
+ # References section (only used references)
318
+ references = "\n".join(
319
+ [f"[Ref: {reference_map[title]}] {title}, {author}, {pub_date}"
320
+ for title, author, pub_date in used_references]
321
+ )
322
+ paragraphs.append(f"**References**\n{references}")
323
+
324
+ # Compile paragraphs into final content
325
+ final_content = "\n\n".join(paragraphs)
326
+
327
+ # Save grouped summaries to CSV with customer_name and current date
328
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
329
+ output_dir = prefix
330
+
331
+ csv_filename = os.path.join(output_dir, f"grouped_summaries.csv")
332
+ output_filename = os.path.join(output_dir, f"review_non_refined.txt")
333
+ # Prepare data for CSV
334
+ grouped_data = relevant_papers_df[['Assigned Subheading', 'Summary']]
335
+ # grouped_data.to_csv(csv_filename, index=False, encoding='utf-8')
336
+ await upload_dataframe_to_minio(
337
+ bucket_name=BUCKET_NAME,
338
+ object_name=csv_filename,
339
+ df=grouped_data
340
+ )
341
+
342
+ await upload_text_to_minio(
343
+ bucket_name=BUCKET_NAME,
344
+ object_name=output_filename,
345
+ file_content=final_content
346
+ )
347
+
348
+ logger.info(f"\nGrouped summaries saved to {csv_filename}")
349
+ logger.info(f"Non-refined review saved to {output_filename}")
350
+ return final_content
351
+
352
+
353
+ # Function to enhance language and readability to meet Nature journal style
354
+ async def enhance_language_readability(
355
+ content,
356
+ uuid, customer_name, model_name,
357
+ chat_func
358
+ ):
359
+ """
360
+ Enhance the language and readability of the given content to meet the style of the *Nature* journal.
361
+
362
+ Args:
363
+ content (str): The content to enhance.
364
+ chat_func (function): The function to use for the chat completion.
365
+
366
+ Returns:
367
+ str: The enhanced content.
368
+
369
+ """
370
+ # Separate sections based on paragraph breaks
371
+ sections = content.split("\n\n")
372
+ enhanced_sections = []
373
+ prompts = []
374
+ for section in sections:
375
+ prompt = (
376
+ "Enhance the following text to align with the writing style of *Nature* journal. Refine language to be sophisticated and objective, "
377
+ "using advanced vocabulary and a factual tone. Ensure a high level of lexical diversity and rhythm, with alternating sentence lengths "
378
+ "and varied structures for readability. Avoid emotional, speculative, or conversational language, focusing on objective analysis.\n\n"
379
+ f"Text:\n{section}\n\n"
380
+ "Output format:\n[Enhanced text here]"
381
+ )
382
+ prompts.append(prompt)
383
+
384
+ responses = await asyncio.gather(
385
+ *(chat_func(prompt) for prompt in prompts)
386
+ )
387
+ for response in responses:
388
+ enhanced_section = response.choices[0].message.content.strip()
389
+ enhanced_sections.append(enhanced_section)
390
+
391
+ enhanced_content = "\n\n".join(enhanced_sections)
392
+ await upload_text_to_minio(
393
+ bucket_name=BUCKET_NAME,
394
+ object_name=f"{customer_name}/{uuid}/{model_name}/review_paper.txt",
395
+ file_content=enhanced_content
396
+ )
397
+
398
+ return enhanced_content
399
+
400
+
401
+ async def process_papers(
402
+ dataframe, topic, direction,
403
+ uuid, customer_name, model_name,
404
+ chat_func
405
+ ):
406
+ """
407
+ Process the given papers to extract relevant information and save it to a CSV file.
408
+
409
+ Args:
410
+ dataframe (pandas.DataFrame): The DataFrame containing the papers.
411
+ topic (str): The topic to filter the papers by.
412
+ direction (str): The direction to filter the papers by.
413
+ uuid (str): The UUID of the task.
414
+ customer_name (str): The name of the customer.
415
+ chat_func (function): The function to use for the chat completion.
416
+
417
+ Returns:
418
+ pandas.DataFrame: The DataFrame containing the relevant papers.
419
+
420
+ """
421
+ # Duplicate, no need
422
+ # relevant_rows = [] # List to collect relevant rows for DataFrame creation
423
+
424
+ # Set up the output directory and CSV file
425
+ # output_dir = os.path.join(customer_name)
426
+ # os.makedirs(output_dir, exist_ok=True)
427
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
428
+ output_dir = prefix
429
+
430
+ output_path = os.path.join(output_dir, "relevant_papers.csv")
431
+
432
+ # Create or clear the output file at the beginning
433
+ # with open(output_path, 'w', newline='', encoding='utf-8') as f:
434
+ # writer = csv.writer(f, quoting=csv.QUOTE_ALL)
435
+ # writer.writerow(["Journal Title", "Publication Date", "Title", "First Author", "Summary", "Is Relevant", "Relevance Keywords"]) # Writing header
436
+ texts = ""
437
+ fieldnames = ["Journal Title", "Publication Date", "Title",
438
+ "First Author", "Summary", "Is Relevant", "Relevance Keywords"]
439
+ texts += ",".join([escape_csv_field(x) for x in fieldnames]) + "\n"
440
+
441
+ titles = []
442
+ abstracts = []
443
+ journal_titles = []
444
+ pubd_dates = []
445
+ first_authors = []
446
+ summaries = []
447
+ for idx, row in dataframe.iterrows():
448
+ title = row["TI"]
449
+ abstract = row["AB"]
450
+ journal_title = row["JT"]
451
+ pub_date = row["DCOM"]
452
+ first_author = row["FAU-frist"]
453
+
454
+ titles.append(title)
455
+ abstracts.append(abstract)
456
+ journal_titles.append(journal_title)
457
+ pubd_dates.append(pub_date)
458
+ first_authors.append(first_author)
459
+
460
+ relevants = await asyncio.gather(
461
+ *(is_relevant(
462
+ title, abstract, topic, direction, chat_func
463
+ ) for title, abstract in zip(titles, abstracts))
464
+ )
465
+
466
+ is_relevant_flags = [relevant[0] for relevant in relevants]
467
+ relevance_keywords = [relevant[1] for relevant in relevants]
468
+
469
+ rtitles = []
470
+ rabstracts = []
471
+ rjournal_titles = []
472
+ rpubd_dates = []
473
+ rfirst_authors = []
474
+ rflags = []
475
+ rkeywords = []
476
+
477
+ for (
478
+ rflag, rkeyword, title, abstarct, first_author, journal_title, pub_date
479
+ ) in zip(
480
+ is_relevant_flags, relevance_keywords,
481
+ titles, abstracts, first_authors, journal_titles, pubd_dates
482
+ ):
483
+ if rflag:
484
+ rtitles.append(title)
485
+ rabstracts.append(abstarct)
486
+ rfirst_authors.append(first_author)
487
+ rjournal_titles.append(journal_title)
488
+ rpubd_dates.append(pub_date)
489
+ rflags.append(rflag)
490
+ rkeywords.append(rkeyword)
491
+
492
+ summaries = await asyncio.gather(
493
+ *(summarize_abstract(
494
+ title, abstract, first_author, chat_func
495
+ ) for title, abstract, first_author in
496
+ zip(rtitles, rabstracts, rfirst_authors)
497
+ )
498
+ )
499
+
500
+ for (
501
+ summary,
502
+ journal_title, pub_date, title, first_author,
503
+ rflag, rkeyword
504
+ ) in zip(
505
+ summaries,
506
+ rjournal_titles, rpubd_dates, rtitles, rfirst_authors,
507
+ rflags, rkeywords
508
+ ):
509
+ journal_title = escape_csv_field(journal_title)
510
+ pub_date = escape_csv_field(pub_date)
511
+ title = escape_csv_field(title)
512
+ first_author = escape_csv_field(first_author)
513
+ summary = escape_csv_field(summary)
514
+ rkeyword = escape_csv_field(rkeyword)
515
+
516
+ texts += ",".join([
517
+ str(x) for x in [
518
+ journal_title, pub_date, title, first_author,
519
+ summary, rflag, rkeyword
520
+ ]
521
+ ]) + "\n"
522
+
523
+ # Print the added summary and keywords
524
+ logger.info(f"Added summary: {summary}")
525
+ logger.info(f"Relevance Keywords: {rkeyword}")
526
+
527
+ # Create the relevant DataFrame to return
528
+ # relevant_df = pd.DataFrame(relevant_rows)
529
+ # return relevant_df
530
+ await upload_text_to_minio(
531
+ bucket_name=BUCKET_NAME,
532
+ object_name=output_path,
533
+ file_content=texts
534
+ )
535
+
536
+ return output_path
537
+
538
+
539
+ async def translate_to_chinese_before_references(
540
+ text,
541
+ uuid, customer_name, model_name,
542
+ chat_func
543
+ ):
544
+ """
545
+ Translates the content of a text file to Chinese, keeping the '**References**' section in English.
546
+
547
+ Args:
548
+ text (str): The content of the text file.
549
+ output_filename (str): The name of the output file.
550
+ chat_func (function): The function to use for translation.
551
+
552
+ Returns:
553
+ str: The translated content.
554
+
555
+ """
556
+ lines = text.split("\n")
557
+
558
+ # Step 3: 找到 '**References**' 行的索引
559
+ references_index = None
560
+ for i, line in enumerate(lines):
561
+ if line.strip() == "**References**":
562
+ references_index = i
563
+ break
564
+
565
+ # Step 4: 根据找到的索引分割内容
566
+ if references_index is not None:
567
+ main_content_lines = lines[:references_index]
568
+ references_content_lines = lines[references_index:]
569
+ else:
570
+ # 如果没有找到 '**References**',则认为整个内容为正文
571
+ main_content_lines = lines
572
+ references_content_lines = []
573
+
574
+ # 将正文内容拼接为一个字符串
575
+ main_content = "\n".join(main_content_lines)
576
+
577
+ # Step 5: 分段处理正文内容进行翻译
578
+ sections = main_content.split("\n\n")
579
+ translated_sections = []
580
+
581
+ prompts = []
582
+
583
+ for section in sections:
584
+ # 简化 prompt,只要求翻译正文内容
585
+ prompt = (
586
+ "Translate the following text to academic Chinese:\n\n"
587
+ f"Text:\n{section}\n\n"
588
+ "Output format:\n[Translated Chinese text here]"
589
+ )
590
+ prompts.append(prompt)
591
+
592
+ responses = await asyncio.gather(
593
+ *(chat_func(prompt) for prompt in prompts)
594
+ )
595
+ for response in responses:
596
+ translated_section = response.choices[0].message.content.strip()
597
+ translated_sections.append(translated_section)
598
+
599
+ # Step 6: 将翻译后的正文拼接
600
+ translated_content = "\n\n".join(translated_sections)
601
+
602
+ # Step 7: 合并翻译后的正文和 References 部分
603
+ if references_content_lines:
604
+ references_content = "\n".join(references_content_lines)
605
+ final_content = translated_content + "\n\n" + references_content
606
+ else:
607
+ final_content = translated_content
608
+
609
+ # Step 8: 保存结果到新的文件
610
+ output_filename = f"{customer_name}/{uuid}/{model_name}/review_paper_translated.txt"
611
+ await upload_text_to_minio(
612
+ bucket_name=BUCKET_NAME,
613
+ object_name=output_filename,
614
+ file_content=final_content
615
+ )
616
+
617
+ logger.info(f"\nTranslated content saved to {output_filename}")
618
+ return output_filename
619
+
620
+
621
+ # Main function to automate the review paper creation process with language enhancement step
622
+ async def create_review_paper(
623
+ relevant_papers_df,
624
+ main_topic,
625
+ uuid, customer_name, model_name,
626
+ chat_func,
627
+ translate_to_cn=False
628
+ ):
629
+ """
630
+ Main function to automate the review paper creation process with language enhancement step.
631
+
632
+ Args:
633
+ relevant_papers_df (pd.DataFrame): DataFrame containing relevant papers.
634
+ main_topic (str): Main topic of the review paper.
635
+ uuid (str): Unique identifier for the review paper.
636
+ customer_name (str): Name of the customer.
637
+ chat_func (function): Function to handle chat interactions.
638
+ translate_to_cn (bool): Flag to indicate if translation to Chinese is required.
639
+
640
+ Returns:
641
+ None
642
+
643
+ """
644
+
645
+ # Step 1: Generate subheadings related to the main topic
646
+ subheadings = await generate_subheadings(
647
+ relevant_papers_df, main_topic,
648
+ chat_func
649
+ )
650
+
651
+ # Step 2: Assign each summary to a subheading
652
+ relevant_papers_df = await assign_subheadings_to_summaries(
653
+ relevant_papers_df, subheadings,
654
+ uuid, customer_name, model_name,
655
+ chat_func
656
+ )
657
+
658
+ # Step 3: Create paragraphs by subheading, with introductory and concluding sections, and references
659
+ review_content = await create_paragraphs_by_subheading(
660
+ relevant_papers_df, subheadings, main_topic,
661
+ uuid, customer_name, model_name,
662
+ chat_func
663
+ )
664
+
665
+ # Step 4: Enhance language and readability
666
+ enhanced_content = await enhance_language_readability(
667
+ review_content,
668
+ chat_func
669
+ )
670
+
671
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
672
+ output_dir = prefix
673
+
674
+ output_filename = os.path.join(output_dir, "review_paper.txt")
675
+
676
+ # Step: Translate to Chinese
677
+ if translate_to_cn:
678
+ await translate_to_chinese_before_references(
679
+ enhanced_content,
680
+ output_filename.replace(".txt", "_cn.txt"),
681
+ chat_func
682
+ )
683
+
684
+ # Step 6: Save the generated content to a text file
685
+ # with open(output_filename, "w", encoding="utf-8") as f:
686
+ # f.write(enhanced_content)
687
+ await upload_text_to_minio(
688
+ bucket_name=BUCKET_NAME,
689
+ object_name=output_filename,
690
+ file_content=enhanced_content
691
+ )
692
+
693
+ logger.info(f"\nReview paper saved to {output_filename}")
694
+ return output_filename
utils/pubmed_plus_utils.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import asyncio
3
+
4
+ from minio import Minio
5
+ from loguru import logger
6
+
7
+ from entities.task import PubMedPlusTask
8
+ from utils.api_utils import (
9
+ retry_operation,
10
+ get_chat_func,
11
+ compare_chat_chocies
12
+ )
13
+ from utils.r2_utils import (
14
+ get_client,
15
+ get_file_from_minio,
16
+ get_dataframe_from_minio,
17
+ upload_text_to_minio,
18
+ upload_task_json_to_minio,
19
+ )
20
+ from utils.paper_plus_utils import (
21
+ process_papers,
22
+ generate_subheadings,
23
+ assign_subheadings_to_summaries,
24
+ create_paragraphs_by_subheading,
25
+ refine_review_content,
26
+ translate_refined_review_to_chinese,
27
+ translate_to_chinese_before_references
28
+ )
29
+ from utils.pubmed_utils import (
30
+ generate_pubmed_search_string,
31
+ process_pubmed_data
32
+ )
33
+
34
+
35
+ BUCKET_NAME = "ai-scientist"
36
+
37
+
38
+ # =================================
39
+ # Function Groups: Pipeline for PubMed
40
+ #
41
+ # 1. pipeline
42
+ # 2. single model chat
43
+ # =================================
44
+
45
+ async def pubmed_plus_pipeline(
46
+ task: PubMedPlusTask,
47
+ client: Minio = None,
48
+ max_retries: int = 5,
49
+ delay: float = 0.5
50
+ ):
51
+ """
52
+ Pubmed pipeline
53
+
54
+ Args:
55
+ task: PubMedTask object, containig basic information for PubMedTask
56
+ client: Minio, minio client
57
+ max_retries: int, max retries for each step
58
+ delay: float, delay between each retry
59
+
60
+ Returns:
61
+ None
62
+
63
+ """
64
+ if client is None:
65
+ client = get_client()
66
+
67
+ customer_name = task.customer_name
68
+ uuid = task.uuid
69
+ model_names = task.model_names
70
+
71
+ task.status_string["overall"] = "processing"
72
+
73
+ await asyncio.gather(
74
+ *(process_pubmed_single_chat(
75
+ task, model_name, client, max_retries, delay
76
+ ) for model_name in model_names)
77
+ )
78
+
79
+ # if compare between models
80
+ # at least 3 models should be selected
81
+ logger.info("Check Compare...")
82
+ if task.do_compare and len(task.model_names) >= 3:
83
+ if task.status.get("compare", 0) == 0:
84
+ contents = await asyncio.gather(
85
+ *(get_file_from_minio(
86
+ bucket_name=BUCKET_NAME,
87
+ object_name=f"{customer_name}/{uuid}/{model_name}/review_paper.txt",
88
+ client=client
89
+ ) for model_name in model_names)
90
+ )
91
+ contents = [c.data.decode("utf-8") for c in contents]
92
+ task.status_string["overall"] = "Start Compare"
93
+
94
+ rank_scores = await compare_chat_chocies(
95
+ contents=contents,
96
+ model_names=model_names
97
+ )
98
+ best_content = contents[min(rank_scores, key=rank_scores.get)]
99
+ await upload_text_to_minio(
100
+ bucket_name=BUCKET_NAME,
101
+ object_name=f"{customer_name}/{uuid}/compared_reveiw_paper.txt",
102
+ file_content=best_content
103
+ )
104
+ task.status_string["overall"] = "Finished"
105
+ task.status["compare"] = 1
106
+ await upload_task_json_to_minio(task, client)
107
+ else:
108
+ task.status_string["overall"] = "Finished"
109
+ await upload_task_json_to_minio(task, client)
110
+ else:
111
+ logger.info("No Compare.")
112
+ task.status_string["overall"] = "Finished"
113
+ await upload_task_json_to_minio(task, client)
114
+
115
+
116
+ async def process_pubmed_single_chat(
117
+ task: PubMedPlusTask,
118
+ model_name: str,
119
+ client: Minio = None,
120
+ max_retries: int = 5,
121
+ delay: float = 0.5
122
+ ):
123
+ """
124
+ Process PubMed Task
125
+
126
+ Args:
127
+ task: PubMedTask object, containig basic information for PubMedTask
128
+ model_name: str, model name, refer to the model used at this step
129
+ client: Minio, minio client
130
+ max_retries: int, max retries for each step
131
+ delay: float, delay between each retry
132
+
133
+ Returns:
134
+ None
135
+
136
+ """
137
+
138
+ # get minio client
139
+ if client is None:
140
+ client = get_client()
141
+
142
+ # add status for <model_name>
143
+ if model_name not in task.status.keys():
144
+ task.status[model_name] = 0
145
+
146
+ # set task status string
147
+ task.status_string["overall"] = "processing"
148
+
149
+ process_steps = {
150
+ 0: process_pubmed_generate_pubmed_string,
151
+ 1: process_pubmed_fetch_data,
152
+ 2: process_pubmed_process_papers,
153
+ 3: process_pubmed_generate_subheadings,
154
+ 4: process_pubmed_assign_subheadings_to_summaries,
155
+ 5: process_pubmed_create_paragraphs_by_subheading,
156
+ 6: process_pubmed_refine,
157
+ 7: process_pubmed_translate,
158
+ }
159
+
160
+ state_description = {
161
+ 0: "Finished pubmed string generation.",
162
+ 1: "Finished fetching data.",
163
+ 2: "Finished paper processing.",
164
+ 3: "Finished subheading generation.",
165
+ 4: "Finished subheading assignment.",
166
+ 5: "Finished paragraph generation.",
167
+ 6: "Finished review refine.",
168
+ 7: "Finished review translate.",
169
+ }
170
+
171
+ # Execute Phase
172
+ current_state = task.status[model_name]
173
+ for state in range(current_state, len(process_steps.keys())):
174
+ await process_steps[state](
175
+ task=task,
176
+ model_name=model_name,
177
+ save_name=model_name,
178
+ prev_name=model_name,
179
+ client=client,
180
+ max_retries=max_retries, delay=delay
181
+ )
182
+ task.status_string[model_name] = state_description[state]
183
+ task.status[model_name] = state + 1
184
+ await upload_task_json_to_minio(task, client)
185
+
186
+ task.status_string[model_name] = "Finished."
187
+ await upload_task_json_to_minio(task, client)
188
+
189
+
190
+ # =================================
191
+ # Function Groups: process_pubmed_*
192
+ # 1. _generate_pubmed_string
193
+ # 2. _fetch_data
194
+ # 3. _process_papers
195
+ # 3. _generate_subheadings
196
+ # 4. _assign_subheadings_to_summaries
197
+ # 5. _create_paragraphs_by_subheading
198
+ # 6. _refine
199
+ # 7. _translate
200
+ # =================================
201
+
202
+ async def process_pubmed_generate_pubmed_string(
203
+ task: PubMedPlusTask,
204
+ model_name: str,
205
+ save_name: str,
206
+ prev_name: str = None,
207
+ client: Minio = None,
208
+ max_retries: int = 5,
209
+ delay: float = 0.5
210
+ ):
211
+ """
212
+ Generate pubmed search string step
213
+
214
+ Args:
215
+ task: PubMedTask object, containig basic information for PubMedTask
216
+ prev_model_name: str, previous model name, refer to previous step result
217
+ model_name: str, next model name, refer to the model used at this step
218
+ save_name: str, save name for minio path
219
+ client: Minio, minio client
220
+ max_retries: int, max retries for each step
221
+ delay: float, delay between each retry
222
+
223
+ Returns:
224
+ path to save results
225
+
226
+ """
227
+
228
+ if client is None:
229
+ client = get_client()
230
+
231
+ if prev_name is not None:
232
+ logger.warning("For first step, prev_model_name is not used.")
233
+
234
+ query = task.query
235
+ customer_name = task.customer_name
236
+ uuid = task.uuid
237
+
238
+ chat_func = get_chat_func(model_names=[model_name])[0]
239
+
240
+ pubmed_search_string, exceptions = await retry_operation(
241
+ generate_pubmed_search_string, task,
242
+ query=query,
243
+ max_retries=max_retries, delay=delay,
244
+ chat_func=chat_func
245
+ )
246
+ if pubmed_search_string is None: # no valid result after max retries
247
+ # store exception strings in status
248
+ task.status_string[model_name] = exceptions
249
+ await upload_task_json_to_minio(task, client)
250
+ raise RuntimeError("Pubmed Search String Generation Failed.") # exit
251
+
252
+ await upload_text_to_minio(
253
+ bucket_name=BUCKET_NAME,
254
+ object_name=f"{customer_name}/{uuid}/{save_name}/pubmed_search_string.txt",
255
+ file_content=pubmed_search_string
256
+ )
257
+
258
+
259
+ async def process_pubmed_fetch_data(
260
+ task: PubMedPlusTask,
261
+ model_name: str,
262
+ save_name: str,
263
+ prev_name: str = None,
264
+ client: Minio = None,
265
+ max_retries: int = 5,
266
+ delay: float = 0.5
267
+ ):
268
+ """
269
+ Process PubMed Fetch Data
270
+
271
+ Args:
272
+ task: PubMedTask object, containig basic information for PubMedTask
273
+ prev_model_name: str, previous model name, refer to previous step result
274
+ model_name: str, next model name, refer to the model used at this step
275
+ save_name: str, save name for minio path
276
+ client: Minio, minio client
277
+
278
+ Returns:
279
+ path to save results
280
+
281
+ """
282
+
283
+ if client is None:
284
+ client = get_client()
285
+
286
+ customer_name = task.customer_name
287
+ uuid = task.uuid
288
+ start_year = task.start_year
289
+ end_year = task.end_year
290
+ size = task.size
291
+
292
+ pubmed_search_string = await get_file_from_minio(
293
+ bucket_name=BUCKET_NAME,
294
+ object_name=f"{customer_name}/{uuid}/{prev_name}/pubmed_search_string.txt",
295
+ client=client
296
+ )
297
+ pubmed_search_string = pubmed_search_string.data.decode("utf-8")
298
+ results, exceptions = await retry_operation(
299
+ process_pubmed_data, task,
300
+ query=pubmed_search_string,
301
+ model_name=save_name,
302
+ start_year=start_year, end_year=end_year,
303
+ size=size,
304
+ uuid=uuid, customer_name=customer_name,
305
+ max_retries=max_retries, delay=delay
306
+ )
307
+ if results is None: # no valid result after max retries
308
+ # store exception strings in status
309
+ task.status_string[model_name] = exceptions
310
+ await upload_task_json_to_minio(task, client)
311
+ raise ConnectionError("Pubmed Data Fetch Failed.") # exit
312
+
313
+
314
+ async def process_pubmed_process_papers(
315
+ task: PubMedPlusTask,
316
+ model_name: str,
317
+ save_name: str,
318
+ prev_name: str = None,
319
+ client: Minio = None,
320
+ max_retries: int = 5,
321
+ delay: float = 0.5
322
+ ):
323
+ """
324
+ Process PubMed Process Papers
325
+
326
+ Args:
327
+ task: PubMedTask object, containig basic information for PubMedTask
328
+ prev_model_name: str, previous model name, refer to previous step result
329
+ model_name: str, next model name, refer to the model used at this step
330
+ save_name: str, save name for minio path
331
+ client: Minio, minio client
332
+
333
+ Returns:
334
+ path to save results
335
+
336
+ """
337
+ if client is None:
338
+ client = get_client()
339
+
340
+ query = task.query
341
+ direction = task.direction
342
+ customer_name = task.customer_name
343
+ uuid = task.uuid
344
+
345
+ chat_func = get_chat_func(model_names=[model_name])[0]
346
+
347
+ non_review_pubmed_df = await get_dataframe_from_minio(
348
+ bucket_name=BUCKET_NAME,
349
+ object_name=f"{customer_name}/{uuid}/{prev_name}/pubmed_results_non_reviews.csv",
350
+ client=client
351
+ )
352
+ results, exceptions = await retry_operation(
353
+ process_papers, task,
354
+ dataframe=non_review_pubmed_df,
355
+ topic=query, direction=direction,
356
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
357
+ max_retries=max_retries, delay=delay,
358
+ chat_func=chat_func
359
+ )
360
+ if results is None: # no valid result after max retries
361
+ # store exception strings in status
362
+ task.status_string[model_name] = exceptions
363
+ await upload_task_json_to_minio(task, client)
364
+ raise RuntimeError("Pubmed Paper Processing Failed.") # exit
365
+
366
+
367
+ async def process_pubmed_generate_subheadings(
368
+ task: PubMedPlusTask,
369
+ model_name: str,
370
+ save_name: str,
371
+ prev_name: str = None,
372
+ client: Minio = None,
373
+ max_retries: int = 5,
374
+ delay: float = 0.5
375
+ ):
376
+ """
377
+ Process PubMed Generate Subheadings
378
+ Args:
379
+ task: PubMedTask object, containig basic information for PubMedTask
380
+ prev_model_name: str, previous model name, refer to previous step result
381
+ model_name: str, next model name, refer to the model used at this step
382
+ save_name: str, save name for minio path
383
+
384
+ Returns:
385
+ path to save results
386
+ """
387
+ if client is None:
388
+ client = get_client()
389
+
390
+ query = task.query
391
+ customer_name = task.customer_name
392
+ uuid = task.uuid
393
+
394
+ chat_func = get_chat_func([model_name])[0]
395
+
396
+ relevant_papers_df = await get_dataframe_from_minio(
397
+ bucket_name=BUCKET_NAME,
398
+ object_name=f"{customer_name}/{uuid}/{prev_name}/relevant_papers.csv",
399
+ client=client
400
+ )
401
+
402
+ results, exceptions = await retry_operation(
403
+ generate_subheadings, task,
404
+ relevant_papers_df=relevant_papers_df,
405
+ main_topic=query,
406
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
407
+ chat_func=chat_func,
408
+ max_retries=max_retries, delay=delay
409
+ )
410
+ if results is None: # no valid result after max retries
411
+ # store exception strings in status
412
+ task.status_string[model_name] = exceptions
413
+ await upload_task_json_to_minio(task, client)
414
+ raise RuntimeError("Pubmed Generate Subheadings Failed.") # exit
415
+
416
+
417
+ async def process_pubmed_assign_subheadings_to_summaries(
418
+ task: PubMedPlusTask,
419
+ model_name: str,
420
+ save_name: str,
421
+ prev_name: str = None,
422
+ client: Minio = None,
423
+ max_retries: int = 5,
424
+ delay: float = 0.5
425
+ ):
426
+ """
427
+ Process PubMed Assign Subheadings to Summaries
428
+ Args:
429
+ task: PubMedTask object, containig basic information for PubMedTask
430
+ prev_model_name: str, previous model name, refer to previous step result
431
+ model_name: str, next model name, refer to the model used at this step
432
+ save_name: str, save name for minio path
433
+
434
+ Returns:
435
+ path to save results
436
+ """
437
+
438
+ if client is None:
439
+ client = get_client()
440
+
441
+ customer_name = task.customer_name
442
+ uuid = task.uuid
443
+
444
+ chat_func = get_chat_func([model_name])[0]
445
+
446
+ subheadings = await get_file_from_minio(
447
+ bucket_name=BUCKET_NAME,
448
+ object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
449
+ client=client
450
+ )
451
+ subheadings = subheadings.data.decode("utf-8").split("\n")
452
+
453
+ relevant_papers_df = await get_dataframe_from_minio(
454
+ bucket_name=BUCKET_NAME,
455
+ object_name=f"{customer_name}/{uuid}/{prev_name}/relevant_papers.csv",
456
+ client=client
457
+ )
458
+
459
+ results, exceptions = await retry_operation(
460
+ assign_subheadings_to_summaries, task,
461
+ subheadings=subheadings,
462
+ relevant_papers_df=relevant_papers_df,
463
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
464
+ chat_func=chat_func,
465
+ max_retries=max_retries, delay=delay
466
+ )
467
+ if results is None: # no valid result after max retries
468
+ # store exception strings in status
469
+ task.status_string[model_name] = exceptions
470
+ await upload_task_json_to_minio(task, client)
471
+ raise RuntimeError("Pubmed Assign Subheadings Failed.") # exit
472
+
473
+
474
+ async def process_pubmed_create_paragraphs_by_subheading(
475
+ task: PubMedPlusTask,
476
+ model_name: str,
477
+ save_name: str,
478
+ prev_name: str = None,
479
+ client: Minio = None,
480
+ max_retries: int = 5,
481
+ delay: float = 0.5
482
+ ):
483
+ """
484
+ Process PubMed Create Paragraphs by Subheading
485
+ Args:
486
+ task: PubMedTask object, containig basic information for PubMedTask
487
+ prev_model_name: str, previous model name, refer to previous step result
488
+ model_name: str, next model name, refer to the model used at this step
489
+ save_name: str, save name for minio path
490
+ client: Minio, minio client
491
+ max_retries: int, max retries for the operation
492
+ delay: float, delay between retries
493
+
494
+ Returns:
495
+ path to save results
496
+ """
497
+
498
+ if client is None:
499
+ client = get_client()
500
+
501
+ query = task.query
502
+ customer_name = task.customer_name
503
+ uuid = task.uuid
504
+
505
+ chat_func = get_chat_func([model_name])[0]
506
+
507
+ subheadings = await get_file_from_minio(
508
+ bucket_name=BUCKET_NAME,
509
+ object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
510
+ client=client
511
+ )
512
+ subheadings = subheadings.data.decode("utf-8").split("\n")
513
+
514
+ relevant_papers_df = await get_dataframe_from_minio(
515
+ bucket_name=BUCKET_NAME,
516
+ object_name=f"{customer_name}/{uuid}/{prev_name}/assigned_subheadings.csv",
517
+ client=client
518
+ )
519
+
520
+ results, exceptions = await retry_operation(
521
+ create_paragraphs_by_subheading, task,
522
+ subheadings=subheadings, main_topic=query,
523
+ relevant_papers_df=relevant_papers_df,
524
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
525
+ chat_func=chat_func,
526
+ max_retries=max_retries, delay=delay
527
+ )
528
+ if results is None: # no valid result after max retries
529
+ # store exception strings in status
530
+ task.status_string[model_name] = exceptions
531
+ await upload_task_json_to_minio(task, client)
532
+ raise RuntimeError("Pubmed Create Paragraphs Failed.") # exit
533
+
534
+
535
+ async def process_pubmed_translate(
536
+ task: PubMedPlusTask,
537
+ model_name: str,
538
+ save_name: str,
539
+ prev_name: str = None,
540
+ client: Minio = None,
541
+ max_retries: int = 5,
542
+ delay: float = 0.5
543
+ ):
544
+ """
545
+ Process PubMed Translate
546
+ Args:
547
+ task: PubMedTask object, containig basic information for PubMedTask
548
+ prev_model_name: str, previous model name, refer to previous step result
549
+ model_name: str, next model name, refer to the model used at this step
550
+ save_name: str, save name for minio path
551
+ client: Minio, minio client
552
+ max_retries: int, max retries for the operation
553
+ delay: float, delay between retries
554
+
555
+ Returns:
556
+ path to save results
557
+ """
558
+
559
+ if client is None:
560
+ client = get_client()
561
+
562
+ customer_name = task.customer_name
563
+ uuid = task.uuid
564
+ do_refine = task.do_refine
565
+
566
+ chat_func = get_chat_func([model_name])[0]
567
+
568
+ if do_refine:
569
+ refined_review_content = await get_file_from_minio(
570
+ bucket_name=BUCKET_NAME,
571
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_paper_refined.docx",
572
+ client=client
573
+ )
574
+ refined_review_content = io.BytesIO(refined_review_content.data)
575
+
576
+ results, exceptions = await retry_operation(
577
+ translate_refined_review_to_chinese, task,
578
+ refined_review_content=refined_review_content,
579
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
580
+ chat_func=chat_func,
581
+ max_retries=max_retries, delay=delay
582
+ )
583
+ if results is None: # no valid result after max retries
584
+ # store exception strings in status
585
+ task.status_string[model_name] = exceptions
586
+ await upload_task_json_to_minio(task, client)
587
+ raise RuntimeError("Pubmed Translate Refined Review Failed.") # exit
588
+ else:
589
+ review_content = await get_file_from_minio(
590
+ bucket_name=BUCKET_NAME,
591
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_non_refined.txt",
592
+ client=client
593
+ )
594
+ results, exceptions = await retry_operation(
595
+ translate_to_chinese_before_references, task,
596
+ text=review_content,
597
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
598
+ chat_func=chat_func,
599
+ max_retries=max_retries, delay=delay
600
+ )
601
+ if results is None: # no valid result after max retries
602
+ # store exception strings in status
603
+ task.status_string[model_name] = exceptions
604
+ await upload_task_json_to_minio(task, client)
605
+ raise RuntimeError("Pubmed Translate Failed.") # exit
606
+
607
+
608
+ async def process_pubmed_refine(
609
+ task: PubMedPlusTask,
610
+ model_name: str,
611
+ save_name: str,
612
+ prev_name: str = None,
613
+ client: Minio = None,
614
+ max_retries: int = 5,
615
+ delay: float = 0.5
616
+ ):
617
+ """
618
+ Process PubMed Refine
619
+ Args:
620
+ task: PubMedTask object, containig basic information for PubMedTask
621
+ prev_model_name: str, previous model name, refer to previous step result
622
+ model_name: str, next model name, refer to the model used at this step
623
+ save_name: str, save name for minio path
624
+ client: Minio, minio client
625
+ max_retries: int, max retries for the operation
626
+ delay: float, delay between retries
627
+
628
+ Returns:
629
+ path to save results
630
+ """
631
+
632
+ # additional check on if do_refine
633
+ # if not refine, exit here with 1
634
+ if not task.do_refine:
635
+ return 1
636
+
637
+ if client is None:
638
+ client = get_client()
639
+
640
+ customer_name = task.customer_name
641
+ uuid = task.uuid
642
+
643
+ chat_func = get_chat_func([model_name])[0]
644
+
645
+ review_content = await get_file_from_minio(
646
+ bucket_name=BUCKET_NAME,
647
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_non_refined.txt",
648
+ client=client
649
+ )
650
+ review_content = review_content.data.decode("utf-8")
651
+
652
+ results, exceptions = await retry_operation(
653
+ refine_review_content, task,
654
+ non_refine_content=review_content,
655
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
656
+ chat_func=chat_func,
657
+ max_retries=max_retries, delay=delay
658
+ )
659
+ if results is None: # no valid result after max retries
660
+ # store exception strings in status
661
+ task.status_string[model_name] = exceptions
662
+ await upload_task_json_to_minio(task, client)
663
+ raise RuntimeError("Pubmed Refine Failed.") # exit
664
+
665
+
utils/pubmed_utils.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import aiohttp
4
+ import requests
5
+ import pandas as pd
6
+
7
+ from minio import Minio
8
+ from loguru import logger
9
+ from bs4 import BeautifulSoup
10
+
11
+ from entities.task import PubMedTask
12
+ from utils.api_utils import (
13
+ retry_operation,
14
+ get_chat_func,
15
+ compare_chat_chocies
16
+ )
17
+ from utils.r2_utils import (
18
+ get_client,
19
+ get_file_from_minio,
20
+ get_dataframe_from_minio,
21
+ upload_text_to_minio,
22
+ upload_dataframe_to_minio,
23
+ upload_task_json_to_minio,
24
+ )
25
+ from utils.common_utils import escape_csv_field
26
+ from utils.paper_utils import (
27
+ process_papers,
28
+ generate_subheadings,
29
+ assign_subheadings_to_summaries,
30
+ create_paragraphs_by_subheading,
31
+ enhance_language_readability,
32
+ translate_to_chinese_before_references
33
+ )
34
+
35
+
36
+ BUCKET_NAME = "ai-scientist"
37
+
38
+
39
+ # =================================
40
+ # Function Groups: Pipeline for PubMed
41
+ #
42
+ # 1. pipeline
43
+ # 2. single model chat
44
+ # =================================
45
+
46
+ async def pubmed_pipeline(
47
+ task: PubMedTask,
48
+ client: Minio = None,
49
+ max_retries: int = 5,
50
+ delay: float = 0.5
51
+ ):
52
+ """
53
+ Pubmed pipeline
54
+
55
+ Args:
56
+ task: PubMedTask object, containig basic information for PubMedTask
57
+ client: Minio, minio client
58
+ max_retries: int, max retries for each step
59
+ delay: float, delay between each retry
60
+
61
+ Returns:
62
+ None
63
+
64
+ """
65
+ if client is None:
66
+ client = get_client()
67
+
68
+ customer_name = task.customer_name
69
+ uuid = task.uuid
70
+ model_names = task.model_names
71
+
72
+ task.status_string["overall"] = "processing"
73
+
74
+ await asyncio.gather(
75
+ *(process_pubmed_single_chat(
76
+ task, model_name, client, max_retries, delay
77
+ ) for model_name in model_names)
78
+ )
79
+
80
+ # if compare between models
81
+ # at least 3 models should be selected
82
+ logger.info("Check Compare...")
83
+ if task.do_compare and len(task.model_names) >= 3:
84
+ if task.status.get("compare", 0) == 0:
85
+ contents = await asyncio.gather(
86
+ *(get_file_from_minio(
87
+ bucket_name=BUCKET_NAME,
88
+ object_name=f"{customer_name}/{uuid}/{model_name}/review_paper.txt",
89
+ client=client
90
+ ) for model_name in model_names)
91
+ )
92
+ contents = [c.data.decode("utf-8") for c in contents]
93
+ task.status_string["overall"] = "Start Compare"
94
+
95
+ rank_scores = await compare_chat_chocies(
96
+ contents=contents,
97
+ model_names=model_names
98
+ )
99
+ best_content = contents[min(rank_scores, key=rank_scores.get)]
100
+ await upload_text_to_minio(
101
+ bucket_name=BUCKET_NAME,
102
+ object_name=f"{customer_name}/{uuid}/compared_reveiw_paper.txt",
103
+ file_content=best_content
104
+ )
105
+ task.status_string["overall"] = "Finished"
106
+ task.status["compare"] = 1
107
+ await upload_task_json_to_minio(task, client)
108
+ else:
109
+ task.status_string["overall"] = "Finished"
110
+ await upload_task_json_to_minio(task, client)
111
+ else:
112
+ logger.info("No Compare.")
113
+ task.status_string["overall"] = "Finished"
114
+ await upload_task_json_to_minio(task, client)
115
+
116
+
117
+ async def process_pubmed_single_chat(
118
+ task: PubMedTask,
119
+ model_name: str,
120
+ client: Minio = None,
121
+ max_retries: int = 5,
122
+ delay: float = 0.5
123
+ ):
124
+ """
125
+ Process PubMed Task
126
+
127
+ Args:
128
+ task: PubMedTask object, containig basic information for PubMedTask
129
+ model_name: str, model name, refer to the model used at this step
130
+ client: Minio, minio client
131
+ max_retries: int, max retries for each step
132
+ delay: float, delay between each retry
133
+
134
+ Returns:
135
+ None
136
+
137
+ """
138
+
139
+ # get minio client
140
+ if client is None:
141
+ client = get_client()
142
+
143
+ # add status for <model_name>
144
+ if model_name not in task.status.keys():
145
+ task.status[model_name] = 0
146
+
147
+ # set task status string
148
+ task.status_string["overall"] = "processing"
149
+
150
+ process_steps = {
151
+ 0: process_pubmed_generate_pubmed_string,
152
+ 1: process_pubmed_fetch_data,
153
+ 2: process_pubmed_process_papers,
154
+ 3: process_pubmed_generate_subheadings,
155
+ 4: process_pubmed_assign_subheadings_to_summaries,
156
+ 5: process_pubmed_create_paragraphs_by_subheading,
157
+ 6: process_pubmed_enhance_language_readability,
158
+ 7: process_pubmed_translate
159
+ }
160
+
161
+ state_description = {
162
+ 0: "Finished pubmed string generation.",
163
+ 1: "Finished fetching data.",
164
+ 2: "Finished paper processing.",
165
+ 3: "Finished subheading generation.",
166
+ 4: "Finished subheading assignment.",
167
+ 5: "Finished paragraph generation.",
168
+ 6: "Finished review language readability enhancement.",
169
+ 7: "Finished review translation."
170
+ }
171
+
172
+ # Execute Phase
173
+ current_state = task.status[model_name]
174
+ for state in range(current_state, len(process_steps.keys())):
175
+ await process_steps[state](
176
+ task=task,
177
+ model_name=model_name,
178
+ save_name=model_name,
179
+ prev_name=model_name,
180
+ client=client,
181
+ max_retries=max_retries, delay=delay
182
+ )
183
+ task.status_string[model_name] = state_description[state]
184
+ task.status[model_name] = state + 1
185
+ await upload_task_json_to_minio(task, client)
186
+
187
+ task.status_string[model_name] = "Finished."
188
+ await upload_task_json_to_minio(task, client)
189
+
190
+
191
+ # =================================
192
+ # Function Groups: process_pubmed_*
193
+ # 1. _generate_pubmed_string
194
+ # 2. _fetch_data
195
+ # 3. _process_papers
196
+ # 4. _generate_subheadings
197
+ # 5. _assign_subheadings_to_summaries
198
+ # 6. _create_paragraphs_by_subheading
199
+ # 7. _enhance_language_readability
200
+ # 8. _translate
201
+ # =================================
202
+
203
+ async def process_pubmed_generate_pubmed_string(
204
+ task: PubMedTask,
205
+ model_name: str,
206
+ save_name: str,
207
+ prev_name: str = None,
208
+ client: Minio = None,
209
+ max_retries: int = 5,
210
+ delay: float = 0.5
211
+ ):
212
+ """
213
+ Generate pubmed search string step
214
+
215
+ Args:
216
+ task: PubMedTask object, containig basic information for PubMedTask
217
+ prev_model_name: str, previous model name, refer to previous step result
218
+ model_name: str, next model name, refer to the model used at this step
219
+ save_name: str, save name for minio path
220
+ client: Minio, minio client
221
+ max_retries: int, max retries for each step
222
+ delay: float, delay between each retry
223
+
224
+ Returns:
225
+ path to save results
226
+
227
+ """
228
+
229
+ if client is None:
230
+ client = get_client()
231
+
232
+ if prev_name is not None:
233
+ logger.warning("For first step, prev_model_name is not used.")
234
+
235
+ query = task.query
236
+ customer_name = task.customer_name
237
+ uuid = task.uuid
238
+
239
+ chat_func = get_chat_func(model_names=[model_name])[0]
240
+
241
+ pubmed_search_string, exceptions = await retry_operation(
242
+ generate_pubmed_search_string, task,
243
+ query=query,
244
+ max_retries=max_retries, delay=delay,
245
+ chat_func=chat_func
246
+ )
247
+ if pubmed_search_string is None: # no valid result after max retries
248
+ # store exception strings in status
249
+ task.status_string[model_name] = exceptions
250
+ await upload_task_json_to_minio(task, client)
251
+ raise RuntimeError("Pubmed Search String Generation Failed.") # exit
252
+
253
+ await upload_text_to_minio(
254
+ bucket_name=BUCKET_NAME,
255
+ object_name=f"{customer_name}/{uuid}/{save_name}/pubmed_search_string.txt",
256
+ file_content=pubmed_search_string
257
+ )
258
+
259
+
260
+ async def process_pubmed_fetch_data(
261
+ task: PubMedTask,
262
+ model_name: str,
263
+ save_name: str,
264
+ prev_name: str = None,
265
+ client: Minio = None,
266
+ max_retries: int = 5,
267
+ delay: float = 0.5
268
+ ):
269
+ """
270
+ Process PubMed Fetch Data
271
+
272
+ Args:
273
+ task: PubMedTask object, containig basic information for PubMedTask
274
+ prev_model_name: str, previous model name, refer to previous step result
275
+ model_name: str, next model name, refer to the model used at this step
276
+ save_name: str, save name for minio path
277
+ client: Minio, minio client
278
+
279
+ Returns:
280
+ path to save results
281
+
282
+ """
283
+
284
+ if client is None:
285
+ client = get_client()
286
+
287
+ customer_name = task.customer_name
288
+ uuid = task.uuid
289
+ start_year = task.start_year
290
+ end_year = task.end_year
291
+ size = task.size
292
+
293
+ pubmed_search_string = await get_file_from_minio(
294
+ bucket_name=BUCKET_NAME,
295
+ object_name=f"{customer_name}/{uuid}/{prev_name}/pubmed_search_string.txt",
296
+ client=client
297
+ )
298
+ pubmed_search_string = pubmed_search_string.data.decode("utf-8")
299
+ results, exceptions = await retry_operation(
300
+ process_pubmed_data, task,
301
+ query=pubmed_search_string,
302
+ model_name=save_name,
303
+ start_year=start_year, end_year=end_year,
304
+ size=size,
305
+ uuid=uuid, customer_name=customer_name,
306
+ max_retries=max_retries, delay=delay
307
+ )
308
+ if results is None: # no valid result after max retries
309
+ # store exception strings in status
310
+ task.status_string[model_name] = exceptions
311
+ await upload_task_json_to_minio(task, client)
312
+ raise ConnectionError("Pubmed Data Fetch Failed.") # exit
313
+
314
+
315
+ async def process_pubmed_process_papers(
316
+ task: PubMedTask,
317
+ model_name: str,
318
+ save_name: str,
319
+ prev_name: str = None,
320
+ client: Minio = None,
321
+ max_retries: int = 5,
322
+ delay: float = 0.5
323
+ ):
324
+ """
325
+ Process PubMed Process Papers
326
+
327
+ Args:
328
+ task: PubMedTask object, containig basic information for PubMedTask
329
+ prev_model_name: str, previous model name, refer to previous step result
330
+ model_name: str, next model name, refer to the model used at this step
331
+ save_name: str, save name for minio path
332
+ client: Minio, minio client
333
+
334
+ Returns:
335
+ path to save results
336
+
337
+ """
338
+ if client is None:
339
+ client = get_client()
340
+
341
+ query = task.query
342
+ direction = task.direction
343
+ customer_name = task.customer_name
344
+ uuid = task.uuid
345
+
346
+ chat_func = get_chat_func(model_names=[model_name])[0]
347
+
348
+ non_review_pubmed_df = await get_dataframe_from_minio(
349
+ bucket_name=BUCKET_NAME,
350
+ object_name=f"{customer_name}/{uuid}/{prev_name}/pubmed_results_non_reviews.csv",
351
+ client=client
352
+ )
353
+ results, exceptions = await retry_operation(
354
+ process_papers, task,
355
+ dataframe=non_review_pubmed_df,
356
+ topic=query, direction=direction,
357
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
358
+ max_retries=max_retries, delay=delay,
359
+ chat_func=chat_func
360
+ )
361
+ if results is None: # no valid result after max retries
362
+ # store exception strings in status
363
+ task.status_string[model_name] = exceptions
364
+ await upload_task_json_to_minio(task, client)
365
+ raise RuntimeError("Pubmed Paper Processing Failed.") # exit
366
+
367
+
368
+ async def process_pubmed_generate_subheadings(
369
+ task: PubMedTask,
370
+ model_name: str,
371
+ save_name: str,
372
+ prev_name: str = None,
373
+ client: Minio = None,
374
+ max_retries: int = 5,
375
+ delay: float = 0.5
376
+ ):
377
+ """
378
+ Process PubMed Generate Subheadings
379
+ Args:
380
+ task: PubMedTask object, containig basic information for PubMedTask
381
+ prev_model_name: str, previous model name, refer to previous step result
382
+ model_name: str, next model name, refer to the model used at this step
383
+ save_name: str, save name for minio path
384
+
385
+ Returns:
386
+ path to save results
387
+ """
388
+ if client is None:
389
+ client = get_client()
390
+
391
+ query = task.query
392
+ customer_name = task.customer_name
393
+ uuid = task.uuid
394
+
395
+ chat_func = get_chat_func([model_name])[0]
396
+
397
+ relevant_papers_df = await get_dataframe_from_minio(
398
+ bucket_name=BUCKET_NAME,
399
+ object_name=f"{customer_name}/{uuid}/{prev_name}/relevant_papers.csv",
400
+ client=client
401
+ )
402
+
403
+ results, exceptions = await retry_operation(
404
+ generate_subheadings, task,
405
+ relevant_papers_df=relevant_papers_df,
406
+ main_topic=query,
407
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
408
+ chat_func=chat_func,
409
+ max_retries=max_retries, delay=delay
410
+ )
411
+ if results is None: # no valid result after max retries
412
+ # store exception strings in status
413
+ task.status_string[model_name] = exceptions
414
+ await upload_task_json_to_minio(task, client)
415
+ raise RuntimeError("Pubmed Generate Subheadings Failed.") # exit
416
+
417
+
418
+ async def process_pubmed_assign_subheadings_to_summaries(
419
+ task: PubMedTask,
420
+ model_name: str,
421
+ save_name: str,
422
+ prev_name: str = None,
423
+ client: Minio = None,
424
+ max_retries: int = 5,
425
+ delay: float = 0.5
426
+ ):
427
+ """
428
+ Process PubMed Assign Subheadings to Summaries
429
+ Args:
430
+ task: PubMedTask object, containig basic information for PubMedTask
431
+ prev_model_name: str, previous model name, refer to previous step result
432
+ model_name: str, next model name, refer to the model used at this step
433
+ save_name: str, save name for minio path
434
+
435
+ Returns:
436
+ path to save results
437
+ """
438
+
439
+ if client is None:
440
+ client = get_client()
441
+
442
+ customer_name = task.customer_name
443
+ uuid = task.uuid
444
+
445
+ chat_func = get_chat_func([model_name])[0]
446
+
447
+ subheadings = await get_file_from_minio(
448
+ bucket_name=BUCKET_NAME,
449
+ object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
450
+ client=client
451
+ )
452
+ subheadings = subheadings.data.decode("utf-8").split("\n")
453
+
454
+ relevant_papers_df = await get_dataframe_from_minio(
455
+ bucket_name=BUCKET_NAME,
456
+ object_name=f"{customer_name}/{uuid}/{prev_name}/relevant_papers.csv",
457
+ client=client
458
+ )
459
+
460
+ results, exceptions = await retry_operation(
461
+ assign_subheadings_to_summaries, task,
462
+ subheadings=subheadings,
463
+ relevant_papers_df=relevant_papers_df,
464
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
465
+ chat_func=chat_func,
466
+ max_retries=max_retries, delay=delay
467
+ )
468
+ if results is None: # no valid result after max retries
469
+ # store exception strings in status
470
+ task.status_string[model_name] = exceptions
471
+ await upload_task_json_to_minio(task, client)
472
+ raise RuntimeError("Pubmed Assign Subheadings Failed.") # exit
473
+
474
+
475
+ async def process_pubmed_create_paragraphs_by_subheading(
476
+ task: PubMedTask,
477
+ model_name: str,
478
+ save_name: str,
479
+ prev_name: str = None,
480
+ client: Minio = None,
481
+ max_retries: int = 5,
482
+ delay: float = 0.5
483
+ ):
484
+ """
485
+ Process PubMed Create Paragraphs by Subheading
486
+ Args:
487
+ task: PubMedTask object, containig basic information for PubMedTask
488
+ prev_model_name: str, previous model name, refer to previous step result
489
+ model_name: str, next model name, refer to the model used at this step
490
+ save_name: str, save name for minio path
491
+ client: Minio, minio client
492
+ max_retries: int, max retries for the operation
493
+ delay: float, delay between retries
494
+
495
+ Returns:
496
+ path to save results
497
+ """
498
+ if client is None:
499
+ client = get_client()
500
+
501
+ query = task.query
502
+ customer_name = task.customer_name
503
+ uuid = task.uuid
504
+
505
+ chat_func = get_chat_func([model_name])[0]
506
+
507
+ subheadings = await get_file_from_minio(
508
+ bucket_name=BUCKET_NAME,
509
+ object_name=f"{customer_name}/{uuid}/{prev_name}/generated_subheadings.txt",
510
+ client=client
511
+ )
512
+ subheadings = subheadings.data.decode("utf-8").split("\n")
513
+
514
+ relevant_papers_df = await get_dataframe_from_minio(
515
+ bucket_name=BUCKET_NAME,
516
+ object_name=f"{customer_name}/{uuid}/{prev_name}/assigned_subheadings.csv",
517
+ client=client
518
+ )
519
+
520
+ results, exceptions = await retry_operation(
521
+ create_paragraphs_by_subheading, task,
522
+ subheadings=subheadings, main_topic=query,
523
+ relevant_papers_df=relevant_papers_df,
524
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
525
+ chat_func=chat_func,
526
+ max_retries=max_retries, delay=delay
527
+ )
528
+ if results is None: # no valid result after max retries
529
+ # store exception strings in status
530
+ task.status_string[model_name] = exceptions
531
+ await upload_task_json_to_minio(task, client)
532
+ raise RuntimeError("Pubmed Create Paragraphs Failed.") # exit
533
+
534
+
535
+ async def process_pubmed_enhance_language_readability(
536
+ task: PubMedTask,
537
+ model_name: str,
538
+ save_name: str,
539
+ prev_name: str = None,
540
+ client: Minio = None,
541
+ max_retries: int = 5,
542
+ delay: float = 0.5
543
+ ):
544
+ """
545
+ Process PubMed Enhance Language Readability
546
+ Args:
547
+ task: PubMedTask object, containig basic information for PubMedTask
548
+ prev_model_name: str, previous model name, refer to previous step result
549
+ model_name: str, next model name, refer to the model used at this step
550
+ save_name: str, save name for minio path
551
+ client: Minio, minio client
552
+ max_retries: int, max retries for the operation
553
+ delay: float, delay between retries
554
+
555
+ Returns:
556
+ path to save results
557
+ """
558
+ if client is None:
559
+ client = get_client()
560
+
561
+ customer_name = task.customer_name
562
+ uuid = task.uuid
563
+
564
+ chat_func = get_chat_func([model_name])[0]
565
+
566
+ review_content = await get_file_from_minio(
567
+ bucket_name=BUCKET_NAME,
568
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_non_refined.txt",
569
+ client=client
570
+ )
571
+ review_content = review_content.data.decode("utf-8")
572
+
573
+ results, exceptions = await retry_operation(
574
+ enhance_language_readability, task,
575
+ content=review_content,
576
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
577
+ chat_func=chat_func,
578
+ max_retries=max_retries, delay=delay
579
+ )
580
+ if results is None: # no valid result after max retries
581
+ # store exception strings in status
582
+ task.status_string[model_name] = exceptions
583
+ await upload_task_json_to_minio(task, client)
584
+ raise RuntimeError("Pubmed Enhance Language Readability Failed.") # exit
585
+
586
+
587
+ async def process_pubmed_translate(
588
+ task: PubMedTask,
589
+ model_name: str,
590
+ save_name: str,
591
+ prev_name: str = None,
592
+ client: Minio = None,
593
+ max_retries: int = 5,
594
+ delay: float = 0.5
595
+ ):
596
+ """
597
+ Process PubMed Translate
598
+ Args:
599
+ task: PubMedTask object, containig basic information for PubMedTask
600
+ prev_model_name: str, previous model name, refer to previous step result
601
+ model_name: str, next model name, refer to the model used at this step
602
+ save_name: str, save name for minio path
603
+ client: Minio, minio client
604
+ max_retries: int, max retries for the operation
605
+ delay: float, delay between retries
606
+
607
+ Returns:
608
+ path to save results
609
+ """
610
+
611
+ if client is None:
612
+ client = get_client()
613
+
614
+ customer_name = task.customer_name
615
+ uuid = task.uuid
616
+
617
+ chat_func = get_chat_func([model_name])[0]
618
+
619
+ review_content = await get_file_from_minio(
620
+ bucket_name=BUCKET_NAME,
621
+ object_name=f"{customer_name}/{uuid}/{prev_name}/review_paper.txt",
622
+ client=client
623
+ )
624
+ review_content = review_content.data.decode("utf-8")
625
+
626
+ results, exceptions = await retry_operation(
627
+ translate_to_chinese_before_references, task,
628
+ text=review_content,
629
+ uuid=uuid, customer_name=customer_name, model_name=save_name,
630
+ chat_func=chat_func,
631
+ max_retries=max_retries, delay=delay
632
+ )
633
+ if results is None: # no valid result after max retries
634
+ # store exception strings in status
635
+ task.status_string[model_name] = exceptions
636
+ await upload_task_json_to_minio(task, client)
637
+ raise RuntimeError("Pubmed Translate Failed.") # exit
638
+
639
+
640
+ # =================================
641
+ # Function Groups: PubMed Task
642
+ #
643
+ # functions specific for pubmed task
644
+ # =================================
645
+
646
+ async def generate_pubmed_search_string(query: str, chat_func) -> str:
647
+ # Construct the improved prompt using triple single quotes
648
+ prompt = f'''
649
+ ### Objective
650
+ Your task is to generate a precise PubMed search string based on the input query: "{query}". You should:
651
+
652
+ 1. **Extract Critical Keywords**: Identify the main entities and concepts that have independent and specific meanings, avoiding overly general terms commonly found in many articles (e.g., "analysis", "study"). Focus on terms central to the topic.
653
+
654
+ 2. **Understand Keyword Relationships**: Analyze the logical relationship between keywords. If two or more keywords are conceptually similar or interchangeable, connect them using the OR operator. If they represent distinct concepts that must co-exist, connect them using the AND operator.
655
+
656
+ 3. **Expand Synonyms Thoughtfully**: For each critical keyword, generate at least 6 relevant English synonyms or related terms used in academic research. Ensure they align with the context of the query, including synonyms that may look different but are relevant based on the keyword's definition and hierarchy.
657
+
658
+ 4. **Include MeSH Terms**: Find the corresponding MeSH (Medical Subject Headings) terms for each critical keyword if available.
659
+
660
+ 5. **Construct the PubMed Search String**: Combine the critical keywords, their synonyms, and MeSH terms using Boolean operators. Ensure correct grouping using parentheses to reflect the logical relationships:
661
+ - If a group of terms is interchangeable (e.g., synonyms), use OR within parentheses.
662
+ - Use AND between distinct keyword groups.
663
+
664
+ ### Instructions
665
+ - **Language**: All words must be in English.
666
+ - **Avoid Stop Words**: Do not include stop words (e.g., 'a', 'an', 'the').
667
+ - **Synonym Requirement**: For each critical keyword, generate **at least 6 synonyms** or related terms.
668
+ - **Logical Operator Selection**: Adjust the Boolean logic based on the relationship between terms to accurately represent (A OR B) AND C patterns.
669
+ - **Term Length**: Each term should be concise, with phrases containing at most two words.
670
+ - **Formatting**:
671
+ - Use Boolean operators (AND, OR) to connect terms and use parentheses where necessary.
672
+ - Format MeSH terms as: "Term"[MeSH Terms]
673
+ - Format other terms as: "Term"[All Fields]
674
+
675
+ ### Example
676
+ **Input**: Role of AI in antimicrobial resistance and drug discovery
677
+
678
+ **Process**:
679
+ 1. **Extract Critical Keywords**:
680
+ - AI
681
+ - Antimicrobial resistance
682
+ - Drug discovery
683
+
684
+ 2. **Analyze Keyword Relationships**:
685
+ - AI OR machine learning (similar concepts)
686
+ - Antimicrobial resistance AND drug discovery (distinct concepts)
687
+
688
+ 3. **Expand Synonyms Thoughtfully**:
689
+ - **AI**: machine learning, artificial intelligence, deep learning, neural networks, computational intelligence, data-driven algorithms
690
+ - **Antimicrobial resistance**: antibiotic resistance, drug resistance, microbial resistance, bacterial resistance, pathogen resistance, multidrug resistance
691
+ - **Drug discovery**: drug design, pharmaceutical research, drug development, lead discovery, molecular screening, target identification
692
+
693
+ 4. **Include MeSH Terms**:
694
+ - **AI**: "Artificial Intelligence"[MeSH Terms]
695
+ - **Antimicrobial resistance**: "Drug Resistance, Microbial"[MeSH Terms]
696
+ - **Drug discovery**: "Drug Discovery"[MeSH Terms]
697
+
698
+ 5. **Construct the PubMed Search String**:
699
+
700
+ '(("Artificial Intelligence"[MeSH Terms] OR "machine learning"[All Fields] OR "deep learning"[All Fields] OR "neural networks"[All Fields] OR "computational intelligence"[All Fields] OR "data-driven algorithms"[All Fields]) AND ("Drug Resistance, Microbial"[MeSH Terms] OR "antibiotic resistance"[All Fields] OR "microbial resistance"[All Fields] OR "bacterial resistance"[All Fields] OR "pathogen resistance"[All Fields] OR "multidrug resistance"[All Fields])) AND ("Drug Discovery"[MeSH Terms] OR "drug design"[All Fields] OR "pharmaceutical research"[All Fields] OR "drug development"[All Fields] OR "lead discovery"[All Fields] OR "molecular screening"[All Fields])'
701
+
702
+ ### Now, generate the PubMed search string for the following query:
703
+
704
+ **Query**: {query}
705
+
706
+ Please provide only the final PubMed search string in the specified format.
707
+ '''
708
+
709
+ # Call the language model to get the PubMed search string
710
+ result = await chat_func(prompt)
711
+
712
+ # Extract the PubMed search string from the model's response
713
+ pubmed_search_string = result.choices[0].message.content.strip()
714
+
715
+ return pubmed_search_string
716
+
717
+
718
+ async def process_pubmed_data(
719
+ query,
720
+ model_name,
721
+ start_year, end_year, size,
722
+ uuid, customer_name
723
+ ):
724
+ """
725
+ Process PubMed Data
726
+
727
+ Args:
728
+ query: str, query for PubMed search
729
+ model_name: str, model name
730
+ start_year: int, start year for PubMed search
731
+ end_year: int, end year for PubMed search
732
+ size: int, number of results per page
733
+ uuid: str, uuid for the task
734
+ customer_name: str, customer name for the task
735
+ client: Minio, minio client
736
+
737
+ Returns:
738
+ path to save results
739
+
740
+ """
741
+
742
+ # get prefix
743
+ prefix = f"{customer_name}/{uuid}/{model_name}/"
744
+ output_folder = prefix
745
+
746
+ # set file paths
747
+ combined_txt_filename = os.path.join(
748
+ output_folder, f'pubmed_page_combined.txt')
749
+ results_csv_filename = os.path.join(output_folder, f'pubmed_results.csv')
750
+ results_with_links_csv_filename = os.path.join(
751
+ output_folder, f'pubmed_results_with_full_text_links.csv')
752
+ impact_factors_csv_filename = os.path.join(
753
+ output_folder, f'pubmed_results_with_impact_factors.csv')
754
+ non_review_csv_filename = os.path.join(
755
+ output_folder, f'pubmed_results_non_reviews.csv')
756
+
757
+ # step 1: save pubmed pages
758
+ await save_combined_pubmed_page(query, start_year, end_year, size, output_filename=combined_txt_filename)
759
+
760
+ # step 2: process pubmed files
761
+ await process_pubmed_file(combined_txt_filename, results_csv_filename)
762
+
763
+ # step 3:添加全文链接
764
+ # pubmed_df = pd.read_csv(results_csv_filename)
765
+ pubmed_df = await get_dataframe_from_minio(
766
+ bucket_name=BUCKET_NAME,
767
+ object_name=results_csv_filename
768
+ )
769
+ pubmed_df["Full_Text_Links"] = pubmed_df["PMID"].apply(get_full_text_links)
770
+ await upload_dataframe_to_minio(
771
+ bucket_name=BUCKET_NAME,
772
+ object_name=results_with_links_csv_filename,
773
+ df=pubmed_df
774
+ )
775
+
776
+ # step 4: merge impact factor
777
+ impact_factors_df = await get_dataframe_from_minio(
778
+ bucket_name=BUCKET_NAME,
779
+ object_name='2023-JCR.xlsx'
780
+ )
781
+
782
+ # Standardize the case of the JT column in both dataframes to lowercase
783
+ pubmed_df['JT'] = pubmed_df['JT'].str.lower()
784
+ impact_factors_df['JT'] = impact_factors_df['JT'].str.lower()
785
+
786
+ # Perform the merge based on the JT column
787
+ merged_df = pd.merge(pubmed_df, impact_factors_df, on='JT', how='left')
788
+
789
+ # Save the merged dataframe to a new CSV file
790
+ await upload_dataframe_to_minio(
791
+ bucket_name=BUCKET_NAME,
792
+ object_name=impact_factors_csv_filename,
793
+ df=merged_df
794
+ )
795
+ logger.info(f"Merged data saved to {impact_factors_csv_filename}")
796
+
797
+ # step 5: filter non review papers
798
+ pubmed_df = await get_dataframe_from_minio(
799
+ bucket_name=BUCKET_NAME,
800
+ object_name=impact_factors_csv_filename
801
+ )
802
+ non_review_pubmed_df = pubmed_df[pubmed_df["Review"] == "No"]
803
+ await upload_dataframe_to_minio(
804
+ bucket_name=BUCKET_NAME,
805
+ object_name=non_review_csv_filename,
806
+ df=non_review_pubmed_df
807
+ )
808
+
809
+ logger.info(f"非评论类文章已保存到 {non_review_csv_filename}")
810
+
811
+ return pubmed_df, non_review_pubmed_df
812
+
813
+
814
+ async def save_combined_pubmed_page(query, start_year, end_year, size=200, output_filename='pubmed_page_combined.txt'):
815
+ content1 = await save_pubmed_page(query, start_year, end_year, size)
816
+ content2 = await save_pubmed_page_date(query, start_year, end_year, size)
817
+
818
+ combined_content = content1 + "\n" + content2
819
+
820
+ # 保存合并的网页内容到指定的txt文件
821
+ # async with aiofiles.open(output_filename, 'w', encoding='utf-8') as file:
822
+ # await file.write(combined_content)
823
+ await upload_text_to_minio(
824
+ bucket_name=BUCKET_NAME,
825
+ object_name=output_filename,
826
+ file_content=combined_content
827
+ )
828
+
829
+ logger.info(f"Page content saved to {output_filename}")
830
+
831
+
832
+ async def save_pubmed_page(query, start_year, end_year, size=200):
833
+ base_url = "https://pubmed.ncbi.nlm.nih.gov/"
834
+ params = {
835
+ 'term': query,
836
+ 'filter': f'years.{start_year}-{end_year}',
837
+ 'format': 'pubmed',
838
+ 'size': size
839
+ }
840
+
841
+ # 构建检索网址
842
+ search_url = f"{base_url}?term={params['term']}&filter={params['filter']}&format={params['format']}&size={params['size']}"
843
+ logger.info(f"检索网址: {search_url}")
844
+
845
+ async with aiohttp.ClientSession() as session:
846
+ async with session.get(base_url, params=params) as response:
847
+ if response.status != 200:
848
+ logger.error("Failed to retrieve data from save_pubmed_page")
849
+ raise ConnectionError(
850
+ "Failed to retrieve data from save_pubmed_page")
851
+ return await response.text()
852
+
853
+
854
+ async def save_pubmed_page_date(query, start_year, end_year, size=200):
855
+ base_url = "https://pubmed.ncbi.nlm.nih.gov/"
856
+ params = {
857
+ 'term': query,
858
+ 'filter': f'years.{start_year}-{end_year}',
859
+ 'format': 'pubmed',
860
+ 'size': size,
861
+ 'sort': 'date'
862
+ }
863
+
864
+ # 构建检索网址
865
+ search_url = f"{base_url}?term={params['term']}&filter={params['filter']}&format={params['format']}&size={params['size']}&sort={params['sort']}"
866
+ logger.info(f"检索网址: {search_url}")
867
+
868
+ async with aiohttp.ClientSession() as session:
869
+ async with session.get(base_url, params=params) as response:
870
+ if response.status != 200:
871
+ logger.error(
872
+ "Failed to retrieve data from save_pubmed_page_date")
873
+ raise ConnectionError(
874
+ "Failed to retrieve data from save_pubmed_page_date")
875
+ return await response.text()
876
+
877
+
878
+ async def process_pubmed_file(input_file, output_file):
879
+ # Read the file and replace specific text
880
+ # async with aiofiles.open(input_file, 'r', encoding='utf-8') as file:
881
+ # content = await file.read()
882
+
883
+ content = await get_file_from_minio(
884
+ bucket_name=BUCKET_NAME,
885
+ object_name=input_file
886
+ )
887
+ content = content.data.decode("utf-8")
888
+ content = content.replace(
889
+ '<pre class="search-results-chunk">PMID-', '<pre class="search-results-chunk">\nPMID-')
890
+
891
+ # Split the content into lines
892
+ lines = content.split('\n')
893
+ records = []
894
+ current_record = {}
895
+ collecting_abstract = False
896
+ collecting_title = False
897
+ collecting_pt = False
898
+ abstract_lines = []
899
+ title_lines = []
900
+ pt_lines = []
901
+ first_author_recorded = False # Flag to capture the first occurrence of FAU
902
+
903
+ for line in lines:
904
+ if line.startswith("PMID- "):
905
+ if current_record:
906
+ # Finalize the current record before starting a new one
907
+ current_record['AB'] = ' '.join(
908
+ abstract_lines).replace('\n', ' ')
909
+ current_record['TI'] = ' '.join(title_lines).replace('\n', ' ')
910
+ current_record['PT'] = ' '.join(pt_lines).replace('\n', ' ')
911
+
912
+ # Default Review to 'No' if not set to 'Yes' during PT or AB processing
913
+ if 'Review' not in current_record:
914
+ current_record['Review'] = 'No'
915
+
916
+ # Check for mismatches between FAU-frist and the first entry in FAU list
917
+ if 'FAU-frist' in current_record and 'FAU' in current_record and current_record['FAU']:
918
+ if current_record['FAU-frist'] != current_record['FAU'][0]:
919
+ current_record['FAU'].insert(
920
+ 0, current_record['FAU-frist'])
921
+
922
+ if 'JT' not in current_record:
923
+ # Ensure JT is present even if not found
924
+ current_record['JT'] = ''
925
+ if 'DCOM' not in current_record:
926
+ # Ensure DCOM is present even if not found
927
+ current_record['DCOM'] = ''
928
+
929
+ # Add current record to list of records
930
+ records.append(current_record)
931
+
932
+ # Start a new record
933
+ current_record = {'PMID': line.split("PMID- ")[1].strip()}
934
+ collecting_abstract = False
935
+ collecting_title = False
936
+ collecting_pt = False
937
+ abstract_lines = []
938
+ title_lines = []
939
+ pt_lines = []
940
+ first_author_recorded = False # Reset the flag for a new record
941
+
942
+ elif line.startswith("FAU - "):
943
+ # Append each FAU to a list and capture the first occurrence
944
+ author_name = line.split("FAU - ")[1].strip()
945
+ if 'FAU' not in current_record:
946
+ current_record['FAU'] = []
947
+ current_record['FAU'].append(author_name)
948
+
949
+ # Record the first occurrence in FAU-frist
950
+ if not first_author_recorded:
951
+ current_record['FAU-frist'] = author_name
952
+ first_author_recorded = True
953
+
954
+ elif line.startswith("JT - "):
955
+ current_record['JT'] = line.split("JT - ")[1].strip()
956
+
957
+ elif line.startswith("DCOM- "):
958
+ current_record['DCOM'] = line.split("DCOM- ")[1].strip()
959
+
960
+ elif line.startswith("TI - "):
961
+ collecting_title = True
962
+ title_lines.append(line.split("TI - ")[1].strip())
963
+
964
+ elif collecting_title:
965
+ if any(line.startswith(prefix) for prefix in ["LID - ", "AB - ", "FAU - ", "PG - "]):
966
+ collecting_title = False
967
+ else:
968
+ title_lines.append(line.strip())
969
+
970
+ elif line.startswith("LID - "):
971
+ lid = line.split("LID - ")[1].strip()
972
+ if '[doi]' in lid:
973
+ lid = lid.split(' [doi]')[0]
974
+ # 保留较长的LID
975
+ if 'LID' in current_record:
976
+ current_record['LID'] = lid if len(lid) > len(
977
+ current_record['LID']) else current_record['LID']
978
+ else:
979
+ current_record['LID'] = lid
980
+
981
+ elif line.startswith("AB - "):
982
+ collecting_abstract = True
983
+ abstract_text = line.split("AB - ")[1].strip()
984
+ abstract_lines.append(abstract_text)
985
+ # Check if 'review' is in AB line (case insensitive)
986
+ if 'review' in abstract_text.lower():
987
+ current_record['Review'] = 'Yes'
988
+
989
+ elif collecting_abstract:
990
+ if any(line.startswith(prefix) for prefix in ["LID - ", "FAU - ", "PG - "]):
991
+ collecting_abstract = False
992
+ else:
993
+ abstract_text = line.strip()
994
+ abstract_lines.append(abstract_text)
995
+ # Check if 'review' is in AB line (case insensitive)
996
+ if 'review' in abstract_text.lower():
997
+ current_record['Review'] = 'Yes'
998
+
999
+ elif line.startswith("PT - "):
1000
+ pt_line = line.split("PT - ")[1].strip()
1001
+ pt_lines.append(pt_line)
1002
+ # Check if 'review' is in PT line (case insensitive)
1003
+ if 'review' in pt_line.lower():
1004
+ current_record['Review'] = 'Yes'
1005
+
1006
+ elif collecting_pt:
1007
+ if any(line.startswith(prefix) for prefix in ["LID - ", "AB - ", "FAU - ", "PG - "]):
1008
+ collecting_pt = False
1009
+ else:
1010
+ pt_text = line.strip()
1011
+ pt_lines.append(pt_text)
1012
+ # Check if 'review' is in PT line (case insensitive)
1013
+ if 'review' in pt_text.lower():
1014
+ current_record['Review'] = 'Yes'
1015
+
1016
+ # Final record handling after loop ends
1017
+ if current_record:
1018
+ current_record['AB'] = ' '.join(abstract_lines).replace('\n', ' ')
1019
+ current_record['TI'] = ' '.join(title_lines).replace('\n', ' ')
1020
+ current_record['PT'] = ' '.join(pt_lines).replace('\n', ' ')
1021
+ if 'Review' not in current_record:
1022
+ current_record['Review'] = 'No'
1023
+
1024
+ # Check for mismatches between FAU-frist and the first entry in FAU list
1025
+ if 'FAU-frist' in current_record and 'FAU' in current_record and current_record['FAU']:
1026
+ if current_record['FAU-frist'] != current_record['FAU'][0]:
1027
+ current_record['FAU'].insert(0, current_record['FAU-frist'])
1028
+
1029
+ if 'JT' not in current_record:
1030
+ current_record['JT'] = ''
1031
+ if 'DCOM' not in current_record:
1032
+ current_record['DCOM'] = ''
1033
+
1034
+ records.append(current_record)
1035
+
1036
+ # Remove duplicate records by PMID
1037
+ unique_records = []
1038
+ seen_pmids = set()
1039
+ for record in records:
1040
+ if record['PMID'] not in seen_pmids:
1041
+ seen_pmids.add(record['PMID'])
1042
+ unique_records.append(record)
1043
+
1044
+ # Write unique records to output CSV
1045
+ # async with aiofiles.open(output_file, 'w', encoding='utf-8', newline='') as csvfile:
1046
+ text = ""
1047
+ fieldnames = ['JT', 'DCOM', 'PMID', 'TI', 'LID',
1048
+ 'AB', 'FAU', 'FAU-frist', 'PT', 'Review']
1049
+ header = ','.join(fieldnames) + '\n'
1050
+ text += header
1051
+
1052
+ # Write each record
1053
+ for record in unique_records:
1054
+ # Join the FAU list as a single string
1055
+ record['FAU'] = '; '.join(record.get('FAU', [])) # Safely get 'FAU'
1056
+
1057
+ # Prepare the row as a CSV string
1058
+ row = ','.join([escape_csv_field(str(record.get(field, '')))
1059
+ for field in fieldnames]) + '\n'
1060
+ text += row
1061
+
1062
+ await upload_text_to_minio(
1063
+ bucket_name=BUCKET_NAME,
1064
+ object_name=output_file,
1065
+ file_content=text
1066
+ )
1067
+
1068
+
1069
+ def get_full_text_links(pmid):
1070
+ url = f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/"
1071
+ response = requests.get(url)
1072
+ soup = BeautifulSoup(response.text, 'html.parser')
1073
+
1074
+ # 从页面中提取所有链接
1075
+ links = [link['href'] for link in soup.find_all('a', href=True)]
1076
+
1077
+ # 如果存在第27个链接,则返回它,否则返回None
1078
+ return links[26] if len(links) >= 27 else None
utils/r2_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import io
3
+ import json
4
+ import asyncio
5
+ import pandas as pd
6
+
7
+ from docx import Document
8
+ from loguru import logger
9
+ from entities.task import Task, task_factory
10
+
11
+
12
+ BUCKET_NAME = "ai-scientist"
13
+
14
+ r2_endpoint = "https://468d92a3c903c841bc2de3b413e45072.r2.cloudflarestorage.com/ai-scientist"
15
+
16
+ TOKEN = "KhGGD1ZJI_YTlLaZ0nSMfBJSLnOhgYN6cwq1De7G"
17
+ R2_ACCESS_KEY_ID = "b9bc4becece838742ae1dc161be92de3"
18
+ R2_SECRET_ACCESS_KEY = "f68eb82bd1c00528f26c6ac9b57d737fe0e4729ac7c429030fbc22a17dc8f105"
19
+
20
+ def get_client():
21
+ return boto3.client(
22
+ "s3",
23
+ endpoint_url=r2_endpoint,
24
+ aws_access_key_id=R2_ACCESS_KEY_ID,
25
+ aws_secret_access_key=R2_SECRET_ACCESS_KEY,
26
+ region_name="auto" # R2 需要设置为 auto
27
+ )
28
+
29
+
30
+ async def get_task_from_minio(
31
+ uuid: str,
32
+ customer_name: str,
33
+ client=None
34
+ ) -> Task:
35
+ if client is None:
36
+ client = get_client()
37
+
38
+ response = await asyncio.to_thread(
39
+ lambda: client.list_objects_v2(
40
+ Bucket=BUCKET_NAME,
41
+ Prefix=f"{customer_name}/"
42
+ )
43
+ )
44
+
45
+ objects = response.get("Contents", [])
46
+ if not objects:
47
+ raise FileNotFoundError(f"No task found for customer {customer_name}")
48
+
49
+ object_names = [obj["Key"].split("/")[1] for obj in objects]
50
+ if uuid not in object_names:
51
+ raise FileNotFoundError(f"No task found for customer {customer_name} with uuid {uuid}")
52
+
53
+ json_file = await get_file_from_minio(
54
+ bucket_name=BUCKET_NAME,
55
+ object_name=f"{customer_name}/{uuid}/task.json",
56
+ client=client
57
+ )
58
+
59
+ json_data = json_file.decode("utf-8")
60
+ json_data = json.loads(json_data)
61
+ return task_factory[json_data["task_type"]].load_from_json(json_data)
62
+
63
+
64
+ async def get_all_tasks_from_minio(
65
+ customer_name: str,
66
+ client=None
67
+ ) -> list[Task]:
68
+ if client is None:
69
+ client = get_client()
70
+
71
+ response = await asyncio.to_thread(
72
+ lambda: client.list_objects_v2(
73
+ Bucket=BUCKET_NAME,
74
+ Prefix=f"{customer_name}/"
75
+ )
76
+ )
77
+ objects = response.get("Contents", [])
78
+ if not objects:
79
+ return []
80
+
81
+ task_ids = list(set([obj["Key"].split("/")[1] for obj in objects]))
82
+ task_jsons = await asyncio.gather(
83
+ *(get_task_from_minio(uuid=task_id, customer_name=customer_name, client=client) for task_id in task_ids)
84
+ )
85
+ return task_jsons
86
+
87
+
88
+ async def upload_task_json_to_minio(task: Task, client=None) -> Task:
89
+ if client is None:
90
+ client = get_client()
91
+
92
+ json_data = task.save_to_json()
93
+ byte_data = io.BytesIO(json_data.encode("utf-8"))
94
+
95
+ await asyncio.to_thread(
96
+ lambda: client.put_object(
97
+ Bucket=BUCKET_NAME,
98
+ Key=f"{task.customer_name}/{task.uuid}/task.json",
99
+ Body=byte_data,
100
+ ContentType="application/json"
101
+ )
102
+ )
103
+ return task
104
+
105
+
106
+ async def upload_text_to_minio(
107
+ bucket_name: str,
108
+ object_name: str,
109
+ file_content: str,
110
+ client=None,
111
+ ):
112
+ if client is None:
113
+ client = get_client()
114
+
115
+ file_data = io.BytesIO(file_content.encode("utf-8"))
116
+
117
+ await asyncio.to_thread(
118
+ lambda: client.put_object(
119
+ Bucket=bucket_name,
120
+ Key=object_name,
121
+ Body=file_data
122
+ )
123
+ )
124
+
125
+
126
+ async def upload_dataframe_to_minio(
127
+ bucket_name: str,
128
+ object_name: str,
129
+ df: pd.DataFrame,
130
+ client=None,
131
+ ):
132
+ buffer = io.BytesIO()
133
+ df.to_csv(buffer, index=False)
134
+ await upload_text_to_minio(
135
+ bucket_name=bucket_name,
136
+ object_name=object_name,
137
+ file_content=buffer.getvalue().decode("utf-8"),
138
+ client=client
139
+ )
140
+
141
+
142
+ async def upload_document_to_minio(
143
+ bucket_name: str,
144
+ object_name: str,
145
+ document: Document,
146
+ client=None,
147
+ ):
148
+ if client is None:
149
+ client = get_client()
150
+
151
+ buffer = io.BytesIO()
152
+ document.save(buffer)
153
+ buffer.seek(0)
154
+
155
+ await asyncio.to_thread(
156
+ lambda: client.put_object(
157
+ Bucket=bucket_name,
158
+ Key=object_name,
159
+ Body=buffer,
160
+ ContentType="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
161
+ )
162
+ )
163
+
164
+
165
+ async def get_file_from_minio(
166
+ bucket_name: str,
167
+ object_name: str,
168
+ client=None,
169
+ ):
170
+ if client is None:
171
+ client = get_client()
172
+
173
+ try:
174
+ response = await asyncio.to_thread(
175
+ lambda: client.get_object(Bucket=bucket_name, Key=object_name)
176
+ )
177
+ return response["Body"].read()
178
+ except Exception as e:
179
+ raise Exception(f"Error getting file from minio: {e}")
180
+
181
+
182
+ async def get_dataframe_from_minio(
183
+ bucket_name: str,
184
+ object_name: str,
185
+ client=None,
186
+ ):
187
+ file_data = await get_file_from_minio(
188
+ bucket_name=bucket_name,
189
+ object_name=object_name,
190
+ client=client
191
+ )
192
+
193
+ if object_name.endswith(".csv"):
194
+ df = pd.read_csv(io.BytesIO(file_data))
195
+ elif object_name.endswith(".xlsx") or object_name.endswith(".xls"):
196
+ df = pd.read_excel(io.BytesIO(file_data))
197
+ return df