henryu commited on
Commit
0201e5d
Β·
1 Parent(s): 9fd474a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -1
app.py CHANGED
@@ -1 +1,372 @@
1
- cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from mmgpt.models.builder import create_model_and_transforms
8
+
9
+ TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
10
+ response_split = "### Response:"
11
+
12
+
13
+ class Inferencer:
14
+
15
+ def __init__(self, finetune_path, llama_path, open_flamingo_path):
16
+ ckpt = torch.load(finetune_path, map_location="cpu")
17
+ if "model_state_dict" in ckpt:
18
+ state_dict = ckpt["model_state_dict"]
19
+ # remove the "module." prefix
20
+ state_dict = {
21
+ k[7:]: v
22
+ for k, v in state_dict.items() if k.startswith("module.")
23
+ }
24
+ else:
25
+ state_dict = ckpt
26
+ tuning_config = ckpt.get("tuning_config")
27
+ if tuning_config is None:
28
+ print("tuning_config not found in checkpoint")
29
+ else:
30
+ print("tuning_config found in checkpoint: ", tuning_config)
31
+ model, image_processor, tokenizer = create_model_and_transforms(
32
+ model_name="open_flamingo",
33
+ clip_vision_encoder_path="ViT-L-14",
34
+ clip_vision_encoder_pretrained="openai",
35
+ lang_encoder_path=llama_path,
36
+ tokenizer_path=llama_path,
37
+ pretrained_model_path=open_flamingo_path,
38
+ tuning_config=tuning_config,
39
+ )
40
+ model.load_state_dict(state_dict, strict=False)
41
+ model.half()
42
+ model = model.to("cuda")
43
+ model.eval()
44
+ tokenizer.padding_side = "left"
45
+ tokenizer.add_eos_token = False
46
+ self.model = model
47
+ self.image_processor = image_processor
48
+ self.tokenizer = tokenizer
49
+
50
+ def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
51
+ top_k, top_p, do_sample):
52
+ if len(imgpaths) > 1:
53
+ raise gr.Error(
54
+ "Current only support one image, please clear gallery and upload one image"
55
+ )
56
+ lang_x = self.tokenizer([prompt], return_tensors="pt")
57
+ if len(imgpaths) == 0 or imgpaths is None:
58
+ for layer in self.model.lang_encoder._get_decoder_layers():
59
+ layer.condition_only_lang_x(True)
60
+ output_ids = self.model.lang_encoder.generate(
61
+ input_ids=lang_x["input_ids"].cuda(),
62
+ attention_mask=lang_x["attention_mask"].cuda(),
63
+ max_new_tokens=max_new_token,
64
+ num_beams=num_beams,
65
+ temperature=temperature,
66
+ top_k=top_k,
67
+ top_p=top_p,
68
+ do_sample=do_sample,
69
+ )[0]
70
+ for layer in self.model.lang_encoder._get_decoder_layers():
71
+ layer.condition_only_lang_x(False)
72
+ else:
73
+ images = (Image.open(fp) for fp in imgpaths)
74
+ vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
75
+ vision_x = torch.cat(vision_x, dim=0)
76
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
77
+
78
+ output_ids = self.model.generate(
79
+ vision_x=vision_x.cuda(),
80
+ lang_x=lang_x["input_ids"].cuda(),
81
+ attention_mask=lang_x["attention_mask"].cuda(),
82
+ max_new_tokens=max_new_token,
83
+ num_beams=num_beams,
84
+ temperature=temperature,
85
+ top_k=top_k,
86
+ top_p=top_p,
87
+ do_sample=do_sample,
88
+ )[0]
89
+ generated_text = self.tokenizer.decode(
90
+ output_ids, skip_special_tokens=True)
91
+ # print(generated_text)
92
+ result = generated_text.split(response_split)[-1].strip()
93
+ return result
94
+
95
+
96
+ class PromptGenerator:
97
+
98
+ def __init__(
99
+ self,
100
+ prompt_template=TEMPLATE,
101
+ ai_prefix="Response",
102
+ user_prefix="Instruction",
103
+ sep: str = "\n\n### ",
104
+ buffer_size=0,
105
+ ):
106
+ self.all_history = list()
107
+ self.ai_prefix = ai_prefix
108
+ self.user_prefix = user_prefix
109
+ self.buffer_size = buffer_size
110
+ self.prompt_template = prompt_template
111
+ self.sep = sep
112
+
113
+ def add_message(self, role, message):
114
+ self.all_history.append([role, message])
115
+
116
+ def get_images(self):
117
+ img_list = list()
118
+ if self.buffer_size > 0:
119
+ all_history = self.all_history[-2 * (self.buffer_size + 1):]
120
+ elif self.buffer_size == 0:
121
+ all_history = self.all_history[-2:]
122
+ else:
123
+ all_history = self.all_history[:]
124
+ for his in all_history:
125
+ if type(his[-1]) == tuple:
126
+ img_list.append(his[-1][-1])
127
+ return img_list
128
+
129
+ def get_prompt(self):
130
+ format_dict = dict()
131
+ if "{user_prefix}" in self.prompt_template:
132
+ format_dict["user_prefix"] = self.user_prefix
133
+ if "{ai_prefix}" in self.prompt_template:
134
+ format_dict["ai_prefix"] = self.ai_prefix
135
+ prompt_template = self.prompt_template.format(**format_dict)
136
+ ret = prompt_template
137
+ if self.buffer_size > 0:
138
+ all_history = self.all_history[-2 * (self.buffer_size + 1):]
139
+ elif self.buffer_size == 0:
140
+ all_history = self.all_history[-2:]
141
+ else:
142
+ all_history = self.all_history[:]
143
+ context = []
144
+ have_image = False
145
+ for role, message in all_history[::-1]:
146
+ if message:
147
+ if type(message) is tuple and message[
148
+ 1] is not None and not have_image:
149
+ message, _ = message
150
+ context.append(self.sep + "Image:\n<image>" + self.sep +
151
+ role + ":\n" + message)
152
+ else:
153
+ context.append(self.sep + role + ":\n" + message)
154
+ else:
155
+ context.append(self.sep + role + ":\n")
156
+
157
+ ret += "".join(context[::-1])
158
+ return ret
159
+
160
+
161
+ def to_gradio_chatbot(prompt_generator):
162
+ ret = []
163
+ for i, (role, msg) in enumerate(prompt_generator.all_history):
164
+ if i % 2 == 0:
165
+ if type(msg) is tuple:
166
+ import base64
167
+ from io import BytesIO
168
+
169
+ msg, image = msg
170
+ if type(image) is str:
171
+ from PIL import Image
172
+
173
+ image = Image.open(image)
174
+ max_hw, min_hw = max(image.size), min(image.size)
175
+ aspect_ratio = max_hw / min_hw
176
+ max_len, min_len = 800, 400
177
+ shortest_edge = int(
178
+ min(max_len / aspect_ratio, min_len, min_hw))
179
+ longest_edge = int(shortest_edge * aspect_ratio)
180
+ H, W = image.size
181
+ if H > W:
182
+ H, W = longest_edge, shortest_edge
183
+ else:
184
+ H, W = shortest_edge, longest_edge
185
+ image = image.resize((H, W))
186
+ # image = image.resize((224, 224))
187
+ buffered = BytesIO()
188
+ image.save(buffered, format="JPEG")
189
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
190
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
191
+ msg = msg + img_str
192
+ ret.append([msg, None])
193
+ else:
194
+ ret[-1][-1] = msg
195
+ return ret
196
+
197
+
198
+ def bot(
199
+ text,
200
+ image,
201
+ state,
202
+ prompt,
203
+ ai_prefix,
204
+ user_prefix,
205
+ seperator,
206
+ history_buffer,
207
+ max_new_token,
208
+ num_beams,
209
+ temperature,
210
+ top_k,
211
+ top_p,
212
+ do_sample,
213
+ ):
214
+ state.prompt_template = prompt
215
+ state.ai_prefix = ai_prefix
216
+ state.user_prefix = user_prefix
217
+ state.sep = seperator
218
+ state.buffer_size = history_buffer
219
+ if image:
220
+ state.add_message(user_prefix, (text, image))
221
+ else:
222
+ state.add_message(user_prefix, text)
223
+ state.add_message(ai_prefix, None)
224
+ inputs = state.get_prompt()
225
+ image_paths = state.get_images()[-1:]
226
+
227
+ inference_results = inferencer(inputs, image_paths, max_new_token,
228
+ num_beams, temperature, top_k, top_p,
229
+ do_sample)
230
+ state.all_history[-1][-1] = inference_results
231
+ memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3,
232
+ 2)) + 'GB'
233
+ return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated
234
+
235
+
236
+ def clear(state):
237
+ state.all_history = []
238
+ return state, to_gradio_chatbot(state), "", None, ""
239
+
240
+
241
+ title_markdown = ("""
242
+ # πŸ€– Multi-modal GPT
243
+ [[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""")
244
+
245
+
246
+ def build_conversation_demo():
247
+ with gr.Blocks(title="Multi-modal GPT") as demo:
248
+ gr.Markdown(title_markdown)
249
+
250
+ state = gr.State(PromptGenerator())
251
+ with gr.Row():
252
+ with gr.Column(scale=3):
253
+ memory_allocated = gr.Textbox(
254
+ value=init_memory, label="Memory")
255
+ imagebox = gr.Image(type="filepath")
256
+ # TODO config parameters
257
+ with gr.Accordion(
258
+ "Parameters",
259
+ open=True,
260
+ ):
261
+ max_new_token_bar = gr.Slider(
262
+ 0, 1024, 512, label="max_new_token", step=1)
263
+ num_beams_bar = gr.Slider(
264
+ 0.0, 10, 3, label="num_beams", step=1)
265
+ temperature_bar = gr.Slider(
266
+ 0.0, 1.0, 1.0, label="temperature", step=0.01)
267
+ topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
268
+ topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
269
+ do_sample = gr.Checkbox(True, label="do_sample")
270
+ with gr.Accordion(
271
+ "Prompt",
272
+ open=False,
273
+ ):
274
+ with gr.Row():
275
+ ai_prefix = gr.Text("Response", label="AI Prefix")
276
+ user_prefix = gr.Text(
277
+ "Instruction", label="User Prefix")
278
+ seperator = gr.Text("\n\n### ", label="Seperator")
279
+ history_buffer = gr.Slider(
280
+ -1, 10, -1, label="History buffer", step=1)
281
+ prompt = gr.Text(TEMPLATE, label="Prompt")
282
+ model_inputs = gr.Textbox(label="Actual inputs for Model")
283
+
284
+ with gr.Column(scale=6):
285
+ with gr.Row():
286
+ with gr.Column():
287
+ chatbot = gr.Chatbot(elem_id="chatbot").style(
288
+ height=750)
289
+ with gr.Row():
290
+ with gr.Column(scale=8):
291
+ textbox = gr.Textbox(
292
+ show_label=False,
293
+ placeholder="Enter text and press ENTER",
294
+ ).style(container=False)
295
+ submit_btn = gr.Button(value="Submit")
296
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear history")
297
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
298
+ gr.Examples(
299
+ examples=[
300
+ [
301
+ f"{cur_dir}/docs/images/demo_image.jpg",
302
+ "What is in this image?"
303
+ ],
304
+ ],
305
+ inputs=[imagebox, textbox],
306
+ )
307
+ textbox.submit(
308
+ bot,
309
+ [
310
+ textbox,
311
+ imagebox,
312
+ state,
313
+ prompt,
314
+ ai_prefix,
315
+ user_prefix,
316
+ seperator,
317
+ history_buffer,
318
+ max_new_token_bar,
319
+ num_beams_bar,
320
+ temperature_bar,
321
+ topk_bar,
322
+ topp_bar,
323
+ do_sample,
324
+ ],
325
+ [
326
+ state, chatbot, textbox, imagebox, model_inputs,
327
+ memory_allocated
328
+ ],
329
+ )
330
+ submit_btn.click(
331
+ bot,
332
+ [
333
+ textbox,
334
+ imagebox,
335
+ state,
336
+ prompt,
337
+ ai_prefix,
338
+ user_prefix,
339
+ seperator,
340
+ history_buffer,
341
+ max_new_token_bar,
342
+ num_beams_bar,
343
+ temperature_bar,
344
+ topk_bar,
345
+ topp_bar,
346
+ do_sample,
347
+ ],
348
+ [
349
+ state, chatbot, textbox, imagebox, model_inputs,
350
+ memory_allocated
351
+ ],
352
+ )
353
+ clear_btn.click(clear, [state],
354
+ [state, chatbot, textbox, imagebox, model_inputs])
355
+ return demo
356
+
357
+
358
+ if __name__ == "__main__":
359
+ llama_path = "checkpoints/llama-7b_hf"
360
+ open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
361
+ finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
362
+
363
+ inferencer = Inferencer(
364
+ llama_path=llama_path,
365
+ open_flamingo_path=open_flamingo_path,
366
+ finetune_path=finetune_path)
367
+ init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB'
368
+ demo = build_conversation_demo()
369
+ demo.queue(concurrency_count=3)
370
+ IP = "0.0.0.0"
371
+ PORT = 8997
372
+ demo.launch(server_name=IP, server_port=PORT, share=True)