VictorLJZ commited on
Commit
f237c31
·
1 Parent(s): e0efe4c

added llm factory and browsing capabilities

Browse files
main.py CHANGED
@@ -5,9 +5,7 @@ from dotenv import load_dotenv
5
  from transformers import logging
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
- from langchain_openai import ChatOpenAI
9
- from langgraph.checkpoint.memory import MemorySaver
10
- from langchain_openai import ChatOpenAI
11
 
12
  from interface import create_demo
13
  from medrax.agent import *
@@ -25,10 +23,10 @@ def initialize_agent(
25
  model_dir="/model-weights",
26
  temp_dir="temp",
27
  device="cuda",
28
- model="chatgpt-4o-latest",
29
  temperature=0.7,
30
  top_p=0.95,
31
- openai_kwargs={}
32
  ):
33
  """Initialize the MedRAX agent with specified tools and configuration.
34
 
@@ -38,10 +36,10 @@ def initialize_agent(
38
  model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
39
  temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
40
  device (str, optional): Device to run models on. Defaults to "cuda".
41
- model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
42
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
43
  top_p (float, optional): Top P for the model. Defaults to 0.95.
44
- openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
45
 
46
  Returns:
47
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
@@ -65,6 +63,7 @@ def initialize_agent(
65
  ),
66
  "ImageVisualizerTool": lambda: ImageVisualizerTool(),
67
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
 
68
  }
69
 
70
  # Initialize only selected tools or all if none specified
@@ -75,9 +74,22 @@ def initialize_agent(
75
  tools_dict[tool_name] = all_tools[tool_name]()
76
 
77
  checkpointer = MemorySaver()
78
- model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  agent = Agent(
80
- model,
81
  tools=list(tools_dict.values()),
82
  log_tools=True,
83
  log_dir="logs",
@@ -105,18 +117,19 @@ if __name__ == "__main__":
105
  "ChestXRaySegmentationTool",
106
  "ChestXRayReportGeneratorTool",
107
  "XRayVQATool",
 
108
  # "LlavaMedTool",
109
  # "XRayPhraseGroundingTool",
110
  # "ChestXRayGeneratorTool",
111
  ]
112
 
113
- # Collect the ENV variables
114
- openai_kwargs = {}
115
- if api_key := os.getenv("OPENAI_API_KEY"):
116
- openai_kwargs["api_key"] = api_key
117
-
118
- if base_url := os.getenv("OPENAI_BASE_URL"):
119
- openai_kwargs["base_url"] = base_url
120
 
121
  agent, tools_dict = initialize_agent(
122
  "medrax/docs/system_prompts.txt",
@@ -124,10 +137,10 @@ if __name__ == "__main__":
124
  model_dir="/model-weights", # Change this to the path of the model weights
125
  temp_dir="temp", # Change this to the path of the temporary directory
126
  device="cuda", # Change this to the device you want to use
127
- model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini
128
  temperature=0.7,
129
  top_p=0.95,
130
- openai_kwargs=openai_kwargs
131
  )
132
  demo = create_demo(agent, tools_dict)
133
 
 
5
  from transformers import logging
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
+ from medrax.models import ModelFactory
 
 
9
 
10
  from interface import create_demo
11
  from medrax.agent import *
 
23
  model_dir="/model-weights",
24
  temp_dir="temp",
25
  device="cuda",
26
+ model="gpt-4o",
27
  temperature=0.7,
28
  top_p=0.95,
29
+ model_kwargs={}
30
  ):
31
  """Initialize the MedRAX agent with specified tools and configuration.
32
 
 
36
  model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
37
  temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
38
  device (str, optional): Device to run models on. Defaults to "cuda".
39
+ model (str, optional): Model to use. Defaults to "gpt-4o".
40
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
41
  top_p (float, optional): Top P for the model. Defaults to 0.95.
42
+ model_kwargs (dict, optional): Additional keyword arguments for model.
43
 
44
  Returns:
45
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
 
63
  ),
64
  "ImageVisualizerTool": lambda: ImageVisualizerTool(),
65
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
66
+ "WebBrowserTool": lambda: WebBrowserTool(),
67
  }
68
 
69
  # Initialize only selected tools or all if none specified
 
74
  tools_dict[tool_name] = all_tools[tool_name]()
75
 
76
  checkpointer = MemorySaver()
