Spaces:
Sleeping
Sleeping
added tools and agent
Browse files- agent.py +197 -0
- app.py +258 -0
- tool_TOON_formater.py +237 -0
- tool_create_FAISS_vector.py +473 -0
- tool_describe_figure.py +57 -0
- tool_fetch_documents_DOI.py +0 -0
- tool_query_FAISS_vector.py +52 -0
agent.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import requests
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from markdownify import markdownify
|
| 6 |
+
from requests.exceptions import RequestException
|
| 7 |
+
from smolagents import (
|
| 8 |
+
LiteLLMModel,
|
| 9 |
+
CodeAgent,
|
| 10 |
+
ToolCallingAgent,
|
| 11 |
+
InferenceClientModel,
|
| 12 |
+
WebSearchTool,
|
| 13 |
+
tool,
|
| 14 |
+
FinalAnswerTool,
|
| 15 |
+
WikipediaSearchTool,
|
| 16 |
+
VisitWebpageTool,
|
| 17 |
+
DuckDuckGoSearchTool
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
from langfuse import get_client
|
| 23 |
+
langfuse = get_client()
|
| 24 |
+
if langfuse.auth_check():
|
| 25 |
+
print("Langfuse client is authenticated and ready!")
|
| 26 |
+
else:
|
| 27 |
+
print("Authentication failed. Please check your credentials and host.")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
| 31 |
+
SmolagentsInstrumentor().instrument()
|
| 32 |
+
|
| 33 |
+
model = LiteLLMModel(
|
| 34 |
+
model_id="openai/Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
| 35 |
+
api_key=os.environ.get("NEBIUS_API_KEY"),
|
| 36 |
+
api_base="https://api.tokenfactory.nebius.com/v1/"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from tool_clinical_trial import ClinicalTrialsSearchTool
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@tool
|
| 43 |
+
def search_pubmed(topic: str, author: str) -> list[str]:
|
| 44 |
+
"""
|
| 45 |
+
Searches the PubMed database for articles related to a specific topic.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
topic: The topic or keywords to search for (e.g., "CRISPR gene editing").
|
| 49 |
+
author: The name of the author to search for (e.g., "Albert Einstein").
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
A list of PubMed IDs (strings) for the top 100 articles found.
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
requests.exceptions.HTTPError: If the API request fails.
|
| 56 |
+
"""
|
| 57 |
+
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
|
| 58 |
+
|
| 59 |
+
terms = []
|
| 60 |
+
if topic:
|
| 61 |
+
terms.append(topic)
|
| 62 |
+
if author:
|
| 63 |
+
terms.append(f"{author}[Author]")
|
| 64 |
+
|
| 65 |
+
query = " AND ".join(terms)
|
| 66 |
+
params = {
|
| 67 |
+
"db": "pubmed",
|
| 68 |
+
"term": query,
|
| 69 |
+
"retmode": "json",
|
| 70 |
+
"retmax": 1000
|
| 71 |
+
}
|
| 72 |
+
response = requests.get(base_url, params=params)
|
| 73 |
+
response.raise_for_status()
|
| 74 |
+
data = response.json()
|
| 75 |
+
|
| 76 |
+
return data["esearchresult"]["idlist"]
|
| 77 |
+
|
| 78 |
+
@tool
|
| 79 |
+
def parse_pdf(pdf_path:str)->list[str]:
|
| 80 |
+
"""
|
| 81 |
+
Reads a PDF file from a specified path and extracts the text content
|
| 82 |
+
from every page.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
pdf_path: The local file path (string) to the PDF document to be parsed.
|
| 86 |
+
**NOTE**: In a remote agent environment, this path must be
|
| 87 |
+
accessible by the executing process (e.g., a path to an
|
| 88 |
+
uploaded file).
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
A list of strings, where each string is the extracted text content
|
| 92 |
+
from a single page of the PDF.
|
| 93 |
+
"""
|
| 94 |
+
from pypdf import PdfReader
|
| 95 |
+
|
| 96 |
+
reader = PdfReader(pdf_path)
|
| 97 |
+
number_of_pages = len(reader.pages)
|
| 98 |
+
text=list()
|
| 99 |
+
for p in range(number_of_pages):
|
| 100 |
+
page = reader.pages[p]
|
| 101 |
+
text.append(page.extract_text())
|
| 102 |
+
return text
|
| 103 |
+
|
| 104 |
+
# @tool
|
| 105 |
+
# def make_rag_ressource(paths :list(str)) -> list(str):
|
| 106 |
+
# """
|
| 107 |
+
# Use extracted text to build a RAG tool and retreive documents to use to answer request
|
| 108 |
+
|
| 109 |
+
# Args:
|
| 110 |
+
# paths: The list of path where the file are stored
|
| 111 |
+
|
| 112 |
+
# Returns:
|
| 113 |
+
# A list of strings, where each string is the extracted text content
|
| 114 |
+
# from the retreiver
|
| 115 |
+
# """
|
| 116 |
+
|
| 117 |
+
# pdf_files=[]
|
| 118 |
+
# for path in paths:
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# pdf_documents = []
|
| 122 |
+
# for pdf_file in pdf_files:
|
| 123 |
+
# loader = PyPDFLoader(pdf_file)
|
| 124 |
+
# pdf_documents.extend(loader.load())
|
| 125 |
+
# embeddings_model = OpenAIEmbeddings()
|
| 126 |
+
# pdf_texts = [doc.page_content for doc in pdf_documents]
|
| 127 |
+
# return ""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# # Initialize the model
|
| 131 |
+
# model = InferenceClientModel(
|
| 132 |
+
# model_id="Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
| 133 |
+
# provider="nebius"
|
| 134 |
+
# )
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Create clinical trial search agent
|
| 139 |
+
|
| 140 |
+
clinical_agent = CodeAgent(
|
| 141 |
+
name="clinical_agent",
|
| 142 |
+
description=(
|
| 143 |
+
"Retrieve and parse clinical study data for a given disease. "
|
| 144 |
+
"Use ClinicalTrialsSearchTool for trials, search_pubmed for authors, and parse_pdf for full-text analysis. "
|
| 145 |
+
"Return structured tables or summaries as requested."
|
| 146 |
+
"Gather general or recent information from online sources. "
|
| 147 |
+
"Use Wikipedia for overviews, DuckDuckGo for recent data, and VisitWebpageTool for specific URLs. "
|
| 148 |
+
"Return structured summaries with sources."
|
| 149 |
+
),
|
| 150 |
+
tools=[ClinicalTrialsSearchTool()],
|
| 151 |
+
additional_authorized_imports=["time", "numpy", "pandas"],
|
| 152 |
+
# executor_type="blaxel", #executor_type="modal",
|
| 153 |
+
use_structured_outputs_internally=True,
|
| 154 |
+
return_full_result=True,
|
| 155 |
+
planning_interval=3, # V3 add structure
|
| 156 |
+
model=model,
|
| 157 |
+
max_steps=6,
|
| 158 |
+
verbosity_level=2
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
search_online_info = CodeAgent(
|
| 162 |
+
name="search_online_info",
|
| 163 |
+
description=(
|
| 164 |
+
"Gather general or recent information from online sources. "
|
| 165 |
+
"Use Wikipedia for overviews, DuckDuckGo for recent data, and VisitWebpageTool for specific URLs. "
|
| 166 |
+
"Return structured summaries with sources."
|
| 167 |
+
),
|
| 168 |
+
tools=[WikipediaSearchTool(),VisitWebpageTool(max_output_length=10000),DuckDuckGoSearchTool(max_results=5),search_pubmed,parse_pdf],
|
| 169 |
+
additional_authorized_imports=["time", "numpy", "pandas"],
|
| 170 |
+
# use_structured_outputs_internally=True,
|
| 171 |
+
# executor_type="modal",
|
| 172 |
+
planning_interval=2,
|
| 173 |
+
model=model,
|
| 174 |
+
max_steps=4,
|
| 175 |
+
verbosity_level=2
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
manager_agent = CodeAgent(
|
| 181 |
+
name="manager_agent",
|
| 182 |
+
description=(
|
| 183 |
+
"Most important task is to provide a complete answer to user questions based on clinical trial data and online information. "
|
| 184 |
+
"Orchestrate workflow between clinical and online agents. "
|
| 185 |
+
"Validate outputs, resolve conflicts, and ensure the final answer is complete and accurate."
|
| 186 |
+
),
|
| 187 |
+
tools=[FinalAnswerTool()],
|
| 188 |
+
model=model,
|
| 189 |
+
managed_agents=[clinical_agent,search_online_info],
|
| 190 |
+
# executor_type="modal",
|
| 191 |
+
provide_run_summary=True,
|
| 192 |
+
additional_authorized_imports=["time", "numpy", "pandas"],
|
| 193 |
+
use_structured_outputs_internally=True,
|
| 194 |
+
verbosity_level=2,
|
| 195 |
+
planning_interval=3,
|
| 196 |
+
max_steps=6,
|
| 197 |
+
)
|
app.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agent import manager_agent
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from smolagents import stream_to_gradio
|
| 4 |
+
import smolagents
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import ast
|
| 8 |
+
|
| 9 |
+
agent = manager_agent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
logging.info("Processing request")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- PATCH OpenTelemetry detach bug (generator-safe) ---
|
| 17 |
+
from opentelemetry.context import _RUNTIME_CONTEXT
|
| 18 |
+
_orig_detach = _RUNTIME_CONTEXT.detach
|
| 19 |
+
def _safe_detach(token):
|
| 20 |
+
try:
|
| 21 |
+
_orig_detach(token)
|
| 22 |
+
except Exception:
|
| 23 |
+
# Suppress context-var boundary errors caused by streamed generators
|
| 24 |
+
pass
|
| 25 |
+
_RUNTIME_CONTEXT.detach = _safe_detach
|
| 26 |
+
# --- PATCH OpenTelemetry detach bug (generator-safe) ---
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def answer_question(question):
|
| 30 |
+
"""Use a smolagent CodeAgent with tools to answer a question.
|
| 31 |
+
The agent streams its thought process (planning steps) and the final answer.
|
| 32 |
+
Args:
|
| 33 |
+
question (str): The question to be answered by the agent.
|
| 34 |
+
Yields:
|
| 35 |
+
tuple(str, str): A tuple containing the current 'thoughts' (planning/intermediate steps)
|
| 36 |
+
and the current 'final_answer'.
|
| 37 |
+
"""
|
| 38 |
+
thoughts = ""
|
| 39 |
+
final_answer = ""
|
| 40 |
+
n_tokens =0
|
| 41 |
+
try:
|
| 42 |
+
logging.info(f"Received question: {question}")
|
| 43 |
+
for st in manager_agent.run(question,stream=True,return_full_result=True):
|
| 44 |
+
if isinstance(st, smolagents.memory.PlanningStep):
|
| 45 |
+
plan = st.model_output_message.content.split("## 2.")[-1]
|
| 46 |
+
for m in plan.split("\n"):
|
| 47 |
+
thoughts += "\n" + m
|
| 48 |
+
yield thoughts, final_answer
|
| 49 |
+
|
| 50 |
+
elif isinstance(st, smolagents.memory.ToolCall):
|
| 51 |
+
thoughts += f"\nTool called: {st.dict()['function']['name']}\n"
|
| 52 |
+
for m in st.dict()['function']['arguments'].split("\n"):
|
| 53 |
+
thoughts += "\n" + m
|
| 54 |
+
yield thoughts, final_answer
|
| 55 |
+
|
| 56 |
+
elif isinstance(st, smolagents.agents.ActionOutput):
|
| 57 |
+
if st.output:
|
| 58 |
+
thoughts += "\n" + str(st.output) + "\n"
|
| 59 |
+
yield thoughts, final_answer
|
| 60 |
+
else:
|
| 61 |
+
thoughts += "\n****************\nNo output from action.\n****************\n"
|
| 62 |
+
yield thoughts, final_answer
|
| 63 |
+
|
| 64 |
+
elif isinstance(st, smolagents.memory.ActionStep):
|
| 65 |
+
|
| 66 |
+
for m in st.model_output_message.content.split("\n"):
|
| 67 |
+
thoughts += m
|
| 68 |
+
yield thoughts, final_answer
|
| 69 |
+
|
| 70 |
+
thoughts += "\n********** End fo Step " + str(st.step_number) + " : *********\n " + str(st.token_usage) + "\nStep duration" + str(st.timing) + "\n\n"
|
| 71 |
+
yield thoughts, final_answer
|
| 72 |
+
|
| 73 |
+
elif isinstance(st, smolagents.memory.FinalAnswerStep):
|
| 74 |
+
final_answer = st.output
|
| 75 |
+
yield thoughts, final_answer
|
| 76 |
+
except GeneratorExit:
|
| 77 |
+
print("Stream closed cleanly.")
|
| 78 |
+
return "",""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# def create_rag_files(refs :list[str], VECTOR_DB_PATH:str)-> str:
|
| 83 |
+
# from tool_create_FAISS_vector import create_vector_store_from_list_of_doi
|
| 84 |
+
|
| 85 |
+
# FAISS_VECTOR_PATH = create_vector_store_from_list_of_doi(refs,VECTOR_DB_PATH)
|
| 86 |
+
# return FAISS_VECTOR_PATH
|
| 87 |
+
|
| 88 |
+
def tool_clinical_trial(query_cond:str=None, query_term:str=None,query_lead:str=None,max_results: int = 5000) -> list:
|
| 89 |
+
"""
|
| 90 |
+
Search Clinical Trials database for trials with 4 arguments.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
query_cond (str): Disease or condition (e.g., 'lung cancer', 'diabetes')
|
| 94 |
+
query_term (str): Other terms (e.g., 'AREA[LastUpdatePostDate]RANGE[2023-01-15,MAX]').
|
| 95 |
+
query_lead (str): Searches the LeadSponsorName
|
| 96 |
+
max_results (int): Number of trials to return (max: 1000)
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
list(str): each string being a structured representation of a trial.
|
| 100 |
+
"""
|
| 101 |
+
from tool_TOON_formater import TOON_formater
|
| 102 |
+
try:
|
| 103 |
+
max_results = int(max_results)
|
| 104 |
+
except:
|
| 105 |
+
max_results = 500
|
| 106 |
+
|
| 107 |
+
params = {
|
| 108 |
+
"query.cond": query_cond,
|
| 109 |
+
"query.term":query_term,
|
| 110 |
+
"query.lead":query_lead,
|
| 111 |
+
"pageSize": min(max_results, 5000),
|
| 112 |
+
"format": "json"
|
| 113 |
+
}
|
| 114 |
+
params = {k: v for k, v in params.items() if v is not None}
|
| 115 |
+
try:
|
| 116 |
+
response = requests.get(
|
| 117 |
+
"https://clinicaltrials.gov/api/v2/studies",
|
| 118 |
+
params=params,
|
| 119 |
+
timeout=30
|
| 120 |
+
)
|
| 121 |
+
response.raise_for_status()
|
| 122 |
+
studies = response.json().get("studies", [])
|
| 123 |
+
|
| 124 |
+
structured_trials = []
|
| 125 |
+
for i, study in enumerate(studies):
|
| 126 |
+
structured_data = TOON_formater(study)
|
| 127 |
+
structured_trials.append(structured_data)
|
| 128 |
+
|
| 129 |
+
return structured_trials
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return [f"Error searching clinical trials: {str(e)}"]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def create_rag(refs :str, VECTOR_DB_PATH:str)-> str:
|
| 137 |
+
"""Create a RAG (Retrieval-Augmented Generation) vector store from a list of DOIs.
|
| 138 |
+
Args:
|
| 139 |
+
refs (str): A comma-separated string of DOIs (Digital Object Identifiers).
|
| 140 |
+
VECTOR_DB_PATH (str): The local path where the FAISS vector store should be saved.
|
| 141 |
+
Returns:
|
| 142 |
+
str: The path to the newly created FAISS vector store.
|
| 143 |
+
"""
|
| 144 |
+
from tool_create_FAISS_vector import create_vector_store_from_list_of_doi
|
| 145 |
+
FAISS_VECTOR_PATH = create_vector_store_from_list_of_doi(refs,VECTOR_DB_PATH)
|
| 146 |
+
return FAISS_VECTOR_PATH
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def use_rag(query: str, store_name: str, top_k: int = 5) -> str:
|
| 151 |
+
"""Retrieve context from a FAISS vector store based on a query.
|
| 152 |
+
Args:
|
| 153 |
+
query (str): The question or query string to use for retrieval.
|
| 154 |
+
store_name (str): The path to the FAISS vector store to query.
|
| 155 |
+
top_k (int): The number of top-k most relevant context documents to retrieve (default: 5).
|
| 156 |
+
Returns:
|
| 157 |
+
str: A JSON string containing the retrieved context, including the content and source (DOI).
|
| 158 |
+
"""
|
| 159 |
+
from tool_query_FAISS_vector import query_vector_store
|
| 160 |
+
context_as_dict = query_vector_store(query, store_name, top_k)
|
| 161 |
+
return json.dumps(context_as_dict, indent=2)
|
| 162 |
+
|
| 163 |
+
from PIL import Image
|
| 164 |
+
|
| 165 |
+
def describe_figure(figure : Image) -> str:
|
| 166 |
+
"""Provide a detailed, thorough description of an image figure.
|
| 167 |
+
Args:
|
| 168 |
+
figure (Image): The image figure object (from PIL) to be described.
|
| 169 |
+
Returns:
|
| 170 |
+
description (str): A detailed textual description of the figure's content.
|
| 171 |
+
"""
|
| 172 |
+
from tool_describe_figure import thourough_picture_description
|
| 173 |
+
description = thourough_picture_description(figure)
|
| 174 |
+
return description
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# Create neat interface - Question Analyzer as a Blocks component
|
| 179 |
+
with gr.Blocks() as interface2:
|
| 180 |
+
gr.Markdown("# Question Analyzer")
|
| 181 |
+
gr.Markdown("""Enter a question to analyze. Examples:
|
| 182 |
+
- Find the name of the sponsor that did the most studies on Alzheimer's disease in the last 10 years.
|
| 183 |
+
- Provide a summary of recent clinical trials on diabetes and list 3 relevant research articles from PubMed.
|
| 184 |
+
- What are the scientific paper linked to the clinical study referenced as NCT04516746?
|
| 185 |
+
- How many clinical studies on cancer were completed in the last 5 years?
|
| 186 |
+
- Find recent phase 3 trials for lung cancer sponsored by Pfizer
|
| 187 |
+
""")
|
| 188 |
+
|
| 189 |
+
with gr.Row():
|
| 190 |
+
with gr.Column():
|
| 191 |
+
question_input = gr.Textbox(
|
| 192 |
+
label="Question",
|
| 193 |
+
placeholder="Enter your question here...",
|
| 194 |
+
lines=3,
|
| 195 |
+
)
|
| 196 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
| 197 |
+
response_output = gr.Textbox(
|
| 198 |
+
label="Final Answer",
|
| 199 |
+
interactive=False,
|
| 200 |
+
lines=8
|
| 201 |
+
)
|
| 202 |
+
with gr.Column():
|
| 203 |
+
thoughts_output = gr.Textbox(
|
| 204 |
+
label="LLM Thoughts/Reasoning",
|
| 205 |
+
interactive=False,
|
| 206 |
+
lines=8
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
chat_history = gr.State([])
|
| 211 |
+
|
| 212 |
+
submit_btn.click(
|
| 213 |
+
fn=answer_question,
|
| 214 |
+
inputs=[question_input],
|
| 215 |
+
outputs=[thoughts_output, response_output],
|
| 216 |
+
queue=True
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Combine interfaces into a single tabbed interface
|
| 221 |
+
demo = gr.TabbedInterface(
|
| 222 |
+
[interface2,
|
| 223 |
+
gr.Interface(
|
| 224 |
+
fn=create_rag,
|
| 225 |
+
inputs=[gr.Textbox("list of references to include in vector store",lines=2, info="(can be DOIs, PMIDs, erxivs, ... and a mix of it)"),
|
| 226 |
+
gr.Textbox("Name of the vactore store", lines=2, placeholder="My_Diabetes_vector") ],
|
| 227 |
+
outputs=gr.Textbox("path of the vactore store"),
|
| 228 |
+
api_name="create_vector_store_for_rag"),
|
| 229 |
+
|
| 230 |
+
gr.Interface(
|
| 231 |
+
fn=use_rag,
|
| 232 |
+
inputs=[gr.Textbox("question that needs context to answer"),
|
| 233 |
+
gr.Textbox("Name of the vector store to use", placeholder="Diabetes, Sickel_cell_anemia, Prostate_cancer, ..")],
|
| 234 |
+
outputs=gr.Textbox("Answer with Rag"),
|
| 235 |
+
api_name="use_vector_store_to_create_context"),
|
| 236 |
+
gr.Interface(
|
| 237 |
+
fn=tool_clinical_trial,
|
| 238 |
+
inputs=[gr.Textbox("Disease or condition (e.g., 'lung cancer', 'diabetes')"),
|
| 239 |
+
gr.Textbox("Other terms (e.g., 'AREA[LastUpdatePostDate]RANGE[2023-01-15,MAX]'"),
|
| 240 |
+
gr.Textbox("Searches the LeadSponsorName"),
|
| 241 |
+
gr.Textbox("max results")],
|
| 242 |
+
outputs=gr.Textbox("TOON formated response"),
|
| 243 |
+
api_name="use_vector_store_to_create_context"),
|
| 244 |
+
gr.Interface(
|
| 245 |
+
describe_figure,
|
| 246 |
+
gr.Image(type="pil"),
|
| 247 |
+
gr.Textbox(),
|
| 248 |
+
api_name="figure_description"),
|
| 249 |
+
],
|
| 250 |
+
["Use a code agent with sandbox execution equiped with clinical trial tool",
|
| 251 |
+
"Create RAG tool with FAISS vector store",
|
| 252 |
+
"Query RAG tool",
|
| 253 |
+
"Query clinical trial database"
|
| 254 |
+
"Thourough figure description",]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
demo.queue().launch(mcp_server=True)
|
tool_TOON_formater.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def TOON_formater(api_response):
|
| 2 |
+
"""
|
| 3 |
+
Extract core partner identification information from ClinicalTrials.gov API response.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
api_response (dict): Raw API response from ClinicalTrials.gov
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
str: TOOn (Token-Oriented Object Notation) formatted string with 41 core fields
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# Helper function to safely navigate nested dicts
|
| 13 |
+
def safe_get(data, *keys, default=None):
|
| 14 |
+
for key in keys:
|
| 15 |
+
if isinstance(data, dict):
|
| 16 |
+
data = data.get(key, {})
|
| 17 |
+
else:
|
| 18 |
+
return default
|
| 19 |
+
return data if data != {} else default
|
| 20 |
+
|
| 21 |
+
# Helper function to format value for TOOn
|
| 22 |
+
def format_value(val):
|
| 23 |
+
if val is None:
|
| 24 |
+
return ''
|
| 25 |
+
elif isinstance(val, bool):
|
| 26 |
+
return str(val).lower()
|
| 27 |
+
else:
|
| 28 |
+
return str(val)
|
| 29 |
+
|
| 30 |
+
# Helper function to format list for TOOn
|
| 31 |
+
def format_list(lst):
|
| 32 |
+
if not lst:
|
| 33 |
+
return ''
|
| 34 |
+
# Escape commas in individual items by wrapping in quotes if needed
|
| 35 |
+
formatted_items = []
|
| 36 |
+
for item in lst:
|
| 37 |
+
item_str = format_value(item)
|
| 38 |
+
if ',' in item_str or '\n' in item_str:
|
| 39 |
+
item_str = f'"{item_str}"'
|
| 40 |
+
formatted_items.append(item_str)
|
| 41 |
+
return ','.join(formatted_items)
|
| 42 |
+
|
| 43 |
+
protocol = api_response.get('protocolSection', {})
|
| 44 |
+
|
| 45 |
+
# Extract basic identification
|
| 46 |
+
identification = protocol.get('identificationModule', {})
|
| 47 |
+
nct_id = identification.get('nctId')
|
| 48 |
+
brief_title = identification.get('briefTitle')
|
| 49 |
+
official_title = identification.get('officialTitle')
|
| 50 |
+
org_full_name = safe_get(identification, 'organization', 'fullName')
|
| 51 |
+
|
| 52 |
+
# Extract status information
|
| 53 |
+
status = protocol.get('statusModule', {})
|
| 54 |
+
overall_status = status.get('overallStatus')
|
| 55 |
+
last_update_post_date = safe_get(status, 'lastUpdatePostDateStruct', 'date')
|
| 56 |
+
recruitment_status = overall_status
|
| 57 |
+
start_date = safe_get(status, 'startDateStruct', 'date')
|
| 58 |
+
primary_completion_date = safe_get(status, 'primaryCompletionDateStruct', 'date')
|
| 59 |
+
completion_date = safe_get(status, 'completionDateStruct', 'date')
|
| 60 |
+
study_first_post_date = safe_get(status, 'studyFirstPostDateStruct', 'date')
|
| 61 |
+
|
| 62 |
+
# Extract sponsor/collaborator information
|
| 63 |
+
sponsors = protocol.get('sponsorCollaboratorsModule', {})
|
| 64 |
+
lead_sponsor = sponsors.get('leadSponsor', {})
|
| 65 |
+
lead_sponsor_name = lead_sponsor.get('name')
|
| 66 |
+
lead_sponsor_class = lead_sponsor.get('class')
|
| 67 |
+
|
| 68 |
+
# Extract collaborators (list)
|
| 69 |
+
collaborators = sponsors.get('collaborators', [])
|
| 70 |
+
collaborator_names = [c.get('name') for c in collaborators if c.get('name')]
|
| 71 |
+
collaborator_classes = [c.get('class') for c in collaborators if c.get('class')]
|
| 72 |
+
num_collaborators = len(collaborators)
|
| 73 |
+
num_collaborators_plus_lead = num_collaborators + 1 if lead_sponsor_name else num_collaborators
|
| 74 |
+
|
| 75 |
+
# Extract responsible party
|
| 76 |
+
responsible_party = sponsors.get('responsibleParty', {})
|
| 77 |
+
responsible_party_investigator_full_name = responsible_party.get('investigatorFullName')
|
| 78 |
+
responsible_party_investigator_affiliation = responsible_party.get('investigatorAffiliation')
|
| 79 |
+
|
| 80 |
+
# Extract overall officials
|
| 81 |
+
contacts_locations = protocol.get('contactsLocationsModule', {})
|
| 82 |
+
overall_officials = contacts_locations.get('overallOfficials', [])
|
| 83 |
+
|
| 84 |
+
overall_official_names = [o.get('name') for o in overall_officials if o.get('name')]
|
| 85 |
+
overall_official_affiliations = [o.get('affiliation') for o in overall_officials if o.get('affiliation')]
|
| 86 |
+
overall_official_roles = [o.get('role') for o in overall_officials if o.get('role')]
|
| 87 |
+
|
| 88 |
+
# Extract conditions and interventions
|
| 89 |
+
conditions_module = protocol.get('conditionsModule', {})
|
| 90 |
+
conditions = conditions_module.get('conditions', [])
|
| 91 |
+
|
| 92 |
+
arms_interventions = protocol.get('armsInterventionsModule', {})
|
| 93 |
+
interventions = arms_interventions.get('interventions', [])
|
| 94 |
+
intervention_names = [i.get('name') for i in interventions if i.get('name')]
|
| 95 |
+
intervention_types = [i.get('type') for i in interventions if i.get('type')]
|
| 96 |
+
|
| 97 |
+
# Extract design information
|
| 98 |
+
design = protocol.get('designModule', {})
|
| 99 |
+
study_type = design.get('studyType')
|
| 100 |
+
phases = design.get('phases', [])
|
| 101 |
+
primary_purpose = safe_get(design, 'designInfo', 'primaryPurpose')
|
| 102 |
+
|
| 103 |
+
# Extract enrollment
|
| 104 |
+
enrollment_info = design.get('enrollmentInfo', {})
|
| 105 |
+
enrollment_count = enrollment_info.get('count')
|
| 106 |
+
|
| 107 |
+
# Extract primary outcome
|
| 108 |
+
outcomes = protocol.get('outcomesModule', {})
|
| 109 |
+
primary_outcomes = outcomes.get('primaryOutcomes', [])
|
| 110 |
+
primary_outcome_measures = [p.get('measure') for p in primary_outcomes if p.get('measure')]
|
| 111 |
+
|
| 112 |
+
# Extract locations
|
| 113 |
+
locations = contacts_locations.get('locations', [])
|
| 114 |
+
num_locations = len(locations)
|
| 115 |
+
|
| 116 |
+
location_facilities = [loc.get('facility') for loc in locations if loc.get('facility')]
|
| 117 |
+
location_cities = [loc.get('city') for loc in locations if loc.get('city')]
|
| 118 |
+
location_states = [loc.get('state') for loc in locations if loc.get('state')]
|
| 119 |
+
location_countries = [loc.get('country') for loc in locations if loc.get('country')]
|
| 120 |
+
location_statuses = [loc.get('status') for loc in locations if loc.get('status')]
|
| 121 |
+
|
| 122 |
+
# Extract geopoints
|
| 123 |
+
geopoints = [loc.get('geoPoint') for loc in locations if loc.get('geoPoint')]
|
| 124 |
+
|
| 125 |
+
# Extract MeSH terms
|
| 126 |
+
derived = api_response.get('derivedSection', {})
|
| 127 |
+
condition_browse = derived.get('conditionBrowseModule', {})
|
| 128 |
+
condition_mesh_terms = [m.get('term') for m in condition_browse.get('meshes', []) if m.get('term')]
|
| 129 |
+
|
| 130 |
+
intervention_browse = derived.get('interventionBrowseModule', {})
|
| 131 |
+
intervention_mesh_terms = [m.get('term') for m in intervention_browse.get('meshes', []) if m.get('term')]
|
| 132 |
+
|
| 133 |
+
# Extract has results
|
| 134 |
+
has_results = api_response.get('hasResults', False)
|
| 135 |
+
|
| 136 |
+
# Extract oversight
|
| 137 |
+
oversight = protocol.get('oversightModule', {})
|
| 138 |
+
oversight_has_dmc = oversight.get('oversightHasDmc')
|
| 139 |
+
is_fda_regulated_drug = oversight.get('isFdaRegulatedDrug')
|
| 140 |
+
is_fda_regulated_device = oversight.get('isFdaRegulatedDevice')
|
| 141 |
+
|
| 142 |
+
# Extract references/citations
|
| 143 |
+
references_module = protocol.get('referencesModule', {})
|
| 144 |
+
references = references_module.get('references', [])
|
| 145 |
+
citations = []
|
| 146 |
+
pmids = []
|
| 147 |
+
for ref in references:
|
| 148 |
+
citations.append(ref.get('citation'))
|
| 149 |
+
pmids.append(ref.get('pmid'))
|
| 150 |
+
|
| 151 |
+
# Build TOOn formatted output
|
| 152 |
+
toon_lines = []
|
| 153 |
+
|
| 154 |
+
# Basic identification
|
| 155 |
+
toon_lines.append(f"nct_id: {format_value(nct_id)}")
|
| 156 |
+
toon_lines.append(f"brief_title: {format_value(brief_title)}")
|
| 157 |
+
toon_lines.append(f"official_title: {format_value(official_title)}")
|
| 158 |
+
toon_lines.append(f"overall_status: {format_value(overall_status)}")
|
| 159 |
+
|
| 160 |
+
# Organization & Sponsor
|
| 161 |
+
toon_lines.append(f"lead_sponsor_name: {format_value(lead_sponsor_name)}")
|
| 162 |
+
toon_lines.append(f"lead_sponsor_class: {format_value(lead_sponsor_class)}")
|
| 163 |
+
toon_lines.append(f"collaborator_names[{len(collaborator_names)}]: {format_list(collaborator_names)}")
|
| 164 |
+
toon_lines.append(f"collaborator_classes[{len(collaborator_classes)}]: {format_list(collaborator_classes)}")
|
| 165 |
+
toon_lines.append(f"org_full_name: {format_value(org_full_name)}")
|
| 166 |
+
|
| 167 |
+
# Key personnel
|
| 168 |
+
toon_lines.append(f"overall_official_names[{len(overall_official_names)}]: {format_list(overall_official_names)}")
|
| 169 |
+
toon_lines.append(f"overall_official_affiliations[{len(overall_official_affiliations)}]: {format_list(overall_official_affiliations)}")
|
| 170 |
+
toon_lines.append(f"overall_official_roles[{len(overall_official_roles)}]: {format_list(overall_official_roles)}")
|
| 171 |
+
toon_lines.append(f"responsible_party_investigator_full_name: {format_value(responsible_party_investigator_full_name)}")
|
| 172 |
+
toon_lines.append(f"responsible_party_investigator_affiliation: {format_value(responsible_party_investigator_affiliation)}")
|
| 173 |
+
toon_lines.append(f"num_collaborators: {format_value(num_collaborators)}")
|
| 174 |
+
|
| 175 |
+
# Scientific focus
|
| 176 |
+
toon_lines.append(f"conditions[{len(conditions)}]: {format_list(conditions)}")
|
| 177 |
+
toon_lines.append(f"intervention_names[{len(intervention_names)}]: {format_list(intervention_names)}")
|
| 178 |
+
toon_lines.append(f"intervention_types[{len(intervention_types)}]: {format_list(intervention_types)}")
|
| 179 |
+
toon_lines.append(f"phases[{len(phases)}]: {format_list(phases)}")
|
| 180 |
+
toon_lines.append(f"primary_outcome_measures[{len(primary_outcome_measures)}]: {format_list(primary_outcome_measures)}")
|
| 181 |
+
|
| 182 |
+
# Study scope & capacity
|
| 183 |
+
toon_lines.append(f"enrollment_count: {format_value(enrollment_count)}")
|
| 184 |
+
toon_lines.append(f"study_type: {format_value(study_type)}")
|
| 185 |
+
toon_lines.append(f"num_locations: {format_value(num_locations)}")
|
| 186 |
+
toon_lines.append(f"location_facilities[{len(location_facilities)}]: {format_list(location_facilities)}")
|
| 187 |
+
toon_lines.append(f"location_cities[{len(location_cities)}]: {format_list(location_cities)}")
|
| 188 |
+
toon_lines.append(f"location_states[{len(location_states)}]: {format_list(location_states)}")
|
| 189 |
+
toon_lines.append(f"location_countries[{len(location_countries)}]: {format_list(location_countries)}")
|
| 190 |
+
|
| 191 |
+
# Experience & track record
|
| 192 |
+
toon_lines.append(f"study_first_post_date: {format_value(study_first_post_date)}")
|
| 193 |
+
toon_lines.append(f"completion_date: {format_value(completion_date)}")
|
| 194 |
+
toon_lines.append(f"has_results: {format_value(has_results)}")
|
| 195 |
+
toon_lines.append(f"num_collaborators_plus_lead: {format_value(num_collaborators_plus_lead)}")
|
| 196 |
+
|
| 197 |
+
# Therapeutic area expertise
|
| 198 |
+
toon_lines.append(f"condition_mesh_terms[{len(condition_mesh_terms)}]: {format_list(condition_mesh_terms)}")
|
| 199 |
+
toon_lines.append(f"intervention_mesh_terms[{len(intervention_mesh_terms)}]: {format_list(intervention_mesh_terms)}")
|
| 200 |
+
toon_lines.append(f"primary_purpose: {format_value(primary_purpose)}")
|
| 201 |
+
|
| 202 |
+
# Current activity status
|
| 203 |
+
toon_lines.append(f"last_update_post_date: {format_value(last_update_post_date)}")
|
| 204 |
+
toon_lines.append(f"recruitment_status: {format_value(recruitment_status)}")
|
| 205 |
+
toon_lines.append(f"start_date: {format_value(start_date)}")
|
| 206 |
+
toon_lines.append(f"primary_completion_date: {format_value(primary_completion_date)}")
|
| 207 |
+
|
| 208 |
+
# Secondary fields
|
| 209 |
+
toon_lines.append(f"oversight_has_dmc: {format_value(oversight_has_dmc)}")
|
| 210 |
+
toon_lines.append(f"is_fda_regulated_drug: {format_value(is_fda_regulated_drug)}")
|
| 211 |
+
toon_lines.append(f"is_fda_regulated_device: {format_value(is_fda_regulated_device)}")
|
| 212 |
+
toon_lines.append(f"location_statuses[{len(location_statuses)}]: {format_list(location_statuses)}")
|
| 213 |
+
|
| 214 |
+
# Additional fields
|
| 215 |
+
toon_lines.append(f"citations[{len(citations)}]: {format_list(citations)}")
|
| 216 |
+
toon_lines.append(f"pmids[{len(pmids)}]: {format_list(pmids)}")
|
| 217 |
+
|
| 218 |
+
# Geopoints (structured data - format as array of objects)
|
| 219 |
+
if geopoints:
|
| 220 |
+
geo_keys = set()
|
| 221 |
+
for gp in geopoints:
|
| 222 |
+
if gp:
|
| 223 |
+
geo_keys.update(gp.keys())
|
| 224 |
+
|
| 225 |
+
if geo_keys:
|
| 226 |
+
geo_keys_sorted = sorted(geo_keys)
|
| 227 |
+
toon_lines.append(f"geopoints[{len(geopoints)}]{{{','.join(geo_keys_sorted)}}}:")
|
| 228 |
+
for gp in geopoints:
|
| 229 |
+
if gp:
|
| 230 |
+
values = [format_value(gp.get(k)) for k in geo_keys_sorted]
|
| 231 |
+
toon_lines.append(f" {','.join(values)}")
|
| 232 |
+
else:
|
| 233 |
+
toon_lines.append(f" {','.join(['' for _ in geo_keys_sorted])}")
|
| 234 |
+
else:
|
| 235 |
+
toon_lines.append(f"geopoints[0]:")
|
| 236 |
+
|
| 237 |
+
return '\n'.join(toon_lines)
|
tool_create_FAISS_vector.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pypdf import PdfReader
|
| 2 |
+
import requests
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
import serpapi
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
from langchain_core.documents import Document as LangchainDocument
|
| 10 |
+
from metapub import FindIt
|
| 11 |
+
import requests
|
| 12 |
+
import xml.etree.ElementTree as ET
|
| 13 |
+
|
| 14 |
+
from ftplib import FTP
|
| 15 |
+
from urllib.parse import urlparse
|
| 16 |
+
from io import BytesIO
|
| 17 |
+
|
| 18 |
+
from langchain_community.retrievers import ArxivRetriever
|
| 19 |
+
|
| 20 |
+
import arxiv
|
| 21 |
+
import requests
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
from pypdf import PdfReader
|
| 24 |
+
import re
|
| 25 |
+
|
| 26 |
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
| 27 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 28 |
+
from transformers import AutoTokenizer
|
| 29 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
|
| 32 |
+
import re
|
| 33 |
+
from typing import List, Dict, Tuple
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def process_ref(extr_ref:tuple[str,str]) -> str:
|
| 38 |
+
if extr_ref[1] == "arxiv":
|
| 39 |
+
for tool in [get_paper_from_arxiv_id,get_paper_from_arxiv_id_langchain]:
|
| 40 |
+
try:
|
| 41 |
+
return tool(extr_ref[0])
|
| 42 |
+
except:
|
| 43 |
+
continue
|
| 44 |
+
elif extr_ref[1] == "pmid":
|
| 45 |
+
for tool in [get_paper_from_pmid,parse_pdf_from_pubmed_pmid]:
|
| 46 |
+
try:
|
| 47 |
+
return tool(extr_ref[0])
|
| 48 |
+
except:
|
| 49 |
+
continue
|
| 50 |
+
elif extr_ref[1] == "doi":
|
| 51 |
+
for tool in [download_paper_from_doi,get_pdf_content_serpapi]:
|
| 52 |
+
try:
|
| 53 |
+
return tool(extr_ref[0])
|
| 54 |
+
except:
|
| 55 |
+
continue
|
| 56 |
+
elif extr_ref[1] == "pmcid":
|
| 57 |
+
return get_paper_from_pmid(extr_ref[0])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ReferenceExtractor:
|
| 61 |
+
"""Extract and classify references from LLM outputs."""
|
| 62 |
+
|
| 63 |
+
# Regex patterns for identification
|
| 64 |
+
DOI_PATTERN = r"10\.\d{4,9}/[-._;()/:A-Za-z0-9]+"
|
| 65 |
+
DOI_LOOSE = r"10\.\d{4,9}/[A-Za-z0-9.\-_/]+"
|
| 66 |
+
PMID_PATTERN = r"\b\d{7,8}\b"
|
| 67 |
+
ARXIV_NEW = r"\b\d{4}\.\d{4,5}(?:v\d+)?\b"
|
| 68 |
+
ARXIV_OLD = r"\b[a-z\-]+/\d{7}\b"
|
| 69 |
+
PMCID_PATTERN = r"\bPMC\d+\b"
|
| 70 |
+
|
| 71 |
+
def __init__(self):
|
| 72 |
+
"""Initialize the extractor with compiled regex patterns."""
|
| 73 |
+
self.patterns = {
|
| 74 |
+
'doi': re.compile(self.DOI_PATTERN, re.IGNORECASE),
|
| 75 |
+
'pmid': re.compile(self.PMID_PATTERN),
|
| 76 |
+
'arxiv': re.compile(f"({self.ARXIV_NEW})|({self.ARXIV_OLD})", re.IGNORECASE),
|
| 77 |
+
'pmcid': re.compile(self.PMCID_PATTERN, re.IGNORECASE)
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def extract_references(self, text: str) -> List[Tuple[str, str]]:
|
| 81 |
+
"""
|
| 82 |
+
Extract all references from text and classify them.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
text: Input string that may contain references in various formats
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
List of tuples: (reference_value, reference_type)
|
| 89 |
+
"""
|
| 90 |
+
references = []
|
| 91 |
+
seen = set()
|
| 92 |
+
|
| 93 |
+
# First, try to parse as a list-like string
|
| 94 |
+
list_refs = self._extract_from_list_format(text)
|
| 95 |
+
if list_refs:
|
| 96 |
+
for ref in list_refs:
|
| 97 |
+
ref_type = self._classify_single_ref(ref)
|
| 98 |
+
if ref not in seen:
|
| 99 |
+
references.append((ref, ref_type))
|
| 100 |
+
seen.add(ref)
|
| 101 |
+
return references
|
| 102 |
+
|
| 103 |
+
# If not a list format, extract using regex patterns
|
| 104 |
+
for ref_type, pattern in self.patterns.items():
|
| 105 |
+
matches = pattern.finditer(text)
|
| 106 |
+
for match in matches:
|
| 107 |
+
ref_value = match.group(0).strip()
|
| 108 |
+
if ref_value not in seen:
|
| 109 |
+
references.append((ref_value, ref_type))
|
| 110 |
+
seen.add(ref_value)
|
| 111 |
+
|
| 112 |
+
return references
|
| 113 |
+
|
| 114 |
+
def _extract_from_list_format(self, text: str) -> List[str]:
|
| 115 |
+
"""
|
| 116 |
+
Extract references from list-like formats.
|
| 117 |
+
Handles: "id1,id2,id3" and '["id1","id2"]' and "['id1', 'id2']"
|
| 118 |
+
"""
|
| 119 |
+
text = text.strip()
|
| 120 |
+
|
| 121 |
+
# Try parsing as Python list string
|
| 122 |
+
if text.startswith('[') and text.endswith(']'):
|
| 123 |
+
try:
|
| 124 |
+
# Remove brackets and quotes, split by comma
|
| 125 |
+
cleaned = text[1:-1]
|
| 126 |
+
# Handle both single and double quotes
|
| 127 |
+
items = re.findall(r'["\']([^"\']+)["\']', cleaned)
|
| 128 |
+
if items:
|
| 129 |
+
return [item.strip() for item in items]
|
| 130 |
+
except:
|
| 131 |
+
pass
|
| 132 |
+
|
| 133 |
+
# Try comma-separated format (no brackets)
|
| 134 |
+
if ',' in text and not any(char in text for char in ['\n', '(', ')']):
|
| 135 |
+
# Check if it looks like a simple list
|
| 136 |
+
if text.count(',') >= 1 and len(text) < 200:
|
| 137 |
+
items = [item.strip().strip('"\'') for item in text.split(',')]
|
| 138 |
+
# Filter out empty strings
|
| 139 |
+
return [item for item in items if item]
|
| 140 |
+
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
def _classify_single_ref(self, ref: str) -> str:
|
| 144 |
+
"""Classify a single extracted reference string."""
|
| 145 |
+
ref = ref.strip().strip('"\'')
|
| 146 |
+
|
| 147 |
+
# Check each pattern in priority order
|
| 148 |
+
if re.match(r"^10\.\d{4,9}/[A-Za-z0-9.\-_/:()]+$", ref, re.IGNORECASE):
|
| 149 |
+
return "doi"
|
| 150 |
+
|
| 151 |
+
if re.match(r"^PMC\d+$", ref, re.IGNORECASE):
|
| 152 |
+
return "pmcid"
|
| 153 |
+
|
| 154 |
+
if re.match(r"^\d{4}\.\d{4,5}(?:v\d+)?$", ref):
|
| 155 |
+
return "arxiv"
|
| 156 |
+
|
| 157 |
+
if re.match(r"^[a-z\-]+/\d{7}$", ref, re.IGNORECASE):
|
| 158 |
+
return "arxiv"
|
| 159 |
+
|
| 160 |
+
if re.match(r"^\d{7,8}$", ref):
|
| 161 |
+
return "pmid"
|
| 162 |
+
|
| 163 |
+
return "unknown"
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def download_paper_from_doi(doi):
|
| 167 |
+
"""
|
| 168 |
+
Attempt to download paper from DOI with multiple fallback methods
|
| 169 |
+
"""
|
| 170 |
+
# Clean DOI if it has prefix
|
| 171 |
+
doi = doi.replace('https://doi.org/', '').replace('http://doi.org/', '')
|
| 172 |
+
|
| 173 |
+
# Method 1: Try Unpaywall API (free, legal access)
|
| 174 |
+
try:
|
| 175 |
+
unpaywall_url = f"https://api.unpaywall.org/v2/{doi}?email=your@email.com"
|
| 176 |
+
response = requests.get(unpaywall_url, timeout=10)
|
| 177 |
+
if response.status_code == 200:
|
| 178 |
+
data = response.json()
|
| 179 |
+
if data.get('best_oa_location') and data['best_oa_location'].get('url_for_pdf'):
|
| 180 |
+
pdf_url = data['best_oa_location']['url_for_pdf']
|
| 181 |
+
text = download_pdf_from_url(pdf_url)
|
| 182 |
+
print(f"Found PDF via Unpaywall: {pdf_url}")
|
| 183 |
+
return text
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f"Unpaywall failed: {e}")
|
| 186 |
+
|
| 187 |
+
# Method 2: Try arXiv if it's an arXiv paper
|
| 188 |
+
if 'arxiv' in doi.lower() or doi.startswith('2'):
|
| 189 |
+
try:
|
| 190 |
+
# Extract arXiv ID
|
| 191 |
+
arxiv_id = doi.split('/')[-1] if '/' in doi else doi
|
| 192 |
+
arxiv_pdf_url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
|
| 193 |
+
text = download_pdf_from_url(arxiv_pdf_url)
|
| 194 |
+
print(f"Trying arXiv: {arxiv_pdf_url}")
|
| 195 |
+
return text
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"arXiv failed: {e}")
|
| 198 |
+
|
| 199 |
+
# Method 3: Try Sci-Hub (use with caution - check your local laws)
|
| 200 |
+
try:
|
| 201 |
+
scihub_url = f"https://sci-hub.se/{doi}"
|
| 202 |
+
print(f"Trying Sci-Hub: {scihub_url}")
|
| 203 |
+
headers = {'User-Agent': 'Mozilla/5.0'}
|
| 204 |
+
response = requests.get(scihub_url, headers=headers, timeout=15)
|
| 205 |
+
|
| 206 |
+
if response.status_code == 200:
|
| 207 |
+
# Look for PDF link in the HTML
|
| 208 |
+
pdf_match = re.search(r'(https?://[^"]+\.pdf[^"]*)', response.text)
|
| 209 |
+
if pdf_match:
|
| 210 |
+
pdf_url = pdf_match.group(1)
|
| 211 |
+
text = download_pdf_from_url(pdf_url)
|
| 212 |
+
print(f"got {doi} by chance")
|
| 213 |
+
return text
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Sci-Hub failed: {e}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def download_pdf_from_url(url):
|
| 220 |
+
"""
|
| 221 |
+
Download and extract text from a PDF URL
|
| 222 |
+
"""
|
| 223 |
+
headers = {
|
| 224 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
response = requests.get(url, headers=headers, timeout=30)
|
| 228 |
+
response.raise_for_status()
|
| 229 |
+
content_type = response.headers.get('content-type', '').lower()
|
| 230 |
+
if 'pdf' not in content_type and not response.content.startswith(b'%PDF'):
|
| 231 |
+
raise Exception(f"URL did not return a PDF (got {content_type})")
|
| 232 |
+
|
| 233 |
+
reader = PdfReader(BytesIO(response.content))
|
| 234 |
+
text = ""
|
| 235 |
+
for page in reader.pages:
|
| 236 |
+
text += page.extract_text() or ""
|
| 237 |
+
return text
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_paper_from_arxiv_id(doi: str):
|
| 241 |
+
"""
|
| 242 |
+
Retrieve paper from arXiv using its arXiv ID.
|
| 243 |
+
"""
|
| 244 |
+
client = arxiv.Client()
|
| 245 |
+
search = arxiv.Search(query=doi, max_results=1)
|
| 246 |
+
results = client.results(search)
|
| 247 |
+
pdf_url = next(results).pdf_url
|
| 248 |
+
text = parse_pdf_file(pdf_url)
|
| 249 |
+
return text
|
| 250 |
+
|
| 251 |
+
def get_paper_from_arxiv_id_langchain(arxiv_id: str):
|
| 252 |
+
"""
|
| 253 |
+
Retrieve paper from arXiv using its arXiv ID.
|
| 254 |
+
"""
|
| 255 |
+
search = "2304.07814"
|
| 256 |
+
retriever = ArxivRetriever(
|
| 257 |
+
load_max_docs=2,
|
| 258 |
+
get_full_documents=True,
|
| 259 |
+
)
|
| 260 |
+
docs = retriever.invoke(search)
|
| 261 |
+
return docs
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def parse_pdf_file(path:str) -> str:
|
| 265 |
+
|
| 266 |
+
if path.startswith("http://") or path.startswith("https://") or path.startswith("ftp://"):
|
| 267 |
+
response = requests.get(path)
|
| 268 |
+
response.raise_for_status() # Ensure download succeeded
|
| 269 |
+
reader = PdfReader(BytesIO(response.content))
|
| 270 |
+
else:
|
| 271 |
+
reader = PdfReader(path)
|
| 272 |
+
|
| 273 |
+
text = ""
|
| 274 |
+
for page in reader.pages:
|
| 275 |
+
text += page.extract_text() or ""
|
| 276 |
+
|
| 277 |
+
return text
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_pdf_content_serpapi(doi: str) -> str:
|
| 281 |
+
"""
|
| 282 |
+
Get the link to the paper from its DOI using SerpAPI Google Scholar search.
|
| 283 |
+
"""
|
| 284 |
+
client = serpapi.Client(api_key=os.getenv("SERPAPI_API_KEY"))
|
| 285 |
+
results = client.search({
|
| 286 |
+
'engine': 'google_scholar',
|
| 287 |
+
'q': doi,
|
| 288 |
+
})
|
| 289 |
+
|
| 290 |
+
pdf_path = results["organic_results"][0]["link"]
|
| 291 |
+
pdf_text = parse_pdf_file(pdf_path)
|
| 292 |
+
return pdf_text
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def get_paper_from_pmid(pmid:str):
|
| 298 |
+
src = FindIt(pmid)
|
| 299 |
+
if src.url:
|
| 300 |
+
pdf_text = parse_pdf_file(src.url)
|
| 301 |
+
return pdf_text
|
| 302 |
+
else:
|
| 303 |
+
print(src.reason)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def download_pdf_via_ftp(url: str) -> bytes:
|
| 309 |
+
"""
|
| 310 |
+
Download a PDF file from an FTP URL and return its content as bytes.
|
| 311 |
+
"""
|
| 312 |
+
parsed_url = urlparse(url)
|
| 313 |
+
ftp_host = parsed_url.netloc
|
| 314 |
+
ftp_path = parsed_url.path
|
| 315 |
+
|
| 316 |
+
file_buffer = BytesIO()
|
| 317 |
+
|
| 318 |
+
with FTP(ftp_host) as ftp:
|
| 319 |
+
ftp.login()
|
| 320 |
+
ftp.retrbinary(f'RETR {ftp_path}', file_buffer.write)
|
| 321 |
+
|
| 322 |
+
file_buffer.getvalue()
|
| 323 |
+
file_buffer.seek(0)
|
| 324 |
+
return file_buffer
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def parse_pdf_from_pubmed_pmid(pmid: str) -> str:
|
| 328 |
+
"""
|
| 329 |
+
Download and parse a PDF from PubMed using its PMID.
|
| 330 |
+
"""
|
| 331 |
+
url = f"https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?id={pmid}"
|
| 332 |
+
response = requests.get(url)
|
| 333 |
+
cleaned_string = response.content.decode('utf-8').strip()
|
| 334 |
+
try:
|
| 335 |
+
root = ET.fromstring(cleaned_string)
|
| 336 |
+
pdf_link_element = root.find(".//link[@format='pdf']")
|
| 337 |
+
ftp_url = pdf_link_element.get('href')
|
| 338 |
+
file_byte = download_pdf_via_ftp(ftp_url)
|
| 339 |
+
|
| 340 |
+
reader = PdfReader(file_byte)
|
| 341 |
+
text = ""
|
| 342 |
+
for page in reader.pages:
|
| 343 |
+
text += page.extract_text() or ""
|
| 344 |
+
print(f"got {pmid} via ftp download")
|
| 345 |
+
return text
|
| 346 |
+
except ET.ParseError as e:
|
| 347 |
+
pass
|
| 348 |
+
|
| 349 |
+
def safe_parse_of_ref_list(refs : list[str]) -> list[str]:
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
return
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def classify_ref(ref: str) -> str:
|
| 358 |
+
DOI_REGEX = r"10\.\d{4,9}/[-._;()/:A-Za-z0-9]+"
|
| 359 |
+
DOI_LOOSE = r"^10\.\d{4,9}/?[A-Za-z0-9.\-_/]+$" # supports 'NEJMoa2307100'
|
| 360 |
+
PMID_REGEX = r"^\d{7,8}$"
|
| 361 |
+
ARXIV_REGEX = r"^\d{4}\.\d{4,5}(v\d+)?$" # new style
|
| 362 |
+
ARXIV_OLD = r"^[a-z\-]+/\d{7}$" # old style hep-th/xxxxxxx
|
| 363 |
+
|
| 364 |
+
ref = ref.strip()
|
| 365 |
+
if re.match(DOI_REGEX, ref, re.IGNORECASE) or re.match(DOI_LOOSE, ref, re.IGNORECASE):
|
| 366 |
+
return "doi"
|
| 367 |
+
if re.match(PMID_REGEX, ref):
|
| 368 |
+
return "pmid"
|
| 369 |
+
if re.match(ARXIV_REGEX, ref, re.IGNORECASE) or re.match(ARXIV_OLD, ref, re.IGNORECASE):
|
| 370 |
+
return "arxiv"
|
| 371 |
+
return "unknown"
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def process_ref(ref: str):
|
| 375 |
+
"""We try twice to download"""
|
| 376 |
+
kind = classify_ref(ref)
|
| 377 |
+
if kind == "doi":
|
| 378 |
+
for tool in [download_paper_from_doi,get_pdf_content_serpapi]:
|
| 379 |
+
try:
|
| 380 |
+
return tool(ref)
|
| 381 |
+
except:
|
| 382 |
+
continue
|
| 383 |
+
if kind == "pmid":
|
| 384 |
+
for tool in [get_paper_from_pmid,parse_pdf_from_pubmed_pmid]:
|
| 385 |
+
try:
|
| 386 |
+
return tool(ref)
|
| 387 |
+
except:
|
| 388 |
+
continue
|
| 389 |
+
if kind == "arxiv":
|
| 390 |
+
for tool in [get_paper_from_arxiv_id,get_pdf_content_serpapi]:
|
| 391 |
+
try:
|
| 392 |
+
return tool(ref)
|
| 393 |
+
except:
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
print(f"Skipping invalid ref: {ref}")
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
| 401 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 402 |
+
from transformers import AutoTokenizer
|
| 403 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 404 |
+
from tqdm import tqdm
|
| 405 |
+
|
| 406 |
+
def create_vector_store_from_list_of_doi(refs :list[str], VECTOR_DB_PATH:str) -> str:
|
| 407 |
+
|
| 408 |
+
from langchain_community.vectorstores import FAISS
|
| 409 |
+
|
| 410 |
+
# define embedding
|
| 411 |
+
embedding_name="BAAI/bge-large-en-v1.5"
|
| 412 |
+
embedding_model = HuggingFaceEmbeddings(model_name=embedding_name,
|
| 413 |
+
model_kwargs={"device": "mps"},
|
| 414 |
+
encode_kwargs={"normalize_embeddings": True},)
|
| 415 |
+
try:
|
| 416 |
+
# Load the vector database from the folder
|
| 417 |
+
print(f"try to load vector store from {VECTOR_DB_PATH}")
|
| 418 |
+
KNOWLEDGE_VECTOR_DATABASE = FAISS.load_local(
|
| 419 |
+
VECTOR_DB_PATH,
|
| 420 |
+
embedding_model,
|
| 421 |
+
allow_dangerous_deserialization=True # Required for security in newer LangChain versions
|
| 422 |
+
)
|
| 423 |
+
existing_reference = [doc.metadata.get("source") for doc in KNOWLEDGE_VECTOR_DATABASE.docstore._dict.values()]
|
| 424 |
+
print("vectro store loaded")
|
| 425 |
+
except Exception as e :
|
| 426 |
+
print("FAISS load error:", e)
|
| 427 |
+
KNOWLEDGE_VECTOR_DATABASE = None
|
| 428 |
+
existing_reference = []
|
| 429 |
+
print("no vector store found, creating a new one...")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# fetch docs
|
| 433 |
+
extractor = ReferenceExtractor()
|
| 434 |
+
REFS = extractor.extract_references(refs) # Change here the type of IDs to DEBUG
|
| 435 |
+
raw_docs=[]
|
| 436 |
+
|
| 437 |
+
for ref in tqdm(REFS):
|
| 438 |
+
if ref not in set(existing_reference):
|
| 439 |
+
text = process_ref(ref)
|
| 440 |
+
if text:
|
| 441 |
+
raw_docs.append(LangchainDocument(page_content=text,metadata={'source':ref[0]}))
|
| 442 |
+
|
| 443 |
+
recover_yield = f" *** -> {round(100*len(raw_docs)/len(REFS))}% papers downloaded"
|
| 444 |
+
print(recover_yield)
|
| 445 |
+
|
| 446 |
+
# split texts into chunks
|
| 447 |
+
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
| 448 |
+
AutoTokenizer.from_pretrained(embedding_name),
|
| 449 |
+
chunk_size=3000,
|
| 450 |
+
chunk_overlap=int(3000 / 10),
|
| 451 |
+
add_start_index=True,
|
| 452 |
+
strip_whitespace=True,
|
| 453 |
+
separators="."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if raw_docs:
|
| 457 |
+
docs_processed = text_splitter.split_documents(raw_docs)
|
| 458 |
+
print("creating the vector store...")
|
| 459 |
+
|
| 460 |
+
# create the vector store
|
| 461 |
+
NEW_KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents(docs_processed, embedding_model, distance_strategy=DistanceStrategy.COSINE)
|
| 462 |
+
|
| 463 |
+
if KNOWLEDGE_VECTOR_DATABASE :
|
| 464 |
+
print("merge vector store")
|
| 465 |
+
KNOWLEDGE_VECTOR_DATABASE.merge_from(NEW_KNOWLEDGE_VECTOR_DATABASE)
|
| 466 |
+
KNOWLEDGE_VECTOR_DATABASE.save_local(VECTOR_DB_PATH)
|
| 467 |
+
else:
|
| 468 |
+
NEW_KNOWLEDGE_VECTOR_DATABASE.save_local(VECTOR_DB_PATH)
|
| 469 |
+
|
| 470 |
+
return VECTOR_DB_PATH
|
| 471 |
+
|
| 472 |
+
else:
|
| 473 |
+
return f"all the data already in vector store {VECTOR_DB_PATH}"
|
tool_describe_figure.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
# The OpenAI library handles the API key and base URL automatically
|
| 5 |
+
# after instantiation.
|
| 6 |
+
|
| 7 |
+
def thorough_picture_description(figure: str) -> str:
|
| 8 |
+
"""
|
| 9 |
+
Generates a thorough description for a given image URL using
|
| 10 |
+
the Nebius Token Factory endpoint.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
figure: The URL of the image to describe.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
The generated text description of the image.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
client = OpenAI(
|
| 21 |
+
base_url="https://api.tokenfactory.nebius.com/v1/",
|
| 22 |
+
api_key=os.environ.get("NEBIUS_API_KEY")
|
| 23 |
+
)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
|
| 26 |
+
return f"Error initializing OpenAI client: {e}"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
messages_payload = [
|
| 30 |
+
{
|
| 31 |
+
"role": "user",
|
| 32 |
+
"content": [
|
| 33 |
+
{"type": "text", "text": "Provide a very detailed, thorough, and descriptive analysis of this image."},
|
| 34 |
+
{
|
| 35 |
+
"type": "image_url",
|
| 36 |
+
"image_url": {"url": figure},
|
| 37 |
+
},
|
| 38 |
+
],
|
| 39 |
+
}
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
response = client.chat.completions.create(
|
| 45 |
+
model="gemini-2.5-flash",
|
| 46 |
+
messages=messages_payload,
|
| 47 |
+
max_tokens=2048
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if response.choices and response.choices[0].message.content:
|
| 52 |
+
return response.choices[0].message.content
|
| 53 |
+
else:
|
| 54 |
+
return "Could not retrieve a description from the API."
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
return f"An error occurred during the API call: {e}"
|
tool_fetch_documents_DOI.py
ADDED
|
File without changes
|
tool_query_FAISS_vector.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def query_vector_store(query: str, store_name: str, top_k: int = 5) -> dict:
|
| 7 |
+
"""
|
| 8 |
+
Query a specific vector store to retreive top_k documents related to the user question.
|
| 9 |
+
Each document have metadata that is the identification of the source, it must be said clearly.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
query (str): User's question
|
| 14 |
+
store_name (str): Which vector store to search
|
| 15 |
+
top_k (int): Number of chunks to retrieve
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
dict: Retrieved context, sources, store_name
|
| 19 |
+
"""
|
| 20 |
+
from langchain_community.vectorstores import FAISS
|
| 21 |
+
|
| 22 |
+
vector_stores = os.listdir("./vector_stores")
|
| 23 |
+
store_path = f"./vector_stores{store_name}"
|
| 24 |
+
if store_name not in vector_stores:
|
| 25 |
+
return {"error": f"Vector store '{store_name}' not found, you must create it first with tool create faiss vector"}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
embedding_name="BAAI/bge-large-en-v1.5"
|
| 29 |
+
embedding_model = HuggingFaceEmbeddings(model_name=embedding_name,
|
| 30 |
+
model_kwargs={"device": "mps"},
|
| 31 |
+
encode_kwargs={"normalize_embeddings": True},)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
vector_store = FAISS.load_local(
|
| 35 |
+
store_path,
|
| 36 |
+
embedding_model,
|
| 37 |
+
allow_dangerous_deserialization=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
results = vector_store.similarity_search(query, top_k)
|
| 41 |
+
|
| 42 |
+
context = "\n\n".join([r["text"] for r in results])
|
| 43 |
+
sources = [
|
| 44 |
+
{"ids": r["metadata"]["source"], "relevance": r["score"]}
|
| 45 |
+
for r in results
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
"context": context,
|
| 50 |
+
"sources": sources,
|
| 51 |
+
"store_name": store_name
|
| 52 |
+
}
|