Spaces:
Sleeping
Sleeping
File size: 13,242 Bytes
b6fb0da 9dc78d7 b6fb0da 1f83b1b 3d05726 37bdbfa 27f1dea a782449 b6fb0da 1f83b1b b8918e3 1f83b1b b8918e3 55abc5d 1f83b1b b8918e3 55abc5d 1f83b1b f372684 1f83b1b b8918e3 1f83b1b b8918e3 1f83b1b b6fb0da 8bd044e b6fb0da 86d018b 11b0dac b2960ee 11b0dac b6fb0da 9dc78d7 b2960ee b6fb0da 9dc78d7 b2960ee 3a18164 9dc78d7 3a18164 9dc78d7 3a18164 9dc78d7 3a18164 9dc78d7 b6fb0da 9dc78d7 b6fb0da 9dc78d7 b6fb0da 86d018b 7d28f28 879e071 7d28f28 86d018b a5cf627 86d018b d2991ca b2960ee d2991ca 879e071 86d018b 879e071 8bd044e b6fb0da 11b0dac 227178d 30b53f2 227178d 30b53f2 227178d 30b53f2 227178d 30b53f2 227178d 30b53f2 227178d 30b53f2 227178d 30b53f2 879e071 8edb8da 11b0dac 8edb8da 11b0dac 879e071 e7caa1e cd8c9a2 4167f4b 79ed555 879e071 8edb8da 879e071 8edb8da 879e071 b6fb0da 0b51074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
"""
MedRax2 Gradio Interface for Hugging Face Spaces
Simple standalone version for deployment
"""
import os
os.environ["OMP_NUM_THREADS"] = "1"
import warnings
warnings.filterwarnings("ignore")
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("β Logged in to HuggingFace")
import gradio as gr
from dotenv import load_dotenv
import torch
from PIL import Image
import base64
from io import BytesIO
load_dotenv()
from langgraph.checkpoint.memory import MemorySaver
from medrax.models import ModelFactory
from medrax.agent import Agent
from medrax.utils import load_prompts_from_file
os.makedirs("temp", exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
tools = []
if device == "cuda":
# Load GPU-based tools
# NV-Reason-CXR - Disabled due to transformers version conflict
# Requires transformers 4.56.0, but MAIRA-2 (grounding) requires <4.52
# Prioritizing grounding tool for visualization
# try:
# from medrax.tools import NVReasonCXRTool
# nv_reason_tool = NVReasonCXRTool(
# device=device,
# load_in_4bit=False
# )
# tools.append(nv_reason_tool)
# print("β Loaded NV-Reason-CXR tool")
# except Exception as e:
# print(f"β Failed to load NV-Reason-CXR tool: {e}")
print("β NV-Reason-CXR tool disabled (transformers version conflict)")
# MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM)
try:
from medrax.tools import XRayPhraseGroundingTool
grounding_tool = XRayPhraseGroundingTool(
device=device,
temp_dir="temp",
load_in_4bit=False # Quantization disabled due to compatibility
)
tools.append(grounding_tool)
print("β Loaded grounding tool")
except Exception as e:
print(f"β Failed to load grounding tool: {e}")
try:
from medrax.tools.vqa import CheXagentXRayVQATool
vqa_tool = CheXagentXRayVQATool(
device=device,
temp_dir="temp",
load_in_4bit=True
)
tools.append(vqa_tool)
print("β Loaded VQA tool")
except Exception as e:
print(f"β Failed to load VQA tool: {e}")
try:
from medrax.tools.classification import TorchXRayVisionClassifierTool
classification_tool = TorchXRayVisionClassifierTool(
device=device
)
tools.append(classification_tool)
print("β Loaded classification tool")
except Exception as e:
print(f"β Failed to load classification tool: {e}")
try:
from medrax.tools.report_generation import ChestXRayReportGeneratorTool
report_tool = ChestXRayReportGeneratorTool(
device=device
)
tools.append(report_tool)
print("β Loaded report generation tool")
except Exception as e:
print(f"β Failed to load report generation tool: {e}")
try:
from medrax.tools.segmentation import ChestXRaySegmentationTool
segmentation_tool = ChestXRaySegmentationTool(
device=device,
temp_dir="temp"
)
tools.append(segmentation_tool)
print("β Loaded segmentation tool")
except Exception as e:
print(f"β Failed to load segmentation tool: {e}")
# Load non-GPU tools
try:
from medrax.tools.dicom import DicomProcessorTool
dicom_tool = DicomProcessorTool(temp_dir="temp")
tools.append(dicom_tool)
print("β Loaded DICOM tool")
except Exception as e:
print(f"β Failed to load DICOM tool: {e}")
try:
from medrax.tools.browsing import WebBrowserTool
browsing_tool = WebBrowserTool()
tools.append(browsing_tool)
print("β Loaded web browsing tool")
except Exception as e:
print(f"β Failed to load web browsing tool: {e}")
checkpointer = MemorySaver()
llm = ModelFactory.create_model(
model_name="gemini-2.0-flash",
temperature=0.7,
max_tokens=5000
)
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
print(f"Tools loaded: {len(tools)}")
import glob
# Store agents for each mode to avoid recreating them
agents_cache = {}
def get_or_create_agent(mode):
"""Get or create an agent for the specified mode."""
if mode not in agents_cache:
# Select appropriate prompt based on mode
if mode == "socratic":
prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.")
else:
prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.")
# Create agent with specified mode
agents_cache[mode] = Agent(
llm,
tools=tools,
system_prompt=prompt,
checkpointer=checkpointer,
mode=mode # Pass the mode to the Agent
)
return agents_cache[mode]
def chat(message, history, mode, uploaded_image_path=None):
"""Chat function that uses the appropriate agent based on mode."""
config = {"configurable": {"thread_id": f"thread_{mode}"}}
# Get or create the appropriate agent
agent = get_or_create_agent(mode)
# Handle multimodal input - Gemini 2.0 Flash supports vision
image_content = None
current_upload = None # Track the current uploaded image
if isinstance(message, dict):
text = message.get("text", "")
files = message.get("files", [])
if files and len(files) > 0:
image_path = files[0]
current_upload = image_path # Store for visualization
# Check if it's a DICOM file
is_dicom = image_path.lower().endswith(('.dcm', '.dicom'))
# Store image path for tools to use
# LangChain Google GenAI expects images as base64 or PIL
try:
if is_dicom:
# DICOM files need to be converted first
# We'll just pass the path and let the agent handle it
text = f"[DICOM file uploaded: {image_path}]\n\n{text}"
print(f"DICOM file detected: {image_path}")
else:
# Open and encode image for Gemini
with Image.open(image_path) as img:
# Convert to RGB if needed
if img.mode != "RGB":
img = img.convert("RGB")
# Resize if too large (max 4096x4096 for Gemini)
max_size = 4096
if img.width > max_size or img.height > max_size:
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Store as bytes for LangChain
buffered = BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_b64 = base64.b64encode(img_bytes).decode()
# Create multimodal content for Gemini
# Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]
image_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_b64}"
}
}
# Include image path in text for tools to use
text = f"[Image: {image_path}]\n\n{text}"
except Exception as e:
print(f"Error processing image: {e}")
import traceback
traceback.print_exc()
text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}"
message = text
# Create message content - multimodal if image exists
if image_content:
# For Gemini multimodal: pass list of content parts
user_message = [
{"type": "text", "text": message},
image_content
]
else:
user_message = message
response = agent.workflow.invoke(
{"messages": [("user", user_message)]},
config=config
)
# Extract text response
assistant_message = response["messages"][-1].content
# Check for visualization images (grounding or segmentation)
viz_image = None
viz_files = glob.glob("temp/grounding_*.png") + glob.glob("temp/segmentation_*.png")
if viz_files:
# Get the most recent visualization
viz_files.sort(key=os.path.getmtime, reverse=True)
latest_viz = viz_files[0]
# Return the file path directly - Gradio can handle it
if os.path.exists(latest_viz):
viz_image = latest_viz
# If assistant message is empty but we have a visualization, provide a default message
if not assistant_message or assistant_message.strip() == "":
if "segmentation" in latest_viz:
assistant_message = "I've segmented the requested anatomical structures. The visualization is shown on the right."
elif "grounding" in latest_viz:
assistant_message = "I've highlighted the requested regions. The visualization is shown on the right."
else:
# No visualization generated - show the uploaded image as reference
if current_upload:
viz_image = current_upload
# Final fallback for empty messages
if not assistant_message or assistant_message.strip() == "":
assistant_message = "I processed your request. Please check the visualization panel on the right."
return assistant_message, viz_image
# Custom interface with image output
with gr.Blocks() as demo:
gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded | **Orchestrator:** Gemini 2.0 Flash")
# Add mode toggle at the top
with gr.Row():
mode_toggle = gr.Radio(
["Assistant Mode", "Tutor Mode"],
value="Assistant Mode",
label="Interaction Mode",
info="Assistant Mode: Direct answers | Tutor Mode: Socratic guidance through questions"
)
# Side-by-side layout: Chat on left, Visualization on right
with gr.Row():
# Left column: Chat interface (unified chat + message box)
with gr.Column(scale=2):
# Chatbot with reduced height to leave room for message box
try:
chatbot = gr.Chatbot(type="messages", height=520, show_label=False)
except TypeError:
# Fallback for older Gradio versions
chatbot = gr.Chatbot(height=520, show_label=False)
# Message box directly below chatbot (no gap)
msg = gr.MultimodalTextbox(
placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...",
file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"],
show_label=False,
container=False
)
# Right column: Visualization
with gr.Column(scale=1):
viz_output = gr.Image(label="Visualization", height=600)
# State to persist the uploaded image across interactions
current_image_state = gr.State(None)
def respond(message, chat_history, mode_selection, current_image):
# Convert mode selection to internal mode string
mode = "socratic" if mode_selection == "Tutor Mode" else "assistant"
# Track uploaded image - update state when new image is uploaded
if isinstance(message, dict) and message.get("files"):
current_image = message["files"][0]
# Get response and visualization with mode
bot_message, viz_image = chat(message, chat_history, mode)
# Initialize chat history if None
if chat_history is None:
chat_history = []
# Extract text from multimodal message
if isinstance(message, dict):
user_text = message.get("text", "")
if message.get("files"):
user_text = f"[Image uploaded] {user_text}"
else:
user_text = message
# Add BOTH user message and assistant response to create proper chat flow
chat_history.append({"role": "user", "content": user_text})
chat_history.append({"role": "assistant", "content": bot_message})
# If no visualization was generated, show the current uploaded image as reference
if viz_image is None and current_image is not None:
viz_image = current_image
return "", chat_history, viz_image, current_image
msg.submit(respond, [msg, chatbot, mode_toggle, current_image_state], [msg, chatbot, viz_output, current_image_state])
gr.Examples(
examples=[
[{"text": "What do you see in this X-ray?", "files": []}],
[{"text": "Can you show me where exactly using grounding?", "files": []}],
],
inputs=msg,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|