File size: 6,310 Bytes
895f657
 
 
 
 
 
 
 
 
07794b9
895f657
07794b9
 
ef294a8
07794b9
 
 
f2960e2
895f657
081767b
895f657
 
 
4e42180
895f657
 
 
 
aa123d5
 
 
 
 
895f657
 
07794b9
895f657
 
081767b
 
895f657
 
081767b
895f657
081767b
 
895f657
081767b
07794b9
 
 
895f657
 
081767b
895f657
081767b
 
895f657
081767b
 
07794b9
081767b
895f657
081767b
 
 
 
895f657
081767b
 
 
 
895f657
081767b
 
895f657
 
 
 
 
 
07794b9
c39b808
 
 
 
 
 
 
 
 
 
 
 
895f657
 
3431d22
081767b
07794b9
7e73b68
53eb065
7e580b6
 
07794b9
7e580b6
 
7e73b68
7e580b6
7e73b68
 
07794b9
7e73b68
07794b9
 
 
 
 
 
7e73b68
 
 
7e580b6
07794b9
 
 
 
 
 
 
 
 
 
 
 
 
 
081767b
7e73b68
 
 
 
 
 
 
 
53eb065
7e73b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895f657
 
 
07794b9
895f657
07794b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
from threading import Thread
import re
import time
import torch
import spaces
import math
import os
# from qwen_vl_utils import process_vision_info, fetch_image

# run locally: CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_PORT=7860 MODEL=./model_dir python app.py
# and open http://localhost:7860

# pretrained_model_name_or_path=os.environ.get("MODEL", "amrn/gmdsv5mx3")
# pretrained_model_name_or_path=os.environ.get("MODEL", "amrn/gr1")
pretrained_model_name_or_path=os.environ.get("MODEL", "amrn/mrcxr1")

auth_token = os.environ.get("HF_TOKEN") or True
DEFAULT_PROMPT = "Find abnormalities and support devices."

model = AutoModelForImageTextToText.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    dtype=torch.bfloat16,
    token=auth_token
).eval().to("cuda")


processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path,
    use_fast=True,
  )


@spaces.GPU
def model_inference(
    text, history, image
): 

    print(f"text: {text}")
    print(f"history: {history}")

    if len(text) == 0:
        raise gr.Error("Please input a query.", duration=3, print_exception=False)

    if image is None:
        raise gr.Error("Please provide an image.", duration=3, print_exception=False)

    print(f"image0: {image} size: {image.size}")
    # image = fetch_image({"image": image, "min_pixels": 28*28*2, "max_pixels": 476*476})
    # image.thumbnail((512, 512)) #resize image to 512x512  preserve aspect ratio
    # print(f"image1: {image} size: {image.size}")


    messages=[]
    if len(history) > 0:
        valid_index = None
        for i in range(len(history)):
            h = history[i]
            if len(h.get("content").strip()) > 0:
                if valid_index is None and h['role'] == 'assistant':
                    valid_index = i-1 
                messages.append({"role": h['role'], "content": [{"type": "text", "text": h['content']}] })

        if valid_index is None:
            messages = []
        if len(messages) > 0 and valid_index > 0:
            messages = messages[valid_index:] #remove previous messages (without image)

    # current prompt
    messages.append({"role": "user","content": [{"type": "text", "text": text}]})
    messages[0]['content'].insert(0, {"type": "image"})
    print(f"messages: {messages}")


    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[image], return_tensors="pt")
    inputs = inputs.to('cuda')


    # Generate
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_args = dict(inputs, streamer=streamer, max_new_tokens=4096)

    with torch.inference_mode():
        thread = Thread(target=model.generate, kwargs=generation_args)
        thread.start()

        yield "..."
        buffer = ""
        
        
        for new_text in streamer:
            buffer += new_text
            yield buffer


with gr.Blocks() as demo:

    # gr.Markdown('<h1 style="text-align:center; margin: 0.2em 0;">Demo.</h1>')
    send_btn = gr.Button("Send", variant="primary", render=False)
    textbox = gr.Textbox(show_label=False, placeholder="Enter your text here and press ENTER", render=False, submit_btn="Send")

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", visible=True, sources="upload", show_label=False)

            clear_btn = gr.Button("Clear", variant="secondary")

            ex =gr.Examples(
                examples=[
                    ["example_images/35.jpg", "Examine the chest X-ray."],
                    ["example_images/363.jpg", "Provide a comprehensive image analysis, and list all abnormalities."],
                    ["example_images/4747.jpg", "Find abnormalities and support devices."],
                    ["example_images/87.jpg", "Find abnormalities and support devices."],
                    ["example_images/6218.jpg", "Find abnormalities and support devices."],
                    ["example_images/6447.jpg", "Find abnormalities and support devices."],


                ],
                inputs=[image_input, textbox],
            )

        with gr.Column(scale=2):
            chat_interface = gr.ChatInterface(fn=model_inference,
                type="messages",
                chatbot=gr.Chatbot(type="messages", label="AI", render_markdown=True, sanitize_html=False, allow_tags=True, height='35vw', container=False, show_share_button=False),
                textbox=textbox,
                additional_inputs=image_input,
                multimodal=False,
                fill_height=False,
                show_api=False,
                )
            gr.HTML('<span style="color:lightgray">Start with a full prompt: Find abnormalities and support devices.<br>\
                Follow up with additial questions, such as Provide differentials or Write a structured report.<br>')



        # Clear chat history when an example is selected (keep example-populated inputs intact)
        ex.load_input_event.then(
                lambda: ([], [], [], None),
                None,
                [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input],
                queue=False,
                show_api=False,
            )
               
        # Clear chat history when a new image is uploaded via the image input
        image_input.upload(
                lambda: ([], [], [], None, DEFAULT_PROMPT),
                None,
                [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox],
                queue=False,
                show_api=False,
            )

        # Clear everything on Clear button click
        clear_btn.click(
                lambda: ([], [], [], None, "", None),
                None,
                [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox, image_input],
                queue=False,
                show_api=False,
            )



demo.queue(max_size=10)
demo.launch(debug=False, server_name="0.0.0.0")