samwell commited on
Commit
27f1dea
·
1 Parent(s): 1f83b1b

Add NVIDIA NV-Reason-CXR tool for expert chest X-ray analysis

Browse files
Files changed (3) hide show
  1. app.py +11 -0
  2. medrax/tools/__init__.py +1 -0
  3. medrax/tools/nv_reason_cxr.py +202 -0
app.py CHANGED
@@ -34,6 +34,17 @@ tools = []
34
 
35
  if device == "cuda":
36
  # Load GPU-based tools
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
  from medrax.tools import XRayPhraseGroundingTool
39
  grounding_tool = XRayPhraseGroundingTool(
 
34
 
35
  if device == "cuda":
36
  # Load GPU-based tools
37
+ try:
38
+ from medrax.tools import NVReasonCXRTool
39
+ nv_reason_tool = NVReasonCXRTool(
40
+ device=device,
41
+ load_in_4bit=True
42
+ )
43
+ tools.append(nv_reason_tool)
44
+ print("✓ Loaded NV-Reason-CXR tool")
45
+ except Exception as e:
46
+ print(f"✗ Failed to load NV-Reason-CXR tool: {e}")
47
+
48
  try:
49
  from medrax.tools import XRayPhraseGroundingTool
50
  grounding_tool = XRayPhraseGroundingTool(
medrax/tools/__init__.py CHANGED
@@ -11,3 +11,4 @@ from .utils import *
11
  from .rag import *
12
  from .browsing import *
13
  from .python_tool import *
 
 
11
  from .rag import *
12
  from .browsing import *
13
  from .python_tool import *
14
+ from .nv_reason_cxr import *
medrax/tools/nv_reason_cxr.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NVIDIA NV-Reason-CXR tool for expert chest X-ray analysis."""
2
+ from typing import Dict, Optional, Tuple, Type, Any
3
+ from pathlib import Path
4
+ import torch
5
+ from PIL import Image
6
+ from pydantic import BaseModel, Field
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
8
+
9
+ from langchain_core.callbacks import (
10
+ AsyncCallbackManagerForToolRun,
11
+ CallbackManagerForToolRun,
12
+ )
13
+ from langchain_core.tools import BaseTool
14
+
15
+
16
+ class NVReasonCXRInput(BaseModel):
17
+ """Input schema for the NV-Reason-CXR Tool."""
18
+
19
+ image_path: str = Field(
20
+ ...,
21
+ description="Path to the chest X-ray image file (JPG or PNG)",
22
+ )
23
+ query: str = Field(
24
+ default="Find abnormalities and support devices.",
25
+ description="Question or instruction for analyzing the X-ray (e.g., 'Find abnormalities and support devices', 'Provide differential diagnoses', 'Write a structured report')",
26
+ )
27
+ max_new_tokens: int = Field(
28
+ default=2048,
29
+ description="Maximum number of tokens to generate in response"
30
+ )
31
+
32
+
33
+ class NVReasonCXRTool(BaseTool):
34
+ """Tool for expert chest X-ray analysis using NVIDIA's NV-Reason-CXR model.
35
+
36
+ This tool uses NVIDIA's specialized NV-Reason-CXR-3B model for detailed chest X-ray
37
+ analysis, including abnormality detection, support device identification, differential
38
+ diagnoses, and structured report generation.
39
+ """
40
+
41
+ name: str = "nv_reason_cxr_analysis"
42
+ description: str = (
43
+ "Expert chest X-ray analysis using NVIDIA's specialized NV-Reason-CXR model. "
44
+ "This tool provides detailed medical reasoning and can: "
45
+ "1) Detect abnormalities and support devices in chest X-rays "
46
+ "2) Provide differential diagnoses "
47
+ "3) Generate structured radiology reports "
48
+ "4) Answer specific questions about chest X-ray findings. "
49
+ "Use this for comprehensive chest X-ray interpretation. "
50
+ "Example input: {'image_path': '/path/to/xray.jpg', 'query': 'Find abnormalities and support devices'}"
51
+ )
52
+ args_schema: Type[BaseModel] = NVReasonCXRInput
53
+
54
+ model: Any = None
55
+ processor: Any = None
56
+ device: str = "cuda"
57
+
58
+ def __init__(
59
+ self,
60
+ model_path: str = "nvidia/NV-Reason-CXR-3B",
61
+ cache_dir: Optional[str] = None,
62
+ load_in_4bit: bool = True,
63
+ device: Optional[str] = "cuda",
64
+ ):
65
+ """Initialize the NV-Reason-CXR Tool."""
66
+ super().__init__()
67
+ self.device = device
68
+
69
+ # Setup quantization config
70
+ if load_in_4bit:
71
+ quantization_config = BitsAndBytesConfig(
72
+ load_in_4bit=True,
73
+ bnb_4bit_compute_dtype=torch.bfloat16,
74
+ bnb_4bit_use_double_quant=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ )
77
+ else:
78
+ quantization_config = None
79
+
80
+ # Load model
81
+ print(f"Loading NV-Reason-CXR model from {model_path}...")
82
+ self.model = AutoModelForImageTextToText.from_pretrained(
83
+ model_path,
84
+ device_map=self.device,
85
+ cache_dir=cache_dir,
86
+ torch_dtype=torch.bfloat16,
87
+ quantization_config=quantization_config,
88
+ trust_remote_code=True,
89
+ ).eval()
90
+
91
+ self.processor = AutoProcessor.from_pretrained(
92
+ model_path,
93
+ cache_dir=cache_dir,
94
+ trust_remote_code=True,
95
+ use_fast=True,
96
+ )
97
+
98
+ print(f"✓ NV-Reason-CXR model loaded successfully")
99
+
100
+ def _run(
101
+ self,
102
+ image_path: str,
103
+ query: str = "Find abnormalities and support devices.",
104
+ max_new_tokens: int = 2048,
105
+ run_manager: Optional[CallbackManagerForToolRun] = None,
106
+ ) -> Tuple[Dict[str, Any], Dict]:
107
+ """Analyze a chest X-ray image using NV-Reason-CXR.
108
+
109
+ Args:
110
+ image_path: Path to the chest X-ray image file
111
+ query: Question or instruction for analysis
112
+ max_new_tokens: Maximum tokens to generate
113
+ run_manager: Optional callback manager
114
+
115
+ Returns:
116
+ Tuple[Dict, Dict]: Output dictionary and metadata dictionary
117
+ """
118
+ try:
119
+ # Load image
120
+ image = Image.open(image_path)
121
+ if image.mode != "RGB":
122
+ image = image.convert("RGB")
123
+
124
+ # Prepare messages in chat format
125
+ messages = [
126
+ {
127
+ "role": "user",
128
+ "content": [
129
+ {"type": "image"},
130
+ {"type": "text", "text": query}
131
+ ]
132
+ }
133
+ ]
134
+
135
+ # Apply chat template
136
+ prompt = self.processor.apply_chat_template(
137
+ messages,
138
+ add_generation_prompt=True
139
+ )
140
+
141
+ # Prepare inputs
142
+ inputs = self.processor(
143
+ text=prompt,
144
+ images=[image],
145
+ return_tensors="pt"
146
+ )
147
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
148
+
149
+ # Generate response
150
+ with torch.inference_mode():
151
+ output_ids = self.model.generate(
152
+ **inputs,
153
+ max_new_tokens=max_new_tokens,
154
+ do_sample=False, # Deterministic for medical analysis
155
+ pad_token_id=self.processor.tokenizer.eos_token_id,
156
+ )
157
+
158
+ # Decode response
159
+ prompt_length = inputs["input_ids"].shape[-1]
160
+ generated_ids = output_ids[0][prompt_length:]
161
+ response = self.processor.decode(
162
+ generated_ids,
163
+ skip_special_tokens=True,
164
+ clean_up_tokenization_spaces=True
165
+ )
166
+
167
+ output = {
168
+ "analysis": response,
169
+ "query": query,
170
+ }
171
+
172
+ metadata = {
173
+ "image_path": image_path,
174
+ "model": "nvidia/NV-Reason-CXR-3B",
175
+ "device": str(self.device),
176
+ "tokens_generated": len(generated_ids),
177
+ "status": "completed",
178
+ }
179
+
180
+ return output, metadata
181
+
182
+ except Exception as e:
183
+ output = {
184
+ "error": str(e),
185
+ "analysis": None,
186
+ }
187
+ metadata = {
188
+ "image_path": image_path,
189
+ "status": "failed",
190
+ "error_details": str(e),
191
+ }
192
+ return output, metadata
193
+
194
+ async def _arun(
195
+ self,
196
+ image_path: str,
197
+ query: str = "Find abnormalities and support devices.",
198
+ max_new_tokens: int = 2048,
199
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
200
+ ) -> Tuple[Dict[str, Any], Dict]:
201
+ """Asynchronous version of _run."""
202
+ return self._run(image_path, query, max_new_tokens, run_manager)