Spaces:
Sleeping
Sleeping
| """ | |
| MedRax2 Gradio Interface for Hugging Face Spaces | |
| Simple standalone version for deployment | |
| """ | |
| import os | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from huggingface_hub import login | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| print("✓ Logged in to HuggingFace") | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| import torch | |
| from PIL import Image | |
| import base64 | |
| from io import BytesIO | |
| load_dotenv() | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from medrax.models import ModelFactory | |
| from medrax.agent import Agent | |
| from medrax.utils import load_prompts_from_file | |
| os.makedirs("temp", exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| tools = [] | |
| if device == "cuda": | |
| # Load GPU-based tools | |
| # NV-Reason-CXR - Disabled due to transformers version conflict | |
| # Requires transformers 4.56.0, but MAIRA-2 (grounding) requires <4.52 | |
| # Prioritizing grounding tool for visualization | |
| # try: | |
| # from medrax.tools import NVReasonCXRTool | |
| # nv_reason_tool = NVReasonCXRTool( | |
| # device=device, | |
| # load_in_4bit=False | |
| # ) | |
| # tools.append(nv_reason_tool) | |
| # print("✓ Loaded NV-Reason-CXR tool") | |
| # except Exception as e: | |
| # print(f"✗ Failed to load NV-Reason-CXR tool: {e}") | |
| print("⊗ NV-Reason-CXR tool disabled (transformers version conflict)") | |
| # MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM) | |
| try: | |
| from medrax.tools import XRayPhraseGroundingTool | |
| grounding_tool = XRayPhraseGroundingTool( | |
| device=device, | |
| temp_dir="temp", | |
| load_in_4bit=False # Quantization disabled due to compatibility | |
| ) | |
| tools.append(grounding_tool) | |
| print("✓ Loaded grounding tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load grounding tool: {e}") | |
| try: | |
| from medrax.tools.vqa import CheXagentXRayVQATool | |
| vqa_tool = CheXagentXRayVQATool( | |
| device=device, | |
| temp_dir="temp", | |
| load_in_4bit=True | |
| ) | |
| tools.append(vqa_tool) | |
| print("✓ Loaded VQA tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load VQA tool: {e}") | |
| try: | |
| from medrax.tools.classification import TorchXRayVisionClassifierTool | |
| classification_tool = TorchXRayVisionClassifierTool( | |
| device=device | |
| ) | |
| tools.append(classification_tool) | |
| print("✓ Loaded classification tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load classification tool: {e}") | |
| try: | |
| from medrax.tools.report_generation import ChestXRayReportGeneratorTool | |
| report_tool = ChestXRayReportGeneratorTool( | |
| device=device | |
| ) | |
| tools.append(report_tool) | |
| print("✓ Loaded report generation tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load report generation tool: {e}") | |
| try: | |
| from medrax.tools.segmentation import ChestXRaySegmentationTool | |
| segmentation_tool = ChestXRaySegmentationTool( | |
| device=device, | |
| temp_dir="temp" | |
| ) | |
| tools.append(segmentation_tool) | |
| print("✓ Loaded segmentation tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load segmentation tool: {e}") | |
| # Load non-GPU tools | |
| try: | |
| from medrax.tools.dicom import DicomProcessorTool | |
| dicom_tool = DicomProcessorTool(temp_dir="temp") | |
| tools.append(dicom_tool) | |
| print("✓ Loaded DICOM tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load DICOM tool: {e}") | |
| try: | |
| from medrax.tools.browsing import WebBrowserTool | |
| browsing_tool = WebBrowserTool() | |
| tools.append(browsing_tool) | |
| print("✓ Loaded web browsing tool") | |
| except Exception as e: | |
| print(f"✗ Failed to load web browsing tool: {e}") | |
| checkpointer = MemorySaver() | |
| llm = ModelFactory.create_model( | |
| model_name="gemini-2.0-flash", | |
| temperature=0.7, | |
| max_tokens=5000 | |
| ) | |
| prompts = load_prompts_from_file("medrax/docs/system_prompts.txt") | |
| print(f"Tools loaded: {len(tools)}") | |
| import glob | |
| # Store agents for each mode to avoid recreating them | |
| agents_cache = {} | |
| def get_or_create_agent(mode): | |
| """Get or create an agent for the specified mode.""" | |
| if mode not in agents_cache: | |
| # Select appropriate prompt based on mode | |
| if mode == "socratic": | |
| prompt = prompts.get("SOCRATIC_TUTOR", "You are a Socratic medical educator.") | |
| else: | |
| prompt = prompts.get("MEDICAL_ASSISTANT", "You are a helpful medical imaging assistant.") | |
| # Create agent with specified mode | |
| agents_cache[mode] = Agent( | |
| llm, | |
| tools=tools, | |
| system_prompt=prompt, | |
| checkpointer=checkpointer, | |
| mode=mode # Pass the mode to the Agent | |
| ) | |
| return agents_cache[mode] | |
| def chat(message, history, mode, uploaded_image_path=None): | |
| """Chat function that uses the appropriate agent based on mode.""" | |
| config = {"configurable": {"thread_id": f"thread_{mode}"}} | |
| # Get or create the appropriate agent | |
| agent = get_or_create_agent(mode) | |
| # Handle multimodal input - Gemini 2.0 Flash supports vision | |
| image_content = None | |
| current_upload = None # Track the current uploaded image | |
| if isinstance(message, dict): | |
| text = message.get("text", "") | |
| files = message.get("files", []) | |
| if files and len(files) > 0: | |
| image_path = files[0] | |
| current_upload = image_path # Store for visualization | |
| # Check if it's a DICOM file | |
| is_dicom = image_path.lower().endswith(('.dcm', '.dicom')) | |
| # Store image path for tools to use | |
| # LangChain Google GenAI expects images as base64 or PIL | |
| try: | |
| if is_dicom: | |
| # DICOM files need to be converted first | |
| # We'll just pass the path and let the agent handle it | |
| text = f"[DICOM file uploaded: {image_path}]\n\n{text}" | |
| print(f"DICOM file detected: {image_path}") | |
| else: | |
| # Open and encode image for Gemini | |
| with Image.open(image_path) as img: | |
| # Convert to RGB if needed | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| # Resize if too large (max 4096x4096 for Gemini) | |
| max_size = 4096 | |
| if img.width > max_size or img.height > max_size: | |
| img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| # Store as bytes for LangChain | |
| buffered = BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_bytes = buffered.getvalue() | |
| img_b64 = base64.b64encode(img_bytes).decode() | |
| # Create multimodal content for Gemini | |
| # Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}] | |
| image_content = { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/png;base64,{img_b64}" | |
| } | |
| } | |
| # Include image path in text for tools to use | |
| text = f"[Image: {image_path}]\n\n{text}" | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}" | |
| message = text | |
| # Create message content - multimodal if image exists | |
| if image_content: | |
| # For Gemini multimodal: pass list of content parts | |
| user_message = [ | |
| {"type": "text", "text": message}, | |
| image_content | |
| ] | |
| else: | |
| user_message = message | |
| response = agent.workflow.invoke( | |
| {"messages": [("user", user_message)]}, | |
| config=config | |
| ) | |
| # Extract text response | |
| assistant_message = response["messages"][-1].content | |
| # Check for visualization images (grounding or segmentation) | |
| viz_image = None | |
| viz_files = glob.glob("temp/grounding_*.png") + glob.glob("temp/segmentation_*.png") | |
| if viz_files: | |
| # Get the most recent visualization | |
| viz_files.sort(key=os.path.getmtime, reverse=True) | |
| latest_viz = viz_files[0] | |
| # Return the file path directly - Gradio can handle it | |
| if os.path.exists(latest_viz): | |
| viz_image = latest_viz | |
| # If assistant message is empty but we have a visualization, provide a default message | |
| if not assistant_message or assistant_message.strip() == "": | |
| if "segmentation" in latest_viz: | |
| assistant_message = "I've segmented the requested anatomical structures. The visualization is shown on the right." | |
| elif "grounding" in latest_viz: | |
| assistant_message = "I've highlighted the requested regions. The visualization is shown on the right." | |
| else: | |
| # No visualization generated - show the uploaded image as reference | |
| if current_upload: | |
| viz_image = current_upload | |
| # Final fallback for empty messages | |
| if not assistant_message or assistant_message.strip() == "": | |
| assistant_message = "I processed your request. Please check the visualization panel on the right." | |
| return assistant_message, viz_image | |
| # Custom interface with image output | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded | **Orchestrator:** Gemini 2.0 Flash") | |
| # Add mode toggle at the top | |
| with gr.Row(): | |
| mode_toggle = gr.Radio( | |
| ["Assistant Mode", "Tutor Mode"], | |
| value="Assistant Mode", | |
| label="Interaction Mode", | |
| info="Assistant Mode: Direct answers | Tutor Mode: Socratic guidance through questions" | |
| ) | |
| # Side-by-side layout: Chat on left, Visualization on right | |
| with gr.Row(): | |
| # Left column: Chat interface (unified chat + message box) | |
| with gr.Column(scale=2): | |
| # Chatbot with reduced height to leave room for message box | |
| try: | |
| chatbot = gr.Chatbot(type="messages", height=520, show_label=False) | |
| except TypeError: | |
| # Fallback for older Gradio versions | |
| chatbot = gr.Chatbot(height=520, show_label=False) | |
| # Message box directly below chatbot (no gap) | |
| msg = gr.MultimodalTextbox( | |
| placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...", | |
| file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"], | |
| show_label=False, | |
| container=False | |
| ) | |
| # Right column: Visualization | |
| with gr.Column(scale=1): | |
| viz_output = gr.Image(label="Visualization", height=600) | |
| # State to persist the uploaded image across interactions | |
| current_image_state = gr.State(None) | |
| def respond(message, chat_history, mode_selection, current_image): | |
| # Convert mode selection to internal mode string | |
| mode = "socratic" if mode_selection == "Tutor Mode" else "assistant" | |
| # Track uploaded image - update state when new image is uploaded | |
| if isinstance(message, dict) and message.get("files"): | |
| current_image = message["files"][0] | |
| # Get response and visualization with mode | |
| bot_message, viz_image = chat(message, chat_history, mode) | |
| # Initialize chat history if None | |
| if chat_history is None: | |
| chat_history = [] | |
| # Extract text from multimodal message | |
| if isinstance(message, dict): | |
| user_text = message.get("text", "") | |
| if message.get("files"): | |
| user_text = f"[Image uploaded] {user_text}" | |
| else: | |
| user_text = message | |
| # Add BOTH user message and assistant response to create proper chat flow | |
| chat_history.append({"role": "user", "content": user_text}) | |
| chat_history.append({"role": "assistant", "content": bot_message}) | |
| # If no visualization was generated, show the current uploaded image as reference | |
| if viz_image is None and current_image is not None: | |
| viz_image = current_image | |
| return "", chat_history, viz_image, current_image | |
| msg.submit(respond, [msg, chatbot, mode_toggle, current_image_state], [msg, chatbot, viz_output, current_image_state]) | |
| gr.Examples( | |
| examples=[ | |
| [{"text": "What do you see in this X-ray?", "files": []}], | |
| [{"text": "Can you show me where exactly using grounding?", "files": []}], | |
| ], | |
| inputs=msg, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |