| | --- |
| | license: apache-2.0 |
| | language: |
| | - en |
| | base_model: |
| | - openai/clip-vit-large-patch14 |
| | pipeline_tag: image-to-text |
| | tags: |
| | - medical |
| | - retinal |
| | --- |
| | |
| |
|
| | # RetinalGPT: Large Language-and-Vision Assistant for Retinal Health ποΈ |
| |
|
| |
|
| | **RetinalGPT** is a specialized multimodal vision-language model (VLM) based on the **LLaVA-v1.5** architecture. It is specifically engineered for the high-precision domain of **ophthalmology**, with a focus on interpreting retinal fundus photography and Optical Coherence Tomography (OCT) scans. |
| |
|
| | --- |
| |
|
| | ## π Model Summary |
| |
|
| | RetinalGPT bridges the gap between general-purpose VLMs and specialized ophthalmic diagnostics. By fine-tuning on a curated corpus of retinal image-text pairs, the model demonstrates advanced capabilities in identifying pathologies such as Diabetic Retinopathy (DR), Glaucoma, and Age-related Macular Degeneration (AMD). |
| |
|
| | - **Base LLM:** Llama-7b |
| | - **Vision Tower:** CLIP-ViT-L-14-336px |
| | - **Connector:** MLP Projection Layer |
| | - **Domain:** Ophthalmology / Retinal Imaging |
| |
|
| | --- |
| |
|
| | ## π Key Capabilities |
| |
|
| | RetinalGPT is trained to perform complex visual reasoning tasks including: |
| | * **Automated Screening:** Grading Diabetic Retinopathy severity (Stage 0-4). |
| | * **Lesion Characterization:** Identifying and describing microaneurysms, hemorrhages, and exudates. |
| | * **Anatomical Mapping:** Precise description of the optic disc, cup-to-disc ratio, and foveal reflex. |
| | * **Clinical QA:** Engaging in multi-turn dialogues about specific clinical findings in a retinal scan. |
| |
|
| | --- |
| |
|
| | ## π» How to Use |
| |
|
| | RetinalGPT follows the standard LLaVA inference pipeline. You will need the `llava` library installed. |
| |
|
| | ### Usage |
| |
|
| | ```Python |
| | from llava.model.builder import load_pretrained_model |
| | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path |
| | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
| | from PIL import Image |
| | import torch |
| | |
| | model_path = "your-username/retinalgpt" |
| | model_name = get_model_name_from_path(model_path) |
| | |
| | tokenizer, model, image_processor, context_len = load_pretrained_model( |
| | model_path=model_path, |
| | model_base=None, |
| | model_name=model_name |
| | ) |
| | |
| | # Prepare Image |
| | image = Image.open("fundus_sample.jpg") |
| | image_tensor = image_processor.preprocess(image, return_tensors='pt')['images'].half().cuda() |
| | |
| | prompt = "Can you describe this image?" |
| | |
| | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
| | |
| | # Generate Response |
| | with torch.inference_mode(): |
| | output_ids = model.generate( |
| | input_ids, |
| | images=image_tensor, |
| | do_sample=True, |
| | temperature=0.2, |
| | max_new_tokens=512, |
| | use_cache=True |
| | ) |
| | print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) |
| | ``` |
| |
|