77
+
78
+ # Create the language model using the factory
79
+ try:
80
+ llm = ModelFactory.create_model(
81
+ model_name=model,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ **model_kwargs
85
+ )
86
+ except ValueError as e:
87
+ print(f"Error creating language model: {e}")
88
+ print(f"Available model providers: {list(ModelFactory._model_providers.keys())}")
89
+ raise
90
+
91
  agent = Agent(
92
+ llm,
93
  tools=list(tools_dict.values()),
94
  log_tools=True,
95
  log_dir="logs",
 
117
  "ChestXRaySegmentationTool",
118
  "ChestXRayReportGeneratorTool",
119
  "XRayVQATool",
120
+ "WebBrowserTool", # Add the web browser tool
121
  # "LlavaMedTool",
122
  # "XRayPhraseGroundingTool",
123
  # "ChestXRayGeneratorTool",
124
  ]
125
 
126
+ # Prepare any additional model-specific kwargs
127
+ model_kwargs = {}
128
+
129
+ # Set up API keys for the web browser tool
130
+ # You'll need to set these environment variables:
131
+ # - GOOGLE_SEARCH_API_KEY: Your Google Custom Search API key
132
+ # - GOOGLE_SEARCH_ENGINE_ID: Your Google Custom Search Engine ID
133
 
134
  agent, tools_dict = initialize_agent(
135
  "medrax/docs/system_prompts.txt",
 
137
  model_dir="/model-weights", # Change this to the path of the model weights
138
  temp_dir="temp", # Change this to the path of the temporary directory
139
  device="cuda", # Change this to the device you want to use
140
+ model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini, gemini-2.5-pro
141
  temperature=0.7,
142
  top_p=0.95,
143
+ model_kwargs=model_kwargs
144
  )
145
  demo = create_demo(agent, tools_dict)
146
 
medrax/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Model module for MedRAX."""
2
+
3
+ from .model_factory import ModelFactory
4
+
5
+ __all__ = ["ModelFactory"]
medrax/models/model_factory.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory for creating language model instances based on model name."""
2
+
3
+ import os
4
+ from typing import Dict, Any, Type
5
+
6
+ from langchain_core.language_models import BaseLanguageModel
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+
10
+
11
+ class ModelFactory:
12
+ """Factory for creating language model instances based on model name.
13
+
14
+ This class implements a registry of language model providers and provides
15
+ methods to create appropriate language model instances based on the model name.
16
+ """
17
+
18
+ # Registry of model providers
19
+ _model_providers = {
20
+ "gpt": {
21
+ "class": ChatOpenAI,
22
+ "env_key": "OPENAI_API_KEY",
23
+ "base_url_key": "OPENAI_BASE_URL"
24
+ },
25
+ "gemini": {
26
+ "class": ChatGoogleGenerativeAI,
27
+ "env_key": "GOOGLE_API_KEY"
28
+ },
29
+ # Add more providers with default configurations here
30
+ }
31
+
32
+ @classmethod
33
+ def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel],
34
+ env_key: str, **kwargs) -> None:
35
+ """Register a new model provider.
36
+
37
+ Args:
38
+ prefix (str): The prefix used to identify this model provider (e.g., 'gpt', 'gemini')
39
+ model_class (Type[BaseLanguageModel]): The LangChain model class to use
40
+ env_key (str): The environment variable name for the API key
41
+ **kwargs: Additional provider-specific configuration
42
+ """
43
+ cls._model_providers[prefix] = {
44
+ "class": model_class,
45
+ "env_key": env_key,
46
+ **kwargs
47
+ }
48
+
49
+ @classmethod
50
+ def create_model(cls, model_name: str, temperature: float = 0.7,
51
+ top_p: float = 0.95, **kwargs) -> BaseLanguageModel:
52
+ """Create and return an instance of the appropriate language model.
53
+
54
+ Args:
55
+ model_name (str): Name of the model to create (e.g., 'gpt-4o', 'gemini-2.5-pro')
56
+ temperature (float, optional): Temperature parameter. Defaults to 0.7.
57
+ top_p (float, optional): Top-p sampling parameter. Defaults to 0.95.
58
+ **kwargs: Additional model-specific parameters
59
+
60
+ Returns:
61
+ BaseLanguageModel: An initialized language model instance
62
+
63
+ Raises:
64
+ ValueError: If no provider is found for the given model name
65
+ ValueError: If the required API key is missing
66
+ """
67
+ # Find the matching provider based on model name prefix
68
+ provider_prefix = next(
69
+ (prefix for prefix in cls._model_providers if model_name.startswith(prefix)),
70
+ None
71
+ )
72
+
73
+ if not provider_prefix:
74
+ raise ValueError(
75
+ f"No provider found for model: {model_name}. "
76
+ f"Registered providers are for: {list(cls._model_providers.keys())}"
77
+ )
78
+
79
+ provider = cls._model_providers[provider_prefix]
80
+ model_class = provider["class"]
81
+ env_key = provider["env_key"]
82
+
83
+ # Set up provider-specific kwargs
84
+ provider_kwargs = {}
85
+
86
+ # Handle API key
87
+ if env_key in os.environ:
88
+ provider_kwargs["api_key"] = os.environ[env_key]
89
+ else:
90
+ # Log warning but don't fail - the model class might handle missing API keys differently
91
+ print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
92
+
93
+ # Check for base_url if applicable
94
+ if "base_url_key" in provider and provider["base_url_key"] in os.environ:
95
+ provider_kwargs["base_url"] = os.environ[provider["base_url_key"]]
96
+
97
+ # Merge with any additional provider-specific settings from the registry
98
+ for k, v in provider.items():
99
+ if k not in ["class", "env_key", "base_url_key"]:
100
+ provider_kwargs[k] = v
101
+
102
+ # Create and return the model instance
103
+ return model_class(
104
+ model=model_name,
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ **provider_kwargs,
108
+ **kwargs
109
+ )
110
+
111
+ @classmethod
112
+ def list_providers(cls) -> Dict[str, Dict[str, Any]]:
113
+ """List all registered model providers.
114
+
115
+ Returns:
116
+ Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
117
+ """
118
+ # Return a copy to prevent accidental modification
119
+ return {k: {kk: vv for kk, vv in v.items() if kk != "class"}
120
+ for k, v in cls._model_providers.items()}
medrax/tools/__init__.py CHANGED
@@ -9,3 +9,4 @@ from .grounding import *
9
  from .generation import *
