Spaces:
Paused
Paused
refined and tested gemini and browsing
Browse files- main.py +10 -10
- medrax/tools/web_browser.py +25 -15
main.py
CHANGED
|
@@ -111,12 +111,12 @@ if __name__ == "__main__":
|
|
| 111 |
# Example: initialize with only specific tools
|
| 112 |
# Here three tools are commented out, you can uncomment them to use them
|
| 113 |
selected_tools = [
|
| 114 |
-
"ImageVisualizerTool",
|
| 115 |
-
"DicomProcessorTool",
|
| 116 |
-
"ChestXRayClassifierTool",
|
| 117 |
-
"ChestXRaySegmentationTool",
|
| 118 |
-
"ChestXRayReportGeneratorTool",
|
| 119 |
-
"XRayVQATool",
|
| 120 |
"WebBrowserTool", # Add the web browser tool
|
| 121 |
# "LlavaMedTool",
|
| 122 |
# "XRayPhraseGroundingTool",
|
|
@@ -130,14 +130,14 @@ if __name__ == "__main__":
|
|
| 130 |
# You'll need to set these environment variables:
|
| 131 |
# - GOOGLE_SEARCH_API_KEY: Your Google Custom Search API key
|
| 132 |
# - GOOGLE_SEARCH_ENGINE_ID: Your Google Custom Search Engine ID
|
| 133 |
-
|
| 134 |
agent, tools_dict = initialize_agent(
|
| 135 |
"medrax/docs/system_prompts.txt",
|
| 136 |
tools_to_use=selected_tools,
|
| 137 |
-
model_dir="/
|
| 138 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 139 |
-
device="
|
| 140 |
-
model="
|
| 141 |
temperature=0.7,
|
| 142 |
top_p=0.95,
|
| 143 |
model_kwargs=model_kwargs
|
|
|
|
| 111 |
# Example: initialize with only specific tools
|
| 112 |
# Here three tools are commented out, you can uncomment them to use them
|
| 113 |
selected_tools = [
|
| 114 |
+
# "ImageVisualizerTool",
|
| 115 |
+
# "DicomProcessorTool",
|
| 116 |
+
# "ChestXRayClassifierTool",
|
| 117 |
+
# "ChestXRaySegmentationTool",
|
| 118 |
+
# "ChestXRayReportGeneratorTool",
|
| 119 |
+
# "XRayVQATool",
|
| 120 |
"WebBrowserTool", # Add the web browser tool
|
| 121 |
# "LlavaMedTool",
|
| 122 |
# "XRayPhraseGroundingTool",
|
|
|
|
| 130 |
# You'll need to set these environment variables:
|
| 131 |
# - GOOGLE_SEARCH_API_KEY: Your Google Custom Search API key
|
| 132 |
# - GOOGLE_SEARCH_ENGINE_ID: Your Google Custom Search Engine ID
|
| 133 |
+
|
| 134 |
agent, tools_dict = initialize_agent(
|
| 135 |
"medrax/docs/system_prompts.txt",
|
| 136 |
tools_to_use=selected_tools,
|
| 137 |
+
model_dir="/m_weights", # Change this to the path of the model weights
|
| 138 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 139 |
+
device="cpu", # Change this to the device you want to use
|
| 140 |
+
model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4o-mini, gemini-2.5-pro
|
| 141 |
temperature=0.7,
|
| 142 |
top_p=0.95,
|
| 143 |
model_kwargs=model_kwargs
|
medrax/tools/web_browser.py
CHANGED
|
@@ -7,7 +7,8 @@ to search the web, visit URLs, and extract information from web pages.
|
|
| 7 |
import os
|
| 8 |
import re
|
| 9 |
import json
|
| 10 |
-
|
|
|
|
| 11 |
from urllib.parse import urlparse
|
| 12 |
|
| 13 |
import requests
|
|
@@ -16,6 +17,12 @@ from langchain_core.tools import BaseTool
|
|
| 16 |
from pydantic import BaseModel, Field
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class SearchQuerySchema(BaseModel):
|
| 20 |
"""Schema for web search queries."""
|
| 21 |
query: str = Field(..., description="The search query string")
|
|
@@ -40,6 +47,7 @@ class WebBrowserTool(BaseTool):
|
|
| 40 |
search_engine_id: Optional[str] = None
|
| 41 |
user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
| 42 |
max_results: int = 5
|
|
|
|
| 43 |
|
| 44 |
def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
|
| 45 |
"""Initialize the web browser tool.
|
|
@@ -179,27 +187,29 @@ class WebBrowserTool(BaseTool):
|
|
| 179 |
"""Run the tool asynchronously."""
|
| 180 |
return json.dumps(self._run(query=query, url=url))
|
| 181 |
|
| 182 |
-
def _run(self, query: str = "", url: str = "") -> Dict[str, Any]:
|
| 183 |
"""Run the web browser tool.
|
| 184 |
-
|
| 185 |
Args:
|
| 186 |
query: Search query (if searching)
|
| 187 |
url: URL to visit (if visiting a specific page)
|
| 188 |
-
|
| 189 |
Returns:
|
| 190 |
-
Dict containing the results
|
| 191 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
if url:
|
| 193 |
-
|
|
|
|
| 194 |
elif query:
|
| 195 |
-
|
|
|
|
| 196 |
else:
|
| 197 |
-
return {"error": "Please provide either a search query or a URL to visit"}
|
| 198 |
|
| 199 |
-
def args_schema(self) -> type[BaseModel]:
|
| 200 |
-
"""Return the schema for the tool arguments."""
|
| 201 |
-
class WebBrowserSchema(BaseModel):
|
| 202 |
-
"""Combined schema for web browser tool."""
|
| 203 |
-
query: str = Field("", description="The search query (leave empty if visiting a URL)")
|
| 204 |
-
url: str = Field("", description="The URL to visit (leave empty if performing a search)")
|
| 205 |
-
return WebBrowserSchema
|
|
|
|
| 7 |
import os
|
| 8 |
import re
|
| 9 |
import json
|
| 10 |
+
import time
|
| 11 |
+
from typing import Dict, Optional, Any, Type, Tuple
|
| 12 |
from urllib.parse import urlparse
|
| 13 |
|
| 14 |
import requests
|
|
|
|
| 17 |
from pydantic import BaseModel, Field
|
| 18 |
|
| 19 |
|
| 20 |
+
class WebBrowserSchema(BaseModel):
|
| 21 |
+
"""Schema for web browser tool."""
|
| 22 |
+
query: str = Field("", description="The search query (leave empty if visiting a URL)")
|
| 23 |
+
url: str = Field("", description="The URL to visit (leave empty if performing a search)")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class SearchQuerySchema(BaseModel):
|
| 27 |
"""Schema for web search queries."""
|
| 28 |
query: str = Field(..., description="The search query string")
|
|
|
|
| 47 |
search_engine_id: Optional[str] = None
|
| 48 |
user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
| 49 |
max_results: int = 5
|
| 50 |
+
args_schema: Type[BaseModel] = WebBrowserSchema
|
| 51 |
|
| 52 |
def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
|
| 53 |
"""Initialize the web browser tool.
|
|
|
|
| 187 |
"""Run the tool asynchronously."""
|
| 188 |
return json.dumps(self._run(query=query, url=url))
|
| 189 |
|
| 190 |
+
def _run(self, query: str = "", url: str = "") -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 191 |
"""Run the web browser tool.
|
| 192 |
+
|
| 193 |
Args:
|
| 194 |
query: Search query (if searching)
|
| 195 |
url: URL to visit (if visiting a specific page)
|
| 196 |
+
|
| 197 |
Returns:
|
| 198 |
+
Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the results and metadata
|
| 199 |
"""
|
| 200 |
+
metadata = {
|
| 201 |
+
"query": query if query else "",
|
| 202 |
+
"url": url if url else "",
|
| 203 |
+
"timestamp": time.time(),
|
| 204 |
+
"tool": "WebBrowserTool"
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
if url:
|
| 208 |
+
result = self.visit_url(url)
|
| 209 |
+
return result, metadata
|
| 210 |
elif query:
|
| 211 |
+
result = self.search_web(query)
|
| 212 |
+
return result, metadata
|
| 213 |
else:
|
| 214 |
+
return {"error": "Please provide either a search query or a URL to visit"}, metadata
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|