|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- vision |
|
|
- text-generation |
|
|
- medical |
|
|
- chest-xray |
|
|
- healthcare |
|
|
- multimodal |
|
|
pipeline_tag: image-to-text |
|
|
--- |
|
|
|
|
|
# 🩺 ChestX – Chest X-ray Report Generation (ViT-GPT2) |
|
|
|
|
|
This model generates **medical diagnostic reports from chest X-ray images**. |
|
|
It was developed for the **TWESD Healthcare AI Competition 2024** as part of my final-year engineering project. |
|
|
|
|
|
The architecture combines a **Vision Transformer (ViT)** for image encoding with **GPT-2** as the language decoder, forming an **encoder–decoder multimodal model**. |
|
|
|
|
|
--- |
|
|
|
|
|
## 📌 Model Description |
|
|
- **Architecture:** VisionEncoderDecoderModel (ViT + GPT-2) |
|
|
- **Input:** Chest X-ray image |
|
|
- **Output:** Text report describing findings |
|
|
- **Framework:** PyTorch + Hugging Face Transformers |
|
|
|
|
|
--- |
|
|
|
|
|
## 💡 Intended Uses & Limitations |
|
|
✅ Intended for: |
|
|
- Research in **medical AI & multimodal learning** |
|
|
- Exploring **vision-to-text generation** |
|
|
- Educational and prototyping purposes |
|
|
|
|
|
⚠️ Limitations: |
|
|
- Not intended for **real clinical diagnosis** |
|
|
- Trained on a limited dataset (IU Chest X-ray), may not generalize to all populations |
|
|
|
|
|
--- |
|
|
|
|
|
## 🛠️ How to Use |
|
|
|
|
|
```python |
|
|
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = VisionEncoderDecoderModel.from_pretrained("Molkaatb/ChestX").to("cuda") |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
|
|
|
|
# Example image |
|
|
image = Image.open("example_xray.png").convert("RGB") |
|
|
inputs = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") |
|
|
|
|
|
# Generate report |
|
|
outputs = model.generate(inputs, max_length=512, num_beams=4) |
|
|
report = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
print(report) |
|
|
|