--- license: apache-2.0 language: - en base_model: - openai/clip-vit-large-patch14 pipeline_tag: image-to-text tags: - medical - retinal --- # RetinalGPT: Large Language-and-Vision Assistant for Retinal Health 👁️ **RetinalGPT** is a specialized multimodal vision-language model (VLM) based on the **LLaVA-v1.5** architecture. It is specifically engineered for the high-precision domain of **ophthalmology**, with a focus on interpreting retinal fundus photography and Optical Coherence Tomography (OCT) scans. --- ## 📌 Model Summary RetinalGPT bridges the gap between general-purpose VLMs and specialized ophthalmic diagnostics. By fine-tuning on a curated corpus of retinal image-text pairs, the model demonstrates advanced capabilities in identifying pathologies such as Diabetic Retinopathy (DR), Glaucoma, and Age-related Macular Degeneration (AMD). - **Base LLM:** Llama-7b - **Vision Tower:** CLIP-ViT-L-14-336px - **Connector:** MLP Projection Layer - **Domain:** Ophthalmology / Retinal Imaging --- ## 🚀 Key Capabilities RetinalGPT is trained to perform complex visual reasoning tasks including: * **Automated Screening:** Grading Diabetic Retinopathy severity (Stage 0-4). * **Lesion Characterization:** Identifying and describing microaneurysms, hemorrhages, and exudates. * **Anatomical Mapping:** Precise description of the optic disc, cup-to-disc ratio, and foveal reflex. * **Clinical QA:** Engaging in multi-turn dialogues about specific clinical findings in a retinal scan. --- ## 💻 How to Use RetinalGPT follows the standard LLaVA inference pipeline. You will need the `llava` library installed. ### Usage ```Python from llava.model.builder import load_pretrained_model from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from PIL import Image import torch model_path = "your-username/retinalgpt" model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( model_path=model_path, model_base=None, model_name=model_name ) # Prepare Image image = Image.open("fundus_sample.jpg") image_tensor = image_processor.preprocess(image, return_tensors='pt')['images'].half().cuda() prompt = "Can you describe this image?" input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() # Generate Response with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=512, use_cache=True ) print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) ```