Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import requests | |
| import tempfile | |
| from pathlib import Path | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| _model_cache = {} | |
| def load_model_and_processor(hf_token: str): | |
| """ | |
| Loads the MAIRA-2 model and processor from Hugging Face using the provided token. | |
| The loaded objects are cached keyed by the token. | |
| """ | |
| if hf_token in _model_cache: | |
| return _model_cache[hf_token] | |
| device = torch.device("cpu") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/maira-2", | |
| trust_remote_code=True, | |
| use_auth_token=hf_token | |
| ) | |
| processor = AutoProcessor.from_pretrained( | |
| "microsoft/maira-2", | |
| trust_remote_code=True, | |
| use_auth_token=hf_token | |
| ) | |
| model.eval() | |
| model.to(device) | |
| _model_cache[hf_token] = (model, processor) | |
| return model, processor | |
| def get_sample_data() -> dict: | |
| """ | |
| Downloads sample chest X-ray images and associated data. | |
| """ | |
| frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png" | |
| lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png" | |
| def download_and_open(url: str) -> Image.Image: | |
| response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True) | |
| return Image.open(response.raw).convert("RGB") | |
| frontal = download_and_open(frontal_image_url) | |
| lateral = download_and_open(lateral_image_url) | |
| return { | |
| "frontal": frontal, | |
| "lateral": lateral, | |
| "indication": "Dyspnea.", | |
| "technique": "PA and lateral views of the chest.", | |
| "comparison": "None.", | |
| "phrase": "Pleural effusion." | |
| } | |
| def generate_report(hf_token, frontal, lateral, indication, technique, comparison, use_grounding): | |
| """ | |
| Generates a radiology report using the MAIRA-2 model. | |
| If any image/text input is missing, sample data is used. | |
| """ | |
| try: | |
| model, processor = load_model_and_processor(hf_token) | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| device = torch.device("cpu") | |
| sample = get_sample_data() | |
| if frontal is None: | |
| frontal = sample["frontal"] | |
| if lateral is None: | |
| lateral = sample["lateral"] | |
| if not indication: | |
| indication = sample["indication"] | |
| if not technique: | |
| technique = sample["technique"] | |
| if not comparison: | |
| comparison = sample["comparison"] | |
| processed_inputs = processor.format_and_preprocess_reporting_input( | |
| current_frontal=frontal, | |
| current_lateral=lateral, | |
| prior_frontal=None, # No prior study is used in this demo. | |
| indication=indication, | |
| technique=technique, | |
| comparison=comparison, | |
| prior_report=None, | |
| return_tensors="pt", | |
| get_grounding=use_grounding, | |
| ) | |
| # Move all tensors to the CPU | |
| processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} | |
| # Remove keys containing "image_sizes" to prevent unexpected keyword errors. | |
| processed_inputs = dict(processed_inputs) | |
| keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] | |
| for key in keys_to_remove: | |
| processed_inputs.pop(key, None) | |
| max_tokens = 450 if use_grounding else 300 | |
| with torch.no_grad(): | |
| output_decoding = model.generate( | |
| **processed_inputs, | |
| max_new_tokens=max_tokens, | |
| use_cache=True, | |
| ) | |
| prompt_length = processed_inputs["input_ids"].shape[-1] | |
| decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) | |
| decoded_text = decoded_text.lstrip() # Remove any leading whitespace | |
| prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) | |
| return prediction | |
| def run_phrase_grounding(hf_token, frontal, phrase): | |
| """ | |
| Runs phrase grounding using the MAIRA-2 model. | |
| If image or phrase is missing, sample data is used. | |
| """ | |
| try: | |
| model, processor = load_model_and_processor(hf_token) | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| device = torch.device("cpu") | |
| sample = get_sample_data() | |
| if frontal is None: | |
| frontal = sample["frontal"] | |
| if not phrase: | |
| phrase = sample["phrase"] | |
| processed_inputs = processor.format_and_preprocess_phrase_grounding_input( | |
| frontal_image=frontal, | |
| phrase=phrase, | |
| return_tensors="pt", | |
| ) | |
| processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} | |
| # Remove keys containing "image_sizes" to prevent unexpected keyword errors. | |
| processed_inputs = dict(processed_inputs) | |
| keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] | |
| for key in keys_to_remove: | |
| processed_inputs.pop(key, None) | |
| with torch.no_grad(): | |
| output_decoding = model.generate( | |
| **processed_inputs, | |
| max_new_tokens=150, | |
| use_cache=True, | |
| ) | |
| prompt_length = processed_inputs["input_ids"].shape[-1] | |
| decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) | |
| prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) | |
| return prediction | |
| def login_ui(hf_token): | |
| """Authenticate the user by loading the model.""" | |
| try: | |
| load_model_and_processor(hf_token) | |
| return "🔓 Login successful! You can now use the model." | |
| except Exception as e: | |
| return f"❌ Login failed: {str(e)}" | |
| def generate_report_ui(hf_token, frontal_path, lateral_path, indication, technique, comparison, | |
| prior_frontal_path, prior_lateral_path, prior_report, grounding): | |
| """ | |
| Wrapper for generate_report that accepts file paths (from the UI) for images. | |
| Prior study fields are ignored. | |
| """ | |
| try: | |
| frontal = Image.open(frontal_path) if frontal_path else None | |
| lateral = Image.open(lateral_path) if lateral_path else None | |
| except Exception as e: | |
| return f"❌ Error loading images: {str(e)}" | |
| return generate_report(hf_token, frontal, lateral, indication, technique, comparison, grounding) | |
| def run_phrase_grounding_ui(hf_token, frontal_path, phrase): | |
| """ | |
| Wrapper for run_phrase_grounding that accepts a file path for the frontal image. | |
| """ | |
| try: | |
| frontal = Image.open(frontal_path) if frontal_path else None | |
| except Exception as e: | |
| return f"❌ Error loading image: {str(e)}" | |
| return run_phrase_grounding(hf_token, frontal, phrase) | |
| def save_temp_image(img: Image.Image) -> str: | |
| """Save a PIL image to a temporary file and return the file path.""" | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| img.save(temp_file.name) | |
| return temp_file.name | |
| def load_sample_findings(): | |
| """ | |
| Loads sample data for the report generation tab. | |
| Returns file paths for current study images, sample text fields, and dummy values for prior study. | |
| """ | |
| sample = get_sample_data() | |
| return [ | |
| save_temp_image(sample["frontal"]), # frontal image file path | |
| save_temp_image(sample["lateral"]), # lateral image file path | |
| sample["indication"], | |
| sample["technique"], | |
| sample["comparison"], | |
| None, # prior frontal (not used) | |
| None, # prior lateral (not used) | |
| None, # prior report (not used) | |
| False # grounding checkbox default | |
| ] | |
| def load_sample_phrase(): | |
| """ | |
| Loads sample data for the phrase grounding tab. | |
| Returns file path for the frontal image and a sample phrase. | |
| """ | |
| sample = get_sample_data() | |
| return [save_temp_image(sample["frontal"]), sample["phrase"]] | |
| with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo: | |
| gr.Markdown( | |
| """ | |
| # MAIRA-2 Medical Assistant | |
| **Authentication required** - You need a Hugging Face account and access token to use this model. | |
| 1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) | |
| 2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2) | |
| 3. Paste your token below to begin | |
| """ | |
| ) | |
| with gr.Row(): | |
| hf_token = gr.Textbox( | |
| label="Hugging Face Token", | |
| placeholder="hf_xxxxxxxxxxxxxxxxxxxx", | |
| type="password" | |
| ) | |
| login_btn = gr.Button("Authenticate") | |
| login_status = gr.Textbox(label="Authentication Status", interactive=False) | |
| login_btn.click( | |
| login_ui, | |
| inputs=hf_token, | |
| outputs=login_status | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Report Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Current Study") | |
| frontal = gr.Image(label="Frontal View", type="filepath") | |
| lateral = gr.Image(label="Lateral View", type="filepath") | |
| indication = gr.Textbox(label="Clinical Indication") | |
| technique = gr.Textbox(label="Imaging Technique") | |
| comparison = gr.Textbox(label="Comparison") | |
| gr.Markdown("## Prior Study (Optional)") | |
| prior_frontal = gr.Image(label="Prior Frontal View", type="filepath") | |
| prior_lateral = gr.Image(label="Prior Lateral View", type="filepath") | |
| prior_report = gr.Textbox(label="Prior Report") | |
| grounding = gr.Checkbox(label="Include Grounding") | |
| sample_btn = gr.Button("Load Sample Data") | |
| with gr.Column(): | |
| report_output = gr.Textbox(label="Generated Report", lines=10) | |
| generate_btn = gr.Button("Generate Report") | |
| sample_btn.click( | |
| load_sample_findings, | |
| outputs=[frontal, lateral, indication, technique, comparison, | |
| prior_frontal, prior_lateral, prior_report, grounding] | |
| ) | |
| generate_btn.click( | |
| generate_report_ui, | |
| inputs=[hf_token, frontal, lateral, indication, technique, comparison, | |
| prior_frontal, prior_lateral, prior_report, grounding], | |
| outputs=report_output | |
| ) | |
| with gr.Tab("Phrase Grounding"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| pg_frontal = gr.Image(label="Frontal View", type="filepath") | |
| phrase = gr.Textbox(label="Phrase to Ground") | |
| pg_sample_btn = gr.Button("Load Sample Data") | |
| with gr.Column(): | |
| pg_output = gr.Textbox(label="Grounding Result", lines=3) | |
| pg_btn = gr.Button("Find Phrase") | |
| pg_sample_btn.click( | |
| load_sample_phrase, | |
| outputs=[pg_frontal, phrase] | |
| ) | |
| pg_btn.click( | |
| run_phrase_grounding_ui, | |
| inputs=[hf_token, pg_frontal, phrase], | |
| outputs=pg_output | |
| ) | |
| demo.launch() | |