File size: 13,242 Bytes
b6fb0da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc78d7
 
 
b6fb0da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f83b1b
3d05726
37bdbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
27f1dea
a782449
 
 
 
 
 
 
 
 
 
 
 
b6fb0da
1f83b1b
b8918e3
 
1f83b1b
 
 
 
 
 
 
 
 
 
b8918e3
 
55abc5d
1f83b1b
 
 
 
 
 
 
b8918e3
 
55abc5d
1f83b1b
 
 
 
 
 
f372684
 
 
 
 
 
 
 
 
 
 
1f83b1b
 
b8918e3
 
1f83b1b
 
 
 
 
 
b8918e3
 
1f83b1b
 
 
 
 
b6fb0da
 
 
8bd044e
 
 
b6fb0da
 
 
 
 
 
86d018b
 
11b0dac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2960ee
11b0dac
 
 
 
 
b6fb0da
9dc78d7
 
b2960ee
b6fb0da
 
 
9dc78d7
 
 
b2960ee
3a18164
 
 
 
9dc78d7
 
 
3a18164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc78d7
 
3a18164
 
9dc78d7
 
 
3a18164
 
 
9dc78d7
b6fb0da
 
9dc78d7
 
 
 
 
 
 
 
 
 
b6fb0da
9dc78d7
b6fb0da
 
86d018b
 
 
 
7d28f28
879e071
7d28f28
86d018b
 
 
 
a5cf627
 
 
86d018b
d2991ca
 
 
 
 
 
b2960ee
 
 
 
d2991ca
 
 
 
 
879e071
86d018b
879e071
 
8bd044e
b6fb0da
11b0dac
 
 
 
 
 
 
 
 
227178d
 
30b53f2
227178d
30b53f2
227178d
30b53f2
227178d
 
30b53f2
227178d
30b53f2
227178d
 
30b53f2
 
 
227178d
 
 
 
30b53f2
879e071
8edb8da
 
 
 
11b0dac
 
 
8edb8da
 
 
 
11b0dac
 
879e071
e7caa1e
cd8c9a2
 
 
4167f4b
 
 
 
 
 
 
 
 
 
79ed555
879e071
8edb8da
 
 
 
 
879e071
8edb8da
879e071
 
 
 
 
 
 
 
b6fb0da
 
