AG-KD / README.md
nielsr's picture
nielsr HF Staff
Add metadata
67385ea verified
|
raw
history blame
4.24 kB
metadata
library_name: transformers
pipeline_tag: image-text-to-text
base_model:
  - microsoft/Florence-2-base-ft
license: apache-2.0
tags:
  - vision-language
  - abnormality-grounding
  - medical-imaging
  - knowledge-distillation
  - multimodal
model-index:
  - name: AG-KD
    results:
      - task:
          type: Abnormality Grounding
          name: Grounding
        metrics:
          - name: none
            type: none
            value: null

πŸš€ Enhancing Abnormality Grounding for Vision-Language Models with Knowledge Descriptions

This repository provides the code and model weights for our paper:
Enhancing Abnormality Grounding for Vision-Language Models with Knowledge Descriptions

πŸ§ͺ Explore our live demo on Hugging Face Spaces to see the model in action!

πŸ“Œ Overview

AG-KD (Abnormality Grounding with Knowledge Descriptions) is a compact 0.23B vision-language model designed for abnormality grounding in medical images. Despite its small size, it delivers performance comparable to 7B state-of-the-art medical VLMs. Our approach integrates structured knowledge descriptions into prompts, enhancing the model’s ability to localize medical abnormalities in images.

πŸ’» How to Use

Simple Example

For detailed examples, visit: AG-KD GitHub Repository


import torch
import requests
from io import BytesIO
from PIL import Image
import numpy as np
import albumentations as A
from transformers import AutoModelForCausalLM, AutoProcessor


def apply_transform(image, size=512):
    transform = A.Compose([
        A.LongestMaxSize(max_size=size),
        A.PadIfNeeded(min_height=size, min_width=size, border_mode=0, value=(0,0,0)),
        A.Resize(height=size, width=size)
    ])
    return transform(image=np.array(image))["image"]

def run_simple(image_url, target, definition, model, processor, device):
    prompt = f"<CAPTION_TO_PHRASE_GROUNDING>Locate the phrases in the caption: {target} means {definition}."
    response = requests.get(image_url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    np_image = apply_transform(image)

    inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(device)

    outputs = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3,
        output_scores=True,
        return_dict_in_generate=True
    )

    transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False)
    generated_text = processor.batch_decode(outputs.sequences, skip_special_tokens=False)[0]

    output_len = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
    length_penalty = model.generation_config.length_penalty
    score = transition_scores.cpu().sum(axis=1) / (output_len**length_penalty)
    prob = np.exp(score.cpu().numpy())

    print(f"\n[IMAGE URL] {image_url}")
    print(f"[TARGET] {target}")
    print(f"[PROBABILITY] {prob[0] * 100:.2f}%")
    print(f"[GENERATED TEXT]\n{generated_text}")

if __name__ == "__main__":
    image_url = "https://huggingface.co/spaces/RioJune/AG-KD/resolve/main/examples/f1eb2216d773ced6330b1f31e18f04f8.png"
    target = "pulmonary fibrosis"
    definition = "Scarring of the lung tissue creating a dense fibrous appearance."

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = "RioJune/AG-KD"

    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

    run_simple(image_url, target, definition, model, processor, device)

πŸ“– Citation

If you use our work, please cite:

@article{li2025enhancing,
    title={Enhancing Abnormality Grounding for Vision Language Models with Knowledge Descriptions},
    author={Li, J. and Liu, C. and Bai, W. and Arcucci, R. and Bercea, C. I. and Schnabel, J. A.},
    journal={arXiv preprint arXiv:2503.03278},
    year={2025}
}