am commited on
Commit
081767b
·
1 Parent(s): 394c7b7
Files changed (1) hide show
  1. app.py +98 -73
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
- from transformers.image_utils import load_image
4
  from transformers.image_transforms import resize
5
  from threading import Thread
6
  import re
@@ -13,10 +13,8 @@ import os
13
  from transformers import Qwen2_5_VLForConditionalGeneration
14
 
15
  pretrained_model_name_or_path=os.environ.get("MODEL", "amrn/testmodel2")
16
-
17
  auth_token = os.environ.get("HF_TOKEN") or True
18
-
19
-
20
 
21
  model = AutoModelForImageTextToText.from_pretrained(
22
  pretrained_model_name_or_path=pretrained_model_name_or_path,
@@ -40,76 +38,50 @@ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path,
40
 
41
  @spaces.GPU
42
  def model_inference(
43
- input_dict, history
44
  ):
45
 
46
- print(f"input_dict: {input_dict}")
47
- print(f"history: {history}")
48
-
49
- text = input_dict["text"]
50
 
51
- if len(history) > 0:
52
- try:
53
- image = history[0]['content'][0]
54
- except:
55
- raise gr.Error("Please refresh the page to start over.")
56
-
57
- else:
58
- try:
59
- image = input_dict["files"][0]
60
- except:
61
- raise gr.Error("Please provide an image.", duration=2)
62
 
63
  if len(text) == 0:
64
- raise gr.Error("Please input a query.", duration=2)
 
65
 
66
- if len(image) == 0:
67
- raise gr.Error("Please provide an image.", duration=2)
68
 
69
- image = load_image(image)
 
70
 
71
- resulting_messages=[]
72
 
 
73
  if len(history) > 0:
74
- for i in range(1, len(history)):
 
75
  h = history[i]
76
- resulting_messages.append({
77
- "role": h['role'],
78
- "content": [{"type": "text", "text": h['content']}]
79
- })
80
 
81
- # latest
82
- resulting_messages.append({
83
- "role": "user",
84
- "content": [{"type": "text", "text": text}]
85
- })
86
- resulting_messages[0]['content'].append({"type": "image"})
87
 
 
 
 
88
 
89
- print(f"resulting_messages: {resulting_messages}")
90
- print(f"image0: {image} size: {image.size}")
91
 
 
92
 
93
- # width, height = image.size
94
- # max_pixels = 512*512
95
- # if height * width > max_pixels:
96
- # beta = math.sqrt((height * width) / max_pixels)
97
- # h_bar = math.floor(height / beta)
98
- # w_bar = math.floor(width / beta)
99
- # image = image.resize((w_bar, h_bar))
100
- # print(f"resizedimage: {image} size: {image.size}")
101
-
102
- # inputs = processor.apply_chat_template(
103
- # resulting_messages,
104
- # add_generation_prompt=True,
105
- # tokenize=True,
106
- # return_dict=True,
107
- # return_tensors="pt",
108
- # padding=True,
109
- # padding_side="left",
110
- # )
111
-
112
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
113
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
114
  inputs = inputs.to('cuda')
115
 
@@ -135,25 +107,78 @@ def model_inference(
135
  yield buffer
136
 
137
 
138
- examples=[
139
- [{"text": "Find abnormalities and support devices.", "files": ["example_images/35.jpg"]}],
140
- [{"text": "Find abnormalities and support devices.", "files": ["example_images/363.jpg"]}],
141
- [{"text": "Find abnormalities and support devices.", "files": ["example_images/376.jpg"]}],
142
 
143
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
145
 
146
- demo = gr.ChatInterface(fn=model_inference,
147
- chatbot=gr.Chatbot(type="messages", render_markdown=True, sanitize_html=False, allow_tags=True, height=640, min_height=640, max_height=640, resizable=False),
148
  type="messages",
149
- title="Demo",
150
- description="Demo.",
151
- examples=examples,
152
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="single", lines=1, max_lines=4), stop_btn=True, multimodal=True,
153
- cache_examples=False,
154
- fill_height=False
155
- # flagging_mode="manual",
156
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
+ from transformers.image_utils import load_image, valid_images
4
  from transformers.image_transforms import resize
5
  from threading import Thread
6
  import re
 
13
  from transformers import Qwen2_5_VLForConditionalGeneration
14
 
15
  pretrained_model_name_or_path=os.environ.get("MODEL", "amrn/testmodel2")
 
16
  auth_token = os.environ.get("HF_TOKEN") or True
17
+ DEFAULT_PROMPT = "Find abnormalities and support devices."
 
18
 
19
  model = AutoModelForImageTextToText.from_pretrained(
20
  pretrained_model_name_or_path=pretrained_model_name_or_path,
 
38
 
39
  @spaces.GPU
40
  def model_inference(
41
+ text, history, image=None
42
  ):
43
 
 
 
 
 
44
 
45
+ print(f"text: {text}")
46
+ print(f"history: {history}")
 
 
 
 
 
 
 
 
 
47
 
48
  if len(text) == 0:
49
+ # return 'bad request', 'Please input a query.'
50
+ raise gr.Error("Please input a query.", duration=3, print_exception=False)
51
 
52
+ if image is None:
53
+ raise gr.Error("Please provide an image.", duration=3, print_exception=False)
54
 
55
+ # image = load_image(image)
56
+ print(f"image0: {image} size: {image.size}")
57
 
 
58
 
59
+ messages=[]
60
  if len(history) > 0:
61
+ valid_index = None
62
+ for i in range(len(history)):
63
  h = history[i]
64
+ if len(h.get("content").strip()) > 0:
65
+ if valid_index is None and h['role'] == 'assistant':
66
+ valid_index = i-1 #supposed to be 0
67
+ messages.append({"role": h['role'], "content": [{"type": "text", "text": h['content']}] })
68
 
69
+ # print(f"valid_index: {valid_index}")
70
+ if valid_index is None:
71
+ messages = []
72
+ if len(messages) > 0 and valid_index > 0:
73
+ # print(f"removing previous messages (without image) valid_index: {valid_index}")
74
+ messages = messages[valid_index:] #remove previous messages (without image)
75
 
76
+ # current prompt
77
+ messages.append({"role": "user","content": [{"type": "text", "text": text}]})
78
+ messages[0]['content'].insert(0, {"type": "image"})
79
 
 
 
80
 
81
+ print(f"messages: {messages}")
82
 
83
+
84
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
86
  inputs = inputs.to('cuda')
87
 
 
107
  yield buffer
108
 
109
 
 
 
 
 
110
 
111
+ # css_no_header = """
112
+ # /* Hide the header row inside this Examples block */
113
+ # #ex_tbl thead { display: none !important; }
114
+ # """
115
+
116
+
117
+
118
+ theme = gr.themes.Default(
119
+ primary_hue="green",
120
+ # text_size="lg",
121
+ )
122
+
123
+ with gr.Blocks(theme=theme) as demo:
124
+
125
+ send_btn = gr.Button("Send", variant="primary", render=False)
126
+ textbox = gr.Textbox(show_label=False, placeholder="Enter your text here and press ENTER", render=False, submit_btn="Send")
127
+
128
+ with gr.Row():
129
+ with gr.Column(scale=4):
130
+ # input_type_radio = gr.Radio(choices=["Image", "Video"], value="Image", label="Select Input Type")
131
+ image_input = gr.Image(type="pil", visible=True, sources="upload", show_label=False)
132
+
133
+ clear_btn = gr.Button("Clear", variant="secondary")
134
+
135
+ with gr.Column():
136
+ ex =gr.Examples(
137
+ examples=[
138
+ ["example_images/35.jpg", "Find abnormalities and support devices."],
139
+ ["example_images/363.jpg", "Provide a comprehensive image analysis, and list all abnormalities."],
140
+ ["example_images/376.jpg", "Examine the chest X-ray."],
141
+ ],
142
+ inputs=[image_input, textbox],
143
+ # elem_id=css_no_header
144
+ )
145
 
146
+ with gr.Column(scale=7):
147
 
148
+ chat_interface = gr.ChatInterface(fn=model_inference,
 
149
  type="messages",
150
+ chatbot=gr.Chatbot(type="messages", label="AI", render_markdown=True, sanitize_html=False, allow_tags=True, height=800,),
151
+ textbox=textbox,
152
+ additional_inputs=image_input,
153
+ multimodal=False,
154
+ )
155
+
156
+ # Clear chat history when an example is selected (keep example-populated inputs intact)
157
+ ex.load_input_event.then(
158
+ lambda: ([], [], [], None),
159
+ None,
160
+ [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input],
161
+ queue=False,
162
+ show_api=False,
163
+ )
164
+
165
+ # Clear chat history when a new image is uploaded via the image input
166
+ image_input.upload(
167
+ lambda: ([], [], [], None, DEFAULT_PROMPT),
168
+ None,
169
+ [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox],
170
+ queue=False,
171
+ show_api=False,
172
+ )
173
+
174
+ # Clear everything on Clear button click
175
+ clear_btn.click(
176
+ lambda: ([], [], [], None, "", None),
177
+ None,
178
+ [chat_interface.chatbot, chat_interface.chatbot_state, chat_interface.chatbot_value, chat_interface.saved_input, textbox, image_input],
179
+ queue=False,
180
+ show_api=False,
181
+ )
182
 
183
 
184