RCaz commited on
Commit
c64c1d6
·
1 Parent(s): 2ea0c95

added tools and agent

Browse files
agent.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ from dotenv import load_dotenv
5
+ from markdownify import markdownify
6
+ from requests.exceptions import RequestException
7
+ from smolagents import (
8
+ LiteLLMModel,
9
+ CodeAgent,
10
+ ToolCallingAgent,
11
+ InferenceClientModel,
12
+ WebSearchTool,
13
+ tool,
14
+ FinalAnswerTool,
15
+ WikipediaSearchTool,
16
+ VisitWebpageTool,
17
+ DuckDuckGoSearchTool
18
+ )
19
+
20
+ load_dotenv()
21
+
22
+ from langfuse import get_client
23
+ langfuse = get_client()
24
+ if langfuse.auth_check():
25
+ print("Langfuse client is authenticated and ready!")
26
+ else:
27
+ print("Authentication failed. Please check your credentials and host.")
28
+
29
+
30
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
31
+ SmolagentsInstrumentor().instrument()
32
+
33
+ model = LiteLLMModel(
34
+ model_id="openai/Qwen/Qwen3-Coder-480B-A35B-Instruct",
35
+ api_key=os.environ.get("NEBIUS_API_KEY"),
36
+ api_base="https://api.tokenfactory.nebius.com/v1/"
37
+ )
38
+
39
+ from tool_clinical_trial import ClinicalTrialsSearchTool
40
+
41
+
42
+ @tool
43
+ def search_pubmed(topic: str, author: str) -> list[str]:
44
+ """
45
+ Searches the PubMed database for articles related to a specific topic.
46
+
47
+ Args:
48
+ topic: The topic or keywords to search for (e.g., "CRISPR gene editing").
49
+ author: The name of the author to search for (e.g., "Albert Einstein").
50
+
51
+ Returns:
52
+ A list of PubMed IDs (strings) for the top 100 articles found.
53
+
54
+ Raises:
55
+ requests.exceptions.HTTPError: If the API request fails.
56
+ """
57
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
58
+
59
+ terms = []
60
+ if topic:
61
+ terms.append(topic)
62
+ if author:
63
+ terms.append(f"{author}[Author]")
64
+
65
+ query = " AND ".join(terms)
66
+ params = {
67
+ "db": "pubmed",
68
+ "term": query,
69
+ "retmode": "json",
70
+ "retmax": 1000
71
+ }
72
+ response = requests.get(base_url, params=params)
73
+ response.raise_for_status()
74
+ data = response.json()
75
+
76
+ return data["esearchresult"]["idlist"]
77
+
78
+ @tool
79
+ def parse_pdf(pdf_path:str)->list[str]:
80
+ """
81
+ Reads a PDF file from a specified path and extracts the text content
82
+ from every page.
83
+
84
+ Args:
85
+ pdf_path: The local file path (string) to the PDF document to be parsed.
86
+ **NOTE**: In a remote agent environment, this path must be
87
+ accessible by the executing process (e.g., a path to an
88
+ uploaded file).
89
+
90
+ Returns:
91
+ A list of strings, where each string is the extracted text content
92
+ from a single page of the PDF.
93
+ """
94
+ from pypdf import PdfReader
95
+
96
+ reader = PdfReader(pdf_path)
97
+ number_of_pages = len(reader.pages)
98
+ text=list()
99
+ for p in range(number_of_pages):
100
+ page = reader.pages[p]
101
+ text.append(page.extract_text())
102
+ return text
103
+
104
+ # @tool
105
+ # def make_rag_ressource(paths :list(str)) -> list(str):
106
+ # """
107
+ # Use extracted text to build a RAG tool and retreive documents to use to answer request
108
+
109
+ # Args:
110
+ # paths: The list of path where the file are stored
111
+
112
+ # Returns:
113
+ # A list of strings, where each string is the extracted text content
114
+ # from the retreiver
115
+ # """
116
+
117
+ # pdf_files=[]
118
+ # for path in paths:
119
+
120
+
121
+ # pdf_documents = []
122
+ # for pdf_file in pdf_files:
123
+ # loader = PyPDFLoader(pdf_file)
124
+ # pdf_documents.extend(loader.load())
125
+ # embeddings_model = OpenAIEmbeddings()
126
+ # pdf_texts = [doc.page_content for doc in pdf_documents]
127
+ # return ""
128
+
129
+
130
+ # # Initialize the model
131
+ # model = InferenceClientModel(
132
+ # model_id="Qwen/Qwen3-Coder-30B-A3B-Instruct",
133
+ # provider="nebius"
134
+ # )
135
+
136
+
137
+
138
+ # Create clinical trial search agent
139
+
140
+ clinical_agent = CodeAgent(
141
+ name="clinical_agent",
142
+ description=(
143
+ "Retrieve and parse clinical study data for a given disease. "
144
+ "Use ClinicalTrialsSearchTool for trials, search_pubmed for authors, and parse_pdf for full-text analysis. "
145
+ "Return structured tables or summaries as requested."
146
+ "Gather general or recent information from online sources. "
147
+ "Use Wikipedia for overviews, DuckDuckGo for recent data, and VisitWebpageTool for specific URLs. "
148
+ "Return structured summaries with sources."
149
+ ),
150
+ tools=[ClinicalTrialsSearchTool()],
151
+ additional_authorized_imports=["time", "numpy", "pandas"],
152
+ # executor_type="blaxel", #executor_type="modal",
153
+ use_structured_outputs_internally=True,
154
+ return_full_result=True,
155
+ planning_interval=3, # V3 add structure
156
+ model=model,
157
+ max_steps=6,
158
+ verbosity_level=2
159
+ )
160
+
161
+ search_online_info = CodeAgent(
162
+ name="search_online_info",
163
+ description=(
164
+ "Gather general or recent information from online sources. "
165
+ "Use Wikipedia for overviews, DuckDuckGo for recent data, and VisitWebpageTool for specific URLs. "
166
+ "Return structured summaries with sources."
167
+ ),
168
+ tools=[WikipediaSearchTool(),VisitWebpageTool(max_output_length=10000),DuckDuckGoSearchTool(max_results=5),search_pubmed,parse_pdf],
169
+ additional_authorized_imports=["time", "numpy", "pandas"],
170
+ # use_structured_outputs_internally=True,
171
+ # executor_type="modal",
172
+ planning_interval=2,
173
+ model=model,
174
+ max_steps=4,
175
+ verbosity_level=2
176
+ )
177
+
178
+
179
+
180
+ manager_agent = CodeAgent(
181
+ name="manager_agent",
182
+ description=(
183
+ "Most important task is to provide a complete answer to user questions based on clinical trial data and online information. "
184
+ "Orchestrate workflow between clinical and online agents. "
185
+ "Validate outputs, resolve conflicts, and ensure the final answer is complete and accurate."
186
+ ),
187
+ tools=[FinalAnswerTool()],
188
+ model=model,
189
+ managed_agents=[clinical_agent,search_online_info],
190
+ # executor_type="modal",
191
+ provide_run_summary=True,
192
+ additional_authorized_imports=["time", "numpy", "pandas"],
193
+ use_structured_outputs_internally=True,
194
+ verbosity_level=2,
195
+ planning_interval=3,
196
+ max_steps=6,
197
+ )
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent import manager_agent
2
+ import gradio as gr
3
+ from smolagents import stream_to_gradio
4
+ import smolagents
5
+ import json
6
+ import re
7
+ import ast
8
+
9
+ agent = manager_agent
10
+
11
+
12
+ import logging
13
+ logging.info("Processing request")
14
+
15
+
16
+ # --- PATCH OpenTelemetry detach bug (generator-safe) ---
17
+ from opentelemetry.context import _RUNTIME_CONTEXT
18
+ _orig_detach = _RUNTIME_CONTEXT.detach
19
+ def _safe_detach(token):
20
+ try:
21
+ _orig_detach(token)
22
+ except Exception:
23
+ # Suppress context-var boundary errors caused by streamed generators
24
+ pass
25
+ _RUNTIME_CONTEXT.detach = _safe_detach
26
+ # --- PATCH OpenTelemetry detach bug (generator-safe) ---
27
+
28
+
29
+ def answer_question(question):
30
+ """Use a smolagent CodeAgent with tools to answer a question.
31
+ The agent streams its thought process (planning steps) and the final answer.
32
+ Args:
33
+ question (str): The question to be answered by the agent.
34
+ Yields:
35
+ tuple(str, str): A tuple containing the current 'thoughts' (planning/intermediate steps)
36
+ and the current 'final_answer'.
37
+ """
38
+ thoughts = ""
39
+ final_answer = ""
40
+ n_tokens =0
41
+ try:
42
+ logging.info(f"Received question: {question}")
43
+ for st in manager_agent.run(question,stream=True,return_full_result=True):
44
+ if isinstance(st, smolagents.memory.PlanningStep):
45
+ plan = st.model_output_message.content.split("## 2.")[-1]
46
+ for m in plan.split("\n"):
47
+ thoughts += "\n" + m
48
+ yield thoughts, final_answer
49
+
50
+ elif isinstance(st, smolagents.memory.ToolCall):
51
+ thoughts += f"\nTool called: {st.dict()['function']['name']}\n"
52
+ for m in st.dict()['function']['arguments'].split("\n"):
53
+ thoughts += "\n" + m
54
+ yield thoughts, final_answer
55
+
56
+ elif isinstance(st, smolagents.agents.ActionOutput):
57
+ if st.output:
58
+ thoughts += "\n" + str(st.output) + "\n"
59
+ yield thoughts, final_answer
60
+ else:
61
+ thoughts += "\n****************\nNo output from action.\n****************\n"
62
+ yield thoughts, final_answer
63
+
64
+ elif isinstance(st, smolagents.memory.ActionStep):
65
+
66
+ for m in st.model_output_message.content.split("\n"):
67
+ thoughts += m
68
+ yield thoughts, final_answer
69
+
70
+ thoughts += "\n********** End fo Step " + str(st.step_number) + " : *********\n " + str(st.token_usage) + "\nStep duration" + str(st.timing) + "\n\n"
71
+ yield thoughts, final_answer
72
+
73
+ elif isinstance(st, smolagents.memory.FinalAnswerStep):
74
+ final_answer = st.output
75
+ yield thoughts, final_answer
76
+ except GeneratorExit:
77
+ print("Stream closed cleanly.")
78
+ return "",""
79
+
80
+
81
+
82
+ # def create_rag_files(refs :list[str], VECTOR_DB_PATH:str)-> str:
83
+ # from tool_create_FAISS_vector import create_vector_store_from_list_of_doi
84
+
85
+ # FAISS_VECTOR_PATH = create_vector_store_from_list_of_doi(refs,VECTOR_DB_PATH)
86
+ # return FAISS_VECTOR_PATH
87
+
88
+ def tool_clinical_trial(query_cond:str=None, query_term:str=None,query_lead:str=None,max_results: int = 5000) -> list:
89
+ """
90
+ Search Clinical Trials database for trials with 4 arguments.
91
+
92
+ Args:
93
+ query_cond (str): Disease or condition (e.g., 'lung cancer', 'diabetes')
94
+ query_term (str): Other terms (e.g., 'AREA[LastUpdatePostDate]RANGE[2023-01-15,MAX]').
95
+ query_lead (str): Searches the LeadSponsorName
96
+ max_results (int): Number of trials to return (max: 1000)
97
+
98
+ Returns:
99
+ list(str): each string being a structured representation of a trial.
100
+ """
101
+ from tool_TOON_formater import TOON_formater
102
+ try:
103
+ max_results = int(max_results)
104
+ except:
105
+ max_results = 500
106
+
107
+ params = {
108
+ "query.cond": query_cond,
109
+ "query.term":query_term,
110
+ "query.lead":query_lead,
111
+ "pageSize": min(max_results, 5000),
112
+ "format": "json"
113
+ }
114
+ params = {k: v for k, v in params.items() if v is not None}
115
+ try:
116
+ response = requests.get(
117
+ "https://clinicaltrials.gov/api/v2/studies",
118
+ params=params,
119
+ timeout=30
120
+ )
121
+ response.raise_for_status()
122
+ studies = response.json().get("studies", [])
123
+
124
+ structured_trials = []
125
+ for i, study in enumerate(studies):
126
+ structured_data = TOON_formater(study)
127
+ structured_trials.append(structured_data)
128
+
129
+ return structured_trials
130
+
131
+ except Exception as e:
132
+ return [f"Error searching clinical trials: {str(e)}"]
133
+
134
+
135
+
136
+ def create_rag(refs :str, VECTOR_DB_PATH:str)-> str:
137
+ """Create a RAG (Retrieval-Augmented Generation) vector store from a list of DOIs.
138
+ Args:
139
+ refs (str): A comma-separated string of DOIs (Digital Object Identifiers).
140
+ VECTOR_DB_PATH (str): The local path where the FAISS vector store should be saved.
141
+ Returns:
142
+ str: The path to the newly created FAISS vector store.
143
+ """
144
+ from tool_create_FAISS_vector import create_vector_store_from_list_of_doi
145
+ FAISS_VECTOR_PATH = create_vector_store_from_list_of_doi(refs,VECTOR_DB_PATH)
146
+ return FAISS_VECTOR_PATH
147
+
148
+
149
+
150
+ def use_rag(query: str, store_name: str, top_k: int = 5) -> str:
151
+ """Retrieve context from a FAISS vector store based on a query.
152
+ Args:
153
+ query (str): The question or query string to use for retrieval.
154
+ store_name (str): The path to the FAISS vector store to query.
155
+ top_k (int): The number of top-k most relevant context documents to retrieve (default: 5).
156
+ Returns:
157
+ str: A JSON string containing the retrieved context, including the content and source (DOI).
158
+ """
159
+ from tool_query_FAISS_vector import query_vector_store
160
+ context_as_dict = query_vector_store(query, store_name, top_k)
161
+ return json.dumps(context_as_dict, indent=2)
162
+
163
+ from PIL import Image
164
+
165
+ def describe_figure(figure : Image) -> str:
166
+ """Provide a detailed, thorough description of an image figure.
167
+ Args:
168
+ figure (Image): The image figure object (from PIL) to be described.
169
+ Returns:
170
+ description (str): A detailed textual description of the figure's content.
171
+ """
172
+ from tool_describe_figure import thourough_picture_description
173
+ description = thourough_picture_description(figure)
174
+ return description
175
+
176
+
177
+
178
+ # Create neat interface - Question Analyzer as a Blocks component
179
+ with gr.Blocks() as interface2:
180
+ gr.Markdown("# Question Analyzer")
181
+ gr.Markdown("""Enter a question to analyze. Examples:
182
+ - Find the name of the sponsor that did the most studies on Alzheimer's disease in the last 10 years.
183
+ - Provide a summary of recent clinical trials on diabetes and list 3 relevant research articles from PubMed.
184
+ - What are the scientific paper linked to the clinical study referenced as NCT04516746?
185
+ - How many clinical studies on cancer were completed in the last 5 years?
186
+ - Find recent phase 3 trials for lung cancer sponsored by Pfizer
187
+ """)
188
+
189
+ with gr.Row():
190
+ with gr.Column():
191
+ question_input = gr.Textbox(
192
+ label="Question",
193
+ placeholder="Enter your question here...",
194
+ lines=3,
195
+ )
196
+ submit_btn = gr.Button("Submit", variant="primary")
197
+ response_output = gr.Textbox(
198
+ label="Final Answer",
199
+ interactive=False,
200
+ lines=8
201
+ )
202
+ with gr.Column():
203
+ thoughts_output = gr.Textbox(
204
+ label="LLM Thoughts/Reasoning",
205
+ interactive=False,
206
+ lines=8
207
+ )
208
+
209
+
210
+ chat_history = gr.State([])
211
+
212
+ submit_btn.click(
213
+ fn=answer_question,
214
+ inputs=[question_input],
215
+ outputs=[thoughts_output, response_output],
216
+ queue=True
217
+ )
218
+
219
+
220
+ # Combine interfaces into a single tabbed interface
221
+ demo = gr.TabbedInterface(
222
+ [interface2,
223
+ gr.Interface(
224
+ fn=create_rag,
225
+ inputs=[gr.Textbox("list of references to include in vector store",lines=2, info="(can be DOIs, PMIDs, erxivs, ... and a mix of it)"),
226
+ gr.Textbox("Name of the vactore store", lines=2, placeholder="My_Diabetes_vector") ],
227
+ outputs=gr.Textbox("path of the vactore store"),
228
+ api_name="create_vector_store_for_rag"),
229
+
230
+ gr.Interface(
231
+ fn=use_rag,
232
+ inputs=[gr.Textbox("question that needs context to answer"),
233
+ gr.Textbox("Name of the vector store to use", placeholder="Diabetes, Sickel_cell_anemia, Prostate_cancer, ..")],
234
+ outputs=gr.Textbox("Answer with Rag"),
235
+ api_name="use_vector_store_to_create_context"),
236
+ gr.Interface(
237
+ fn=tool_clinical_trial,
238
+ inputs=[gr.Textbox("Disease or condition (e.g., 'lung cancer', 'diabetes')"),
239
+ gr.Textbox("Other terms (e.g., 'AREA[LastUpdatePostDate]RANGE[2023-01-15,MAX]'"),
240
+ gr.Textbox("Searches the LeadSponsorName"),
241
+ gr.Textbox("max results")],
242
+ outputs=gr.Textbox("TOON formated response"),
243
+ api_name="use_vector_store_to_create_context"),
244
+ gr.Interface(
245
+ describe_figure,
246
+ gr.Image(type="pil"),
247
+ gr.Textbox(),
248
+ api_name="figure_description"),
249
+ ],
250
+ ["Use a code agent with sandbox execution equiped with clinical trial tool",
251
+ "Create RAG tool with FAISS vector store",
252
+ "Query RAG tool",
253
+ "Query clinical trial database"
254
+ "Thourough figure description",]
255
+ )
256
+
257
+ if __name__ == "__main__":
258
+ demo.queue().launch(mcp_server=True)
tool_TOON_formater.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def TOON_formater(api_response):
2
+ """
3
+ Extract core partner identification information from ClinicalTrials.gov API response.
4
+
5
+ Args:
6
+ api_response (dict): Raw API response from ClinicalTrials.gov
7
+
8
+ Returns:
9
+ str: TOOn (Token-Oriented Object Notation) formatted string with 41 core fields
10
+ """
11
+
12
+ # Helper function to safely navigate nested dicts
13
+ def safe_get(data, *keys, default=None):
14
+ for key in keys:
15
+ if isinstance(data, dict):
16
+ data = data.get(key, {})
17
+ else:
18
+ return default
19
+ return data if data != {} else default
20
+
21
+ # Helper function to format value for TOOn
22
+ def format_value(val):
23
+ if val is None:
24
+ return ''
25
+ elif isinstance(val, bool):
26
+ return str(val).lower()
27
+ else:
28
+ return str(val)
29
+
30
+ # Helper function to format list for TOOn
31
+ def format_list(lst):
32
+ if not lst:
33
+ return ''
34
+ # Escape commas in individual items by wrapping in quotes if needed
35
+ formatted_items = []
36
+ for item in lst:
37
+ item_str = format_value(item)
38
+ if ',' in item_str or '\n' in item_str:
39
+ item_str = f'"{item_str}"'
40
+ formatted_items.append(item_str)
41
+ return ','.join(formatted_items)
42
+
43
+ protocol = api_response.get('protocolSection', {})
44
+
45
+ # Extract basic identification
46
+ identification = protocol.get('identificationModule', {})
47
+ nct_id = identification.get('nctId')
48
+ brief_title = identification.get('briefTitle')
49
+ official_title = identification.get('officialTitle')
50
+ org_full_name = safe_get(identification, 'organization', 'fullName')
51
+
52
+ # Extract status information
53
+ status = protocol.get('statusModule', {})
54
+ overall_status = status.get('overallStatus')
55
+ last_update_post_date = safe_get(status, 'lastUpdatePostDateStruct', 'date')
56
+ recruitment_status = overall_status
57
+ start_date = safe_get(status, 'startDateStruct', 'date')
58
+ primary_completion_date = safe_get(status, 'primaryCompletionDateStruct', 'date')
59
+ completion_date = safe_get(status, 'completionDateStruct', 'date')
60
+ study_first_post_date = safe_get(status, 'studyFirstPostDateStruct', 'date')
61
+
62
+ # Extract sponsor/collaborator information
63
+ sponsors = protocol.get('sponsorCollaboratorsModule', {})
64
+ lead_sponsor = sponsors.get('leadSponsor', {})
65
+ lead_sponsor_name = lead_sponsor.get('name')
66
+ lead_sponsor_class = lead_sponsor.get('class')
67
+
68
+ # Extract collaborators (list)
69
+ collaborators = sponsors.get('collaborators', [])
70
+ collaborator_names = [c.get('name') for c in collaborators if c.get('name')]
71
+ collaborator_classes = [c.get('class') for c in collaborators if c.get('class')]
72
+ num_collaborators = len(collaborators)
73
+ num_collaborators_plus_lead = num_collaborators + 1 if lead_sponsor_name else num_collaborators
74
+
75
+ # Extract responsible party
76
+ responsible_party = sponsors.get('responsibleParty', {})
77
+ responsible_party_investigator_full_name = responsible_party.get('investigatorFullName')
78
+ responsible_party_investigator_affiliation = responsible_party.get('investigatorAffiliation')
79
+
80
+ # Extract overall officials
81
+ contacts_locations = protocol.get('contactsLocationsModule', {})
82
+ overall_officials = contacts_locations.get('overallOfficials', [])
83
+
84
+ overall_official_names = [o.get('name') for o in overall_officials if o.get('name')]
85
+ overall_official_affiliations = [o.get('affiliation') for o in overall_officials if o.get('affiliation')]
86
+ overall_official_roles = [o.get('role') for o in overall_officials if o.get('role')]
87
+
88
+ # Extract conditions and interventions
89
+ conditions_module = protocol.get('conditionsModule', {})
90
+ conditions = conditions_module.get('conditions', [])
91
+
92
+ arms_interventions = protocol.get('armsInterventionsModule', {})
93
+ interventions = arms_interventions.get('interventions', [])
94
+ intervention_names = [i.get('name') for i in interventions if i.get('name')]
95
+ intervention_types = [i.get('type') for i in interventions if i.get('type')]
96
+
97
+ # Extract design information
98
+ design = protocol.get('designModule', {})
99
+ study_type = design.get('studyType')
100
+ phases = design.get('phases', [])
101
+ primary_purpose = safe_get(design, 'designInfo', 'primaryPurpose')
102
+
103
+ # Extract enrollment
104
+ enrollment_info = design.get('enrollmentInfo', {})
105
+ enrollment_count = enrollment_info.get('count')
106
+
107
+ # Extract primary outcome
108
+ outcomes = protocol.get('outcomesModule', {})
109
+ primary_outcomes = outcomes.get('primaryOutcomes', [])
110
+ primary_outcome_measures = [p.get('measure') for p in primary_outcomes if p.get('measure')]
111
+
112
+ # Extract locations
113
+ locations = contacts_locations.get('locations', [])
114
+ num_locations = len(locations)
115
+
116
+ location_facilities = [loc.get('facility') for loc in locations if loc.get('facility')]
117
+ location_cities = [loc.get('city') for loc in locations if loc.get('city')]
118
+ location_states = [loc.get('state') for loc in locations if loc.get('state')]
119
+ location_countries = [loc.get('country') for loc in locations if loc.get('country')]
120
+ location_statuses = [loc.get('status') for loc in locations if loc.get('status')]
121
+
122
+ # Extract geopoints
123
+ geopoints = [loc.get('geoPoint') for loc in locations if loc.get('geoPoint')]
124
+
125
+ # Extract MeSH terms
126
+ derived = api_response.get('derivedSection', {})
127
+ condition_browse = derived.get('conditionBrowseModule', {})
128
+ condition_mesh_terms = [m.get('term') for m in condition_browse.get('meshes', []) if m.get('term')]
129
+
130
+ intervention_browse = derived.get('interventionBrowseModule', {})
131
+ intervention_mesh_terms = [m.get('term') for m in intervention_browse.get('meshes', []) if m.get('term')]
132
+
133
+ # Extract has results
134
+ has_results = api_response.get('hasResults', False)
135
+
136
+ # Extract oversight
137
+ oversight = protocol.get('oversightModule', {})
138
+ oversight_has_dmc = oversight.get('oversightHasDmc')
139
+ is_fda_regulated_drug = oversight.get('isFdaRegulatedDrug')
140
+ is_fda_regulated_device = oversight.get('isFdaRegulatedDevice')
141
+
142
+ # Extract references/citations
143
+ references_module = protocol.get('referencesModule', {})
144
+ references = references_module.get('references', [])
145
+ citations = []
146
+ pmids = []
147
+ for ref in references:
148
+ citations.append(ref.get('citation'))
149
+ pmids.append(ref.get('pmid'))
150
+
151
+ # Build TOOn formatted output
152
+ toon_lines = []
153
+
154
+ # Basic identification
155
+ toon_lines.append(f"nct_id: {format_value(nct_id)}")
156
+ toon_lines.append(f"brief_title: {format_value(brief_title)}")
157
+ toon_lines.append(f"official_title: {format_value(official_title)}")
158
+ toon_lines.append(f"overall_status: {format_value(overall_status)}")
159
+
160
+ # Organization & Sponsor
161
+ toon_lines.append(f"lead_sponsor_name: {format_value(lead_sponsor_name)}")
162
+ toon_lines.append(f"lead_sponsor_class: {format_value(lead_sponsor_class)}")
163
+ toon_lines.append(f"collaborator_names[{len(collaborator_names)}]: {format_list(collaborator_names)}")
164
+ toon_lines.append(f"collaborator_classes[{len(collaborator_classes)}]: {format_list(collaborator_classes)}")
165
+ toon_lines.append(f"org_full_name: {format_value(org_full_name)}")
166
+
167
+ # Key personnel
168
+ toon_lines.append(f"overall_official_names[{len(overall_official_names)}]: {format_list(overall_official_names)}")
169
+ toon_lines.append(f"overall_official_affiliations[{len(overall_official_affiliations)}]: {format_list(overall_official_affiliations)}")
170
+ toon_lines.append(f"overall_official_roles[{len(overall_official_roles)}]: {format_list(overall_official_roles)}")
171
+ toon_lines.append(f"responsible_party_investigator_full_name: {format_value(responsible_party_investigator_full_name)}")
172
+ toon_lines.append(f"responsible_party_investigator_affiliation: {format_value(responsible_party_investigator_affiliation)}")
173
+ toon_lines.append(f"num_collaborators: {format_value(num_collaborators)}")
174
+
175
+ # Scientific focus
176
+ toon_lines.append(f"conditions[{len(conditions)}]: {format_list(conditions)}")
177
+ toon_lines.append(f"intervention_names[{len(intervention_names)}]: {format_list(intervention_names)}")
178
+ toon_lines.append(f"intervention_types[{len(intervention_types)}]: {format_list(intervention_types)}")
179
+ toon_lines.append(f"phases[{len(phases)}]: {format_list(phases)}")
180
+ toon_lines.append(f"primary_outcome_measures[{len(primary_outcome_measures)}]: {format_list(primary_outcome_measures)}")
181
+
182
+ # Study scope & capacity
183
+ toon_lines.append(f"enrollment_count: {format_value(enrollment_count)}")
184
+ toon_lines.append(f"study_type: {format_value(study_type)}")
185
+ toon_lines.append(f"num_locations: {format_value(num_locations)}")
186
+ toon_lines.append(f"location_facilities[{len(location_facilities)}]: {format_list(location_facilities)}")
187
+ toon_lines.append(f"location_cities[{len(location_cities)}]: {format_list(location_cities)}")
188
+ toon_lines.append(f"location_states[{len(location_states)}]: {format_list(location_states)}")
189
+ toon_lines.append(f"location_countries[{len(location_countries)}]: {format_list(location_countries)}")
190
+
191
+ # Experience & track record
192
+ toon_lines.append(f"study_first_post_date: {format_value(study_first_post_date)}")
193
+ toon_lines.append(f"completion_date: {format_value(completion_date)}")
194
+ toon_lines.append(f"has_results: {format_value(has_results)}")
195
+ toon_lines.append(f"num_collaborators_plus_lead: {format_value(num_collaborators_plus_lead)}")
196
+
197
+ # Therapeutic area expertise
198
+ toon_lines.append(f"condition_mesh_terms[{len(condition_mesh_terms)}]: {format_list(condition_mesh_terms)}")
199
+ toon_lines.append(f"intervention_mesh_terms[{len(intervention_mesh_terms)}]: {format_list(intervention_mesh_terms)}")
200
+ toon_lines.append(f"primary_purpose: {format_value(primary_purpose)}")
201
+
202
+ # Current activity status
203
+ toon_lines.append(f"last_update_post_date: {format_value(last_update_post_date)}")
204
+ toon_lines.append(f"recruitment_status: {format_value(recruitment_status)}")
205
+ toon_lines.append(f"start_date: {format_value(start_date)}")
206
+ toon_lines.append(f"primary_completion_date: {format_value(primary_completion_date)}")
207
+
208
+ # Secondary fields
209
+ toon_lines.append(f"oversight_has_dmc: {format_value(oversight_has_dmc)}")
210
+ toon_lines.append(f"is_fda_regulated_drug: {format_value(is_fda_regulated_drug)}")
211
+ toon_lines.append(f"is_fda_regulated_device: {format_value(is_fda_regulated_device)}")
212
+ toon_lines.append(f"location_statuses[{len(location_statuses)}]: {format_list(location_statuses)}")
213
+
214
+ # Additional fields
215
+ toon_lines.append(f"citations[{len(citations)}]: {format_list(citations)}")
216
+ toon_lines.append(f"pmids[{len(pmids)}]: {format_list(pmids)}")
217
+
218
+ # Geopoints (structured data - format as array of objects)
219
+ if geopoints:
220
+ geo_keys = set()
221
+ for gp in geopoints:
222
+ if gp:
223
+ geo_keys.update(gp.keys())
224
+
225
+ if geo_keys:
226
+ geo_keys_sorted = sorted(geo_keys)
227
+ toon_lines.append(f"geopoints[{len(geopoints)}]{{{','.join(geo_keys_sorted)}}}:")
228
+ for gp in geopoints:
229
+ if gp:
230
+ values = [format_value(gp.get(k)) for k in geo_keys_sorted]
231
+ toon_lines.append(f" {','.join(values)}")
232
+ else:
233
+ toon_lines.append(f" {','.join(['' for _ in geo_keys_sorted])}")
234
+ else:
235
+ toon_lines.append(f"geopoints[0]:")
236
+
237
+ return '\n'.join(toon_lines)
tool_create_FAISS_vector.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pypdf import PdfReader
2
+ import requests
3
+ from io import BytesIO
4
+ import serpapi
5
+ import os
6
+ from dotenv import load_dotenv
7
+ load_dotenv()
8
+
9
+ from langchain_core.documents import Document as LangchainDocument
10
+ from metapub import FindIt
11
+ import requests
12
+ import xml.etree.ElementTree as ET
13
+
14
+ from ftplib import FTP
15
+ from urllib.parse import urlparse
16
+ from io import BytesIO
17
+
18
+ from langchain_community.retrievers import ArxivRetriever
19
+
20
+ import arxiv
21
+ import requests
22
+ from io import BytesIO
23
+ from pypdf import PdfReader
24
+ import re
25
+
26
+ from langchain_community.vectorstores.utils import DistanceStrategy
27
+ from langchain_community.embeddings import HuggingFaceEmbeddings
28
+ from transformers import AutoTokenizer
29
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
30
+ from tqdm import tqdm
31
+
32
+ import re
33
+ from typing import List, Dict, Tuple
34
+
35
+
36
+
37
+ def process_ref(extr_ref:tuple[str,str]) -> str:
38
+ if extr_ref[1] == "arxiv":
39
+ for tool in [get_paper_from_arxiv_id,get_paper_from_arxiv_id_langchain]:
40
+ try:
41
+ return tool(extr_ref[0])
42
+ except:
43
+ continue
44
+ elif extr_ref[1] == "pmid":
45
+ for tool in [get_paper_from_pmid,parse_pdf_from_pubmed_pmid]:
46
+ try:
47
+ return tool(extr_ref[0])
48
+ except:
49
+ continue
50
+ elif extr_ref[1] == "doi":
51
+ for tool in [download_paper_from_doi,get_pdf_content_serpapi]:
52
+ try:
53
+ return tool(extr_ref[0])
54
+ except:
55
+ continue
56
+ elif extr_ref[1] == "pmcid":
57
+ return get_paper_from_pmid(extr_ref[0])
58
+
59
+
60
+ class ReferenceExtractor:
61
+ """Extract and classify references from LLM outputs."""
62
+
63
+ # Regex patterns for identification
64
+ DOI_PATTERN = r"10\.\d{4,9}/[-._;()/:A-Za-z0-9]+"
65
+ DOI_LOOSE = r"10\.\d{4,9}/[A-Za-z0-9.\-_/]+"
66
+ PMID_PATTERN = r"\b\d{7,8}\b"
67
+ ARXIV_NEW = r"\b\d{4}\.\d{4,5}(?:v\d+)?\b"
68
+ ARXIV_OLD = r"\b[a-z\-]+/\d{7}\b"
69
+ PMCID_PATTERN = r"\bPMC\d+\b"
70
+
71
+ def __init__(self):
72
+ """Initialize the extractor with compiled regex patterns."""
73
+ self.patterns = {
74
+ 'doi': re.compile(self.DOI_PATTERN, re.IGNORECASE),
75
+ 'pmid': re.compile(self.PMID_PATTERN),
76
+ 'arxiv': re.compile(f"({self.ARXIV_NEW})|({self.ARXIV_OLD})", re.IGNORECASE),
77
+ 'pmcid': re.compile(self.PMCID_PATTERN, re.IGNORECASE)
78
+ }
79
+
80
+ def extract_references(self, text: str) -> List[Tuple[str, str]]:
81
+ """
82
+ Extract all references from text and classify them.
83
+
84
+ Args:
85
+ text: Input string that may contain references in various formats
86
+
87
+ Returns:
88
+ List of tuples: (reference_value, reference_type)
89
+ """
90
+ references = []
91
+ seen = set()
92
+
93
+ # First, try to parse as a list-like string
94
+ list_refs = self._extract_from_list_format(text)
95
+ if list_refs:
96
+ for ref in list_refs:
97
+ ref_type = self._classify_single_ref(ref)
98
+ if ref not in seen:
99
+ references.append((ref, ref_type))
100
+ seen.add(ref)
101
+ return references
102
+
103
+ # If not a list format, extract using regex patterns
104
+ for ref_type, pattern in self.patterns.items():
105
+ matches = pattern.finditer(text)
106
+ for match in matches:
107
+ ref_value = match.group(0).strip()
108
+ if ref_value not in seen:
109
+ references.append((ref_value, ref_type))
110
+ seen.add(ref_value)
111
+
112
+ return references
113
+
114
+ def _extract_from_list_format(self, text: str) -> List[str]:
115
+ """
116
+ Extract references from list-like formats.
117
+ Handles: "id1,id2,id3" and '["id1","id2"]' and "['id1', 'id2']"
118
+ """
119
+ text = text.strip()
120
+
121
+ # Try parsing as Python list string
122
+ if text.startswith('[') and text.endswith(']'):
123
+ try:
124
+ # Remove brackets and quotes, split by comma
125
+ cleaned = text[1:-1]
126
+ # Handle both single and double quotes
127
+ items = re.findall(r'["\']([^"\']+)["\']', cleaned)
128
+ if items:
129
+ return [item.strip() for item in items]
130
+ except:
131
+ pass
132
+
133
+ # Try comma-separated format (no brackets)
134
+ if ',' in text and not any(char in text for char in ['\n', '(', ')']):
135
+ # Check if it looks like a simple list
136
+ if text.count(',') >= 1 and len(text) < 200:
137
+ items = [item.strip().strip('"\'') for item in text.split(',')]
138
+ # Filter out empty strings
139
+ return [item for item in items if item]
140
+
141
+ return []
142
+
143
+ def _classify_single_ref(self, ref: str) -> str:
144
+ """Classify a single extracted reference string."""
145
+ ref = ref.strip().strip('"\'')
146
+
147
+ # Check each pattern in priority order
148
+ if re.match(r"^10\.\d{4,9}/[A-Za-z0-9.\-_/:()]+$", ref, re.IGNORECASE):
149
+ return "doi"
150
+
151
+ if re.match(r"^PMC\d+$", ref, re.IGNORECASE):
152
+ return "pmcid"
153
+
154
+ if re.match(r"^\d{4}\.\d{4,5}(?:v\d+)?$", ref):
155
+ return "arxiv"
156
+
157
+ if re.match(r"^[a-z\-]+/\d{7}$", ref, re.IGNORECASE):
158
+ return "arxiv"
159
+
160
+ if re.match(r"^\d{7,8}$", ref):
161
+ return "pmid"
162
+
163
+ return "unknown"
164
+
165
+
166
+ def download_paper_from_doi(doi):
167
+ """
168
+ Attempt to download paper from DOI with multiple fallback methods
169
+ """
170
+ # Clean DOI if it has prefix
171
+ doi = doi.replace('https://doi.org/', '').replace('http://doi.org/', '')
172
+
173
+ # Method 1: Try Unpaywall API (free, legal access)
174
+ try:
175
+ unpaywall_url = f"https://api.unpaywall.org/v2/{doi}?email=your@email.com"
176
+ response = requests.get(unpaywall_url, timeout=10)
177
+ if response.status_code == 200:
178
+ data = response.json()
179
+ if data.get('best_oa_location') and data['best_oa_location'].get('url_for_pdf'):
180
+ pdf_url = data['best_oa_location']['url_for_pdf']
181
+ text = download_pdf_from_url(pdf_url)
182
+ print(f"Found PDF via Unpaywall: {pdf_url}")
183
+ return text
184
+ except Exception as e:
185
+ print(f"Unpaywall failed: {e}")
186
+
187
+ # Method 2: Try arXiv if it's an arXiv paper
188
+ if 'arxiv' in doi.lower() or doi.startswith('2'):
189
+ try:
190
+ # Extract arXiv ID
191
+ arxiv_id = doi.split('/')[-1] if '/' in doi else doi
192
+ arxiv_pdf_url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
193
+ text = download_pdf_from_url(arxiv_pdf_url)
194
+ print(f"Trying arXiv: {arxiv_pdf_url}")
195
+ return text
196
+ except Exception as e:
197
+ print(f"arXiv failed: {e}")
198
+
199
+ # Method 3: Try Sci-Hub (use with caution - check your local laws)
200
+ try:
201
+ scihub_url = f"https://sci-hub.se/{doi}"
202
+ print(f"Trying Sci-Hub: {scihub_url}")
203
+ headers = {'User-Agent': 'Mozilla/5.0'}
204
+ response = requests.get(scihub_url, headers=headers, timeout=15)
205
+
206
+ if response.status_code == 200:
207
+ # Look for PDF link in the HTML
208
+ pdf_match = re.search(r'(https?://[^"]+\.pdf[^"]*)', response.text)
209
+ if pdf_match:
210
+ pdf_url = pdf_match.group(1)
211
+ text = download_pdf_from_url(pdf_url)
212
+ print(f"got {doi} by chance")
213
+ return text
214
+ except Exception as e:
215
+ print(f"Sci-Hub failed: {e}")
216
+
217
+
218
+
219
+ def download_pdf_from_url(url):
220
+ """
221
+ Download and extract text from a PDF URL
222
+ """
223
+ headers = {
224
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
225
+ }
226
+
227
+ response = requests.get(url, headers=headers, timeout=30)
228
+ response.raise_for_status()
229
+ content_type = response.headers.get('content-type', '').lower()
230
+ if 'pdf' not in content_type and not response.content.startswith(b'%PDF'):
231
+ raise Exception(f"URL did not return a PDF (got {content_type})")
232
+
233
+ reader = PdfReader(BytesIO(response.content))
234
+ text = ""
235
+ for page in reader.pages:
236
+ text += page.extract_text() or ""
237
+ return text
238
+
239
+
240
+ def get_paper_from_arxiv_id(doi: str):
241
+ """
242
+ Retrieve paper from arXiv using its arXiv ID.
243
+ """
244
+ client = arxiv.Client()
245
+ search = arxiv.Search(query=doi, max_results=1)
246
+ results = client.results(search)
247
+ pdf_url = next(results).pdf_url
248
+ text = parse_pdf_file(pdf_url)
249
+ return text
250
+
251
+ def get_paper_from_arxiv_id_langchain(arxiv_id: str):
252
+ """
253
+ Retrieve paper from arXiv using its arXiv ID.
254
+ """
255
+ search = "2304.07814"
256
+ retriever = ArxivRetriever(
257
+ load_max_docs=2,
258
+ get_full_documents=True,
259
+ )
260
+ docs = retriever.invoke(search)
261
+ return docs
262
+
263
+
264
+ def parse_pdf_file(path:str) -> str:
265
+
266
+ if path.startswith("http://") or path.startswith("https://") or path.startswith("ftp://"):
267
+ response = requests.get(path)
268
+ response.raise_for_status() # Ensure download succeeded
269
+ reader = PdfReader(BytesIO(response.content))
270
+ else:
271
+ reader = PdfReader(path)
272
+
273
+ text = ""
274
+ for page in reader.pages:
275
+ text += page.extract_text() or ""
276
+
277
+ return text
278
+
279
+
280
+ def get_pdf_content_serpapi(doi: str) -> str:
281
+ """
282
+ Get the link to the paper from its DOI using SerpAPI Google Scholar search.
283
+ """
284
+ client = serpapi.Client(api_key=os.getenv("SERPAPI_API_KEY"))
285
+ results = client.search({
286
+ 'engine': 'google_scholar',
287
+ 'q': doi,
288
+ })
289
+
290
+ pdf_path = results["organic_results"][0]["link"]
291
+ pdf_text = parse_pdf_file(pdf_path)
292
+ return pdf_text
293
+
294
+
295
+
296
+
297
+ def get_paper_from_pmid(pmid:str):
298
+ src = FindIt(pmid)
299
+ if src.url:
300
+ pdf_text = parse_pdf_file(src.url)
301
+ return pdf_text
302
+ else:
303
+ print(src.reason)
304
+
305
+
306
+
307
+
308
+ def download_pdf_via_ftp(url: str) -> bytes:
309
+ """
310
+ Download a PDF file from an FTP URL and return its content as bytes.
311
+ """
312
+ parsed_url = urlparse(url)
313
+ ftp_host = parsed_url.netloc
314
+ ftp_path = parsed_url.path
315
+
316
+ file_buffer = BytesIO()
317
+
318
+ with FTP(ftp_host) as ftp:
319
+ ftp.login()
320
+ ftp.retrbinary(f'RETR {ftp_path}', file_buffer.write)
321
+
322
+ file_buffer.getvalue()
323
+ file_buffer.seek(0)
324
+ return file_buffer
325
+
326
+
327
+ def parse_pdf_from_pubmed_pmid(pmid: str) -> str:
328
+ """
329
+ Download and parse a PDF from PubMed using its PMID.
330
+ """
331
+ url = f"https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?id={pmid}"
332
+ response = requests.get(url)
333
+ cleaned_string = response.content.decode('utf-8').strip()
334
+ try:
335
+ root = ET.fromstring(cleaned_string)
336
+ pdf_link_element = root.find(".//link[@format='pdf']")
337
+ ftp_url = pdf_link_element.get('href')
338
+ file_byte = download_pdf_via_ftp(ftp_url)
339
+
340
+ reader = PdfReader(file_byte)
341
+ text = ""
342
+ for page in reader.pages:
343
+ text += page.extract_text() or ""
344
+ print(f"got {pmid} via ftp download")
345
+ return text
346
+ except ET.ParseError as e:
347
+ pass
348
+
349
+ def safe_parse_of_ref_list(refs : list[str]) -> list[str]:
350
+
351
+
352
+ return
353
+
354
+
355
+
356
+
357
+ def classify_ref(ref: str) -> str:
358
+ DOI_REGEX = r"10\.\d{4,9}/[-._;()/:A-Za-z0-9]+"
359
+ DOI_LOOSE = r"^10\.\d{4,9}/?[A-Za-z0-9.\-_/]+$" # supports 'NEJMoa2307100'
360
+ PMID_REGEX = r"^\d{7,8}$"
361
+ ARXIV_REGEX = r"^\d{4}\.\d{4,5}(v\d+)?$" # new style
362
+ ARXIV_OLD = r"^[a-z\-]+/\d{7}$" # old style hep-th/xxxxxxx
363
+
364
+ ref = ref.strip()
365
+ if re.match(DOI_REGEX, ref, re.IGNORECASE) or re.match(DOI_LOOSE, ref, re.IGNORECASE):
366
+ return "doi"
367
+ if re.match(PMID_REGEX, ref):
368
+ return "pmid"
369
+ if re.match(ARXIV_REGEX, ref, re.IGNORECASE) or re.match(ARXIV_OLD, ref, re.IGNORECASE):
370
+ return "arxiv"
371
+ return "unknown"
372
+
373
+
374
+ def process_ref(ref: str):
375
+ """We try twice to download"""
376
+ kind = classify_ref(ref)
377
+ if kind == "doi":
378
+ for tool in [download_paper_from_doi,get_pdf_content_serpapi]:
379
+ try:
380
+ return tool(ref)
381
+ except:
382
+ continue
383
+ if kind == "pmid":
384
+ for tool in [get_paper_from_pmid,parse_pdf_from_pubmed_pmid]:
385
+ try:
386
+ return tool(ref)
387
+ except:
388
+ continue
389
+ if kind == "arxiv":
390
+ for tool in [get_paper_from_arxiv_id,get_pdf_content_serpapi]:
391
+ try:
392
+ return tool(ref)
393
+ except:
394
+ continue
395
+
396
+ print(f"Skipping invalid ref: {ref}")
397
+ return None
398
+
399
+
400
+ from langchain_community.vectorstores.utils import DistanceStrategy
401
+ from langchain_community.embeddings import HuggingFaceEmbeddings
402
+ from transformers import AutoTokenizer
403
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
404
+ from tqdm import tqdm
405
+
406
+ def create_vector_store_from_list_of_doi(refs :list[str], VECTOR_DB_PATH:str) -> str:
407
+
408
+ from langchain_community.vectorstores import FAISS
409
+
410
+ # define embedding
411
+ embedding_name="BAAI/bge-large-en-v1.5"
412
+ embedding_model = HuggingFaceEmbeddings(model_name=embedding_name,
413
+ model_kwargs={"device": "mps"},
414
+ encode_kwargs={"normalize_embeddings": True},)
415
+ try:
416
+ # Load the vector database from the folder
417
+ print(f"try to load vector store from {VECTOR_DB_PATH}")
418
+ KNOWLEDGE_VECTOR_DATABASE = FAISS.load_local(
419
+ VECTOR_DB_PATH,
420
+ embedding_model,
421
+ allow_dangerous_deserialization=True # Required for security in newer LangChain versions
422
+ )
423
+ existing_reference = [doc.metadata.get("source") for doc in KNOWLEDGE_VECTOR_DATABASE.docstore._dict.values()]
424
+ print("vectro store loaded")
425
+ except Exception as e :
426
+ print("FAISS load error:", e)
427
+ KNOWLEDGE_VECTOR_DATABASE = None
428
+ existing_reference = []
429
+ print("no vector store found, creating a new one...")
430
+
431
+
432
+ # fetch docs
433
+ extractor = ReferenceExtractor()
434
+ REFS = extractor.extract_references(refs) # Change here the type of IDs to DEBUG
435
+ raw_docs=[]
436
+
437
+ for ref in tqdm(REFS):
438
+ if ref not in set(existing_reference):
439
+ text = process_ref(ref)
440
+ if text:
441
+ raw_docs.append(LangchainDocument(page_content=text,metadata={'source':ref[0]}))
442
+
443
+ recover_yield = f" *** -> {round(100*len(raw_docs)/len(REFS))}% papers downloaded"
444
+ print(recover_yield)
445
+
446
+ # split texts into chunks
447
+ text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
448
+ AutoTokenizer.from_pretrained(embedding_name),
449
+ chunk_size=3000,
450
+ chunk_overlap=int(3000 / 10),
451
+ add_start_index=True,
452
+ strip_whitespace=True,
453
+ separators="."
454
+ )
455
+
456
+ if raw_docs:
457
+ docs_processed = text_splitter.split_documents(raw_docs)
458
+ print("creating the vector store...")
459
+
460
+ # create the vector store
461
+ NEW_KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents(docs_processed, embedding_model, distance_strategy=DistanceStrategy.COSINE)
462
+
463
+ if KNOWLEDGE_VECTOR_DATABASE :
464
+ print("merge vector store")
465
+ KNOWLEDGE_VECTOR_DATABASE.merge_from(NEW_KNOWLEDGE_VECTOR_DATABASE)
466
+ KNOWLEDGE_VECTOR_DATABASE.save_local(VECTOR_DB_PATH)
467
+ else:
468
+ NEW_KNOWLEDGE_VECTOR_DATABASE.save_local(VECTOR_DB_PATH)
469
+
470
+ return VECTOR_DB_PATH
471
+
472
+ else:
473
+ return f"all the data already in vector store {VECTOR_DB_PATH}"
tool_describe_figure.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from openai import OpenAI
4
+ # The OpenAI library handles the API key and base URL automatically
5
+ # after instantiation.
6
+
7
+ def thorough_picture_description(figure: str) -> str:
8
+ """
9
+ Generates a thorough description for a given image URL using
10
+ the Nebius Token Factory endpoint.
11
+
12
+ Args:
13
+ figure: The URL of the image to describe.
14
+
15
+ Returns:
16
+ The generated text description of the image.
17
+ """
18
+
19
+ try:
20
+ client = OpenAI(
21
+ base_url="https://api.tokenfactory.nebius.com/v1/",
22
+ api_key=os.environ.get("NEBIUS_API_KEY")
23
+ )
24
+ except Exception as e:
25
+
26
+ return f"Error initializing OpenAI client: {e}"
27
+
28
+
29
+ messages_payload = [
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {"type": "text", "text": "Provide a very detailed, thorough, and descriptive analysis of this image."},
34
+ {
35
+ "type": "image_url",
36
+ "image_url": {"url": figure},
37
+ },
38
+ ],
39
+ }
40
+ ]
41
+
42
+
43
+ try:
44
+ response = client.chat.completions.create(
45
+ model="gemini-2.5-flash",
46
+ messages=messages_payload,
47
+ max_tokens=2048
48
+ )
49
+
50
+
51
+ if response.choices and response.choices[0].message.content:
52
+ return response.choices[0].message.content
53
+ else:
54
+ return "Could not retrieve a description from the API."
55
+
56
+ except Exception as e:
57
+ return f"An error occurred during the API call: {e}"
tool_fetch_documents_DOI.py ADDED
File without changes
tool_query_FAISS_vector.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+
5
+
6
+ def query_vector_store(query: str, store_name: str, top_k: int = 5) -> dict:
7
+ """
8
+ Query a specific vector store to retreive top_k documents related to the user question.
9
+ Each document have metadata that is the identification of the source, it must be said clearly.
10
+
11
+
12
+ Args:
13
+ query (str): User's question
14
+ store_name (str): Which vector store to search
15
+ top_k (int): Number of chunks to retrieve
16
+
17
+ Returns:
18
+ dict: Retrieved context, sources, store_name
19
+ """
20
+ from langchain_community.vectorstores import FAISS
21
+
22
+ vector_stores = os.listdir("./vector_stores")
23
+ store_path = f"./vector_stores{store_name}"
24
+ if store_name not in vector_stores:
25
+ return {"error": f"Vector store '{store_name}' not found, you must create it first with tool create faiss vector"}
26
+
27
+
28
+ embedding_name="BAAI/bge-large-en-v1.5"
29
+ embedding_model = HuggingFaceEmbeddings(model_name=embedding_name,
30
+ model_kwargs={"device": "mps"},
31
+ encode_kwargs={"normalize_embeddings": True},)
32
+
33
+
34
+ vector_store = FAISS.load_local(
35
+ store_path,
36
+ embedding_model,
37
+ allow_dangerous_deserialization=True
38
+ )
39
+
40
+ results = vector_store.similarity_search(query, top_k)
41
+
42
+ context = "\n\n".join([r["text"] for r in results])
43
+ sources = [
44
+ {"ids": r["metadata"]["source"], "relevance": r["score"]}
45
+ for r in results
46
+ ]
47
+
48
+ return {
49
+ "context": context,
50
+ "sources": sources,
51
+ "store_name": store_name
52
+ }