0b51074
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
MedRax2 Gradio Interface for Hugging Face Spaces
Simple standalone version for deployment
"""
import os
os.environ["OMP_NUM_THREADS"] = "1"

import warnings
warnings.filterwarnings("ignore")

from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
    print("βœ“ Logged in to HuggingFace")

import gradio as gr
from dotenv import load_dotenv
import torch
from PIL import Image
import base64
from io import BytesIO

load_dotenv()

from langgraph.checkpoint.memory import MemorySaver
from medrax.models import ModelFactory
from medrax.agent import Agent
from medrax.utils import load_prompts_from_file

os.makedirs("temp", exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

tools = []

if device == "cuda":
    # Load GPU-based tools

    # NV-Reason-CXR - Disabled due to transformers version conflict
    # Requires transformers 4.56.0, but MAIRA-2 (grounding) requires <4.52
    # Prioritizing grounding tool for visualization
    # try:
    #     from medrax.tools import NVReasonCXRTool
    #     nv_reason_tool = NVReasonCXRTool(
    #         device=device,
    #         load_in_4bit=False
    #     )
    #     tools.append(nv_reason_tool)
    #     print("βœ“ Loaded NV-Reason-CXR tool")
    # except Exception as e:
    #     print(f"βœ— Failed to load NV-Reason-CXR tool: {e}")
    print("βŠ— NV-Reason-CXR tool disabled (transformers version conflict)")

    # MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM)
    try:
        from medrax.tools import XRayPhraseGroundingTool
        grounding_tool = XRayPhraseGroundingTool(
            device=device,
            temp_dir="temp",
            load_in_4bit=False  # Quantization disabled due to compatibility
        )
        tools.append(grounding_tool)
        print("βœ“ Loaded grounding tool")
    except Exception as e:
        print(f"βœ— Failed to load grounding tool: {e}")

    try:
        from medrax.tools.vqa import CheXagentXRayVQATool
        vqa_tool = CheXagentXRayVQATool(
            device=device,
            temp_dir="temp",
            load_in_4bit=True
        )
        tools.append(vqa_tool)
        print("βœ“ Loaded VQA tool")
    except Exception as e:
        print(f"βœ— Failed to load VQA tool: {e}")

    try:
        from medrax.tools.classification import TorchXRayVisionClassifierTool
        classification_tool = TorchXRayVisionClassifierTool(
            device=device
        )
        tools.append(classification_tool)
        print("βœ“ Loaded classification tool")
    except Exception as e:
        print(f"βœ— Failed to load classification tool: {e}")

    try:
        from medrax.tools.report_generation import ChestXRayReportGeneratorTool
        report_tool = ChestXRayReportGeneratorTool(
            device=device
        )
        tools.append(report_tool)
        print("βœ“ Loaded report generation tool")
    except Exception as e:
        print(f"βœ— Failed to load report generation tool: {e}")

    try:
        from medrax.tools.segmentation import ChestXRaySegmentationTool
        segmentation_tool = ChestXRaySegmentationTool(
            device=device,
            temp_dir="temp"
        )
        tools.append(segmentation_tool)
        print("βœ“ Loaded segmentation tool")
    except Exception as e:
        print(f"βœ— Failed to load segmentation tool: {e}")

# Load non-GPU tools
try:
    from medrax.tools.dicom import DicomProcessorTool
    dicom_tool = DicomProcessorTool(temp_dir="temp")
    tools.append(dicom_tool)
    print("βœ“ Loaded DICOM tool")
except Exception as e:
    print(f"βœ— Failed to load DICOM tool: {e}")

try:
    from medrax.tools.browsing import WebBrowserTool
    browsing_tool = WebBrowserTool()
    tools.append(browsing_tool)
    print("βœ“ Loaded web browsing tool")
except Exception as e:
    print(f"βœ— Failed to load web browsing tool: {e}")

checkpointer = MemorySaver()

llm = ModelFactory.create_model(
    model_name="gemini-2.0-flash",
    temperature=0.7,
    max_tokens=5000
)

prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")

print(f"Tools loaded: {len(tools)}")

import glob

# Store agents for each mode to avoid recreating them
agents_cache = {}

def get_or_create_agent(mode):
    """Get or create an agent for the specified mode."""
    if mode not in agents_cache:
        # Select appropriate prompt based on mode
        if mode == "socratic":
            prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.")
        else:
            prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.")

        # Create agent with specified mode
        agents_cache[mode] = Agent(
            llm,
            tools=tools,
            system_prompt=prompt,
            checkpointer=checkpointer,
            mode=mode  # Pass the mode to the Agent
        )
    return agents_cache[mode]

def chat(message, history, mode, uploaded_image_path=None):
    """Chat function that uses the appropriate agent based on mode."""
    config = {"configurable": {"thread_id": f"thread_{mode}"}}

    # Get or create the appropriate agent
    agent = get_or_create_agent(mode)

    # Handle multimodal input - Gemini 2.0 Flash supports vision
    image_content = None
    current_upload = None  # Track the current uploaded image
    if isinstance(message, dict):
        text = message.get("text", "")
        files = message.get("files", [])

        if files and len(files) > 0:
            image_path = files[0]
            current_upload = image_path  # Store for visualization

            # Check if it's a DICOM file
            is_dicom = image_path.lower().endswith(('.dcm', '.dicom'))

            # Store image path for tools to use
            # LangChain Google GenAI expects images as base64 or PIL
            try:
                if is_dicom:
                    # DICOM files need to be converted first
                    # We'll just pass the path and let the agent handle it
                    text = f"[DICOM file uploaded: {image_path}]\n\n{text}"
                    print(f"DICOM file detected: {image_path}")
                else:
                    # Open and encode image for Gemini
                    with Image.open(image_path) as img:
                        # Convert to RGB if needed
                        if img.mode != "RGB":
                            img = img.convert("RGB")

                        # Resize if too large (max 4096x4096 for Gemini)
                        max_size = 4096
                        if img.width > max_size or img.height > max_size:
                            img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)

                        # Store as bytes for LangChain
                        buffered = BytesIO()
                        img.save(buffered, format="PNG")
                        img_bytes = buffered.getvalue()
                        img_b64 = base64.b64encode(img_bytes).decode()

                        # Create multimodal content for Gemini
                        # Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]
                        image_content = {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{img_b64}"
                            }
                        }

                        # Include image path in text for tools to use
                        text = f"[Image: {image_path}]\n\n{text}"

            except Exception as e:
                print(f"Error processing image: {e}")
                import traceback
                traceback.print_exc()
                text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}"

        message = text

    # Create message content - multimodal if image exists
    if image_content:
        # For Gemini multimodal: pass list of content parts
        user_message = [
            {"type": "text", "text": message},
            image_content
        ]
    else:
        user_message = message

    response = agent.workflow.invoke(
        {"messages": [("user", user_message)]},
        config=config
    )

    # Extract text response
    assistant_message = response["messages"][-1].content

    # Check for visualization images (grounding or segmentation)
    viz_image = None
    viz_files = glob.glob("temp/grounding_*.png") + glob.glob("temp/segmentation_*.png")
    if viz_files:
        # Get the most recent visualization
        viz_files.sort(key=os.path.getmtime, reverse=True)
        latest_viz = viz_files[0]
        # Return the file path directly - Gradio can handle it
        if os.path.exists(latest_viz):
            viz_image = latest_viz

            # If assistant message is empty but we have a visualization, provide a default message
            if not assistant_message or assistant_message.strip() == "":
                if "segmentation" in latest_viz:
                    assistant_message = "I've segmented the requested anatomical structures. The visualization is shown on the right."
                elif "grounding" in latest_viz:
                    assistant_message = "I've highlighted the requested regions. The visualization is shown on the right."
    else:
        # No visualization generated - show the uploaded image as reference
        if current_upload:
            viz_image = current_upload

    # Final fallback for empty messages
    if not assistant_message or assistant_message.strip() == "":
        assistant_message = "I processed your request. Please check the visualization panel on the right."

    return assistant_message, viz_image

# Custom interface with image output
with gr.Blocks() as demo:
    gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded | **Orchestrator:** Gemini 2.0 Flash")

    # Add mode toggle at the top
    with gr.Row():
        mode_toggle = gr.Radio(
            ["Assistant Mode", "Tutor Mode"],
            value="Assistant Mode",
            label="Interaction Mode",
            info="Assistant Mode: Direct answers | Tutor Mode: Socratic guidance through questions"
        )

    # Side-by-side layout: Chat on left, Visualization on right
    with gr.Row():
        # Left column: Chat interface (unified chat + message box)
        with gr.Column(scale=2):
            # Chatbot with reduced height to leave room for message box
            try:
                chatbot = gr.Chatbot(type="messages", height=520, show_label=False)
            except TypeError:
                # Fallback for older Gradio versions
                chatbot = gr.Chatbot(height=520, show_label=False)

            # Message box directly below chatbot (no gap)
            msg = gr.MultimodalTextbox(
                placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...",
                file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"],
                show_label=False,
                container=False
            )

        # Right column: Visualization
        with gr.Column(scale=1):
            viz_output = gr.Image(label="Visualization", height=600)

    # State to persist the uploaded image across interactions
    current_image_state = gr.State(None)

    def respond(message, chat_history, mode_selection, current_image):
        # Convert mode selection to internal mode string
        mode = "socratic" if mode_selection == "Tutor Mode" else "assistant"

        # Track uploaded image - update state when new image is uploaded
        if isinstance(message, dict) and message.get("files"):
            current_image = message["files"][0]

        # Get response and visualization with mode
        bot_message, viz_image = chat(message, chat_history, mode)

        # Initialize chat history if None
        if chat_history is None:
            chat_history = []

        # Extract text from multimodal message
        if isinstance(message, dict):
            user_text = message.get("text", "")
            if message.get("files"):
                user_text = f"[Image uploaded] {user_text}"
        else:
            user_text = message

        # Add BOTH user message and assistant response to create proper chat flow
        chat_history.append({"role": "user", "content": user_text})
        chat_history.append({"role": "assistant", "content": bot_message})

        # If no visualization was generated, show the current uploaded image as reference
        if viz_image is None and current_image is not None:
            viz_image = current_image

        return "", chat_history, viz_image, current_image

    msg.submit(respond, [msg, chatbot, mode_toggle, current_image_state], [msg, chatbot, viz_output, current_image_state])

    gr.Examples(
        examples=[
            [{"text": "What do you see in this X-ray?", "files": []}],
            [{"text": "Can you show me where exactly using grounding?", "files": []}],
        ],
        inputs=msg,
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)