MedVersa / README.md
Hongyu Zhou
12.18
e74cad8
---
title: hyzhouMedVersa
sdk: gradio
sdk_version: 4.24.0
---
# MedVersa: A Generalist Learner for Multifaceted Medical Image Interpretation
A model implementation for our paper [A Generalist Learner for Multifaceted Medical Image Interpretation](https://arxiv.org/abs/2405.07988).
MedVersa is a compound medical AI system that coordinates multimodal inputs, orchestrates models and tools for various medical imaging tasks, and generates multimodal outputs.
## Installation
### Prerequisites
- [Python](https://www.python.org/)
- [NVIDIA CUDA-compatible GPU](https://developer.nvidia.com/cuda-gpus)
- [Miniconda](https://docs.anaconda.com/free/miniconda/index.html) or [Anaconda](https://www.anaconda.com/)
### Environment Setup
1. Create and activate the conda environment:
```bash
# For NVIDIA A100 GPUs
conda env create -f environment.yml
conda activate medversa
# For NVIDIA H100 GPUs (CUDA 11.8)
conda env create -f environment_cu118.yml
conda activate medversa
```
### Troubleshooting
If you encounter dependency issues, try these solutions:
```bash
# Fix OpenCV issues
pip install opencv-contrib-python
# Fix incompatible torchvision version
pip install torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
```
## Usage
### Basic Inference
```python
from utils import *
from torch import cuda
# Initialize model
device = 'cuda' if cuda.is_available() else 'cpu'
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
# Define example input
example = {
'images': ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
'context': "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
'prompt': "How would you characterize the findings from <img0>?",
'modality': "cxr",
'task': "report generation"
}
# Configure generation parameters
params = {
'num_beams': 1,
'do_sample': True,
'min_length': 1,
'top_p': 0.9,
'repetition_penalty': 1,
'length_penalty': 1,
'temperature': 0.1
}
# Generate predictions
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(
model,
example['images'],
example['context'],
example['prompt'],
example['modality'],
example['task'],
**params,
device,
)
print(output_text)
```
For more detailed examples and usage scenarios, see `inference.py`.
## Prompts
More prompts can be found in `medomni/datasets/prompts.json`.