File size: 3,067 Bytes
a38315b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789ad27
 
a38315b
789ad27
a38315b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import io

import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig


def load_image_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()

        image = Image.open(io.BytesIO(response.content))
        return image

    except requests.exceptions.RequestException as e:
        print(f"Error loading image: {e}")
        return None


def do_generate(prompts, images, model, processor, generation_config):
    """The interface for generation

    Args:
        prompts (List[str]): List of prompt texts for entire batch
        images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch
        model (MllmForConditionalGeneration): MllmForConditionalGeneration
        processor (MllmProcessor): MllmProcessor
        generation_config (GenerationConfig): generation configurations

    Returns:
        outputs (List[str]): Generated responses for entire batch
    """

    # image, text processing
    inputs = processor(texts=prompts, images=images)

    # prepare inputs
    inputs = {
        k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items()
    }
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # batch decoding
    with torch.inference_mode():
        res = model.generate(**inputs, generation_config=generation_config)

    # decode tokens
    outputs = processor.batch_decode(res, skip_special_tokens=True)
    return outputs


if __name__ == "__main__":
    # Setup constant
    device = torch.device("cuda")
    dtype = torch.bfloat16
    do_sample = False

    # Load Processor and Model
    processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
    generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS")
    model = AutoModelForCausalLM.from_pretrained(
        "Deepnoid/M4CXR-TNNLS",
        trust_remote_code=True,
        torch_dtype=dtype,
        device_map=device,
    )

    # Prepare images
    images = [
        load_image_from_url(
            "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
        ),
        load_image_from_url(
            "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
        ),
    ]

    # seperate question list
    questions = [
        "radiology image: <image> What is the view of this chest X-ray?",
        "radiology image: <image> Provide a description of the findings in the radiology image.",
    ]

    # build prompts with chat template
    prompts = []
    for question in questions:
        chats = [{"role": "user", "content": question}]
        prompt = processor.apply_chat_template(chats, tokenize=False)
        prompts.append(prompt)

    # Generate responses
    generation_config.do_sample = do_sample
    outputs = do_generate(prompts, images, model, processor, generation_config)
    print(outputs)