10
  from .dicom import *
11
  from .utils import *
 
 
9
  from .generation import *
10
  from .dicom import *
11
  from .utils import *
12
+ from .web_browser import *
medrax/tools/web_browser.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web browser tool for MedRAX2.
2
+
3
+ This module implements a web browsing tool for MedRAX2, allowing the agent
4
+ to search the web, visit URLs, and extract information from web pages.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import json
10
+ from typing import Dict, Optional, Any
11
+ from urllib.parse import urlparse
12
+
13
+ import requests
14
+ from bs4 import BeautifulSoup
15
+ from langchain_core.tools import BaseTool
16
+ from pydantic import BaseModel, Field
17
+
18
+
19
+ class SearchQuerySchema(BaseModel):
20
+ """Schema for web search queries."""
21
+ query: str = Field(..., description="The search query string")
22
+
23
+
24
+ class VisitUrlSchema(BaseModel):
25
+ """Schema for URL visits."""
26
+ url: str = Field(..., description="The URL to visit")
27
+
28
+
29
+ class WebBrowserTool(BaseTool):
30
+ """Tool for browsing the web, searching for information, and visiting URLs.
31
+
32
+ This tool provides the agent with internet browsing capabilities, including:
33
+ 1. Performing web searches using a search engine API
34
+ 2. Visiting specific URLs and extracting their content
35
+ 3. Following links within pages
36
+ """
37
+ name: str = "WebBrowserTool"
38
+ description: str = "Search the web for information or visit specific URLs to retrieve content"
39
+ search_api_key: Optional[str] = None
40
+ search_engine_id: Optional[str] = None
41
+ user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
42
+ max_results: int = 5
43
+
44
+ def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
45
+ """Initialize the web browser tool.
46
+
47
+ Args:
48
+ search_api_key: Google Custom Search API key (optional)
49
+ search_engine_id: Google Custom Search Engine ID (optional)
50
+ **kwargs: Additional keyword arguments
51
+ """
52
+ super().__init__(**kwargs)
53
+ # Try to get API keys from environment variables if not provided
54
+ self.search_api_key = search_api_key or os.environ.get("GOOGLE_SEARCH_API_KEY")
55
+ self.search_engine_id = search_engine_id or os.environ.get("GOOGLE_SEARCH_ENGINE_ID")
56
+
57
+ def search_web(self, query: str) -> Dict[str, Any]:
58
+ """Search the web using Google Custom Search API.
59
+
60
+ Args:
61
+ query: The search query string
62
+
63
+ Returns:
64
+ Dict containing search results
65
+ """
66
+ if not self.search_api_key or not self.search_engine_id:
67
+ return {
68
+ "error": "Search API key or engine ID not configured. Please set GOOGLE_SEARCH_API_KEY and GOOGLE_SEARCH_ENGINE_ID environment variables."
69
+ }
70
+
71
+ url = "https://www.googleapis.com/customsearch/v1"
72
+ params = {
73
+ "key": self.search_api_key,
74
+ "cx": self.search_engine_id,
75
+ "q": query,
76
+ "num": self.max_results
77
+ }
78
+
79
+ try:
80
+ response = requests.get(url, params=params, timeout=10)
81
+ response.raise_for_status()
82
+ results = response.json()
83
+
84
+ if "items" not in results:
85
+ return {"results": [], "message": "No results found"}
86
+
87
+ formatted_results = []
88
+ for item in results["items"]:
89
+ formatted_results.append({
90
+ "title": item.get("title"),
91
+ "link": item.get("link"),
92
+ "snippet": item.get("snippet"),
93
+ "source": item.get("displayLink")
94
+ })
95
+
96
+ return {
97
+ "results": formatted_results,
98
+ "message": f"Found {len(formatted_results)} results for query: {query}"
99
+ }
100
+
101
+ except Exception as e:
102
+ return {"error": f"Search failed: {str(e)}"}
103
+
104
+ def visit_url(self, url: str) -> Dict[str, Any]:
105
+ """Visit a URL and extract its content.
106
+
107
+ Args:
108
+ url: The URL to visit
109
+
110
+ Returns:
111
+ Dict containing the page content, title, and metadata
112
+ """
113
+ try:
114
+ # Validate URL
115
+ parsed_url = urlparse(url)
116
+ if not parsed_url.scheme or not parsed_url.netloc:
117
+ return {"error": f"Invalid URL: {url}"}
118
+
119
+ headers = {"User-Agent": self.user_agent}
120
+ response = requests.get(url, headers=headers, timeout=15)
121
+ response.raise_for_status()
122
+
123
+ # Parse the HTML content
124
+ soup = BeautifulSoup(response.text, "html.parser")
125
+
126
+ # Extract title
127
+ title = soup.title.string if soup.title else "No title"
128
+
129
+ # Extract main content (remove scripts, styles, etc.)
130
+ for script in soup(["script", "style", "meta", "noscript"]):
131
+ script.extract()
132
+
133
+ # Get text content
134
+ text_content = soup.get_text(separator="\n", strip=True)
135
+ # Clean up whitespace
136
+ text_content = re.sub(r'\n+', '\n', text_content)
137
+ text_content = re.sub(r' +', ' ', text_content)
138
+
139
+ # Extract links
140
+ links = []
141
+ for link in soup.find_all("a", href=True):
142
+ href = link["href"]
143
+ # Handle relative URLs
144
+ if href.startswith("/"):
145
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
146
+ href = base_url + href
147
+ if href.startswith(("http://", "https://")):
148
+ links.append({
149
+ "text": link.get_text(strip=True) or href,
150
+ "url": href
151
+ })
152
+
153
+ # Extract images (limited to first 3)
154
+ images = []
155
+ for i, img in enumerate(soup.find_all("img", src=True)[:3]):
156
+ src = img["src"]
157
+ # Handle relative URLs
158
+ if src.startswith("/"):
159
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
160
+ src = base_url + src
161
+ if src.startswith(("http://", "https://")):
162
+ images.append(src)
163
+
164
+ return {
165
+ "title": title,
166
+ "content": text_content[:10000] if len(text_content) > 10000 else text_content,
167
+ "url": url,
168
+ "links": links[:10], # Limit to 10 links
169
+ "images": images,
170
+ "content_type": response.headers.get("Content-Type", ""),
171
+ "content_length": len(text_content),
172
+ "truncated": len(text_content) > 10000
173
+ }
174
+
175
+ except Exception as e:
176
+ return {"error": f"Failed to visit {url}: {str(e)}"}
177
+
178
+ async def _arun(self, query: str = "", url: str = "") -> str:
179
+ """Run the tool asynchronously."""
180
+ return json.dumps(self._run(query=query, url=url))
181
+
182
+ def _run(self, query: str = "", url: str = "") -> Dict[str, Any]:
183
+ """Run the web browser tool.
184
+
185
+ Args:
186
+ query: Search query (if searching)
187
+ url: URL to visit (if visiting a specific page)
188
+
189
+ Returns:
190
+ Dict containing the results
191
+ """
192
+ if url:
193
+ return self.visit_url(url)
194
+ elif query:
195
+ return self.search_web(query)
196
+ else:
197
+ return {"error": "Please provide either a search query or a URL to visit"}
198
+
199
+ def args_schema(self) -> type[BaseModel]:
200
+ """Return the schema for the tool arguments."""
201
+ class WebBrowserSchema(BaseModel):
202
+ """Combined schema for web browser tool."""
203
+ query: str = Field("", description="The search query (leave empty if visiting a URL)")
204
+ url: str = Field("", description="The URL to visit (leave empty if performing a search)")
205
+ return WebBrowserSchema