1f commited on
Commit
313b9b3
·
verified ·
1 Parent(s): e2cfb48

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. r1-a/response_generation/minicpm/MiniCPM-o/assets/modelscope_logo.png +0 -0
  2. r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/__init__.py +1 -0
  3. r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/omnilmm.py +457 -0
  4. r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/resampler.py +171 -0
  5. r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/utils.py +555 -0
  6. r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/train/train_utils.py +153 -0
  7. r1-a/response_generation/minicpm/MiniCPM-o/quantize/bnb_quantize.py +81 -0
  8. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/chatbot_web_demo_o2.6.py +552 -0
  9. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/model_server.py +936 -0
  10. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/vad_utils.py +301 -0
  11. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.development +0 -0
  12. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.production +0 -0
  13. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc-auto-import.json +359 -0
  14. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc.cjs +26 -0
  15. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo.py +264 -0
  16. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.5.py +256 -0
  17. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.6.py +557 -0
  18. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-2_5.py +109 -0
  19. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-minicpmv2_6.py +271 -0
  20. r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit.py +99 -0
r1-a/response_generation/minicpm/MiniCPM-o/assets/modelscope_logo.png ADDED
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .omnilmm import OmniLMMForCausalLM
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/omnilmm.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gc
3
+ import math
4
+ import timm
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn as nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ from transformers import AutoConfig, AutoModelForCausalLM
12
+ from transformers import MistralForCausalLM, MistralModel, MistralConfig
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+
15
+ from omnilmm.model.utils import build_transform
16
+ from omnilmm.model.resampler import Resampler
17
+
18
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
+ DEFAULT_IM_START_TOKEN = "<im_start>"
20
+ DEFAULT_IM_END_TOKEN = "<im_end>"
21
+
22
+
23
+ class OmniLMMConfig(MistralConfig):
24
+ model_type = "omnilmm"
25
+
26
+
27
+ class Identity(torch.nn.Identity):
28
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
29
+ return super().forward(input)
30
+
31
+
32
+ def create_vision_module(config):
33
+ vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
34
+ pretrained=False,
35
+ num_classes=0,
36
+ dynamic_img_size=True,
37
+ dynamic_img_pad=True)
38
+
39
+ if isinstance(vision_tower, timm.models.VisionTransformer):
40
+ if vision_tower.attn_pool is not None:
41
+ vision_tower.attn_pool = Identity()
42
+
43
+ # use 2nd last layer's output
44
+ vision_tower.blocks[-1] = Identity()
45
+
46
+ embed_dim = config.hidden_size
47
+ resampler = Resampler(
48
+ grid_size=int(math.sqrt(config.num_query)),
49
+ embed_dim=embed_dim,
50
+ num_heads=embed_dim // 128,
51
+ kv_dim=vision_tower.embed_dim,
52
+ )
53
+ return vision_tower, resampler
54
+
55
+
56
+ class OmniLMMModel(MistralModel):
57
+ config_class = OmniLMMConfig
58
+
59
+ def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
60
+ super(OmniLMMModel, self).__init__(config)
61
+
62
+ if hasattr(config, "mm_vision_tower"):
63
+ vision_tower, resampler = create_vision_module(config)
64
+
65
+ # print(__file__, 'skip loading vision tower weights')
66
+
67
+ # HACK: for FSDP
68
+ self.vision_tower = [vision_tower]
69
+ self.resampler = resampler
70
+ if tune_clip:
71
+ self.vision_tower = self.vision_tower[0]
72
+
73
+ self.vision_config = lambda x: None
74
+
75
+ def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
76
+ self.config.mm_vision_tower = vision_tower
77
+ self.config.use_mm_proj = True
78
+ self.config.num_query = num_query
79
+ self.config.image_size = image_size
80
+
81
+ if not hasattr(self, 'vision_tower'):
82
+ vision_tower, resampler = create_vision_module(self.config)
83
+ state_dict = torch.load(
84
+ '/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
85
+ vision_tower.load_state_dict(state_dict, strict=False)
86
+ del state_dict
87
+ gc.collect()
88
+ else:
89
+ if isinstance(self.vision_tower, list):
90
+ vision_tower = self.vision_tower[0]
91
+ else:
92
+ vision_tower = self.vision_tower
93
+ resampler = self.resampler
94
+ self.vision_tower = vision_tower if tune_clip else [vision_tower]
95
+ self.resampler = resampler
96
+
97
+ train_img_transform = build_transform(
98
+ is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
99
+ eval_img_transform = build_transform(
100
+ is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
101
+
102
+ return dict(
103
+ image_processor=(train_img_transform, eval_img_transform),
104
+ image_token_len=num_query,
105
+ vision_config=self.vision_config
106
+ )
107
+
108
+ def get_vision_embedding(self, pixel_values):
109
+ if isinstance(self.vision_tower, list):
110
+ vision_tower = self.vision_tower[0] # HACK: for FSDP
111
+ else:
112
+ vision_tower = self.vision_tower
113
+
114
+ dtype = vision_tower.pos_embed.data.dtype
115
+ vision_embedding = vision_tower.forward_features(
116
+ pixel_values.type(dtype))
117
+ if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
118
+ vision_embedding = vision_embedding[:,
119
+ vision_tower.num_prefix_tokens:]
120
+ res = self.resampler(vision_embedding)
121
+ return res
122
+
123
+ def get_vllm_embedding(self, data):
124
+
125
+ if 'vision_hidden_states' not in data:
126
+ pixel_values_list = data['pixel_values']
127
+ vision_hidden_states = []
128
+ for pixel_values in pixel_values_list:
129
+ if len(pixel_values) > 0:
130
+ vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
131
+ else:
132
+ vision_hidden_states.append([])
133
+ else:
134
+ vision_hidden_states = data['vision_hidden_states']
135
+
136
+ #vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
137
+ inputs_embeds = self.embed_tokens(data['input_ids'])
138
+ vision_hidden_states = [i.type(inputs_embeds.dtype)
139
+ if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
140
+ ]
141
+
142
+
143
+ # HACK: replace back original embeddings for LLaVA pretraining
144
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
145
+
146
+ new_input_embeds = []
147
+ cur_image_idx = 0
148
+ for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
149
+ if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
150
+ # multimodal LLM, but the current sample is not multimodal
151
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
152
+ new_input_embeds.append(cur_input_embeds)
153
+ continue
154
+
155
+ if self.vision_config.use_im_start_end:
156
+ cur_image_features = vision_hidden_states[cur_image_idx]
157
+ num_patches = cur_image_features.shape[0]
158
+ if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
159
+ raise ValueError(
160
+ "The number of image start tokens and image end tokens should be the same.")
161
+ image_start_tokens = torch.where(
162
+ cur_input_ids == self.vision_config.im_start_token)[0]
163
+ for image_start_token_pos in image_start_tokens:
164
+ cur_image_features = vision_hidden_states[cur_image_idx].to(
165
+ device=cur_input_embeds.device)
166
+ num_patches = cur_image_features.shape[0]
167
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
168
+ raise ValueError(
169
+ "The image end token should follow the image start token.")
170
+ if orig_embeds_params is not None:
171
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
172
+ cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
173
+ else:
174
+ cur_new_input_embeds = torch.cat(
175
+ (cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
176
+ cur_image_idx += 1
177
+ new_input_embeds.append(cur_new_input_embeds)
178
+ else:
179
+ raise NotImplementedError
180
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
181
+
182
+ return inputs_embeds, vision_hidden_states
183
+
184
+ def forward(
185
+ self,
186
+ input_ids: torch.LongTensor = None,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
189
+ inputs_embeds: Optional[torch.FloatTensor] = None,
190
+ use_cache: Optional[bool] = None,
191
+ output_attentions: Optional[bool] = None,
192
+ output_hidden_states: Optional[bool] = None,
193
+ images: Optional[torch.FloatTensor] = None,
194
+ return_dict: Optional[bool] = None,
195
+ **kwargs
196
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
197
+
198
+ # HACK: replace back original embeddings for LLaVA pretraining
199
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
200
+
201
+ if inputs_embeds is None and past_key_values is None:
202
+ inputs_embeds = self.embed_tokens(input_ids)
203
+
204
+ vision_tower = getattr(self, 'vision_tower', None)
205
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
206
+
207
+ if type(images) is list:
208
+ image_features = []
209
+ for image in images:
210
+ image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
211
+ 0]
212
+ image_features.append(image_forward_out)
213
+ else:
214
+ image_features = self.get_vision_embedding(images)
215
+
216
+ dummy_image_features = torch.zeros(
217
+ self.config.num_query,
218
+ self.config.hidden_size,
219
+ device=inputs_embeds.device,
220
+ dtype=inputs_embeds.dtype)
221
+
222
+ new_input_embeds = []
223
+ cur_image_idx = 0
224
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
225
+ if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
226
+ # multimodal LLM, but the current sample is not multimodal
227
+ cur_input_embeds = cur_input_embeds + \
228
+ (0. * dummy_image_features).sum()
229
+ new_input_embeds.append(cur_input_embeds)
230
+ continue
231
+
232
+ if self.vision_config.use_im_start_end:
233
+ cur_image_features = image_features[cur_image_idx]
234
+ num_patches = cur_image_features.shape[0]
235
+ if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
236
+ raise ValueError(
237
+ "The number of image start tokens and image end tokens should be the same.")
238
+ image_start_tokens = torch.where(
239
+ cur_input_ids == self.vision_config.im_start_token)[0]
240
+ for image_start_token_pos in image_start_tokens:
241
+ cur_image_features = image_features[cur_image_idx].to(
242
+ device=cur_input_embeds.device)
243
+ num_patches = cur_image_features.shape[0]
244
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
245
+ raise ValueError(
246
+ "The image end token should follow the image start token.")
247
+ if orig_embeds_params is not None:
248
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
249
+ cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
250
+ else:
251
+ cur_new_input_embeds = torch.cat(
252
+ (cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
253
+ cur_image_idx += 1
254
+ new_input_embeds.append(cur_new_input_embeds)
255
+ else:
256
+ raise NotImplementedError
257
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
258
+ input_ids = None
259
+
260
+ return super(OmniLMMModel, self).forward(
261
+ input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
262
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
263
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
264
+ return_dict=return_dict,
265
+ **kwargs
266
+ )
267
+
268
+
269
+ class OmniLMMForCausalLM(MistralForCausalLM):
270
+ config_class = OmniLMMConfig
271
+
272
+ def __init__(self, config, mm_vision_tower=None, tune_clip=True):
273
+ super(MistralForCausalLM, self).__init__(config)
274
+ self.model = OmniLMMModel(
275
+ config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
276
+
277
+ self.lm_head = nn.Linear(
278
+ config.hidden_size, config.vocab_size, bias=False)
279
+
280
+ # Initialize weights and apply final processing
281
+ self.post_init()
282
+
283
+ def forward(
284
+ self,
285
+ input_ids: torch.LongTensor = None,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
288
+ inputs_embeds: Optional[torch.FloatTensor] = None,
289
+ labels: Optional[torch.LongTensor] = None,
290
+ use_cache: Optional[bool] = None,
291
+ output_attentions: Optional[bool] = None,
292
+ output_hidden_states: Optional[bool] = None,
293
+ images: Optional[torch.FloatTensor] = None,
294
+ return_dict: Optional[bool] = None,
295
+ **kwargs
296
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
297
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
298
+ output_hidden_states = (
299
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
300
+ )
301
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
302
+
303
+ # print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
304
+ # print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
305
+ # print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
306
+
307
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
308
+ outputs = self.model(
309
+ input_ids=input_ids,
310
+ attention_mask=attention_mask,
311
+ past_key_values=past_key_values,
312
+ inputs_embeds=inputs_embeds,
313
+ use_cache=use_cache,
314
+ output_attentions=output_attentions,
315
+ output_hidden_states=output_hidden_states,
316
+ return_dict=return_dict,
317
+ images=images,
318
+ **kwargs
319
+ )
320
+
321
+ hidden_states = outputs[0]
322
+ logits = self.lm_head(hidden_states)
323
+
324
+ loss = None
325
+ if labels is not None:
326
+ # Shift so that tokens < n predict n
327
+ shift_logits = logits[..., :-1, :].contiguous()
328
+ shift_labels = labels[..., 1:].contiguous()
329
+ # Flatten the tokens
330
+ loss_fct = CrossEntropyLoss()
331
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
332
+ shift_labels = shift_labels.view(-1)
333
+ # Enable model/pipeline parallelism
334
+ shift_labels = shift_labels.to(shift_logits.device)
335
+ loss = loss_fct(shift_logits, shift_labels)
336
+
337
+ if not return_dict:
338
+ output = (logits,) + outputs[1:]
339
+ return (loss,) + output if loss is not None else output
340
+
341
+ return CausalLMOutputWithPast(
342
+ loss=loss,
343
+ logits=logits,
344
+ past_key_values=outputs.past_key_values,
345
+ hidden_states=outputs.hidden_states,
346
+ attentions=outputs.attentions,
347
+ )
348
+
349
+ # TODO could be removed for generate_vllm()
350
+ def prepare_inputs_for_generation(
351
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
352
+ ):
353
+ if past_key_values:
354
+ input_ids = input_ids[:, -1:]
355
+
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and past_key_values is None:
358
+ model_inputs = {"inputs_embeds": inputs_embeds}
359
+ else:
360
+ model_inputs = {"input_ids": input_ids}
361
+
362
+ model_inputs.update(
363
+ {
364
+ "past_key_values": past_key_values,
365
+ "use_cache": kwargs.get("use_cache"),
366
+ "attention_mask": attention_mask,
367
+ "images": kwargs.get("images", None),
368
+ }
369
+ )
370
+ return model_inputs
371
+
372
+ def generate_vllm(
373
+ self,
374
+ input_ids: torch.LongTensor = None,
375
+ images: Optional[torch.FloatTensor] = None,
376
+ vision_hidden_states=None,
377
+ return_vision_hidden_states=False,
378
+ **kwargs
379
+ ):
380
+ model_inputs = {'input_ids': input_ids}
381
+ if vision_hidden_states is None:
382
+ model_inputs['pixel_values'] = images
383
+ else:
384
+ model_inputs['vision_hidden_states'] = vision_hidden_states
385
+
386
+ with torch.inference_mode():
387
+ inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
388
+
389
+ result = self.generate(
390
+ inputs_embeds=inputs_embeds,
391
+ **kwargs
392
+ )
393
+
394
+ if return_vision_hidden_states:
395
+ return result, vision_hidden_states
396
+
397
+ return result
398
+
399
+
400
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
401
+ tune_mm_mlp_adapter=False):
402
+ self.model.vision_config.use_im_start_end = mm_use_im_start_end
403
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
404
+ self.resize_token_embeddings(len(tokenizer))
405
+
406
+ if mm_use_im_start_end:
407
+ num_new_tokens = tokenizer.add_tokens(
408
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
409
+ self.resize_token_embeddings(len(tokenizer))
410
+ self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
411
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
412
+
413
+ if num_new_tokens > 0:
414
+ input_embeddings = self.get_input_embeddings().weight.data
415
+ output_embeddings = self.get_output_embeddings().weight.data
416
+
417
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
418
+ dim=0, keepdim=True)
419
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
420
+ dim=0, keepdim=True)
421
+
422
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
423
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
424
+
425
+ # for new sft data
426
+ num_new_tokens = tokenizer.add_tokens(
427
+ ['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
428
+ self.resize_token_embeddings(len(tokenizer))
429
+
430
+ if num_new_tokens > 0:
431
+ input_embeddings = self.get_input_embeddings().weight.data
432
+ output_embeddings = self.get_output_embeddings().weight.data
433
+
434
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
435
+ dim=0, keepdim=True)
436
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
437
+ dim=0, keepdim=True)
438
+
439
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
440
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
441
+
442
+ if tune_mm_mlp_adapter:
443
+ self.model.orig_embeds_params = [
444
+ self.get_input_embeddings().weight.data.clone().to(device=device)]
445
+ for p in self.get_input_embeddings().parameters():
446
+ p.requires_grad = True
447
+ for p in self.get_output_embeddings().parameters():
448
+ p.requires_grad = False
449
+
450
+ self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
451
+ [DEFAULT_IMAGE_PATCH_TOKEN])[0]
452
+ print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
453
+ # exit()
454
+
455
+
456
+ AutoConfig.register("omnilmm", OmniLMMConfig)
457
+ AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/resampler.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ from PIL import Image
12
+ from typing import Callable, Optional, Sequence, Tuple, List, Union
13
+ import numpy as np
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+
23
+ def get_abs_pos(abs_pos, tgt_size):
24
+ # abs_pos: L, C
25
+ # tgt_size: M
26
+ # return: M, C
27
+ src_size = int(math.sqrt(abs_pos.size(0)))
28
+ tgt_size = int(math.sqrt(tgt_size))
29
+ dtype = abs_pos.dtype
30
+
31
+ if src_size != tgt_size:
32
+ return F.interpolate(
33
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
34
+ size=(tgt_size, tgt_size),
35
+ mode="bicubic",
36
+ align_corners=False,
37
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
38
+ else:
39
+ return abs_pos
40
+
41
+
42
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
43
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
44
+ """
45
+ grid_size: int of the grid height and width
46
+ return:
47
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
48
+ """
49
+ grid_h = np.arange(grid_size, dtype=np.float32)
50
+ grid_w = np.arange(grid_size, dtype=np.float32)
51
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
52
+ grid = np.stack(grid, axis=0)
53
+
54
+ grid = grid.reshape([2, 1, grid_size, grid_size])
55
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
56
+ if cls_token:
57
+ pos_embed = np.concatenate(
58
+ [np.zeros([1, embed_dim]), pos_embed], axis=0)
59
+ return pos_embed
60
+
61
+
62
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
63
+ assert embed_dim % 2 == 0
64
+
65
+ # use half of dimensions to encode grid_h
66
+ emb_h = get_1d_sincos_pos_embed_from_grid(
67
+ embed_dim // 2, grid[0]) # (H*W, D/2)
68
+ emb_w = get_1d_sincos_pos_embed_from_grid(
69
+ embed_dim // 2, grid[1]) # (H*W, D/2)
70
+
71
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
72
+ return emb
73
+
74
+
75
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76
+ """
77
+ embed_dim: output dimension for each position
78
+ pos: a list of positions to be encoded: size (M,)
79
+ out: (M, D)
80
+ """
81
+ assert embed_dim % 2 == 0
82
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
83
+ omega /= embed_dim / 2.
84
+ omega = 1. / 10000 ** omega # (D/2,)
85
+
86
+ pos = pos.reshape(-1) # (M,)
87
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
88
+
89
+ emb_sin = np.sin(out) # (M, D/2)
90
+ emb_cos = np.cos(out) # (M, D/2)
91
+
92
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
93
+ return emb
94
+
95
+
96
+ class Resampler(nn.Module):
97
+ """
98
+ A 2D perceiver-resampler network with one cross attention layers by
99
+ (grid_size**2) learnable queries and 2d sincos pos_emb
100
+ Outputs:
101
+ A tensor with the shape of (grid_size**2, embed_dim)
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ grid_size,
107
+ embed_dim,
108
+ num_heads,
109
+ kv_dim=None,
110
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)
111
+ ):
112
+ super().__init__()
113
+ self.num_queries = grid_size ** 2
114
+ self.embed_dim = embed_dim
115
+ self.num_heads = num_heads
116
+
117
+ self.pos_embed = nn.Parameter(
118
+ torch.from_numpy(get_2d_sincos_pos_embed(
119
+ embed_dim, grid_size)).float()
120
+ ).requires_grad_(False)
121
+
122
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
123
+ trunc_normal_(self.query, std=.02)
124
+
125
+ if kv_dim is not None and kv_dim != embed_dim:
126
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
127
+ else:
128
+ self.kv_proj = nn.Identity()
129
+
130
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
131
+ self.ln_q = norm_layer(embed_dim)
132
+ self.ln_kv = norm_layer(embed_dim)
133
+
134
+ self.ln_post = norm_layer(embed_dim)
135
+ self.proj = nn.Parameter(
136
+ (embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
137
+
138
+ self.apply(self._init_weights)
139
+
140
+ def _init_weights(self, m):
141
+ if isinstance(m, nn.Linear):
142
+ trunc_normal_(m.weight, std=.02)
143
+ if isinstance(m, nn.Linear) and m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.LayerNorm):
146
+ nn.init.constant_(m.bias, 0)
147
+ nn.init.constant_(m.weight, 1.0)
148
+
149
+ def forward(self, x, attn_mask=None):
150
+
151
+ pos_embed = get_abs_pos(self.pos_embed, x.size(1))
152
+
153
+ x = self.kv_proj(x)
154
+ x = self.ln_kv(x).permute(1, 0, 2)
155
+
156
+ N = x.shape[1]
157
+ q = self.ln_q(self.query)
158
+ # print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
159
+ out = self.attn(
160
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1),
161
+ x + pos_embed.unsqueeze(1),
162
+ x,
163
+ attn_mask=attn_mask)[0]
164
+ x = out.permute(1, 0, 2)
165
+
166
+ x = self.ln_post(x)
167
+ x = x @ self.proj
168
+ return x
169
+
170
+ def _repeat(self, query, N: int):
171
+ return query.unsqueeze(1).repeat(1, N, 1)
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/utils.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from timm.data.transforms import RandomResizedCropAndInterpolation
3
+ from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
4
+ from transformers import AutoConfig
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import torch.distributed as dist
8
+ import numpy as np
9
+ import pickle
10
+ import base64
11
+ import cv2
12
+ import os
13
+ import torch
14
+ from transformers import AutoConfig, StoppingCriteria
15
+
16
+ try:
17
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
18
+ except ImportError:
19
+ OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
20
+ OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
21
+
22
+
23
+ def auto_upgrade(config):
24
+ cfg = AutoConfig.from_pretrained(config)
25
+ if 'llava' in config and cfg.model_type != 'llava':
26
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
27
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
28
+ confirm = input(
29
+ "Please confirm that you want to upgrade the checkpoint. [Y/N]")
30
+ if confirm.lower() in ["y", "yes"]:
31
+ print("Upgrading checkpoint...")
32
+ assert len(cfg.architectures) == 1
33
+ setattr(cfg.__class__, "model_type", "llava")
34
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
35
+ cfg.save_pretrained(config)
36
+ print("Checkpoint upgraded.")
37
+ else:
38
+ print("Checkpoint upgrade aborted.")
39
+ exit(1)
40
+
41
+
42
+ class KeywordsStoppingCriteria(StoppingCriteria):
43
+ def __init__(self, keywords, tokenizer, input_ids):
44
+ self.keywords = keywords
45
+ self.tokenizer = tokenizer
46
+ self.start_len = None
47
+ self.input_ids = input_ids
48
+
49
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
50
+ if self.start_len is None:
51
+ self.start_len = self.input_ids.shape[1]
52
+ else:
53
+ outputs = self.tokenizer.batch_decode(
54
+ output_ids[:, self.start_len:], skip_special_tokens=True)[0]
55
+ for keyword in self.keywords:
56
+ if keyword in outputs:
57
+ return True
58
+ return False
59
+
60
+
61
+ def auto_upgrade(config):
62
+ cfg = AutoConfig.from_pretrained(config)
63
+ if 'llava' in config and cfg.model_type != 'llava':
64
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
65
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
66
+ confirm = input(
67
+ "Please confirm that you want to upgrade the checkpoint. [Y/N]")
68
+ if confirm.lower() in ["y", "yes"]:
69
+ print("Upgrading checkpoint...")
70
+ assert len(cfg.architectures) == 1
71
+ setattr(cfg.__class__, "model_type", "llava")
72
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
73
+ cfg.save_pretrained(config)
74
+ print("Checkpoint upgraded.")
75
+ else:
76
+ print("Checkpoint upgrade aborted.")
77
+ exit(1)
78
+
79
+ # aug functions
80
+
81
+
82
+ def identity_func(img):
83
+ return img
84
+
85
+
86
+ def autocontrast_func(img, cutoff=0):
87
+ '''
88
+ same output as PIL.ImageOps.autocontrast
89
+ '''
90
+ n_bins = 256
91
+
92
+ def tune_channel(ch):
93
+ n = ch.size
94
+ cut = cutoff * n // 100
95
+ if cut == 0:
96
+ high, low = ch.max(), ch.min()
97
+ else:
98
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
99
+ low = np.argwhere(np.cumsum(hist) > cut)
100
+ low = 0 if low.shape[0] == 0 else low[0]
101
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
102
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
103
+ if high <= low:
104
+ table = np.arange(n_bins)
105
+ else:
106
+ scale = (n_bins - 1) / (high - low)
107
+ table = np.arange(n_bins) * scale - low * scale
108
+ table[table < 0] = 0
109
+ table[table > n_bins - 1] = n_bins - 1
110
+ table = table.clip(0, 255).astype(np.uint8)
111
+ return table[ch]
112
+
113
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
114
+ out = cv2.merge(channels)
115
+ return out
116
+
117
+
118
+ def equalize_func(img):
119
+ '''
120
+ same output as PIL.ImageOps.equalize
121
+ PIL's implementation is different from cv2.equalize
122
+ '''
123
+ n_bins = 256
124
+
125
+ def tune_channel(ch):
126
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
127
+ non_zero_hist = hist[hist != 0].reshape(-1)
128
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
129
+ if step == 0:
130
+ return ch
131
+ n = np.empty_like(hist)
132
+ n[0] = step // 2
133
+ n[1:] = hist[:-1]
134
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
135
+ return table[ch]
136
+
137
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
138
+ out = cv2.merge(channels)
139
+ return out
140
+
141
+
142
+ def rotate_func(img, degree, fill=(0, 0, 0)):
143
+ '''
144
+ like PIL, rotate by degree, not radians
145
+ '''
146
+ H, W = img.shape[0], img.shape[1]
147
+ center = W / 2, H / 2
148
+ M = cv2.getRotationMatrix2D(center, degree, 1)
149
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
150
+ return out
151
+
152
+
153
+ def solarize_func(img, thresh=128):
154
+ '''
155
+ same output as PIL.ImageOps.posterize
156
+ '''
157
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
158
+ table = table.clip(0, 255).astype(np.uint8)
159
+ out = table[img]
160
+ return out
161
+
162
+
163
+ def color_func(img, factor):
164
+ '''
165
+ same output as PIL.ImageEnhance.Color
166
+ '''
167
+ # implementation according to PIL definition, quite slow
168
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
169
+ # out = blend(degenerate, img, factor)
170
+ # M = (
171
+ # np.eye(3) * factor
172
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
173
+ # )[np.newaxis, np.newaxis, :]
174
+ M = (
175
+ np.float32([
176
+ [0.886, -0.114, -0.114],
177
+ [-0.587, 0.413, -0.587],
178
+ [-0.299, -0.299, 0.701]]) * factor
179
+ + np.float32([[0.114], [0.587], [0.299]])
180
+ )
181
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
182
+ return out
183
+
184
+
185
+ def contrast_func(img, factor):
186
+ """
187
+ same output as PIL.ImageEnhance.Contrast
188
+ """
189
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
190
+ table = np.array([(
191
+ el - mean) * factor + mean
192
+ for el in range(256)
193
+ ]).clip(0, 255).astype(np.uint8)
194
+ out = table[img]
195
+ return out
196
+
197
+
198
+ def brightness_func(img, factor):
199
+ '''
200
+ same output as PIL.ImageEnhance.Contrast
201
+ '''
202
+ table = (np.arange(256, dtype=np.float32) *
203
+ factor).clip(0, 255).astype(np.uint8)
204
+ out = table[img]
205
+ return out
206
+
207
+
208
+ def sharpness_func(img, factor):
209
+ '''
210
+ The differences the this result and PIL are all on the 4 boundaries, the center
211
+ areas are same
212
+ '''
213
+ kernel = np.ones((3, 3), dtype=np.float32)
214
+ kernel[1][1] = 5
215
+ kernel /= 13
216
+ degenerate = cv2.filter2D(img, -1, kernel)
217
+ if factor == 0.0:
218
+ out = degenerate
219
+ elif factor == 1.0:
220
+ out = img
221
+ else:
222
+ out = img.astype(np.float32)
223
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
224
+ out[1:-1, 1:-1, :] = degenerate + factor * \
225
+ (out[1:-1, 1:-1, :] - degenerate)
226
+ out = out.astype(np.uint8)
227
+ return out
228
+
229
+
230
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
231
+ H, W = img.shape[0], img.shape[1]
232
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
233
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
234
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
235
+ return out
236
+
237
+
238
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
239
+ '''
240
+ same output as PIL.Image.transform
241
+ '''
242
+ H, W = img.shape[0], img.shape[1]
243
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
244
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
245
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
246
+ return out
247
+
248
+
249
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
250
+ '''
251
+ same output as PIL.Image.transform
252
+ '''
253
+ H, W = img.shape[0], img.shape[1]
254
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
255
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
256
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
257
+ return out
258
+
259
+
260
+ def posterize_func(img, bits):
261
+ '''
262
+ same output as PIL.ImageOps.posterize
263
+ '''
264
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
265
+ return out
266
+
267
+
268
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
269
+ H, W = img.shape[0], img.shape[1]
270
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
271
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
272
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
273
+ return out
274
+
275
+
276
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
277
+ replace = np.array(replace, dtype=np.uint8)
278
+ H, W = img.shape[0], img.shape[1]
279
+ rh, rw = np.random.random(2)
280
+ pad_size = pad_size // 2
281
+ ch, cw = int(rh * H), int(rw * W)
282
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
283
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
284
+ out = img.copy()
285
+ out[x1:x2, y1:y2, :] = replace
286
+ return out
287
+
288
+
289
+ # level to args
290
+ def enhance_level_to_args(MAX_LEVEL):
291
+ def level_to_args(level):
292
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
293
+ return level_to_args
294
+
295
+
296
+ def shear_level_to_args(MAX_LEVEL, replace_value):
297
+ def level_to_args(level):
298
+ level = (level / MAX_LEVEL) * 0.3
299
+ if np.random.random() > 0.5:
300
+ level = -level
301
+ return (level, replace_value)
302
+
303
+ return level_to_args
304
+
305
+
306
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
307
+ def level_to_args(level):
308
+ level = (level / MAX_LEVEL) * float(translate_const)
309
+ if np.random.random() > 0.5:
310
+ level = -level
311
+ return (level, replace_value)
312
+
313
+ return level_to_args
314
+
315
+
316
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
317
+ def level_to_args(level):
318
+ level = int((level / MAX_LEVEL) * cutout_const)
319
+ return (level, replace_value)
320
+
321
+ return level_to_args
322
+
323
+
324
+ def solarize_level_to_args(MAX_LEVEL):
325
+ def level_to_args(level):
326
+ level = int((level / MAX_LEVEL) * 256)
327
+ return (level, )
328
+ return level_to_args
329
+
330
+
331
+ def none_level_to_args(level):
332
+ return ()
333
+
334
+
335
+ def posterize_level_to_args(MAX_LEVEL):
336
+ def level_to_args(level):
337
+ level = int((level / MAX_LEVEL) * 4)
338
+ return (level, )
339
+ return level_to_args
340
+
341
+
342
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
343
+ def level_to_args(level):
344
+ level = (level / MAX_LEVEL) * 30
345
+ if np.random.random() < 0.5:
346
+ level = -level
347
+ return (level, replace_value)
348
+
349
+ return level_to_args
350
+
351
+
352
+ func_dict = {
353
+ 'Identity': identity_func,
354
+ 'AutoContrast': autocontrast_func,
355
+ 'Equalize': equalize_func,
356
+ 'Rotate': rotate_func,
357
+ 'Solarize': solarize_func,
358
+ 'Color': color_func,
359
+ 'Contrast': contrast_func,
360
+ 'Brightness': brightness_func,
361
+ 'Sharpness': sharpness_func,
362
+ 'ShearX': shear_x_func,
363
+ 'TranslateX': translate_x_func,
364
+ 'TranslateY': translate_y_func,
365
+ 'Posterize': posterize_func,
366
+ 'ShearY': shear_y_func,
367
+ }
368
+
369
+ translate_const = 10
370
+ MAX_LEVEL = 10
371
+ replace_value = (128, 128, 128)
372
+ arg_dict = {
373
+ 'Identity': none_level_to_args,
374
+ 'AutoContrast': none_level_to_args,
375
+ 'Equalize': none_level_to_args,
376
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
377
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
378
+ 'Color': enhance_level_to_args(MAX_LEVEL),
379
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
380
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
381
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
382
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
383
+ 'TranslateX': translate_level_to_args(
384
+ translate_const, MAX_LEVEL, replace_value
385
+ ),
386
+ 'TranslateY': translate_level_to_args(
387
+ translate_const, MAX_LEVEL, replace_value
388
+ ),
389
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
390
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
391
+ }
392
+
393
+
394
+ class RandomAugment(object):
395
+
396
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
397
+ self.N = N
398
+ self.M = M
399
+ self.isPIL = isPIL
400
+ if augs:
401
+ self.augs = augs
402
+ else:
403
+ self.augs = list(arg_dict.keys())
404
+
405
+ def get_random_ops(self):
406
+ sampled_ops = np.random.choice(self.augs, self.N)
407
+ return [(op, 0.5, self.M) for op in sampled_ops]
408
+
409
+ def __call__(self, img):
410
+ if self.isPIL:
411
+ img = np.array(img)
412
+ ops = self.get_random_ops()
413
+ for name, prob, level in ops:
414
+ if np.random.random() > prob:
415
+ continue
416
+ args = arg_dict[name](level)
417
+ img = func_dict[name](img, *args)
418
+ return img
419
+
420
+
421
+ def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
422
+ if std_mode == 'IMAGENET_INCEPTION':
423
+ mean = IMAGENET_INCEPTION_MEAN
424
+ std = IMAGENET_INCEPTION_STD
425
+ elif std_mode == 'OPENAI_CLIP':
426
+ mean = OPENAI_CLIP_MEAN
427
+ std = OPENAI_CLIP_STD
428
+ else:
429
+ raise NotImplementedError
430
+
431
+ if is_train:
432
+ crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
433
+ t = [
434
+ RandomResizedCropAndInterpolation(
435
+ input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
436
+ # transforms.RandomHorizontalFlip(),
437
+ ]
438
+ if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
439
+ print(f'@@@@@ Do random aug during training', flush=True)
440
+ t.append(
441
+ RandomAugment(
442
+ 2, 7, isPIL=True,
443
+ augs=[
444
+ 'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
445
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
446
+ ]))
447
+ else:
448
+ print(f'@@@@@ Skip random aug during training', flush=True)
449
+ t += [
450
+ transforms.ToTensor(),
451
+ transforms.Normalize(mean=mean, std=std),
452
+ ]
453
+ t = transforms.Compose(t)
454
+ else:
455
+ t = transforms.Compose([
456
+ transforms.Resize((input_size, input_size),
457
+ interpolation=transforms.InterpolationMode.BICUBIC),
458
+ transforms.ToTensor(),
459
+ transforms.Normalize(mean=mean, std=std)
460
+ ])
461
+
462
+ return t
463
+
464
+
465
+ def img2b64(img_path):
466
+ img = Image.open(img_path) # path to file
467
+ img_buffer = BytesIO()
468
+ img.save(img_buffer, format=img.format)
469
+ byte_data = img_buffer.getvalue()
470
+ base64_str = base64.b64encode(byte_data) # bytes
471
+ base64_str = base64_str.decode("utf-8") # str
472
+ return base64_str
473
+
474
+
475
+ def str2b64(str):
476
+ return base64.b64encode(str.encode('utf-8')).decode('utf-8')
477
+
478
+
479
+ def b642str(b64):
480
+ return base64.b64decode(b64).decode('utf-8')
481
+
482
+
483
+ def is_dist_avail_and_initialized():
484
+ if not dist.is_available():
485
+ return False
486
+ if not dist.is_initialized():
487
+ return False
488
+ return True
489
+
490
+
491
+ def get_world_size():
492
+ if not is_dist_avail_and_initialized():
493
+ return 1
494
+ return dist.get_world_size()
495
+
496
+
497
+ def get_rank():
498
+ if not is_dist_avail_and_initialized():
499
+ return 0
500
+ return dist.get_rank()
501
+
502
+
503
+ def all_gather(data):
504
+ """
505
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
506
+ Args:
507
+ data: any picklable object
508
+ Returns:
509
+ list[data]: list of data gathered from each rank
510
+ """
511
+ world_size = get_world_size()
512
+ if world_size == 1:
513
+ return [data]
514
+
515
+ # serialized to a Tensor
516
+ buffer = pickle.dumps(data)
517
+ storage = torch.ByteStorage.from_buffer(buffer)
518
+ tensor = torch.ByteTensor(storage).to("cuda")
519
+
520
+ # obtain Tensor size of each rank
521
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
522
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
523
+ dist.all_gather(size_list, local_size)
524
+ size_list = [int(size.item()) for size in size_list]
525
+ max_size = max(size_list)
526
+
527
+ # receiving Tensor from all ranks
528
+ # we pad the tensor because torch all_gather does not support
529
+ # gathering tensors of different shapes
530
+ tensor_list = []
531
+ for _ in size_list:
532
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
533
+ if local_size != max_size:
534
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
535
+ tensor = torch.cat((tensor, padding), dim=0)
536
+ dist.all_gather(tensor_list, tensor)
537
+
538
+ data_list = []
539
+ for size, tensor in zip(size_list, tensor_list):
540
+ buffer = tensor.cpu().numpy().tobytes()[:size]
541
+ data_list.append(pickle.loads(buffer))
542
+
543
+ return data_list
544
+
545
+
546
+ def mean(lst):
547
+ return sum(lst) / len(lst)
548
+
549
+
550
+ def stop_gradient_by_name(name: str):
551
+ def apply_fn(module):
552
+ if hasattr(module, name):
553
+ getattr(module, name).requires_grad_(False)
554
+
555
+ return apply_fn
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/train/train_utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import copy
4
+ import time
5
+
6
+ import torch
7
+ import warnings
8
+ import transformers
9
+
10
+ import numpy as np
11
+
12
+ from typing import Dict, Optional, Sequence
13
+ from omnilmm import conversation as conversation_lib
14
+
15
+ IGNORE_INDEX = -100
16
+ DEFAULT_IMAGE_TOKEN = "<image>"
17
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
18
+ DEFAULT_IM_START_TOKEN = "<im_start>"
19
+ DEFAULT_IM_END_TOKEN = "<im_end>"
20
+
21
+
22
+ def _tokenize_fn(strings: Sequence[str],
23
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
24
+ """Tokenize a list of strings."""
25
+ tokenized_list = [
26
+ tokenizer(
27
+ text,
28
+ return_tensors="pt",
29
+ padding="longest",
30
+ max_length=tokenizer.model_max_length,
31
+ truncation=True,
32
+ ) for text in strings
33
+ ]
34
+ input_ids = labels = [
35
+ tokenized.input_ids[0] for tokenized in tokenized_list
36
+ ]
37
+ input_ids_lens = labels_lens = [
38
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
39
+ for tokenized in tokenized_list
40
+ ]
41
+ return dict(
42
+ input_ids=input_ids,
43
+ labels=labels,
44
+ input_ids_lens=input_ids_lens,
45
+ labels_lens=labels_lens,
46
+ )
47
+
48
+
49
+
50
+ def omni_preprocess(sources,
51
+ tokenizer: transformers.PreTrainedTokenizer,
52
+ generation=False):
53
+ system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
54
+ ignore_index = -100
55
+
56
+ response_template = '\n<|assistant|>\n'
57
+ instruction_template = '\n<|user|>\n'
58
+ response_token_ids = tokenizer.encode(
59
+ response_template, add_special_tokens=False)
60
+ instruction_token_ids = tokenizer.encode(
61
+ instruction_template, add_special_tokens=False)
62
+
63
+ batch_input_ids = []
64
+ batch_labels = []
65
+ for i in range(len(sources)):
66
+ new_source = []
67
+ prev_role = 'unexpect'
68
+ for conv_turn in sources[i]:
69
+ role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
70
+ content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
71
+
72
+ role = 'user' if role == 'human' else role
73
+ role = 'assistant' if role == 'gpt' else role
74
+
75
+ assert role in ['user', 'assistant']
76
+ assert role != prev_role, f'role={role}, prev_role={prev_role}'
77
+ prev_role = role
78
+
79
+ new_turn = {
80
+ 'role': role,
81
+ 'content': content
82
+ }
83
+ new_source.append(new_turn)
84
+ if new_source[0]['role'] != 'system':
85
+ new_source.insert(0, {'role': 'system', 'content': system_content})
86
+
87
+ # TODO: this automatically add '\n' to the end
88
+ res_text = tokenizer.apply_chat_template(
89
+ new_source, tokenize=False, add_generation_prompt=generation)
90
+ if not generation:
91
+ res_text = res_text.strip()
92
+
93
+ conversations_tokenized = _tokenize_fn([res_text], tokenizer)
94
+ res_input_ids = conversations_tokenized["input_ids"][0]
95
+
96
+ # since labels and input_ids are reference towards the same object
97
+ res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
98
+
99
+ response_token_ids_idxs = []
100
+ human_token_ids_idxs = []
101
+
102
+ for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
103
+ # find the indexes of the start of a response.
104
+ if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
105
+ response_token_ids)].tolist()
106
+ ):
107
+ response_token_ids_idxs.append(
108
+ assistant_idx + len(response_token_ids))
109
+
110
+ if len(response_token_ids_idxs) == 0:
111
+ warnings.warn(
112
+ f"Could not find response key `{response_template}` in the "
113
+ f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
114
+ f'Raw text is @===>{res_text}<===@'
115
+ f'Raw source is @===>{new_source}<===@'
116
+ f"This instance will be ignored in loss calculation. "
117
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
118
+ )
119
+ res_labels[:] = ignore_index
120
+
121
+ human_token_ids = instruction_token_ids
122
+ for human_idx in np.where(res_labels == human_token_ids[0])[0]:
123
+ # find the indexes of the start of a human answer.
124
+ if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
125
+ human_token_ids_idxs.append(human_idx)
126
+
127
+ if len(human_token_ids_idxs) == 0:
128
+ warnings.warn(
129
+ f"Could not find instruction key `{instruction_template}` in the "
130
+ f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
131
+ f'Raw text is @===>{res_text}<===@'
132
+ f'Raw source is @===>{new_source}<===@'
133
+ f"This instance will be ignored in loss calculation. "
134
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
135
+ )
136
+ res_labels[:] = ignore_index
137
+
138
+ for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
139
+ # Make pytorch loss function ignore all non response tokens
140
+ if idx != 0:
141
+ res_labels[start:end] = ignore_index
142
+ else:
143
+ res_labels[:end] = ignore_index
144
+
145
+ if len(response_token_ids_idxs) < len(human_token_ids_idxs):
146
+ res_labels[human_token_ids_idxs[-1]:] = ignore_index
147
+
148
+ batch_input_ids.append(res_input_ids)
149
+ batch_labels.append(res_labels)
150
+
151
+ return dict(input_ids=batch_input_ids, labels=batch_labels)
152
+
153
+
r1-a/response_generation/minicpm/MiniCPM-o/quantize/bnb_quantize.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ the script will use bitandbytes to quantize the MiniCPM-Llama3-V-2_5 model.
3
+ the be quantized model can be finetuned by MiniCPM-Llama3-V-2_5 or not.
4
+ you only need to set the model_path 、save_path and run bash code
5
+
6
+ cd MiniCPM-V
7
+ python quantize/bnb_quantize.py
8
+
9
+ you will get the quantized model in save_path、quantized_model test time and gpu usage
10
+ """
11
+
12
+
13
+ import torch
14
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
15
+ from PIL import Image
16
+ import time
17
+ import torch
18
+ import GPUtil
19
+ import os
20
+
21
+ assert torch.cuda.is_available(),"CUDA is not available, but this code requires a GPU."
22
+
23
+ device = 'cuda' # Select GPU to use
24
+ model_path = '/root/ld/ld_model_pretrained/MiniCPM-Llama3-V-2_5' # Model download path
25
+ save_path = '/root/ld/ld_model_pretrain/MiniCPM-Llama3-V-2_5_int4' # Quantized model save path
26
+ image_path = './assets/airplane.jpeg'
27
+
28
+
29
+ # Create a configuration object to specify quantization parameters
30
+ quantization_config = BitsAndBytesConfig(
31
+ load_in_4bit=True, # Whether to perform 4-bit quantization
32
+ load_in_8bit=False, # Whether to perform 8-bit quantization
33
+ bnb_4bit_compute_dtype=torch.float16, # Computation precision setting
34
+ bnb_4bit_quant_storage=torch.uint8, # Storage format for quantized weights
35
+ bnb_4bit_quant_type="nf4", # Quantization format, here using normally distributed int4
36
+ bnb_4bit_use_double_quant=True, # Whether to use double quantization, i.e., quantizing zeropoint and scaling parameters
37
+ llm_int8_enable_fp32_cpu_offload=False, # Whether LLM uses int8, with fp32 parameters stored on the CPU
38
+ llm_int8_has_fp16_weight=False, # Whether mixed precision is enabled
39
+ llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # Modules not to be quantized
40
+ llm_int8_threshold=6.0 # Outlier value in the llm.int8() algorithm, distinguishing whether to perform quantization based on this value
41
+ )
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
44
+ model = AutoModel.from_pretrained(
45
+ model_path,
46
+ device_map=device, # Allocate model to device
47
+ quantization_config=quantization_config,
48
+ trust_remote_code=True
49
+ )
50
+
51
+ gpu_usage = GPUtil.getGPUs()[0].memoryUsed
52
+ start=time.time()
53
+ response = model.chat(
54
+ image=Image.open(image_path).convert("RGB"),
55
+ msgs=[
56
+ {
57
+ "role": "user",
58
+ "content": "What is in this picture?"
59
+ }
60
+ ],
61
+ tokenizer=tokenizer
62
+ ) # 模型推理
63
+ print('Output after quantization:',response)
64
+ print('Inference time after quantization:',time.time()-start)
65
+ print(f"GPU memory usage after quantization: {round(gpu_usage/1024,2)}GB")
66
+
67
+ """
68
+ Expected output:
69
+
70
+ Output after quantization: This picture contains specific parts of an airplane, including wings, engines, and tail sections. These components are key parts of large commercial aircraft.
71
+ The wings support lift during flight, while the engines provide thrust to move the plane forward. The tail section is typically used for stabilizing flight and plays a role in airline branding.
72
+ The design and color of the airplane indicate that it belongs to Air China, likely a passenger aircraft due to its large size and twin-engine configuration.
73
+ There are no markings or insignia on the airplane indicating the specific model or registration number; such information may require additional context or a clearer perspective to discern.
74
+ Inference time after quantization: 8.583992719650269 seconds
75
+ GPU memory usage after quantization: 6.41 GB
76
+ """
77
+
78
+ # Save the model and tokenizer
79
+ os.makedirs(save_path, exist_ok=True)
80
+ model.save_pretrained(save_path, safe_serialization=True)
81
+ tokenizer.save_pretrained(save_path)
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/chatbot_web_demo_o2.6.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import torch
4
+ import argparse
5
+ from transformers import AutoModel, AutoTokenizer
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from decord import VideoReader, cpu
9
+ import io
10
+ import os
11
+ import copy
12
+ import requests
13
+ import base64
14
+ import json
15
+ import traceback
16
+ import re
17
+ import modelscope_studio as mgr
18
+
19
+
20
+ # README, How to run demo on different devices
21
+
22
+ # For Nvidia GPUs.
23
+ # python chatbot_web_demo_o2.6.py
24
+
25
+
26
+ # Argparser
27
+ parser = argparse.ArgumentParser(description='demo')
28
+ parser.add_argument('--model', type=str , default="openbmb/MiniCPM-o-2_6", help="huggingface model name or local path")
29
+ parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
30
+ args = parser.parse_args()
31
+ device = "cuda"
32
+ model_name = 'MiniCPM-o 2.6'
33
+
34
+ # Load model
35
+ model_path = args.model
36
+ if args.multi_gpus:
37
+ from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
38
+ with init_empty_weights():
39
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16,
40
+ init_audio=False, init_tts=False)
41
+ device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
42
+ no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
43
+ device_id = device_map["llm.model.embed_tokens"]
44
+ device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
45
+ device_map["vpm"] = device_id
46
+ device_map["resampler"] = device_id
47
+ device_id2 = device_map["llm.model.layers.26"]
48
+ device_map["llm.model.layers.8"] = device_id2
49
+ device_map["llm.model.layers.9"] = device_id2
50
+ device_map["llm.model.layers.10"] = device_id2
51
+ device_map["llm.model.layers.11"] = device_id2
52
+ device_map["llm.model.layers.12"] = device_id2
53
+ device_map["llm.model.layers.13"] = device_id2
54
+ device_map["llm.model.layers.14"] = device_id2
55
+ device_map["llm.model.layers.15"] = device_id2
56
+ device_map["llm.model.layers.16"] = device_id2
57
+ #print(device_map)
58
+
59
+ model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
60
+ else:
61
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, init_audio=False, init_tts=False)
62
+ model = model.to(device=device)
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
65
+ model.eval()
66
+
67
+
68
+
69
+
70
+ ERROR_MSG = "Error, please retry"
71
+ MAX_NUM_FRAMES = 64
72
+ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
73
+ VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
74
+
75
+ def get_file_extension(filename):
76
+ return os.path.splitext(filename)[1].lower()
77
+
78
+ def is_image(filename):
79
+ return get_file_extension(filename) in IMAGE_EXTENSIONS
80
+
81
+ def is_video(filename):
82
+ return get_file_extension(filename) in VIDEO_EXTENSIONS
83
+
84
+
85
+ form_radio = {
86
+ 'choices': ['Beam Search', 'Sampling'],
87
+ #'value': 'Beam Search',
88
+ 'value': 'Sampling',
89
+ 'interactive': True,
90
+ 'label': 'Decode Type'
91
+ }
92
+
93
+
94
+ def create_component(params, comp='Slider'):
95
+ if comp == 'Slider':
96
+ return gr.Slider(
97
+ minimum=params['minimum'],
98
+ maximum=params['maximum'],
99
+ value=params['value'],
100
+ step=params['step'],
101
+ interactive=params['interactive'],
102
+ label=params['label']
103
+ )
104
+ elif comp == 'Radio':
105
+ return gr.Radio(
106
+ choices=params['choices'],
107
+ value=params['value'],
108
+ interactive=params['interactive'],
109
+ label=params['label']
110
+ )
111
+ elif comp == 'Button':
112
+ return gr.Button(
113
+ value=params['value'],
114
+ interactive=True
115
+ )
116
+
117
+
118
+ def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
119
+ return mgr.MultimodalInput(
120
+ upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
121
+ upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
122
+ submit_button_props={'label': 'Submit'}
123
+ )
124
+
125
+
126
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
127
+ try:
128
+ print('msgs:', msgs)
129
+ answer = model.chat(
130
+ image=None,
131
+ msgs=msgs,
132
+ tokenizer=tokenizer,
133
+ **params
134
+ )
135
+ res = re.sub(r'(<box>.*</box>)', '', answer)
136
+ res = res.replace('<ref>', '')
137
+ res = res.replace('</ref>', '')
138
+ res = res.replace('<box>', '')
139
+ answer = res.replace('</box>', '')
140
+ print('answer:', answer)
141
+ return 0, answer, None, None
142
+ except Exception as e:
143
+ print(e)
144
+ traceback.print_exc()
145
+ return -1, ERROR_MSG, None, None
146
+
147
+
148
+ def encode_image(image):
149
+ if not isinstance(image, Image.Image):
150
+ if hasattr(image, 'path'):
151
+ image = Image.open(image.path).convert("RGB")
152
+ else:
153
+ image = Image.open(image.file.path).convert("RGB")
154
+ # resize to max_size
155
+ max_size = 448*16
156
+ if max(image.size) > max_size:
157
+ w,h = image.size
158
+ if w > h:
159
+ new_w = max_size
160
+ new_h = int(h * max_size / w)
161
+ else:
162
+ new_h = max_size
163
+ new_w = int(w * max_size / h)
164
+ image = image.resize((new_w, new_h), resample=Image.BICUBIC)
165
+ return image
166
+ ## save by BytesIO and convert to base64
167
+ #buffered = io.BytesIO()
168
+ #image.save(buffered, format="png")
169
+ #im_b64 = base64.b64encode(buffered.getvalue()).decode()
170
+ #return {"type": "image", "pairs": im_b64}
171
+
172
+
173
+ def encode_video(video):
174
+ def uniform_sample(l, n):
175
+ gap = len(l) / n
176
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
177
+ return [l[i] for i in idxs]
178
+
179
+ if hasattr(video, 'path'):
180
+ vr = VideoReader(video.path, ctx=cpu(0))
181
+ else:
182
+ vr = VideoReader(video.file.path, ctx=cpu(0))
183
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
184
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
185
+ if len(frame_idx)>MAX_NUM_FRAMES:
186
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
187
+ video = vr.get_batch(frame_idx).asnumpy()
188
+ video = [Image.fromarray(v.astype('uint8')) for v in video]
189
+ video = [encode_image(v) for v in video]
190
+ print('video frames:', len(video))
191
+ return video
192
+
193
+
194
+ def check_mm_type(mm_file):
195
+ if hasattr(mm_file, 'path'):
196
+ path = mm_file.path
197
+ else:
198
+ path = mm_file.file.path
199
+ if is_image(path):
200
+ return "image"
201
+ if is_video(path):
202
+ return "video"
203
+ return None
204
+
205
+
206
+ def encode_mm_file(mm_file):
207
+ if check_mm_type(mm_file) == 'image':
208
+ return [encode_image(mm_file)]
209
+ if check_mm_type(mm_file) == 'video':
210
+ return encode_video(mm_file)
211
+ return None
212
+
213
+ def make_text(text):
214
+ #return {"type": "text", "pairs": text} # # For remote call
215
+ return text
216
+
217
+ def encode_message(_question):
218
+ files = _question.files
219
+ question = _question.text
220
+ pattern = r"\[mm_media\]\d+\[/mm_media\]"
221
+ matches = re.split(pattern, question)
222
+ message = []
223
+ if len(matches) != len(files) + 1:
224
+ gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
225
+ assert len(matches) == len(files) + 1
226
+
227
+ text = matches[0].strip()
228
+ if text:
229
+ message.append(make_text(text))
230
+ for i in range(len(files)):
231
+ message += encode_mm_file(files[i])
232
+ text = matches[i + 1].strip()
233
+ if text:
234
+ message.append(make_text(text))
235
+ return message
236
+
237
+
238
+ def check_has_videos(_question):
239
+ images_cnt = 0
240
+ videos_cnt = 0
241
+ for file in _question.files:
242
+ if check_mm_type(file) == "image":
243
+ images_cnt += 1
244
+ else:
245
+ videos_cnt += 1
246
+ return images_cnt, videos_cnt
247
+
248
+
249
+ def count_video_frames(_context):
250
+ num_frames = 0
251
+ for message in _context:
252
+ for item in message["content"]:
253
+ #if item["type"] == "image": # For remote call
254
+ if isinstance(item, Image.Image):
255
+ num_frames += 1
256
+ return num_frames
257
+
258
+
259
+ def respond(_question, _chat_bot, _app_cfg, params_form):
260
+ _context = _app_cfg['ctx'].copy()
261
+ _context.append({'role': 'user', 'content': encode_message(_question)})
262
+
263
+ images_cnt = _app_cfg['images_cnt']
264
+ videos_cnt = _app_cfg['videos_cnt']
265
+ files_cnts = check_has_videos(_question)
266
+ if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
267
+ gr.Warning("Only supports single video file input right now!")
268
+ return _question, _chat_bot, _app_cfg
269
+
270
+ if params_form == 'Beam Search':
271
+ params = {
272
+ 'sampling': False,
273
+ 'num_beams': 3,
274
+ 'repetition_penalty': 1.2,
275
+ "max_new_tokens": 2048
276
+ }
277
+ else:
278
+ params = {
279
+ 'sampling': True,
280
+ 'top_p': 0.8,
281
+ 'top_k': 100,
282
+ 'temperature': 0.7,
283
+ 'repetition_penalty': 1.05,
284
+ "max_new_tokens": 2048
285
+ }
286
+
287
+ if files_cnts[1] + videos_cnt > 0:
288
+ params["max_inp_length"] = 4352 # 4096+256
289
+ params["use_image_id"] = False
290
+ params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
291
+
292
+ code, _answer, _, sts = chat("", _context, None, params)
293
+
294
+ images_cnt += files_cnts[0]
295
+ videos_cnt += files_cnts[1]
296
+ _context.append({"role": "assistant", "content": [make_text(_answer)]})
297
+ _chat_bot.append((_question, _answer))
298
+ if code == 0:
299
+ _app_cfg['ctx']=_context
300
+ _app_cfg['sts']=sts
301
+ _app_cfg['images_cnt'] = images_cnt
302
+ _app_cfg['videos_cnt'] = videos_cnt
303
+
304
+ upload_image_disabled = videos_cnt > 0
305
+ upload_video_disabled = videos_cnt > 0 or images_cnt > 0
306
+ return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
307
+
308
+
309
+ def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
310
+ ctx = _app_cfg["ctx"]
311
+ message_item = []
312
+ if _image is not None:
313
+ image = Image.open(_image).convert("RGB")
314
+ ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
315
+ message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
316
+ else:
317
+ if _user_message:
318
+ ctx.append({"role": "user", "content": [make_text(_user_message)]})
319
+ message_item.append({"text": _user_message, "files": []})
320
+ else:
321
+ message_item.append(None)
322
+ if _assistant_message:
323
+ ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
324
+ message_item.append({"text": _assistant_message, "files": []})
325
+ else:
326
+ message_item.append(None)
327
+
328
+ _chat_bot.append(message_item)
329
+ return None, "", "", _chat_bot, _app_cfg
330
+
331
+
332
+ def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
333
+ user_message_contents = []
334
+ _context = _app_cfg["ctx"].copy()
335
+ if _image:
336
+ image = Image.open(_image).convert("RGB")
337
+ user_message_contents += [encode_image(image)]
338
+ if _user_message:
339
+ user_message_contents += [make_text(_user_message)]
340
+ if user_message_contents:
341
+ _context.append({"role": "user", "content": user_message_contents})
342
+
343
+ if params_form == 'Beam Search':
344
+ params = {
345
+ 'sampling': False,
346
+ 'num_beams': 3,
347
+ 'repetition_penalty': 1.2,
348
+ "max_new_tokens": 2048
349
+ }
350
+ else:
351
+ params = {
352
+ 'sampling': True,
353
+ 'top_p': 0.8,
354
+ 'top_k': 100,
355
+ 'temperature': 0.7,
356
+ 'repetition_penalty': 1.05,
357
+ "max_new_tokens": 2048
358
+ }
359
+
360
+ code, _answer, _, sts = chat("", _context, None, params)
361
+
362
+ _context.append({"role": "assistant", "content": [make_text(_answer)]})
363
+
364
+ if _image:
365
+ _chat_bot.append([
366
+ {"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
367
+ {"text": _answer, "files": []}
368
+ ])
369
+ else:
370
+ _chat_bot.append([
371
+ {"text": _user_message, "files": [_image]},
372
+ {"text": _answer, "files": []}
373
+ ])
374
+ if code == 0:
375
+ _app_cfg['ctx']=_context
376
+ _app_cfg['sts']=sts
377
+ return None, '', '', _chat_bot, _app_cfg
378
+
379
+
380
+ def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
381
+ if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
382
+ gr.Warning('No question for regeneration.')
383
+ return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
384
+ if _app_cfg["chat_type"] == "Chat":
385
+ images_cnt = _app_cfg['images_cnt']
386
+ videos_cnt = _app_cfg['videos_cnt']
387
+ _question = _chat_bot[-1][0]
388
+ _chat_bot = _chat_bot[:-1]
389
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
390
+ files_cnts = check_has_videos(_question)
391
+ images_cnt -= files_cnts[0]
392
+ videos_cnt -= files_cnts[1]
393
+ _app_cfg['images_cnt'] = images_cnt
394
+ _app_cfg['videos_cnt'] = videos_cnt
395
+ upload_image_disabled = videos_cnt > 0
396
+ upload_video_disabled = videos_cnt > 0 or images_cnt > 0
397
+ _question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
398
+ return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
399
+ else:
400
+ last_message = _chat_bot[-1][0]
401
+ last_image = None
402
+ last_user_message = ''
403
+ if last_message.text:
404
+ last_user_message = last_message.text
405
+ if last_message.files:
406
+ last_image = last_message.files[0].file.path
407
+ _chat_bot = _chat_bot[:-1]
408
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
409
+ _image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
410
+ return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
411
+
412
+
413
+ def flushed():
414
+ return gr.update(interactive=True)
415
+
416
+
417
+ def clear(txt_message, chat_bot, app_session):
418
+ txt_message.files.clear()
419
+ txt_message.text = ''
420
+ chat_bot = copy.deepcopy(init_conversation)
421
+ app_session['sts'] = None
422
+ app_session['ctx'] = []
423
+ app_session['images_cnt'] = 0
424
+ app_session['videos_cnt'] = 0
425
+ return create_multimodal_input(), chat_bot, app_session, None, '', ''
426
+
427
+
428
+ def select_chat_type(_tab, _app_cfg):
429
+ _app_cfg["chat_type"] = _tab
430
+ return _app_cfg
431
+
432
+
433
+ init_conversation = [
434
+ [
435
+ None,
436
+ {
437
+ # The first message of bot closes the typewriter.
438
+ "text": "You can talk to me now",
439
+ "flushing": False
440
+ }
441
+ ],
442
+ ]
443
+
444
+
445
+ css = """
446
+ video { height: auto !important; }
447
+ .example label { font-size: 16px;}
448
+ """
449
+
450
+ introduction = """
451
+
452
+ ## Features:
453
+ 1. Chat with single image
454
+ 2. Chat with multiple images
455
+ 3. Chat with video
456
+ 4. In-context few-shot learning
457
+
458
+ Click `How to use` tab to see examples.
459
+ """
460
+
461
+
462
+ with gr.Blocks(css=css) as demo:
463
+ with gr.Tab(model_name):
464
+ with gr.Row():
465
+ with gr.Column(scale=1, min_width=300):
466
+ gr.Markdown(value=introduction)
467
+ params_form = create_component(form_radio, comp='Radio')
468
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
469
+ clear_button = create_component({'value': 'Clear History'}, comp='Button')
470
+
471
+ with gr.Column(scale=3, min_width=500):
472
+ app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
473
+ chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
474
+
475
+ with gr.Tab("Chat") as chat_tab:
476
+ txt_message = create_multimodal_input()
477
+ chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
478
+
479
+ txt_message.submit(
480
+ respond,
481
+ [txt_message, chat_bot, app_session, params_form],
482
+ [txt_message, chat_bot, app_session]
483
+ )
484
+
485
+ with gr.Tab("Few Shot") as fewshot_tab:
486
+ fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
487
+ with gr.Row():
488
+ with gr.Column(scale=1):
489
+ image_input = gr.Image(type="filepath", sources=["upload"])
490
+ with gr.Column(scale=3):
491
+ user_message = gr.Textbox(label="User")
492
+ assistant_message = gr.Textbox(label="Assistant")
493
+ with gr.Row():
494
+ add_demonstration_button = gr.Button("Add Example")
495
+ generate_button = gr.Button(value="Generate", variant="primary")
496
+ add_demonstration_button.click(
497
+ fewshot_add_demonstration,
498
+ [image_input, user_message, assistant_message, chat_bot, app_session],
499
+ [image_input, user_message, assistant_message, chat_bot, app_session]
500
+ )
501
+ generate_button.click(
502
+ fewshot_respond,
503
+ [image_input, user_message, chat_bot, app_session, params_form],
504
+ [image_input, user_message, assistant_message, chat_bot, app_session]
505
+ )
506
+
507
+ chat_tab.select(
508
+ select_chat_type,
509
+ [chat_tab_label, app_session],
510
+ [app_session]
511
+ )
512
+ chat_tab.select( # do clear
513
+ clear,
514
+ [txt_message, chat_bot, app_session],
515
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
516
+ )
517
+ fewshot_tab.select(
518
+ select_chat_type,
519
+ [fewshot_tab_label, app_session],
520
+ [app_session]
521
+ )
522
+ fewshot_tab.select( # do clear
523
+ clear,
524
+ [txt_message, chat_bot, app_session],
525
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
526
+ )
527
+ chat_bot.flushed(
528
+ flushed,
529
+ outputs=[txt_message]
530
+ )
531
+ regenerate.click(
532
+ regenerate_button_clicked,
533
+ [txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
534
+ [txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
535
+ )
536
+ clear_button.click(
537
+ clear,
538
+ [txt_message, chat_bot, app_session],
539
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
540
+ )
541
+
542
+ with gr.Tab("How to use"):
543
+ with gr.Column():
544
+ with gr.Row():
545
+ image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
546
+ example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
547
+ example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
548
+
549
+
550
+ # launch
551
+ demo.launch(share=False, debug=True, show_api=False, server_port=8000, server_name="0.0.0.0")
552
+
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/model_server.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import asyncio
4
+ import numpy as np
5
+ import os, sys, io
6
+ import threading
7
+ import time
8
+ import aiofiles
9
+ import librosa
10
+ import soundfile
11
+ import wave
12
+ from typing import Dict, List, Any, Optional
13
+ import argparse
14
+ import logging
15
+ import torch
16
+ from PIL import Image
17
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor
18
+ import uvicorn
19
+ from fastapi import FastAPI, Header, Query, Request, HTTPException, WebSocket, WebSocketDisconnect
20
+ from fastapi.responses import JSONResponse, StreamingResponse
21
+
22
+ cur_path = os.path.split(os.path.realpath(__file__))[0]
23
+ sys.path.append(os.path.abspath(cur_path))
24
+ import vad_utils
25
+
26
+ def setup_logger():
27
+ logger = logging.getLogger("api_logger")
28
+ logger.setLevel(logging.DEBUG)
29
+
30
+ # Create formatter
31
+ formatter = logging.Formatter(
32
+ '%(asctime)s.%(msecs)03d-%(levelname)s-[%(filename)s:%(lineno)d] - %(message)s',
33
+ datefmt='%Y-%m-%d %H:%M:%S'
34
+ )
35
+
36
+ # Create handlers for stdout and stderr
37
+ stdout_handler = logging.StreamHandler(sys.stdout)
38
+ stdout_handler.setLevel(logging.INFO) # INFO and DEBUG go to stdout
39
+ stdout_handler.setFormatter(formatter)
40
+ stdout_handler.addFilter(lambda record: record.levelno <= logging.INFO)
41
+
42
+ stderr_handler = logging.StreamHandler(sys.stderr)
43
+ stderr_handler.setLevel(logging.WARNING) # WARNING, ERROR, CRITICAL go to stderr
44
+ stderr_handler.setFormatter(formatter)
45
+
46
+ # Add handlers to logger
47
+ logger.addHandler(stdout_handler)
48
+ logger.addHandler(stderr_handler)
49
+
50
+ return logger
51
+
52
+
53
+ app = FastAPI()
54
+ logger = setup_logger()
55
+
56
+ ap = argparse.ArgumentParser()
57
+ ap.add_argument('--port', type=int , default=32550)
58
+ ap.add_argument('--model', type=str , default="openbmb/MiniCPM-o-2_6", help="huggingface model name or local path")
59
+ args = ap.parse_args()
60
+
61
+
62
+ class StreamManager:
63
+ def __init__(self):
64
+ self.uid = None
65
+
66
+ self.is_streaming_complete = threading.Event()
67
+ self.conversation_started = threading.Event()
68
+ self.last_request_time = None
69
+ self.last_stream_time = None
70
+ self.timeout = 900 # seconds timeout
71
+ self.stream_timeout = 3 # seconds no stream
72
+ self.num_stream = 0
73
+ self.stream_started = False
74
+ self.stop_response = False
75
+
76
+ # VAD settings
77
+ self.vad_options = vad_utils.VadOptions()
78
+ self.vad_sequence_length = 5
79
+ self.vad_sequence = []
80
+ self.audio_prefill = []
81
+ self.audio_input = []
82
+ self.image_prefill = None
83
+ self.audio_chunk = 200
84
+
85
+ # customized options
86
+ self.customized_audio = None
87
+ self.customized_options = None
88
+
89
+ # Omni model
90
+ self.target_dtype = torch.bfloat16
91
+ self.device='cuda:0'
92
+
93
+ self.minicpmo_model_path = args.model #"openbmb/MiniCPM-o-2_6"
94
+ self.model_version = "2.6"
95
+ with torch.no_grad():
96
+ self.minicpmo_model = AutoModel.from_pretrained(self.minicpmo_model_path, trust_remote_code=True, torch_dtype=self.target_dtype, attn_implementation='sdpa')
97
+ self.minicpmo_tokenizer = AutoTokenizer.from_pretrained(self.minicpmo_model_path, trust_remote_code=True)
98
+ self.minicpmo_model.init_tts()
99
+ # self.minicpmo_model.tts.float()
100
+ self.minicpmo_model.to(self.device).eval()
101
+
102
+ self.ref_path_video_default = "assets/ref_audios/video_default.wav"
103
+ self.ref_path_default = "assets/ref_audios/default.wav"
104
+ self.ref_path_female = "assets/ref_audios/female_example.wav"
105
+ self.ref_path_male = "assets/ref_audios/male_example.wav"
106
+
107
+ self.input_audio_id = 0
108
+ self.input_audio_vad_id = 0
109
+ self.input_image_id = 0
110
+ self.output_audio_id = 0
111
+ self.flag_decode = False
112
+ self.cnts = None
113
+
114
+ self.all_start_time = time.time()
115
+ self.session_id = 233
116
+ self.sys_prompt_flag = False
117
+ self.vad_time = 0
118
+ self.ls_time = 0
119
+ self.msg_type = 1
120
+
121
+ self.speaking_time_stamp = 0
122
+ self.cycle_wait_time = 12800/24000 + 0.15
123
+ self.extra_wait_time = 2.5
124
+ self.server_wait = True
125
+
126
+ self.past_session_id = 0
127
+ self.sys_prompt_init(0)
128
+ self.session_id += 1
129
+
130
+
131
+ def start_conversation(self):
132
+ logger.info(f"uid {self.uid}: new conversation started.")
133
+ self.conversation_started.set()
134
+ self.stop_response = False
135
+
136
+ def update_last_request_time(self):
137
+ self.last_request_time = time.time()
138
+ #logger.info(f"update last_request_time {self.last_request_time}")
139
+
140
+ def update_last_stream_time(self):
141
+ self.last_stream_time = time.time()
142
+ #logger.info(f"update last_stream_time {self.last_stream_time}")
143
+
144
+ def move_to_device(self, obj, device):
145
+ if isinstance(obj, torch.Tensor):
146
+ obj_ = obj.to(device)
147
+ if (obj_.dtype == torch.float) or (obj_.dtype == torch.half):
148
+ # cast to `torch.bfloat16`
149
+ obj_ = obj_.to(self.target_dtype)
150
+ return obj_
151
+ elif isinstance(obj, dict):
152
+ return {key: self.move_to_device(value, device) for key, value in obj.items()}
153
+ elif isinstance(obj, list):
154
+ return [self.move_to_device(item, device) for item in obj]
155
+ elif isinstance(obj, tuple):
156
+ return tuple(self.move_to_device(item, device) for item in obj)
157
+ elif isinstance(obj, set):
158
+ return {self.move_to_device(item, device) for item in obj}
159
+ else:
160
+ return obj
161
+
162
+ def reset(self):
163
+ logger.info("reset")
164
+ self.is_streaming_complete.clear()
165
+ self.conversation_started.clear()
166
+ self.last_request_time = None
167
+ self.last_stream_time = None
168
+ self.audio_buffer_raw = bytearray()
169
+ self.num_stream = 0
170
+ self.stream_started = False
171
+ self.stop_response = False
172
+ # self.customized_audio = None
173
+ # self.customized_options = None
174
+ # clear model
175
+ self.clear()
176
+
177
+ def merge_wav_files(self, input_bytes_list, output_file):
178
+ with wave.open(io.BytesIO(input_bytes_list[0]), 'rb') as wav:
179
+ params = wav.getparams()
180
+ n_channels, sampwidth, framerate, n_frames, comptype, compname = params
181
+
182
+ with wave.open(output_file, 'wb') as output_wav:
183
+ output_wav.setnchannels(n_channels)
184
+ output_wav.setsampwidth(sampwidth)
185
+ output_wav.setframerate(framerate)
186
+ output_wav.setcomptype(comptype, compname)
187
+
188
+ for wav_bytes in input_bytes_list:
189
+ with wave.open(io.BytesIO(wav_bytes), 'rb') as wav:
190
+ output_wav.writeframes(wav.readframes(wav.getnframes()))
191
+
192
+
193
+ def is_timed_out(self):
194
+ if self.last_request_time is not None:
195
+ return time.time() - self.last_request_time > self.timeout
196
+ return False
197
+
198
+ def no_active_stream(self):
199
+ if self.last_stream_time is not None and self.stream_started:
200
+ no_stream_duration = time.time() - self.last_stream_time
201
+ if no_stream_duration > self.stream_timeout:
202
+ #logger.info(f"no active stream for {no_stream_duration} secs.")
203
+ return True
204
+ return False
205
+
206
+ def sys_prompt_init(self, msg_type):
207
+ if self.past_session_id == self.session_id:
208
+ return
209
+ logger.info("### sys_prompt_init ###")
210
+
211
+ logger.info(f'msg_type is {msg_type}')
212
+ if msg_type <= 1: #audio
213
+ audio_voice_clone_prompt = "Use the voice in the audio prompt to synthesize new content."
214
+ audio_assistant_prompt = "You are a helpful assistant with the above voice style."
215
+ ref_path = self.ref_path_default
216
+
217
+
218
+ if self.customized_options is not None:
219
+ audio_voice_clone_prompt = self.customized_options['voice_clone_prompt']
220
+ audio_assistant_prompt = self.customized_options['assistant_prompt']
221
+ if self.customized_options['use_audio_prompt'] == 1:
222
+ ref_path = self.ref_path_default
223
+ elif self.customized_options['use_audio_prompt'] == 2:
224
+ ref_path = self.ref_path_female
225
+ elif self.customized_options['use_audio_prompt'] == 3:
226
+ ref_path = self.ref_path_male
227
+
228
+ audio_prompt, sr = librosa.load(ref_path, sr=16000, mono=True)
229
+ sys_msg = {'role': 'user', 'content': [audio_voice_clone_prompt + "\n", audio_prompt, "\n" + audio_assistant_prompt]}
230
+ elif msg_type == 2: #video
231
+ voice_clone_prompt="你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。模仿输入音频中的声音特征。"
232
+ assistant_prompt="作为助手,你将使用这种声音风格说话。"
233
+ ref_path = self.ref_path_video_default
234
+
235
+ if self.customized_options is not None:
236
+ voice_clone_prompt = self.customized_options['voice_clone_prompt']
237
+ assistant_prompt = self.customized_options['assistant_prompt']
238
+ if self.customized_options['use_audio_prompt'] == 1:
239
+ ref_path = self.ref_path_default
240
+ elif self.customized_options['use_audio_prompt'] == 2:
241
+ ref_path = self.ref_path_female
242
+ elif self.customized_options['use_audio_prompt'] == 3:
243
+ ref_path = self.ref_path_male
244
+
245
+ audio_prompt, sr = librosa.load(ref_path, sr=16000, mono=True)
246
+ sys_msg = {'role': 'user', 'content': [voice_clone_prompt, audio_prompt, assistant_prompt]}
247
+ # elif msg_type == 3: #user start
248
+ # assistant_prompt="作为助手,你将使用这种声音风格说话。"
249
+ # if self.customized_options is not None:
250
+ # assistant_prompt = self.customized_options['assistant_prompt']
251
+
252
+ # sys_msg = {'role': 'user', 'content': [assistant_prompt]}
253
+
254
+ self.msg_type = msg_type
255
+ msgs = [sys_msg]
256
+ if self.customized_options is not None:
257
+ if self.customized_options['use_audio_prompt'] > 0:
258
+ self.minicpmo_model.streaming_prefill(
259
+ session_id=str(self.session_id),
260
+ msgs=msgs,
261
+ tokenizer=self.minicpmo_tokenizer,
262
+ )
263
+ if msg_type == 0:
264
+ self.minicpmo_model.streaming_prefill(
265
+ session_id=str(self.session_id),
266
+ msgs=msgs,
267
+ tokenizer=self.minicpmo_tokenizer,
268
+ )
269
+
270
+ self.savedir = os.path.join(f"./log_data/{args.port}/", str(time.time()))
271
+ if not os.path.exists(self.savedir):
272
+ os.makedirs(self.savedir)
273
+ if not os.path.exists(self.savedir + "/input_audio_log"):
274
+ os.makedirs(self.savedir + "/input_audio_log")
275
+ if not os.path.exists(self.savedir + "/input_audio_vad_log"):
276
+ os.makedirs(self.savedir + "/input_audio_vad_log")
277
+ if not os.path.exists(self.savedir + "/input_image_log"):
278
+ os.makedirs(self.savedir + "/input_image_log")
279
+ if not os.path.exists(self.savedir + "/output_audio_log"):
280
+ os.makedirs(self.savedir + "/output_audio_log")
281
+ if not os.path.exists(self.savedir + "/feedback_log"):
282
+ os.makedirs(self.savedir + "/feedback_log")
283
+ if not os.path.exists(self.savedir + "/input_audio"):
284
+ os.makedirs(self.savedir + "/input_audio")
285
+
286
+ self.past_session_id = self.session_id
287
+ self.audio_prefill = []
288
+ self.audio_input = []
289
+
290
+ def clear(self):
291
+ try:
292
+ self.flag_decode = False
293
+ self.stream_started = False
294
+ self.cnts = None
295
+ self.vad_sequence = []
296
+ self.audio_prefill = []
297
+ self.audio_input = []
298
+ self.image_prefill = None
299
+
300
+ if self.minicpmo_model.llm_past_key_values[0][0].shape[2]>8192:
301
+ self.session_id += 1 # to clear all kv cache
302
+ self.sys_prompt_flag = False
303
+
304
+ self.vad_time = 0
305
+ self.ls_time = 0
306
+ self.msg_type = 1
307
+
308
+ except Exception as e:
309
+ raise ValueError(f"Clear error: {str(e)}")
310
+
311
+
312
+ def process_message(self, message: Dict[str, Any]):
313
+ try:
314
+ # Process content items
315
+ audio_data = None
316
+ image_data = None
317
+ for content_item in message["content"]:
318
+ if content_item["type"] == "stop_response":
319
+ logger.info("process_message: received request to stop_response")
320
+ self.stop_response = True
321
+ return "stop"
322
+ elif content_item["type"] == "input_audio":
323
+ audio_data = content_item["input_audio"]["data"]
324
+ audio_timestamp = content_item["input_audio"].get("timestamp", "")
325
+ elif content_item["type"] == "image_data":
326
+ image_data = content_item["image_data"]["data"]
327
+ if audio_data is None:
328
+ return "empty audio"
329
+
330
+ if self.conversation_started.is_set() and self.is_streaming_complete.is_set():
331
+ logger.info("conversation not started or still in generation, skip stream message.")
332
+ return "skip"
333
+
334
+ if self.flag_decode:
335
+ return "skip"
336
+
337
+ try:
338
+ audio_bytes = base64.b64decode(audio_data)
339
+
340
+ image = None
341
+ if image_data is not None:
342
+ if len(image_data) > 0:
343
+ image_bytes = base64.b64decode(image_data)
344
+ image_buffer = io.BytesIO(image_bytes)
345
+ image_buffer.seek(0)
346
+ image = Image.open(image_buffer)
347
+ # logger.info("read image")
348
+
349
+ if self.sys_prompt_flag is False:
350
+ self.all_start_time = time.time()
351
+ self.sys_prompt_flag = True
352
+ if image_data is not None:
353
+ self.sys_prompt_init(2)
354
+ else:
355
+ self.sys_prompt_init(1)
356
+
357
+ self.prefill(audio_bytes, image, False)
358
+
359
+ self.vad_sequence.append(audio_bytes)
360
+ if len(self.vad_sequence) < self.vad_sequence_length:
361
+ # logger.info('length of vad_sequence is {}, insufficient'.format(self.vad_sequence_length))
362
+ return "done"
363
+ elif len(self.vad_sequence) > self.vad_sequence_length:
364
+ # logger.info('length of vad_sequence exceeds {}'.format(self.vad_sequence_length))
365
+ self.vad_sequence.pop(0)
366
+ self.vad_check_audio_bytes(audio_bytes, image, 16000)
367
+
368
+ return "done"
369
+
370
+ except Exception as e:
371
+ raise ValueError(f"Audio processing error: {str(e)}")
372
+
373
+ except Exception as e:
374
+ raise ValueError(f"Message processing error: {str(e)}")
375
+
376
+ def resample_audio(self, input_path, src_sr, tar_sr, output_path):
377
+ audio_data, _ = librosa.load(input_path, sr=src_sr)
378
+ audio_new = librosa.resample(audio_data, orig_sr=src_sr, target_sr=tar_sr)
379
+ soundfile.write(output_path, audio_new, tar_sr)
380
+
381
+ def calculate_rms(self, input_path, sr):
382
+ audio_data, _ = librosa.load(input_path, sr=sr)
383
+ return (np.sqrt(np.mean(audio_data**2)) > 0.002)
384
+
385
+ def vad_check_audio_bytes(self, audio, image, sr):
386
+ try:
387
+ input_audio_vad_path = self.savedir + f"/input_audio_vad_log/vad_{self.input_audio_vad_id}.wav"
388
+ self.input_audio_vad_id += 1
389
+ self.merge_wav_files(self.vad_sequence, input_audio_vad_path)
390
+
391
+ with open(input_audio_vad_path,"rb") as f:
392
+ temp_audio = f.read()
393
+ dur_vad, vad_audio_bytes, time_vad = vad_utils.run_vad(temp_audio, sr, self.vad_options)
394
+ if self.customized_options is not None:
395
+ vad_threshold = 1 - self.customized_options['vad_threshold']
396
+ else:
397
+ vad_threshold = 0.2
398
+
399
+ if self.calculate_rms(input_audio_vad_path, sr) and dur_vad > 0.4:
400
+ if self.stream_started == False:
401
+ self.vad_time = time.time()
402
+ self.stream_started = True
403
+ elif dur_vad < vad_threshold:
404
+ if self.stream_started:
405
+ self.stream_started = False
406
+ if (time.time() - self.vad_time >= 0.6):
407
+ self.prefill(audio, image, True)
408
+ self.is_streaming_complete.set()
409
+ # self.ls_time = time.time()
410
+
411
+ except Exception as e:
412
+ logger.error(f"VAD error: {e}")
413
+ raise
414
+ return
415
+
416
+ def prefill(self, audio, image, is_end):
417
+ if self.server_wait:
418
+ now = time.time()
419
+ await_time = self.speaking_time_stamp - now + self.extra_wait_time
420
+ if await_time > 0:
421
+ return False
422
+
423
+ if self.flag_decode:
424
+ return False
425
+
426
+ if image is not None:
427
+ self.image_prefill = image
428
+ try:
429
+ if is_end == False:
430
+ self.audio_prefill.append(audio)
431
+ self.audio_input.append(audio)
432
+ slice_nums = 1
433
+ if is_end and self.customized_options is not None:
434
+ if self.customized_options['hd_video']:
435
+ slice_nums = 6
436
+ else:
437
+ return True
438
+ if (len(self.audio_prefill) == (1000/self.audio_chunk)) or (is_end and len(self.audio_prefill)>0):
439
+ time_prefill = time.time()
440
+ input_audio_path = self.savedir + f"/input_audio_log/input_audio_{self.input_audio_id}.wav"
441
+ self.merge_wav_files(self.audio_prefill, input_audio_path)
442
+ with open(input_audio_path,"rb") as wav_io:
443
+ signal, sr = soundfile.read(wav_io, dtype='float32')
444
+ soundfile.write(input_audio_path, signal, 16000)
445
+ audio_np, sr = librosa.load(input_audio_path, sr=16000, mono=True)
446
+ self.audio_prefill = []
447
+
448
+ if len(audio_np) > 16000:
449
+ audio_np = audio_np[:16000]
450
+
451
+ with torch.no_grad():
452
+ if self.image_prefill is not None:
453
+ input_image_path = self.savedir + f'/input_image_log/input_image_{self.input_audio_id}.png'
454
+ self.image_prefill.save(input_image_path, 'PNG')
455
+ self.image_prefill = self.image_prefill.convert("RGB")
456
+
457
+ cnts = None
458
+ if self.image_prefill is not None:
459
+ cnts = ["<unit>", self.image_prefill, audio_np]
460
+ else:
461
+ cnts = [audio_np]
462
+
463
+ if cnts is not None:
464
+ msg = {"role":"user", "content": cnts}
465
+ msgs = [msg]
466
+ res = self.minicpmo_model.streaming_prefill(
467
+ session_id=str(self.session_id),
468
+ msgs=msgs,
469
+ tokenizer=self.minicpmo_tokenizer,
470
+ max_slice_nums=slice_nums,
471
+ )
472
+
473
+ self.input_audio_id += 1
474
+ return True
475
+
476
+ except Exception as e:
477
+ logger.error(f"prefill error: {e}")
478
+ import traceback
479
+ traceback.print_exc()
480
+ raise
481
+
482
+ def generate_end(self):
483
+ self.input_audio_id += 10
484
+ self.output_audio_id += 10
485
+ self.flag_decode = False
486
+ self.reset()
487
+ return
488
+
489
+ async def generate(self):
490
+ """ return audio bytes and response text (optional) """
491
+ if self.stop_response:
492
+ self.generate_end()
493
+ return
494
+
495
+ self.flag_decode = True
496
+ try:
497
+ with torch.no_grad():
498
+ logger.info("=== model gen start ===")
499
+ time_gen = time.time()
500
+ input_audio_path = self.savedir + f"/input_audio/all_input_audio_{self.input_audio_id}.wav"
501
+ self.merge_wav_files(self.audio_input, input_audio_path)
502
+ audio_stream = None
503
+ try:
504
+ with open(input_audio_path, 'rb') as wav_file:
505
+ audio_stream = wav_file.read()
506
+ except FileNotFoundError:
507
+ print(f"File {input_audio_path} not found.")
508
+ yield base64.b64encode(audio_stream).decode('utf-8'), "assistant:\n"
509
+
510
+ print('=== gen start: ', time.time() - time_gen)
511
+ first_time = True
512
+ temp_time = time.time()
513
+ temp_time1 = time.time()
514
+ with torch.inference_mode():
515
+ if self.stop_response:
516
+ self.generate_end()
517
+ return
518
+ self.minicpmo_model.config.stream_input=True
519
+ msg = {"role":"user", "content": self.cnts}
520
+ msgs = [msg]
521
+ text = ''
522
+ self.speaking_time_stamp = time.time()
523
+ try:
524
+ for r in self.minicpmo_model.streaming_generate(
525
+ session_id=str(self.session_id),
526
+ tokenizer=self.minicpmo_tokenizer,
527
+ generate_audio=True,
528
+ # enable_regenerate=True,
529
+ ):
530
+ if self.stop_response:
531
+ self.generate_end()
532
+ return
533
+ audio_np, sr, text = r["audio_wav"], r["sampling_rate"], r["text"]
534
+
535
+ output_audio_path = self.savedir + f'/output_audio_log/output_audio_{self.output_audio_id}.wav'
536
+ self.output_audio_id += 1
537
+ soundfile.write(output_audio_path, audio_np, samplerate=sr)
538
+ audio_stream = None
539
+ try:
540
+ with open(output_audio_path, 'rb') as wav_file:
541
+ audio_stream = wav_file.read()
542
+ except FileNotFoundError:
543
+ print(f"File {output_audio_path} not found.")
544
+ temp_time1 = time.time()
545
+ print('text: ', text)
546
+ yield base64.b64encode(audio_stream).decode('utf-8'), text
547
+ self.speaking_time_stamp += self.cycle_wait_time
548
+ except Exception as e:
549
+ logger.error(f"Error happened during generation: {str(e)}")
550
+ yield None, '\n<end>'
551
+
552
+ except Exception as e:
553
+ logger.error(f"发生异常:{e}")
554
+ import traceback
555
+ traceback.print_exc()
556
+ raise
557
+
558
+ finally:
559
+ logger.info(f"uid {self.uid}: generation finished!")
560
+ self.generate_end()
561
+
562
+ async def check_activity(self):
563
+ while True:
564
+ # Check for overall inactivity (30 minutes)
565
+ if self.is_timed_out():
566
+ self.reset()
567
+ if self.no_active_stream() and not self.is_streaming_complete.is_set():
568
+ self.is_streaming_complete.set()
569
+
570
+ await asyncio.sleep(1) # Check every second
571
+
572
+ def upload_customized_audio(self, audio_data, audio_fmt):
573
+ self.customized_audio = None
574
+ try:
575
+ if audio_data is not None and len(audio_data) > 0:
576
+ # if audio_fmt == "mp3" or audio_fmt == "wav":
577
+ audio_bytes = base64.b64decode(audio_data)
578
+ fio = io.BytesIO(audio_bytes)
579
+ fio.seek(0)
580
+ audio_np, sr = librosa.load(fio, sr=16000, mono=True)
581
+ if audio_np is not None and len(audio_np) > 1000:
582
+ output_audio_path = self.savedir + f'/customized_audio.wav'
583
+ soundfile.write(output_audio_path, audio_np, sr)
584
+ self.customized_audio = output_audio_path
585
+ logger.info(f"processed customized {audio_fmt} audio")
586
+ print(audio_np.shape, type(audio_np), sr)
587
+ else:
588
+ logger.info(f"empty customized audio, use default value instead.")
589
+ self.customized_audio = None
590
+ except Exception as e:
591
+ raise ValueError(f"Process customized audio error: {str(e)}")
592
+
593
+ def update_customized_options(self, uid, options):
594
+ self.customized_options = None
595
+ if options is None:
596
+ raise ValueError("Invalid None type for options, expected dict type")
597
+ self.customized_options = options
598
+ logger.info(f"uid: {uid} set customized_options to {options}")
599
+
600
+
601
+ stream_manager = StreamManager()
602
+
603
+
604
+ @app.on_event("startup")
605
+ async def startup_event():
606
+ logger.info("Starting application and activity checker")
607
+ asyncio.create_task(stream_manager.check_activity())
608
+
609
+ @app.on_event("shutdown")
610
+ async def shutdown_event():
611
+ logger.info("Shutting down application")
612
+
613
+ @app.post("/stream")
614
+ @app.post("/api/v1/stream")
615
+ async def stream(request: Request, uid: Optional[str] = Header(None)):
616
+ global stream_manager
617
+
618
+ stream_manager.update_last_request_time()
619
+ stream_manager.update_last_stream_time()
620
+
621
+ if not uid:
622
+ raise HTTPException(status_code=400, detail="Missing uid in headers")
623
+ if stream_manager.uid is not None and stream_manager.uid != uid:
624
+ logger.error(f"uid changed during steram: previous uid {stream_manager.uid}, new uid {uid}")
625
+ raise HTTPException(status_code=400, detail="uid changed in stream")
626
+
627
+ try:
628
+ # Parse JSON request
629
+ data = await request.json()
630
+
631
+ # Validate basic structure
632
+ if not isinstance(data, dict) or "messages" not in data:
633
+ raise HTTPException(status_code=400, detail="Invalid request format")
634
+
635
+ # Process messages
636
+ reason = ""
637
+ for message in data["messages"]:
638
+ if not isinstance(message, dict) or "role" not in message or "content" not in message:
639
+ raise HTTPException(status_code=400, detail="Invalid message format")
640
+ reason = stream_manager.process_message(message)
641
+
642
+ # Return response using uid from header
643
+ response = {
644
+ "id": uid,
645
+ "choices": {
646
+ "role": "assistant",
647
+ "content": "success",
648
+ "finish_reason": reason
649
+ }
650
+ }
651
+ return JSONResponse(content=response, status_code=200)
652
+
653
+ except json.JSONDecodeError:
654
+ raise HTTPException(status_code=400, detail="Invalid JSON")
655
+ except Exception as e:
656
+ raise HTTPException(status_code=500, detail=str(e))
657
+
658
+ @app.websocket("/ws/stream")
659
+ @app.websocket("/ws/api/v1/stream")
660
+ async def websocket_stream(websocket: WebSocket,
661
+ uid: Optional[str] = Query(None)):
662
+ global stream_manager
663
+
664
+ if not uid:
665
+ await websocket.close(code=400, reason="Missing uid in request")
666
+ return
667
+
668
+ # Accept the WebSocket connection
669
+ await websocket.accept()
670
+
671
+ #if stream_manager.uid is not None and stream_manager.uid != uid:
672
+ # logger.error(f"uid changed during steram: previous uid {stream_manager.uid}, new uid {uid}")
673
+ # await websocket.close(code=400, reason="Uid changed in stream.")
674
+ # return
675
+
676
+ try:
677
+ while True:
678
+ # Continuously listen for incoming messages from the client
679
+ data = await websocket.receive_text()
680
+
681
+ # Parse JSON request
682
+ try:
683
+ request_data = json.loads(data)
684
+ except json.JSONDecodeError:
685
+ await websocket.send_json({"error": "Invalid JSON"})
686
+ continue
687
+
688
+ stream_manager.update_last_request_time()
689
+ stream_manager.update_last_stream_time()
690
+
691
+ if stream_manager.uid is not None and stream_manager.uid != uid:
692
+ logger.error(f"uid changed during stream: previous uid {stream_manager.uid}, new uid {uid}")
693
+ await websocket.send_json({"error": "UID changed in stream"})
694
+ continue
695
+
696
+ # Validate basic structure
697
+ if not isinstance(request_data, dict) or "messages" not in request_data:
698
+ await websocket.send_json({"error": "Invalid request format"})
699
+ continue
700
+
701
+ # Process messages
702
+ try:
703
+ reason = ""
704
+ for message in request_data["messages"]:
705
+ if not isinstance(message, dict) or "role" not in message or "content" not in message:
706
+ await websocket.send_json({"error": "Invalid message format"})
707
+ continue
708
+ reason = stream_manager.process_message(message)
709
+
710
+ # Respond with success message
711
+ response = {
712
+ "id": uid,
713
+ "choices": {
714
+ "role": "assistant",
715
+ "content": "success",
716
+ "finish_reason": reason,
717
+ },
718
+ }
719
+ await websocket.send_json(response)
720
+ except WebSocketDisconnect:
721
+ # Handle WebSocket disconnection
722
+ break
723
+ except Exception as e:
724
+ logger.error(f"process message error: {str(e)}")
725
+ await websocket.close(code=1011, reason =f"Internal server error: {str(e)}")
726
+
727
+ except WebSocketDisconnect:
728
+ # Handle WebSocket disconnection
729
+ return
730
+ except Exception as e:
731
+ logger.error(f"ws_stream error: {str(e)}")
732
+ await websocket.close(code=1011, reason =f"Unexpected error: {str(e)}")
733
+
734
+
735
+ async def generate_sse_response(request: Request, uid: Optional[str] = Header(None)):
736
+ global stream_manager
737
+ print(f"uid: {uid}")
738
+ try:
739
+ # Wait for streaming to complete or timeout
740
+ while not stream_manager.is_streaming_complete.is_set():
741
+ # if stream_manager.is_timed_out():
742
+ # yield f"data: {json.dumps({'error': 'Stream timeout'})}\n\n"
743
+ # return
744
+ # print(f"{uid} whille not stream_manager.is_streaming_complete.is_set(), asyncio.sleep(0.1)")
745
+ await asyncio.sleep(0.1)
746
+
747
+ logger.info("streaming complete\n")
748
+ # Generate response
749
+ try:
750
+ yield f"event: message\n"
751
+ async for audio, text in stream_manager.generate():
752
+ if text == "stop":
753
+ break
754
+ res = {
755
+ "id": stream_manager.uid,
756
+ "response_id": stream_manager.output_audio_id,
757
+ "choices": [
758
+ {
759
+ "role": "assistant",
760
+ "audio": audio,
761
+ "text": text,
762
+ "finish_reason": "processing"
763
+ }
764
+ ]
765
+ }
766
+ # logger.info("generate_sse_response yield response")
767
+ yield f"data: {json.dumps(res)}\n\n"
768
+ await asyncio.sleep(0)
769
+
770
+ except Exception as e:
771
+ logger.error(f"Error while generation: {str(e)}")
772
+ yield f'data:{{"error": "{str(exc)}"}}\n\n'
773
+ except Exception as e:
774
+ yield f'data:{{"error": "{str(e)}"}}\n\n'
775
+
776
+ @app.post("/completions")
777
+ @app.post("/api/v1/completions")
778
+ async def completions(request: Request, uid: Optional[str] = Header(None)):
779
+ global stream_manager
780
+
781
+ if not uid:
782
+ raise HTTPException(status_code=400, detail="Missing uid in headers")
783
+
784
+ try:
785
+ # if stream_manager.uid is not None and stream_manager.uid != uid:
786
+ if stream_manager.uid != uid:
787
+ # stream_manager.stop_response = True
788
+ # logger.info(f"uid changed, reset model: previous uid {stream_manager.uid}, new uid {uid}")
789
+ stream_manager.session_id += 1
790
+ stream_manager.sys_prompt_flag = False
791
+ stream_manager.reset()
792
+
793
+ # raise HTTPException(
794
+ # status_code=409,
795
+ # detail="User id changed, reset context."
796
+ # )
797
+ stream_manager.speaking_time_stamp = 0
798
+ stream_manager.update_last_request_time()
799
+ stream_manager.uid = uid
800
+ stream_manager.start_conversation()
801
+
802
+ data = await request.json()
803
+
804
+ return StreamingResponse(
805
+ generate_sse_response(request, uid),
806
+ media_type="text/event-stream",
807
+ headers={
808
+ "Cache-Control": "no-cache",
809
+ "Connection": "keep-alive",
810
+ "Transfer-Encoding": "chunked"
811
+ }
812
+ )
813
+ except asyncio.TimeoutError:
814
+ raise HTTPException(
815
+ status_code=503,
816
+ detail="Server busy, please try again later"
817
+ )
818
+ except Exception as e:
819
+ logger.error(f"Error processing request for user {uid}: {str(e)}")
820
+ raise HTTPException(status_code=500, detail=str(e))
821
+
822
+
823
+ @app.post("/stop")
824
+ @app.post("/api/v1/stop")
825
+ async def stop_response(request: Request, uid: Optional[str] = Header(None)):
826
+ if not uid:
827
+ raise HTTPException(status_code=400, detail="Missing uid in headers")
828
+
829
+ global stream_manager
830
+ # stream_manager.session_id += 1
831
+ logger.info(f"uid {uid}: received stop_response")
832
+ stream_manager.stop_response = True
833
+ response = {
834
+ "id": uid,
835
+ "choices": {
836
+ "role": "assistant",
837
+ "content": "success",
838
+ "finish_reason": "stop"
839
+ }
840
+ }
841
+ return JSONResponse(content=response, status_code=200)
842
+
843
+ @app.post("/feedback")
844
+ @app.post("/api/v1/feedback")
845
+ async def feedback(request: Request, uid: Optional[str] = Header(None)):
846
+ global stream_manager
847
+
848
+ # Validate the 'uid' header
849
+ if not uid:
850
+ raise HTTPException(status_code=400, detail="Missing 'uid' header")
851
+
852
+ try:
853
+ data = await request.json()
854
+ if "response_id" not in data or "rating" not in data:
855
+ raise HTTPException(status_code=400, detail="Invalid request: must have response_id and rating")
856
+ response_id = data.get("response_id", "")
857
+ rating = data.get("rating", "")
858
+ comment = data.get("comment", "")
859
+ # Validate the rating
860
+ if rating not in ["like", "dislike"]:
861
+ raise HTTPException(status_code=400, detail=f"Invalid rating value: {rating}")
862
+
863
+ # Define the log file path
864
+ log_file_path = f"{stream_manager.savedir}/feedback_log/{response_id}.{rating}"
865
+ # Write the feedback to the file asynchronously
866
+ async with aiofiles.open(log_file_path, mode="a") as file:
867
+ await file.write(f"model: {stream_manager.minicpmo_model_path}\nuid {uid}: {comment}\n")
868
+ response = {
869
+ "id": uid,
870
+ "choices": {
871
+ "role": "assistant",
872
+ "content": "success",
873
+ "finish_reason": "done"
874
+ }
875
+ }
876
+ return JSONResponse(content=response, status_code=200)
877
+ except Exception as e:
878
+ logger.error(f"Error processing feedback for user {uid}: {str(e)}")
879
+ raise HTTPException(status_code=500, detail=str(e))
880
+
881
+
882
+ @app.post("/init_options")
883
+ @app.post("/api/v1/init_options")
884
+ async def init_options(request: Request, uid: Optional[str] = Header(None)):
885
+ global stream_manager
886
+
887
+ stream_manager.update_last_request_time()
888
+
889
+ if not uid:
890
+ raise HTTPException(status_code=400, detail="Missing uid in headers")
891
+ try:
892
+ # Parse JSON request
893
+ data = await request.json()
894
+
895
+ # Validate basic structure
896
+ if not isinstance(data, dict) or "messages" not in data:
897
+ raise HTTPException(status_code=400, detail="Invalid request format")
898
+
899
+ messages = data.get("messages", [])
900
+ for message in messages:
901
+ if not isinstance(message, dict) or "role" not in message or "content" not in message:
902
+ raise HTTPException(status_code=400, detail="Invalid message format")
903
+
904
+ for content in message.get("content", []):
905
+ if content["type"] == "input_audio":
906
+ audio_data = content["input_audio"].get("data", "")
907
+ audio_fmt = content["input_audio"].get("format", "")
908
+ stream_manager.upload_customized_audio(audio_data, audio_fmt)
909
+ elif content["type"] == "options":
910
+ stream_manager.update_customized_options(uid, content["options"])
911
+ else:
912
+ ctype = content["type"]
913
+ raise HTTPException(status_code=400, detail=f"Invalid content type: {ctype}")
914
+ version = stream_manager.model_version
915
+ print(version)
916
+ response = {
917
+ "id": uid,
918
+ "choices": {
919
+ "role": "assistant",
920
+ "content": version,
921
+ "finish_reason": "done"
922
+ }
923
+ }
924
+ return JSONResponse(content=response, status_code=200)
925
+ except Exception as e:
926
+ raise HTTPException(status_code=400, detail=f"init options error: {str(e)}")
927
+
928
+
929
+ @app.get('/health')
930
+ @app.get('/api/v1/health')
931
+ async def health_check():
932
+ return {"status": "OK"}
933
+
934
+
935
+ if __name__ == "__main__":
936
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/vad_utils.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+ import librosa
4
+ import os
5
+ import time
6
+ import traceback
7
+
8
+ from typing import List, NamedTuple, Optional
9
+
10
+ class VadOptions(NamedTuple):
11
+ """VAD options.
12
+
13
+ Attributes:
14
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
15
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
16
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
17
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
18
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
19
+ than max_speech_duration_s will be split at the timestamp of the last silence that
20
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
21
+ split aggressively just before max_speech_duration_s.
22
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
23
+ before separating it
24
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
25
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
26
+ Values other than these may affect model performance!!
27
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
28
+ """
29
+
30
+ # threshold: float = 0.3 # rep 0.5
31
+ # min_speech_duration_ms: int = 250
32
+ # max_speech_duration_s: float = float("inf")
33
+ # min_silence_duration_ms: int = 2000
34
+ # window_size_samples: int = 1024
35
+ # speech_pad_ms: int = 600 # rep 400
36
+
37
+ threshold: float = 0.7 # gw: 0.3 # rep 0.5
38
+ min_speech_duration_ms: int = 128 # original & gw: 250
39
+ max_speech_duration_s: float = float("inf")
40
+ min_silence_duration_ms: int = 500 # original & gw: 2000
41
+ window_size_samples: int = 1024
42
+ speech_pad_ms: int = 30 # gw: 600 # rep 400
43
+
44
+ class SileroVADModel:
45
+ def __init__(self, path):
46
+ try:
47
+ import onnxruntime
48
+ except ImportError as e:
49
+ raise RuntimeError(
50
+ "Applying the VAD filter requires the onnxruntime package"
51
+ ) from e
52
+
53
+ opts = onnxruntime.SessionOptions()
54
+ opts.inter_op_num_threads = 1
55
+ opts.intra_op_num_threads = 1
56
+ opts.log_severity_level = 4
57
+
58
+ self.session = onnxruntime.InferenceSession(
59
+ path,
60
+ providers=["CPUExecutionProvider"],
61
+ sess_options=opts,
62
+ )
63
+
64
+ def get_initial_state(self, batch_size: int):
65
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
66
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
67
+ return h, c
68
+
69
+ def __call__(self, x, state, sr: int):
70
+ if len(x.shape) == 1:
71
+ x = np.expand_dims(x, 0)
72
+ if len(x.shape) > 2:
73
+ raise ValueError(
74
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
75
+ )
76
+ if sr / x.shape[1] > 31.25:
77
+ raise ValueError("Input audio chunk is too short")
78
+
79
+ h, c = state
80
+
81
+ ort_inputs = {
82
+ "input": x,
83
+ #"state": np.concatenate((h, c), axis=0),
84
+ "h": h,
85
+ "c": c,
86
+ "sr": np.array(sr, dtype="int64"),
87
+ }
88
+
89
+ out, h, c = self.session.run(None, ort_inputs)
90
+ #out = self.session.run(None, ort_inputs)
91
+ state = (h, c)
92
+ return out, state
93
+
94
+
95
+ @functools.lru_cache
96
+ def get_vad_model():
97
+ """Returns the VAD model instance."""
98
+ path = os.path.join(os.path.dirname(__file__), "silero_vad.onnx")
99
+ return SileroVADModel(path)
100
+
101
+
102
+ def get_speech_timestamps(
103
+ audio: np.ndarray,
104
+ vad_options: Optional[VadOptions] = None,
105
+ **kwargs,
106
+ ) -> List[dict]:
107
+ """This method is used for splitting long audios into speech chunks using silero VAD.
108
+
109
+ Args:
110
+ audio: One dimensional float array.
111
+ vad_options: Options for VAD processing.
112
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
113
+
114
+ Returns:
115
+ List of dicts containing begin and end samples of each speech chunk.
116
+ """
117
+ if vad_options is None:
118
+ vad_options = VadOptions(**kwargs)
119
+
120
+ threshold = vad_options.threshold
121
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
122
+ max_speech_duration_s = vad_options.max_speech_duration_s
123
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
124
+ window_size_samples = vad_options.window_size_samples
125
+ speech_pad_ms = vad_options.speech_pad_ms
126
+
127
+ if window_size_samples not in [512, 1024, 1536]:
128
+ warnings.warn(
129
+ "Unusual window_size_samples! Supported window_size_samples:\n"
130
+ " - [512, 1024, 1536] for 16000 sampling_rate"
131
+ )
132
+
133
+ sampling_rate = 16000
134
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 #如果间隔区间没这个长度就不会添加
135
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
136
+ max_speech_samples = (
137
+ sampling_rate * max_speech_duration_s
138
+ - window_size_samples
139
+ - 2 * speech_pad_samples
140
+ )
141
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 # 在每个silent需要等 min_silence_duration_ms 后才结束,
142
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 # 0.098s # need to adjust?
143
+
144
+ audio_length_samples = len(audio)
145
+
146
+ # import pdb
147
+ # pdb.set_trace()
148
+
149
+ model = get_vad_model()
150
+ state = model.get_initial_state(batch_size=1)
151
+
152
+ speech_probs = []
153
+ #print("audio_length_samples ", audio_length_samples, ", window_size_samples ", window_size_samples)
154
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
155
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
156
+ if len(chunk) < window_size_samples:
157
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
158
+ speech_prob, state = model(chunk, state, sampling_rate)
159
+ speech_probs.append(speech_prob)
160
+
161
+ triggered = False
162
+ speeches = []
163
+ current_speech = {}
164
+ neg_threshold = threshold - 0.15
165
+
166
+ # to save potential segment end (and tolerate some silence)
167
+ temp_end = 0
168
+ # to save potential segment limits in case of maximum segment size reached
169
+ prev_end = next_start = 0
170
+
171
+ # 大概是一段音频找出其中连续部分,如果遇到silent的话会先记录temp_end,然后如果没超过最小silent长度遇到active的情况下会重置temp_end。silent片段会分别记录silent的起终,在超过长度的时候切开(不完全确定,但是inf的最大长也遇不到)
172
+
173
+ for i, speech_prob in enumerate(speech_probs):
174
+ if (speech_prob >= threshold) and temp_end:
175
+ temp_end = 0
176
+ if next_start < prev_end:
177
+ next_start = window_size_samples * i
178
+
179
+ if (speech_prob >= threshold) and not triggered:
180
+ triggered = True
181
+ current_speech["start"] = window_size_samples * i
182
+ continue
183
+
184
+ if (
185
+ triggered
186
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
187
+ ):
188
+ if prev_end:
189
+ current_speech["end"] = prev_end
190
+ speeches.append(current_speech)
191
+ current_speech = {}
192
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
193
+ if next_start < prev_end:
194
+ triggered = False
195
+ else:
196
+ current_speech["start"] = next_start
197
+ prev_end = next_start = temp_end = 0
198
+ else:
199
+ current_speech["end"] = window_size_samples * i
200
+ speeches.append(current_speech)
201
+ current_speech = {}
202
+ prev_end = next_start = temp_end = 0
203
+ triggered = False
204
+ continue
205
+
206
+ if (speech_prob < neg_threshold) and triggered:
207
+ if not temp_end:
208
+ temp_end = window_size_samples * i
209
+ # condition to avoid cutting in very short silence
210
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
211
+ prev_end = temp_end
212
+ if (window_size_samples * i) - temp_end < min_silence_samples:
213
+ continue
214
+ else:
215
+ current_speech["end"] = temp_end
216
+ if (
217
+ current_speech["end"] - current_speech["start"]
218
+ ) > min_speech_samples:
219
+ speeches.append(current_speech)
220
+ current_speech = {}
221
+ prev_end = next_start = temp_end = 0
222
+ triggered = False
223
+ continue
224
+
225
+
226
+ if (
227
+ current_speech
228
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
229
+ ):
230
+ current_speech["end"] = audio_length_samples
231
+ speeches.append(current_speech)
232
+
233
+ # pad 多少ms,每个中间都会不足平分
234
+ for i, speech in enumerate(speeches):
235
+ if i == 0:
236
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
237
+ if i != len(speeches) - 1:
238
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
239
+ if silence_duration < 2 * speech_pad_samples:
240
+ speech["end"] += int(silence_duration // 2)
241
+ speeches[i + 1]["start"] = int(
242
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
243
+ )
244
+ else:
245
+ speech["end"] = int(
246
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
247
+ )
248
+ speeches[i + 1]["start"] = int(
249
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
250
+ )
251
+ else:
252
+ speech["end"] = int(
253
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
254
+ )
255
+ return speeches
256
+
257
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
258
+ """Collects and concatenates audio chunks."""
259
+ if not chunks:
260
+ return np.array([], dtype=np.float32)
261
+
262
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
263
+
264
+
265
+ def run_vad(ori_audio, sr, vad_options=None):
266
+ _st = time.time()
267
+ try:
268
+ audio = np.frombuffer(ori_audio, dtype=np.int16)
269
+ audio = audio.astype(np.float32) / 32768.0
270
+ sampling_rate = 16000
271
+ if sr != sampling_rate:
272
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
273
+ # print('audio.encode.shape: {}'.format(audio.shape))
274
+ if vad_options is None:
275
+ vad_options = VadOptions()
276
+
277
+ # 确保传递给 get_speech_timestamps 的是 VadOptions 实例
278
+ speech_chunks = get_speech_timestamps(audio, vad_options=vad_options)
279
+ # print(speech_chunks)
280
+ audio = collect_chunks(audio, speech_chunks)
281
+ # print(audio.shape)
282
+ duration_after_vad = audio.shape[0] / sampling_rate
283
+
284
+ # print('audio.decode.shape: {}'.format(audio.shape))
285
+ if sr != sampling_rate:
286
+ # resample to original sampling rate
287
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
288
+ else:
289
+ vad_audio = audio
290
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
291
+
292
+ # 这个round会有一定的误差
293
+
294
+ vad_audio_bytes = vad_audio.tobytes()
295
+
296
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
297
+ except Exception as e:
298
+ msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
299
+ print(msg)
300
+ return -1, ori_audio, round(time.time() - _st, 4)
301
+
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.development ADDED
File without changes
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.production ADDED
File without changes
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc-auto-import.json ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "globals": {
3
+ "Component": true,
4
+ "ComponentPublicInstance": true,
5
+ "ComputedRef": true,
6
+ "EffectScope": true,
7
+ "ExtractDefaultPropTypes": true,
8
+ "ExtractPropTypes": true,
9
+ "ExtractPublicPropTypes": true,
10
+ "InjectionKey": true,
11
+ "LegalTypeEnum": true,
12
+ "LoginTypeEnum": true,
13
+ "PropType": true,
14
+ "Ref": true,
15
+ "VNode": true,
16
+ "WritableComputedRef": true,
17
+ "acceptHMRUpdate": true,
18
+ "ajaxHeader": true,
19
+ "asyncComputed": true,
20
+ "authLogin": true,
21
+ "autoResetRef": true,
22
+ "computed": true,
23
+ "computedAsync": true,
24
+ "computedEager": true,
25
+ "computedInject": true,
26
+ "computedWithControl": true,
27
+ "controlledComputed": true,
28
+ "controlledRef": true,
29
+ "createApp": true,
30
+ "createEventHook": true,
31
+ "createGlobalState": true,
32
+ "createInjectionState": true,
33
+ "createPinia": true,
34
+ "createReactiveFn": true,
35
+ "createReusableTemplate": true,
36
+ "createSharedComposable": true,
37
+ "createTemplatePromise": true,
38
+ "createUnrefFn": true,
39
+ "customRef": true,
40
+ "debouncedRef": true,
41
+ "debouncedWatch": true,
42
+ "defineAsyncComponent": true,
43
+ "defineComponent": true,
44
+ "defineStore": true,
45
+ "eagerComputed": true,
46
+ "effectScope": true,
47
+ "extendRef": true,
48
+ "fetchSmsVerifyCode": true,
49
+ "getActivePinia": true,
50
+ "getCurrentInstance": true,
51
+ "getCurrentScope": true,
52
+ "getHomeInfo": true,
53
+ "h": true,
54
+ "ignorableWatch": true,
55
+ "inject": true,
56
+ "injectLocal": true,
57
+ "isDefined": true,
58
+ "isProxy": true,
59
+ "isReactive": true,
60
+ "isReadonly": true,
61
+ "isRef": true,
62
+ "loginSuccess": true,
63
+ "makeDestructurable": true,
64
+ "mapActions": true,
65
+ "mapGetters": true,
66
+ "mapState": true,
67
+ "mapStores": true,
68
+ "mapWritableState": true,
69
+ "markRaw": true,
70
+ "nextTick": true,
71
+ "onActivated": true,
72
+ "onBeforeMount": true,
73
+ "onBeforeRouteLeave": true,
74
+ "onBeforeRouteUpdate": true,
75
+ "onBeforeUnmount": true,
76
+ "onBeforeUpdate": true,
77
+ "onClickOutside": true,
78
+ "onDeactivated": true,
79
+ "onErrorCaptured": true,
80
+ "onKeyStroke": true,
81
+ "onLongPress": true,
82
+ "onMounted": true,
83
+ "onRenderTracked": true,
84
+ "onRenderTriggered": true,
85
+ "onScopeDispose": true,
86
+ "onServerPrefetch": true,
87
+ "onStartTyping": true,
88
+ "onUnmounted": true,
89
+ "onUpdated": true,
90
+ "pausableWatch": true,
91
+ "provide": true,
92
+ "provideLocal": true,
93
+ "reactify": true,
94
+ "reactifyObject": true,
95
+ "reactive": true,
96
+ "reactiveComputed": true,
97
+ "reactiveOmit": true,
98
+ "reactivePick": true,
99
+ "readonly": true,
100
+ "ref": true,
101
+ "refAutoReset": true,
102
+ "refDebounced": true,
103
+ "refDefault": true,
104
+ "refThrottled": true,
105
+ "refWithControl": true,
106
+ "resolveComponent": true,
107
+ "resolveRef": true,
108
+ "resolveUnref": true,
109
+ "setActivePinia": true,
110
+ "setMapStoreSuffix": true,
111
+ "setupStore": true,
112
+ "shallowReactive": true,
113
+ "shallowReadonly": true,
114
+ "shallowRef": true,
115
+ "store": true,
116
+ "storeToRefs": true,
117
+ "submitFeedback": true,
118
+ "syncRef": true,
119
+ "syncRefs": true,
120
+ "templateRef": true,
121
+ "throttledRef": true,
122
+ "throttledWatch": true,
123
+ "toRaw": true,
124
+ "toReactive": true,
125
+ "toRef": true,
126
+ "toRefs": true,
127
+ "toValue": true,
128
+ "triggerRef": true,
129
+ "tryOnBeforeMount": true,
130
+ "tryOnBeforeUnmount": true,
131
+ "tryOnMounted": true,
132
+ "tryOnScopeDispose": true,
133
+ "tryOnUnmounted": true,
134
+ "unref": true,
135
+ "unrefElement": true,
136
+ "until": true,
137
+ "useActiveElement": true,
138
+ "useAnimate": true,
139
+ "useArrayDifference": true,
140
+ "useArrayEvery": true,
141
+ "useArrayFilter": true,
142
+ "useArrayFind": true,
143
+ "useArrayFindIndex": true,
144
+ "useArrayFindLast": true,
145
+ "useArrayIncludes": true,
146
+ "useArrayJoin": true,
147
+ "useArrayMap": true,
148
+ "useArrayReduce": true,
149
+ "useArraySome": true,
150
+ "useArrayUnique": true,
151
+ "useAsyncQueue": true,
152
+ "useAsyncState": true,
153
+ "useAttrs": true,
154
+ "useBase64": true,
155
+ "useBattery": true,
156
+ "useBluetooth": true,
157
+ "useBreakpoints": true,
158
+ "useBroadcastChannel": true,
159
+ "useBrowserLocation": true,
160
+ "useCached": true,
161
+ "useClearLocalCache": true,
162
+ "useClipboard": true,
163
+ "useClipboardItems": true,
164
+ "useCloned": true,
165
+ "useColorMode": true,
166
+ "useConfirmDialog": true,
167
+ "useCounter": true,
168
+ "useCssModule": true,
169
+ "useCssVar": true,
170
+ "useCssVars": true,
171
+ "useCurrentElement": true,
172
+ "useCycleList": true,
173
+ "useDark": true,
174
+ "useDateFormat": true,
175
+ "useDebounce": true,
176
+ "useDebounceFn": true,
177
+ "useDebouncedRefHistory": true,
178
+ "useDeviceMotion": true,
179
+ "useDeviceOrientation": true,
180
+ "useDevicePixelRatio": true,
181
+ "useDevicesList": true,
182
+ "useDisplayMedia": true,
183
+ "useDocumentVisibility": true,
184
+ "useDraggable": true,
185
+ "useDropZone": true,
186
+ "useElementBounding": true,
187
+ "useElementByPoint": true,
188
+ "useElementHover": true,
189
+ "useElementSize": true,
190
+ "useElementVisibility": true,
191
+ "useEventBus": true,
192
+ "useEventListener": true,
193
+ "useEventSource": true,
194
+ "useEyeDropper": true,
195
+ "useFavicon": true,
196
+ "useFetch": true,
197
+ "useFetchLogin": true,
198
+ "useFileDialog": true,
199
+ "useFileSystemAccess": true,
200
+ "useFocus": true,
201
+ "useFocusWithin": true,
202
+ "useFps": true,
203
+ "useFullscreen": true,
204
+ "useGamepad": true,
205
+ "useGeolocation": true,
206
+ "useGetLocalCache": true,
207
+ "useHttp": true,
208
+ "useIdle": true,
209
+ "useImage": true,
210
+ "useInfiniteScroll": true,
211
+ "useIntersectionObserver": true,
212
+ "useInterval": true,
213
+ "useIntervalFn": true,
214
+ "useKeyModifier": true,
215
+ "useLastChanged": true,
216
+ "useLegal": true,
217
+ "useLink": true,
218
+ "useLocalStorage": true,
219
+ "useLogin": true,
220
+ "useMagicKeys": true,
221
+ "useManualRefHistory": true,
222
+ "useMediaControls": true,
223
+ "useMediaQuery": true,
224
+ "useMemoize": true,
225
+ "useMemory": true,
226
+ "useMounted": true,
227
+ "useMouse": true,
228
+ "useMouseInElement": true,
229
+ "useMousePressed": true,
230
+ "useMutationObserver": true,
231
+ "useNavigatorLanguage": true,
232
+ "useNetwork": true,
233
+ "useNow": true,
234
+ "useObjectUrl": true,
235
+ "useOffsetPagination": true,
236
+ "useOnline": true,
237
+ "usePageLeave": true,
238
+ "useParallax": true,
239
+ "useParentElement": true,
240
+ "usePerformanceObserver": true,
241
+ "usePermission": true,
242
+ "usePointer": true,
243
+ "usePointerLock": true,
244
+ "usePointerSwipe": true,
245
+ "usePreferredColorScheme": true,
246
+ "usePreferredContrast": true,
247
+ "usePreferredDark": true,
248
+ "usePreferredLanguages": true,
249
+ "usePreferredReducedMotion": true,
250
+ "usePrevious": true,
251
+ "useRafFn": true,
252
+ "useRefHistory": true,
253
+ "useResizeObserver": true,
254
+ "useRoute": true,
255
+ "useRouter": true,
256
+ "useScreenOrientation": true,
257
+ "useScreenSafeArea": true,
258
+ "useScriptTag": true,
259
+ "useScroll": true,
260
+ "useScrollLock": true,
261
+ "useSessionStorage": true,
262
+ "useSetLocalCache": true,
263
+ "useShare": true,
264
+ "useSlots": true,
265
+ "useSorted": true,
266
+ "useSpeechRecognition": true,
267
+ "useSpeechSynthesis": true,
268
+ "useStepper": true,
269
+ "useStorage": true,
270
+ "useStorageAsync": true,
271
+ "useStyleTag": true,
272
+ "useSupported": true,
273
+ "useSwipe": true,
274
+ "useTemplateRefsList": true,
275
+ "useTextDirection": true,
276
+ "useTextSelection": true,
277
+ "useTextareaAutosize": true,
278
+ "useThrottle": true,
279
+ "useThrottleFn": true,
280
+ "useThrottledRefHistory": true,
281
+ "useTimeAgo": true,
282
+ "useTimeout": true,
283
+ "useTimeoutFn": true,
284
+ "useTimeoutPoll": true,
285
+ "useTimestamp": true,
286
+ "useTitle": true,
287
+ "useToNumber": true,
288
+ "useToString": true,
289
+ "useToggle": true,
290
+ "useTransition": true,
291
+ "useUrlSearchParams": true,
292
+ "useUserMedia": true,
293
+ "useUserStore": true,
294
+ "useUserStoreWithOut": true,
295
+ "useVModel": true,
296
+ "useVModels": true,
297
+ "useVibrate": true,
298
+ "useVirtualList": true,
299
+ "useWakeLock": true,
300
+ "useWebNotification": true,
301
+ "useWebSocket": true,
302
+ "useWebWorker": true,
303
+ "useWebWorkerFn": true,
304
+ "useWindowFocus": true,
305
+ "useWindowScroll": true,
306
+ "useWindowSize": true,
307
+ "watch": true,
308
+ "watchArray": true,
309
+ "watchAtMost": true,
310
+ "watchDebounced": true,
311
+ "watchDeep": true,
312
+ "watchEffect": true,
313
+ "watchIgnorable": true,
314
+ "watchImmediate": true,
315
+ "watchOnce": true,
316
+ "watchPausable": true,
317
+ "watchPostEffect": true,
318
+ "watchSyncEffect": true,
319
+ "watchThrottled": true,
320
+ "watchTriggerable": true,
321
+ "watchWithFilter": true,
322
+ "whenever": true,
323
+ "ElMessage": true,
324
+ "ElLoading": true,
325
+ "deleteHistoryBatch": true,
326
+ "deleteHistoryItem": true,
327
+ "getHistory": true,
328
+ "createConv": true,
329
+ "fetchHistoryList": true,
330
+ "stopChat": true,
331
+ "useChatStore": true,
332
+ "useChatStoreWithOut": true,
333
+ "useChatExchangeStore": true,
334
+ "useChatExchangeStoreWithOut": true,
335
+ "useExchangeStore": true,
336
+ "useExchangeStoreWithOut": true,
337
+ "delMessage": true,
338
+ "sendRating": true,
339
+ "getInitialActions": true,
340
+ "sendFeedback": true,
341
+ "md": true,
342
+ "useMarkdown": true,
343
+ "connectService": true,
344
+ "sendMessage": true,
345
+ "Audio": true,
346
+ "SoundRecording": true,
347
+ "getVolume": true,
348
+ "ElMessageBox": true,
349
+ "encodeWav": true,
350
+ "encodeWAV": true,
351
+ "stopMessage": true,
352
+ "TaskQueue": true,
353
+ "getNewUserId": true,
354
+ "setNewUserId": true,
355
+ "uploadFile": true,
356
+ "feedback": true,
357
+ "uploadConfig": true
358
+ }
359
+ }
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc.cjs ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-env node */
2
+ require('@rushstack/eslint-patch/modern-module-resolution');
3
+
4
+ module.exports = {
5
+ root: true,
6
+ extends: [
7
+ 'plugin:vue/vue3-essential',
8
+ 'eslint:recommended',
9
+ '@vue/eslint-config-prettier/skip-formatting',
10
+ './.eslintrc-auto-import.json',
11
+ ],
12
+ parserOptions: {
13
+ ecmaVersion: 'latest',
14
+ },
15
+ rules: {
16
+ 'no-console': process.env.NODE_ENV === 'production' ? 'off' : 'warn',
17
+ 'no-debugger': process.env.NODE_ENV === 'production' ? 'error' : 'warn',
18
+ 'no-var': process.env.NODE_ENV === 'production' ? 'off' : 'warn',
19
+ 'no-undef': process.env.NODE_ENV === 'production' ? 'error' : 'warn',
20
+ 'vue/multi-word-component-names': 'off', // 不校验组件名
21
+ 'no-empty': 0, // 允许代码块为空
22
+ 'vue/no-unused-components': 'warn',
23
+ 'no-unused-vars': 'warn',
24
+ 'prettier/prettier': 'off', // 不符合prettier格式规范的编码eslint直接自动报错
25
+ },
26
+ };
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import traceback
6
+ import re
7
+ import torch
8
+ import argparse
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ # README, How to run demo on different devices
12
+ # For Nvidia GPUs support BF16 (like A100, H100, RTX3090)
13
+ # python web_demo.py --device cuda --dtype bf16
14
+
15
+ # For Nvidia GPUs do NOT support BF16 (like V100, T4, RTX2080)
16
+ # python web_demo.py --device cuda --dtype fp16
17
+
18
+ # For Mac with MPS (Apple silicon or AMD GPUs).
19
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo.py --device mps --dtype fp16
20
+
21
+ # Argparser
22
+ parser = argparse.ArgumentParser(description='demo')
23
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
24
+ parser.add_argument('--dtype', type=str, default='bf16', help='bf16 or fp16')
25
+ args = parser.parse_args()
26
+ device = args.device
27
+ assert device in ['cuda', 'mps']
28
+ if args.dtype == 'bf16':
29
+ if device == 'mps':
30
+ print('Warning: MPS does not support bf16, will use fp16 instead')
31
+ dtype = torch.float16
32
+ else:
33
+ dtype = torch.bfloat16
34
+ else:
35
+ dtype = torch.float16
36
+
37
+ # Load model
38
+ model_path = 'openbmb/MiniCPM-V-2'
39
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
40
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
41
+
42
+ model = model.to(device=device, dtype=dtype)
43
+ model.eval()
44
+
45
+
46
+
47
+ ERROR_MSG = "Error, please retry"
48
+ model_name = 'MiniCPM-V 2.0'
49
+
50
+ form_radio = {
51
+ 'choices': ['Beam Search', 'Sampling'],
52
+ #'value': 'Beam Search',
53
+ 'value': 'Sampling',
54
+ 'interactive': True,
55
+ 'label': 'Decode Type'
56
+ }
57
+ # Beam Form
58
+ num_beams_slider = {
59
+ 'minimum': 0,
60
+ 'maximum': 5,
61
+ 'value': 3,
62
+ 'step': 1,
63
+ 'interactive': True,
64
+ 'label': 'Num Beams'
65
+ }
66
+ repetition_penalty_slider = {
67
+ 'minimum': 0,
68
+ 'maximum': 3,
69
+ 'value': 1.2,
70
+ 'step': 0.01,
71
+ 'interactive': True,
72
+ 'label': 'Repetition Penalty'
73
+ }
74
+ repetition_penalty_slider2 = {
75
+ 'minimum': 0,
76
+ 'maximum': 3,
77
+ 'value': 1.05,
78
+ 'step': 0.01,
79
+ 'interactive': True,
80
+ 'label': 'Repetition Penalty'
81
+ }
82
+ max_new_tokens_slider = {
83
+ 'minimum': 1,
84
+ 'maximum': 4096,
85
+ 'value': 1024,
86
+ 'step': 1,
87
+ 'interactive': True,
88
+ 'label': 'Max New Tokens'
89
+ }
90
+
91
+ top_p_slider = {
92
+ 'minimum': 0,
93
+ 'maximum': 1,
94
+ 'value': 0.8,
95
+ 'step': 0.05,
96
+ 'interactive': True,
97
+ 'label': 'Top P'
98
+ }
99
+ top_k_slider = {
100
+ 'minimum': 0,
101
+ 'maximum': 200,
102
+ 'value': 100,
103
+ 'step': 1,
104
+ 'interactive': True,
105
+ 'label': 'Top K'
106
+ }
107
+ temperature_slider = {
108
+ 'minimum': 0,
109
+ 'maximum': 2,
110
+ 'value': 0.7,
111
+ 'step': 0.05,
112
+ 'interactive': True,
113
+ 'label': 'Temperature'
114
+ }
115
+
116
+
117
+ def create_component(params, comp='Slider'):
118
+ if comp == 'Slider':
119
+ return gr.Slider(
120
+ minimum=params['minimum'],
121
+ maximum=params['maximum'],
122
+ value=params['value'],
123
+ step=params['step'],
124
+ interactive=params['interactive'],
125
+ label=params['label']
126
+ )
127
+ elif comp == 'Radio':
128
+ return gr.Radio(
129
+ choices=params['choices'],
130
+ value=params['value'],
131
+ interactive=params['interactive'],
132
+ label=params['label']
133
+ )
134
+ elif comp == 'Button':
135
+ return gr.Button(
136
+ value=params['value'],
137
+ interactive=True
138
+ )
139
+
140
+
141
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
142
+ default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
143
+ if params is None:
144
+ params = default_params
145
+ if img is None:
146
+ return -1, "Error, invalid image, please upload a new image", None, None
147
+ try:
148
+ image = img.convert('RGB')
149
+ answer, context, _ = model.chat(
150
+ image=image,
151
+ msgs=msgs,
152
+ context=None,
153
+ tokenizer=tokenizer,
154
+ **params
155
+ )
156
+ res = re.sub(r'(<box>.*</box>)', '', answer)
157
+ res = res.replace('<ref>', '')
158
+ res = res.replace('</ref>', '')
159
+ res = res.replace('<box>', '')
160
+ answer = res.replace('</box>', '')
161
+ return 0, answer, None, None
162
+ except Exception as err:
163
+ print(err)
164
+ traceback.print_exc()
165
+ return -1, ERROR_MSG, None, None
166
+
167
+
168
+ def upload_img(image, _chatbot, _app_session):
169
+ image = Image.fromarray(image)
170
+
171
+ _app_session['sts']=None
172
+ _app_session['ctx']=[]
173
+ _app_session['img']=image
174
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
175
+ return _chatbot, _app_session
176
+
177
+
178
+ def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
179
+ if _app_cfg.get('ctx', None) is None:
180
+ _chat_bot.append((_question, 'Please upload an image to start'))
181
+ return '', _chat_bot, _app_cfg
182
+
183
+ _context = _app_cfg['ctx'].copy()
184
+ if _context:
185
+ _context.append({"role": "user", "content": _question})
186
+ else:
187
+ _context = [{"role": "user", "content": _question}]
188
+ print('<User>:', _question)
189
+
190
+ if params_form == 'Beam Search':
191
+ params = {
192
+ 'sampling': False,
193
+ 'num_beams': num_beams,
194
+ 'repetition_penalty': repetition_penalty,
195
+ "max_new_tokens": 896
196
+ }
197
+ else:
198
+ params = {
199
+ 'sampling': True,
200
+ 'top_p': top_p,
201
+ 'top_k': top_k,
202
+ 'temperature': temperature,
203
+ 'repetition_penalty': repetition_penalty_2,
204
+ "max_new_tokens": 896
205
+ }
206
+ code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
207
+ print('<Assistant>:', _answer)
208
+
209
+ _context.append({"role": "assistant", "content": _answer})
210
+ _chat_bot.append((_question, _answer))
211
+ if code == 0:
212
+ _app_cfg['ctx']=_context
213
+ _app_cfg['sts']=sts
214
+ return '', _chat_bot, _app_cfg
215
+
216
+
217
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
218
+ if len(_chat_bot) <= 1:
219
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
220
+ return '', _chat_bot, _app_cfg
221
+ elif _chat_bot[-1][0] == 'Regenerate':
222
+ return '', _chat_bot, _app_cfg
223
+ else:
224
+ _question = _chat_bot[-1][0]
225
+ _chat_bot = _chat_bot[:-1]
226
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
227
+ return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
228
+
229
+
230
+
231
+ with gr.Blocks() as demo:
232
+ with gr.Row():
233
+ with gr.Column(scale=1, min_width=300):
234
+ params_form = create_component(form_radio, comp='Radio')
235
+ with gr.Accordion("Beam Search") as beams_according:
236
+ num_beams = create_component(num_beams_slider)
237
+ repetition_penalty = create_component(repetition_penalty_slider)
238
+ with gr.Accordion("Sampling") as sampling_according:
239
+ top_p = create_component(top_p_slider)
240
+ top_k = create_component(top_k_slider)
241
+ temperature = create_component(temperature_slider)
242
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
243
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
244
+ with gr.Column(scale=3, min_width=500):
245
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
246
+ bt_pic = gr.Image(label="Upload an image to start")
247
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
248
+ txt_message = gr.Textbox(label="Input text")
249
+
250
+ regenerate.click(
251
+ regenerate_button_clicked,
252
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
253
+ [txt_message, chat_bot, app_session]
254
+ )
255
+ txt_message.submit(
256
+ respond,
257
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
258
+ [txt_message, chat_bot, app_session]
259
+ )
260
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
261
+
262
+ # launch
263
+ demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
264
+
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.5.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import traceback
6
+ import re
7
+ import torch
8
+ import argparse
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ # README, How to run demo on different devices
12
+
13
+ # For Nvidia GPUs.
14
+ # python web_demo_2.5.py --device cuda
15
+
16
+ # For Mac with MPS (Apple silicon or AMD GPUs).
17
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
18
+
19
+ # Argparser
20
+ parser = argparse.ArgumentParser(description='demo')
21
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
22
+ args = parser.parse_args()
23
+ device = args.device
24
+ assert device in ['cuda', 'mps']
25
+
26
+ # Load model
27
+ model_path = 'openbmb/MiniCPM-Llama3-V-2_5'
28
+ if 'int4' in model_path:
29
+ if device == 'mps':
30
+ print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
31
+ exit()
32
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
33
+ else:
34
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
35
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
36
+ model.eval()
37
+
38
+
39
+
40
+ ERROR_MSG = "Error, please retry"
41
+ model_name = 'MiniCPM-V 2.5'
42
+
43
+ form_radio = {
44
+ 'choices': ['Beam Search', 'Sampling'],
45
+ #'value': 'Beam Search',
46
+ 'value': 'Sampling',
47
+ 'interactive': True,
48
+ 'label': 'Decode Type'
49
+ }
50
+ # Beam Form
51
+ num_beams_slider = {
52
+ 'minimum': 0,
53
+ 'maximum': 5,
54
+ 'value': 3,
55
+ 'step': 1,
56
+ 'interactive': True,
57
+ 'label': 'Num Beams'
58
+ }
59
+ repetition_penalty_slider = {
60
+ 'minimum': 0,
61
+ 'maximum': 3,
62
+ 'value': 1.2,
63
+ 'step': 0.01,
64
+ 'interactive': True,
65
+ 'label': 'Repetition Penalty'
66
+ }
67
+ repetition_penalty_slider2 = {
68
+ 'minimum': 0,
69
+ 'maximum': 3,
70
+ 'value': 1.05,
71
+ 'step': 0.01,
72
+ 'interactive': True,
73
+ 'label': 'Repetition Penalty'
74
+ }
75
+ max_new_tokens_slider = {
76
+ 'minimum': 1,
77
+ 'maximum': 4096,
78
+ 'value': 1024,
79
+ 'step': 1,
80
+ 'interactive': True,
81
+ 'label': 'Max New Tokens'
82
+ }
83
+
84
+ top_p_slider = {
85
+ 'minimum': 0,
86
+ 'maximum': 1,
87
+ 'value': 0.8,
88
+ 'step': 0.05,
89
+ 'interactive': True,
90
+ 'label': 'Top P'
91
+ }
92
+ top_k_slider = {
93
+ 'minimum': 0,
94
+ 'maximum': 200,
95
+ 'value': 100,
96
+ 'step': 1,
97
+ 'interactive': True,
98
+ 'label': 'Top K'
99
+ }
100
+ temperature_slider = {
101
+ 'minimum': 0,
102
+ 'maximum': 2,
103
+ 'value': 0.7,
104
+ 'step': 0.05,
105
+ 'interactive': True,
106
+ 'label': 'Temperature'
107
+ }
108
+
109
+
110
+ def create_component(params, comp='Slider'):
111
+ if comp == 'Slider':
112
+ return gr.Slider(
113
+ minimum=params['minimum'],
114
+ maximum=params['maximum'],
115
+ value=params['value'],
116
+ step=params['step'],
117
+ interactive=params['interactive'],
118
+ label=params['label']
119
+ )
120
+ elif comp == 'Radio':
121
+ return gr.Radio(
122
+ choices=params['choices'],
123
+ value=params['value'],
124
+ interactive=params['interactive'],
125
+ label=params['label']
126
+ )
127
+ elif comp == 'Button':
128
+ return gr.Button(
129
+ value=params['value'],
130
+ interactive=True
131
+ )
132
+
133
+
134
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
135
+ default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
136
+ if params is None:
137
+ params = default_params
138
+ if img is None:
139
+ return -1, "Error, invalid image, please upload a new image", None, None
140
+ try:
141
+ image = img.convert('RGB')
142
+ answer = model.chat(
143
+ image=image,
144
+ msgs=msgs,
145
+ tokenizer=tokenizer,
146
+ **params
147
+ )
148
+ res = re.sub(r'(<box>.*</box>)', '', answer)
149
+ res = res.replace('<ref>', '')
150
+ res = res.replace('</ref>', '')
151
+ res = res.replace('<box>', '')
152
+ answer = res.replace('</box>', '')
153
+ return 0, answer, None, None
154
+ except Exception as err:
155
+ print(err)
156
+ traceback.print_exc()
157
+ return -1, ERROR_MSG, None, None
158
+
159
+
160
+ def upload_img(image, _chatbot, _app_session):
161
+ image = Image.fromarray(image)
162
+
163
+ _app_session['sts']=None
164
+ _app_session['ctx']=[]
165
+ _app_session['img']=image
166
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
167
+ return _chatbot, _app_session
168
+
169
+
170
+ def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
171
+ if _app_cfg.get('ctx', None) is None:
172
+ _chat_bot.append((_question, 'Please upload an image to start'))
173
+ return '', _chat_bot, _app_cfg
174
+
175
+ _context = _app_cfg['ctx'].copy()
176
+ if _context:
177
+ _context.append({"role": "user", "content": _question})
178
+ else:
179
+ _context = [{"role": "user", "content": _question}]
180
+ print('<User>:', _question)
181
+
182
+ if params_form == 'Beam Search':
183
+ params = {
184
+ 'sampling': False,
185
+ 'num_beams': num_beams,
186
+ 'repetition_penalty': repetition_penalty,
187
+ "max_new_tokens": 896
188
+ }
189
+ else:
190
+ params = {
191
+ 'sampling': True,
192
+ 'top_p': top_p,
193
+ 'top_k': top_k,
194
+ 'temperature': temperature,
195
+ 'repetition_penalty': repetition_penalty_2,
196
+ "max_new_tokens": 896
197
+ }
198
+ code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
199
+ print('<Assistant>:', _answer)
200
+
201
+ _context.append({"role": "assistant", "content": _answer})
202
+ _chat_bot.append((_question, _answer))
203
+ if code == 0:
204
+ _app_cfg['ctx']=_context
205
+ _app_cfg['sts']=sts
206
+ return '', _chat_bot, _app_cfg
207
+
208
+
209
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
210
+ if len(_chat_bot) <= 1:
211
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
212
+ return '', _chat_bot, _app_cfg
213
+ elif _chat_bot[-1][0] == 'Regenerate':
214
+ return '', _chat_bot, _app_cfg
215
+ else:
216
+ _question = _chat_bot[-1][0]
217
+ _chat_bot = _chat_bot[:-1]
218
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
219
+ return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
220
+
221
+
222
+
223
+ with gr.Blocks() as demo:
224
+ with gr.Row():
225
+ with gr.Column(scale=1, min_width=300):
226
+ params_form = create_component(form_radio, comp='Radio')
227
+ with gr.Accordion("Beam Search") as beams_according:
228
+ num_beams = create_component(num_beams_slider)
229
+ repetition_penalty = create_component(repetition_penalty_slider)
230
+ with gr.Accordion("Sampling") as sampling_according:
231
+ top_p = create_component(top_p_slider)
232
+ top_k = create_component(top_k_slider)
233
+ temperature = create_component(temperature_slider)
234
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
235
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
236
+ with gr.Column(scale=3, min_width=500):
237
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
238
+ bt_pic = gr.Image(label="Upload an image to start")
239
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
240
+ txt_message = gr.Textbox(label="Input text")
241
+
242
+ regenerate.click(
243
+ regenerate_button_clicked,
244
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
245
+ [txt_message, chat_bot, app_session]
246
+ )
247
+ txt_message.submit(
248
+ respond,
249
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
250
+ [txt_message, chat_bot, app_session]
251
+ )
252
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
253
+
254
+ # launch
255
+ demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
256
+
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.6.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import torch
4
+ import argparse
5
+ from transformers import AutoModel, AutoTokenizer
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from decord import VideoReader, cpu
9
+ import io
10
+ import os
11
+ import copy
12
+ import requests
13
+ import base64
14
+ import json
15
+ import traceback
16
+ import re
17
+ import modelscope_studio as mgr
18
+
19
+
20
+ # README, How to run demo on different devices
21
+
22
+ # For Nvidia GPUs.
23
+ # python web_demo_2.6.py --device cuda
24
+
25
+ # For Mac with MPS (Apple silicon or AMD GPUs).
26
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.6.py --device mps
27
+
28
+ # Argparser
29
+ parser = argparse.ArgumentParser(description='demo')
30
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
31
+ parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
32
+ args = parser.parse_args()
33
+ device = args.device
34
+ assert device in ['cuda', 'mps']
35
+
36
+ # Load model
37
+ model_path = 'openbmb/MiniCPM-V-2_6'
38
+ if 'int4' in model_path:
39
+ if device == 'mps':
40
+ print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
41
+ exit()
42
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
43
+ else:
44
+ if args.multi_gpus:
45
+ from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
46
+ with init_empty_weights():
47
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
48
+ device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
49
+ no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
50
+ device_id = device_map["llm.model.embed_tokens"]
51
+ device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
52
+ device_map["vpm"] = device_id
53
+ device_map["resampler"] = device_id
54
+ device_id2 = device_map["llm.model.layers.26"]
55
+ device_map["llm.model.layers.8"] = device_id2
56
+ device_map["llm.model.layers.9"] = device_id2
57
+ device_map["llm.model.layers.10"] = device_id2
58
+ device_map["llm.model.layers.11"] = device_id2
59
+ device_map["llm.model.layers.12"] = device_id2
60
+ device_map["llm.model.layers.13"] = device_id2
61
+ device_map["llm.model.layers.14"] = device_id2
62
+ device_map["llm.model.layers.15"] = device_id2
63
+ device_map["llm.model.layers.16"] = device_id2
64
+ #print(device_map)
65
+
66
+ model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
67
+ else:
68
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
69
+ model = model.to(device=device)
70
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
71
+ model.eval()
72
+
73
+
74
+
75
+
76
+ ERROR_MSG = "Error, please retry"
77
+ model_name = 'MiniCPM-V 2.6'
78
+ MAX_NUM_FRAMES = 64
79
+ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
80
+ VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
81
+
82
+ def get_file_extension(filename):
83
+ return os.path.splitext(filename)[1].lower()
84
+
85
+ def is_image(filename):
86
+ return get_file_extension(filename) in IMAGE_EXTENSIONS
87
+
88
+ def is_video(filename):
89
+ return get_file_extension(filename) in VIDEO_EXTENSIONS
90
+
91
+
92
+ form_radio = {
93
+ 'choices': ['Beam Search', 'Sampling'],
94
+ #'value': 'Beam Search',
95
+ 'value': 'Sampling',
96
+ 'interactive': True,
97
+ 'label': 'Decode Type'
98
+ }
99
+
100
+
101
+ def create_component(params, comp='Slider'):
102
+ if comp == 'Slider':
103
+ return gr.Slider(
104
+ minimum=params['minimum'],
105
+ maximum=params['maximum'],
106
+ value=params['value'],
107
+ step=params['step'],
108
+ interactive=params['interactive'],
109
+ label=params['label']
110
+ )
111
+ elif comp == 'Radio':
112
+ return gr.Radio(
113
+ choices=params['choices'],
114
+ value=params['value'],
115
+ interactive=params['interactive'],
116
+ label=params['label']
117
+ )
118
+ elif comp == 'Button':
119
+ return gr.Button(
120
+ value=params['value'],
121
+ interactive=True
122
+ )
123
+
124
+
125
+ def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
126
+ return mgr.MultimodalInput(upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
127
+ upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
128
+ submit_button_props={'label': 'Submit'})
129
+
130
+
131
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
132
+ try:
133
+ print('msgs:', msgs)
134
+ answer = model.chat(
135
+ image=None,
136
+ msgs=msgs,
137
+ tokenizer=tokenizer,
138
+ **params
139
+ )
140
+ res = re.sub(r'(<box>.*</box>)', '', answer)
141
+ res = res.replace('<ref>', '')
142
+ res = res.replace('</ref>', '')
143
+ res = res.replace('<box>', '')
144
+ answer = res.replace('</box>', '')
145
+ print('answer:', answer)
146
+ return 0, answer, None, None
147
+ except Exception as e:
148
+ print(e)
149
+ traceback.print_exc()
150
+ return -1, ERROR_MSG, None, None
151
+
152
+
153
+ def encode_image(image):
154
+ if not isinstance(image, Image.Image):
155
+ if hasattr(image, 'path'):
156
+ image = Image.open(image.path).convert("RGB")
157
+ else:
158
+ image = Image.open(image.file.path).convert("RGB")
159
+ # resize to max_size
160
+ max_size = 448*16
161
+ if max(image.size) > max_size:
162
+ w,h = image.size
163
+ if w > h:
164
+ new_w = max_size
165
+ new_h = int(h * max_size / w)
166
+ else:
167
+ new_h = max_size
168
+ new_w = int(w * max_size / h)
169
+ image = image.resize((new_w, new_h), resample=Image.BICUBIC)
170
+ return image
171
+ ## save by BytesIO and convert to base64
172
+ #buffered = io.BytesIO()
173
+ #image.save(buffered, format="png")
174
+ #im_b64 = base64.b64encode(buffered.getvalue()).decode()
175
+ #return {"type": "image", "pairs": im_b64}
176
+
177
+
178
+ def encode_video(video):
179
+ def uniform_sample(l, n):
180
+ gap = len(l) / n
181
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
182
+ return [l[i] for i in idxs]
183
+
184
+ if hasattr(video, 'path'):
185
+ vr = VideoReader(video.path, ctx=cpu(0))
186
+ else:
187
+ vr = VideoReader(video.file.path, ctx=cpu(0))
188
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
189
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
190
+ if len(frame_idx)>MAX_NUM_FRAMES:
191
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
192
+ video = vr.get_batch(frame_idx).asnumpy()
193
+ video = [Image.fromarray(v.astype('uint8')) for v in video]
194
+ video = [encode_image(v) for v in video]
195
+ print('video frames:', len(video))
196
+ return video
197
+
198
+
199
+ def check_mm_type(mm_file):
200
+ if hasattr(mm_file, 'path'):
201
+ path = mm_file.path
202
+ else:
203
+ path = mm_file.file.path
204
+ if is_image(path):
205
+ return "image"
206
+ if is_video(path):
207
+ return "video"
208
+ return None
209
+
210
+
211
+ def encode_mm_file(mm_file):
212
+ if check_mm_type(mm_file) == 'image':
213
+ return [encode_image(mm_file)]
214
+ if check_mm_type(mm_file) == 'video':
215
+ return encode_video(mm_file)
216
+ return None
217
+
218
+ def make_text(text):
219
+ #return {"type": "text", "pairs": text} # # For remote call
220
+ return text
221
+
222
+ def encode_message(_question):
223
+ files = _question.files
224
+ question = _question.text
225
+ pattern = r"\[mm_media\]\d+\[/mm_media\]"
226
+ matches = re.split(pattern, question)
227
+ message = []
228
+ if len(matches) != len(files) + 1:
229
+ gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
230
+ assert len(matches) == len(files) + 1
231
+
232
+ text = matches[0].strip()
233
+ if text:
234
+ message.append(make_text(text))
235
+ for i in range(len(files)):
236
+ message += encode_mm_file(files[i])
237
+ text = matches[i + 1].strip()
238
+ if text:
239
+ message.append(make_text(text))
240
+ return message
241
+
242
+
243
+ def check_has_videos(_question):
244
+ images_cnt = 0
245
+ videos_cnt = 0
246
+ for file in _question.files:
247
+ if check_mm_type(file) == "image":
248
+ images_cnt += 1
249
+ else:
250
+ videos_cnt += 1
251
+ return images_cnt, videos_cnt
252
+
253
+
254
+ def count_video_frames(_context):
255
+ num_frames = 0
256
+ for message in _context:
257
+ for item in message["content"]:
258
+ #if item["type"] == "image": # For remote call
259
+ if isinstance(item, Image.Image):
260
+ num_frames += 1
261
+ return num_frames
262
+
263
+
264
+ def respond(_question, _chat_bot, _app_cfg, params_form):
265
+ _context = _app_cfg['ctx'].copy()
266
+ _context.append({'role': 'user', 'content': encode_message(_question)})
267
+
268
+ images_cnt = _app_cfg['images_cnt']
269
+ videos_cnt = _app_cfg['videos_cnt']
270
+ files_cnts = check_has_videos(_question)
271
+ if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
272
+ gr.Warning("Only supports single video file input right now!")
273
+ return _question, _chat_bot, _app_cfg
274
+
275
+ if params_form == 'Beam Search':
276
+ params = {
277
+ 'sampling': False,
278
+ 'num_beams': 3,
279
+ 'repetition_penalty': 1.2,
280
+ "max_new_tokens": 2048
281
+ }
282
+ else:
283
+ params = {
284
+ 'sampling': True,
285
+ 'top_p': 0.8,
286
+ 'top_k': 100,
287
+ 'temperature': 0.7,
288
+ 'repetition_penalty': 1.05,
289
+ "max_new_tokens": 2048
290
+ }
291
+
292
+ if files_cnts[1] + videos_cnt > 0:
293
+ params["max_inp_length"] = 4352 # 4096+256
294
+ params["use_image_id"] = False
295
+ params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
296
+
297
+ code, _answer, _, sts = chat("", _context, None, params)
298
+
299
+ images_cnt += files_cnts[0]
300
+ videos_cnt += files_cnts[1]
301
+ _context.append({"role": "assistant", "content": [make_text(_answer)]})
302
+ _chat_bot.append((_question, _answer))
303
+ if code == 0:
304
+ _app_cfg['ctx']=_context
305
+ _app_cfg['sts']=sts
306
+ _app_cfg['images_cnt'] = images_cnt
307
+ _app_cfg['videos_cnt'] = videos_cnt
308
+
309
+ upload_image_disabled = videos_cnt > 0
310
+ upload_video_disabled = videos_cnt > 0 or images_cnt > 0
311
+ return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
312
+
313
+
314
+ def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
315
+ ctx = _app_cfg["ctx"]
316
+ message_item = []
317
+ if _image is not None:
318
+ image = Image.open(_image).convert("RGB")
319
+ ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
320
+ message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
321
+ else:
322
+ if _user_message:
323
+ ctx.append({"role": "user", "content": [make_text(_user_message)]})
324
+ message_item.append({"text": _user_message, "files": []})
325
+ else:
326
+ message_item.append(None)
327
+ if _assistant_message:
328
+ ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
329
+ message_item.append({"text": _assistant_message, "files": []})
330
+ else:
331
+ message_item.append(None)
332
+
333
+ _chat_bot.append(message_item)
334
+ return None, "", "", _chat_bot, _app_cfg
335
+
336
+
337
+ def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
338
+ user_message_contents = []
339
+ _context = _app_cfg["ctx"].copy()
340
+ if _image:
341
+ image = Image.open(_image).convert("RGB")
342
+ user_message_contents += [encode_image(image)]
343
+ if _user_message:
344
+ user_message_contents += [make_text(_user_message)]
345
+ if user_message_contents:
346
+ _context.append({"role": "user", "content": user_message_contents})
347
+
348
+ if params_form == 'Beam Search':
349
+ params = {
350
+ 'sampling': False,
351
+ 'num_beams': 3,
352
+ 'repetition_penalty': 1.2,
353
+ "max_new_tokens": 2048
354
+ }
355
+ else:
356
+ params = {
357
+ 'sampling': True,
358
+ 'top_p': 0.8,
359
+ 'top_k': 100,
360
+ 'temperature': 0.7,
361
+ 'repetition_penalty': 1.05,
362
+ "max_new_tokens": 2048
363
+ }
364
+
365
+ code, _answer, _, sts = chat("", _context, None, params)
366
+
367
+ _context.append({"role": "assistant", "content": [make_text(_answer)]})
368
+
369
+ if _image:
370
+ _chat_bot.append([
371
+ {"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
372
+ {"text": _answer, "files": []}
373
+ ])
374
+ else:
375
+ _chat_bot.append([
376
+ {"text": _user_message, "files": [_image]},
377
+ {"text": _answer, "files": []}
378
+ ])
379
+ if code == 0:
380
+ _app_cfg['ctx']=_context
381
+ _app_cfg['sts']=sts
382
+ return None, '', '', _chat_bot, _app_cfg
383
+
384
+
385
+ def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
386
+ if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
387
+ gr.Warning('No question for regeneration.')
388
+ return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
389
+ if _app_cfg["chat_type"] == "Chat":
390
+ images_cnt = _app_cfg['images_cnt']
391
+ videos_cnt = _app_cfg['videos_cnt']
392
+ _question = _chat_bot[-1][0]
393
+ _chat_bot = _chat_bot[:-1]
394
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
395
+ files_cnts = check_has_videos(_question)
396
+ images_cnt -= files_cnts[0]
397
+ videos_cnt -= files_cnts[1]
398
+ _app_cfg['images_cnt'] = images_cnt
399
+ _app_cfg['videos_cnt'] = videos_cnt
400
+ upload_image_disabled = videos_cnt > 0
401
+ upload_video_disabled = videos_cnt > 0 or images_cnt > 0
402
+ _question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
403
+ return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
404
+ else:
405
+ last_message = _chat_bot[-1][0]
406
+ last_image = None
407
+ last_user_message = ''
408
+ if last_message.text:
409
+ last_user_message = last_message.text
410
+ if last_message.files:
411
+ last_image = last_message.files[0].file.path
412
+ _chat_bot = _chat_bot[:-1]
413
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
414
+ _image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
415
+ return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
416
+
417
+
418
+ def flushed():
419
+ return gr.update(interactive=True)
420
+
421
+
422
+ def clear(txt_message, chat_bot, app_session):
423
+ txt_message.files.clear()
424
+ txt_message.text = ''
425
+ chat_bot = copy.deepcopy(init_conversation)
426
+ app_session['sts'] = None
427
+ app_session['ctx'] = []
428
+ app_session['images_cnt'] = 0
429
+ app_session['videos_cnt'] = 0
430
+ return create_multimodal_input(), chat_bot, app_session, None, '', ''
431
+
432
+
433
+ def select_chat_type(_tab, _app_cfg):
434
+ _app_cfg["chat_type"] = _tab
435
+ return _app_cfg
436
+
437
+
438
+ init_conversation = [
439
+ [
440
+ None,
441
+ {
442
+ # The first message of bot closes the typewriter.
443
+ "text": "You can talk to me now",
444
+ "flushing": False
445
+ }
446
+ ],
447
+ ]
448
+
449
+
450
+ css = """
451
+ video { height: auto !important; }
452
+ .example label { font-size: 16px;}
453
+ """
454
+
455
+ introduction = """
456
+
457
+ ## Features:
458
+ 1. Chat with single image
459
+ 2. Chat with multiple images
460
+ 3. Chat with video
461
+ 4. In-context few-shot learning
462
+
463
+ Click `How to use` tab to see examples.
464
+ """
465
+
466
+
467
+ with gr.Blocks(css=css) as demo:
468
+ with gr.Tab(model_name):
469
+ with gr.Row():
470
+ with gr.Column(scale=1, min_width=300):
471
+ gr.Markdown(value=introduction)
472
+ params_form = create_component(form_radio, comp='Radio')
473
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
474
+ clear_button = create_component({'value': 'Clear History'}, comp='Button')
475
+
476
+ with gr.Column(scale=3, min_width=500):
477
+ app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
478
+ chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
479
+
480
+ with gr.Tab("Chat") as chat_tab:
481
+ txt_message = create_multimodal_input()
482
+ chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
483
+
484
+ txt_message.submit(
485
+ respond,
486
+ [txt_message, chat_bot, app_session, params_form],
487
+ [txt_message, chat_bot, app_session]
488
+ )
489
+
490
+ with gr.Tab("Few Shot") as fewshot_tab:
491
+ fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
492
+ with gr.Row():
493
+ with gr.Column(scale=1):
494
+ image_input = gr.Image(type="filepath", sources=["upload"])
495
+ with gr.Column(scale=3):
496
+ user_message = gr.Textbox(label="User")
497
+ assistant_message = gr.Textbox(label="Assistant")
498
+ with gr.Row():
499
+ add_demonstration_button = gr.Button("Add Example")
500
+ generate_button = gr.Button(value="Generate", variant="primary")
501
+ add_demonstration_button.click(
502
+ fewshot_add_demonstration,
503
+ [image_input, user_message, assistant_message, chat_bot, app_session],
504
+ [image_input, user_message, assistant_message, chat_bot, app_session]
505
+ )
506
+ generate_button.click(
507
+ fewshot_respond,
508
+ [image_input, user_message, chat_bot, app_session, params_form],
509
+ [image_input, user_message, assistant_message, chat_bot, app_session]
510
+ )
511
+
512
+ chat_tab.select(
513
+ select_chat_type,
514
+ [chat_tab_label, app_session],
515
+ [app_session]
516
+ )
517
+ chat_tab.select( # do clear
518
+ clear,
519
+ [txt_message, chat_bot, app_session],
520
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
521
+ )
522
+ fewshot_tab.select(
523
+ select_chat_type,
524
+ [fewshot_tab_label, app_session],
525
+ [app_session]
526
+ )
527
+ fewshot_tab.select( # do clear
528
+ clear,
529
+ [txt_message, chat_bot, app_session],
530
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
531
+ )
532
+ chat_bot.flushed(
533
+ flushed,
534
+ outputs=[txt_message]
535
+ )
536
+ regenerate.click(
537
+ regenerate_button_clicked,
538
+ [txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
539
+ [txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
540
+ )
541
+ clear_button.click(
542
+ clear,
543
+ [txt_message, chat_bot, app_session],
544
+ [txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
545
+ )
546
+
547
+ with gr.Tab("How to use"):
548
+ with gr.Column():
549
+ with gr.Row():
550
+ image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
551
+ example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
552
+ example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
553
+
554
+
555
+ # launch
556
+ demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")
557
+
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-2_5.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
+
6
+ # Model path
7
+ model_path = "openbmb/MiniCPM-Llama3-V-2_5"
8
+
9
+ # User and assistant names
10
+ U_NAME = "User"
11
+ A_NAME = "Assistant"
12
+
13
+ # Set page configuration
14
+ st.set_page_config(
15
+ page_title="MiniCPM-Llama3-V-2_5 Streamlit",
16
+ page_icon=":robot:",
17
+ layout="wide"
18
+ )
19
+
20
+
21
+ # Load model and tokenizer
22
+ @st.cache_resource
23
+ def load_model_and_tokenizer():
24
+ print(f"load_model_and_tokenizer from {model_path}")
25
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device="cuda")
26
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
27
+ return model, tokenizer
28
+
29
+
30
+ # Initialize session state
31
+ if 'model' not in st.session_state:
32
+ st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
33
+ st.session_state.model.eval()
34
+ print("model and tokenizer had loaded completed!")
35
+
36
+ # Initialize session state
37
+ if 'chat_history' not in st.session_state:
38
+ st.session_state.chat_history = []
39
+
40
+ # Sidebar settings
41
+ sidebar_name = st.sidebar.title("MiniCPM-Llama3-V-2_5 Streamlit")
42
+ max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
43
+ repetition_penalty = st.sidebar.slider("repetition_penalty", 0.0, 2.0, 1.05, step=0.01)
44
+ top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
45
+ top_k = st.sidebar.slider("top_k", 0, 100, 100, step=1)
46
+ temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
47
+
48
+ # Clear chat history button
49
+ buttonClean = st.sidebar.button("Clear chat history", key="clean")
50
+ if buttonClean:
51
+ st.session_state.chat_history = []
52
+ st.session_state.response = ""
53
+ if torch.cuda.is_available():
54
+ torch.cuda.empty_cache()
55
+ st.rerun()
56
+
57
+ # Display chat history
58
+ for i, message in enumerate(st.session_state.chat_history):
59
+ if message["role"] == "user":
60
+ with st.chat_message(name="user", avatar="user"):
61
+ if message["image"] is not None:
62
+ st.image(message["image"], caption='User uploaded image', width=448, use_column_width=False)
63
+ continue
64
+ elif message["content"] is not None:
65
+ st.markdown(message["content"])
66
+ else:
67
+ with st.chat_message(name="model", avatar="assistant"):
68
+ st.markdown(message["content"])
69
+
70
+ # Select mode
71
+ selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
72
+ if selected_mode == "Image":
73
+ # Image mode
74
+ uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"],
75
+ accept_multiple_files=False)
76
+ if uploaded_image is not None:
77
+ st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
78
+ # Add uploaded image to chat history
79
+ st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})
80
+
81
+ # User input box
82
+ user_text = st.chat_input("Enter your question")
83
+ if user_text:
84
+ with st.chat_message(U_NAME, avatar="user"):
85
+ st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
86
+ st.markdown(f"{U_NAME}: {user_text}")
87
+
88
+ # Generate reply using the model
89
+ model = st.session_state.model
90
+ tokenizer = st.session_state.tokenizer
91
+ imagefile = None
92
+
93
+ with st.chat_message(A_NAME, avatar="assistant"):
94
+ # If the previous message contains an image, pass the image to the model
95
+ if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
96
+ uploaded_image = st.session_state.chat_history[-2]["image"]
97
+ imagefile = Image.open(uploaded_image).convert('RGB')
98
+
99
+ msgs = [{"role": "user", "content": user_text}]
100
+ res = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
101
+ sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
102
+ temperature=temperature, stream=True)
103
+
104
+ # Collect the generated_text str
105
+ generated_text = st.write_stream(res)
106
+
107
+ st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None})
108
+
109
+ st.divider()
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-minicpmv2_6.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import streamlit as st
4
+ import torch
5
+ from PIL import Image
6
+ from decord import VideoReader, cpu
7
+ import numpy as np
8
+ from transformers import AutoModel, AutoTokenizer
9
+
10
+ # Model path
11
+ model_path = "openbmb/MiniCPM-V-2_6"
12
+ upload_path = ".\\uploads"
13
+
14
+ # User and assistant names
15
+ U_NAME = "User"
16
+ A_NAME = "Assistant"
17
+
18
+ # Set page configuration
19
+ st.set_page_config(
20
+ page_title="MiniCPM-V-2_6 Streamlit",
21
+ page_icon=":robot:",
22
+ layout="wide"
23
+ )
24
+
25
+
26
+ # Load model and tokenizer
27
+ @st.cache_resource
28
+ def load_model_and_tokenizer():
29
+ print(f"load_model_and_tokenizer from {model_path}")
30
+ model = (AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa').
31
+ to(dtype=torch.bfloat16))
32
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
33
+ return model, tokenizer
34
+
35
+
36
+ # Initialize session state
37
+ if 'model' not in st.session_state:
38
+ st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
39
+ st.session_state.model.eval().cuda()
40
+ print("model and tokenizer had loaded completed!")
41
+
42
+ # Initialize session state
43
+ if 'chat_history' not in st.session_state:
44
+ st.session_state.chat_history = []
45
+ st.session_state.uploaded_image_list = []
46
+ st.session_state.uploaded_image_num = 0
47
+ st.session_state.uploaded_video_list = []
48
+ st.session_state.uploaded_video_num = 0
49
+ st.session_state.response = ""
50
+
51
+ # Sidebar settings
52
+ sidebar_name = st.sidebar.title("MiniCPM-V-2_6 Streamlit")
53
+ max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
54
+ repetition_penalty = st.sidebar.slider("repetition_penalty", 0.0, 2.0, 1.05, step=0.01)
55
+ top_k = st.sidebar.slider("top_k", 0, 100, 100, step=1)
56
+ top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
57
+ temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
58
+
59
+ # Button to clear session history
60
+ buttonClean = st.sidebar.button("Clearing session history", key="clean")
61
+ if buttonClean:
62
+ # Reset the session state history and uploaded file lists
63
+ st.session_state.chat_history = []
64
+ st.session_state.uploaded_image_list = []
65
+ st.session_state.uploaded_image_num = 0
66
+ st.session_state.uploaded_video_list = []
67
+ st.session_state.uploaded_video_num = 0
68
+ st.session_state.response = ""
69
+
70
+ # If using GPU, clear the CUDA cache to free up memory
71
+ if torch.cuda.is_available():
72
+ torch.cuda.empty_cache()
73
+
74
+ # Rerun to refresh the interface
75
+ st.rerun()
76
+
77
+ # Display chat history
78
+ for i, message in enumerate(st.session_state.chat_history):
79
+ if message["role"] == "user":
80
+ with st.chat_message(name="user", avatar="user"):
81
+ if message["image"] is not None:
82
+ st.image(message["image"], caption='User uploaded images', width=512, use_column_width=False)
83
+ continue
84
+ elif message["video"] is not None:
85
+ st.video(message["video"], format="video/mp4", loop=False, autoplay=False, muted=True)
86
+ continue
87
+ elif message["content"] is not None:
88
+ st.markdown(message["content"])
89
+ else:
90
+ with st.chat_message(name="model", avatar="assistant"):
91
+ st.markdown(message["content"])
92
+
93
+ # Select mode
94
+ selected_mode = st.sidebar.selectbox("Select Mode", ["Text", "Single Image", "Multiple Images", "Video"])
95
+
96
+ # Supported image file extensions
97
+ image_type = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
98
+
99
+ if selected_mode == "Single Image":
100
+ # Single Image Mode
101
+ uploaded_image = st.sidebar.file_uploader("Upload a Single Image", key=1, type=image_type,
102
+ accept_multiple_files=False)
103
+ if uploaded_image is not None:
104
+ st.image(uploaded_image, caption='User Uploaded Image', width=512, use_column_width=False)
105
+ # Add the uploaded image to the chat history
106
+ st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image, "video": None})
107
+ st.session_state.uploaded_image_list = [uploaded_image]
108
+ st.session_state.uploaded_image_num = 1
109
+
110
+ if selected_mode == "Multiple Images":
111
+ # Multiple Images Mode
112
+ uploaded_image_list = st.sidebar.file_uploader("Upload Multiple Images", key=2, type=image_type,
113
+ accept_multiple_files=True)
114
+ uploaded_image_num = len(uploaded_image_list)
115
+
116
+ if uploaded_image_list is not None and uploaded_image_num > 0:
117
+ for img in uploaded_image_list:
118
+ st.image(img, caption='User Uploaded Image', width=512, use_column_width=False)
119
+ # Add the uploaded images to the chat history
120
+ st.session_state.chat_history.append({"role": "user", "content": None, "image": img, "video": None})
121
+ # Update the uploaded image list and count in st.session_state
122
+ st.session_state.uploaded_image_list = uploaded_image_list
123
+ st.session_state.uploaded_image_num = uploaded_image_num
124
+
125
+ # Supported video format suffixes
126
+ video_type = ['.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v']
127
+
128
+ # Tip: You can use the command `streamlit run ./web_demo_streamlit-minicpmv2_6.py --server.maxUploadSize 1024`
129
+ # to adjust the maximum upload size to 1024MB or larger files.
130
+ # The default 200MB limit of Streamlit's file_uploader component might be insufficient for video-based interactions.
131
+ # Adjust the size based on your GPU memory usage.
132
+
133
+ if selected_mode == "Video":
134
+ # 单个视频模态
135
+ uploaded_video = st.sidebar.file_uploader("Upload a single video file", key=3, type=video_type,
136
+ accept_multiple_files=False)
137
+ if uploaded_video is not None:
138
+ st.video(uploaded_video, format="video/mp4", loop=False, autoplay=False, muted=True)
139
+ st.session_state.chat_history.append({"role": "user", "content": None, "image": None, "video": uploaded_video})
140
+
141
+ uploaded_video_path = os.path.join(upload_path, uploaded_video.name)
142
+ with open(uploaded_video_path, "wb") as vf:
143
+ vf.write(uploaded_video.getvalue())
144
+ st.session_state.uploaded_video_list = [uploaded_video_path]
145
+ st.session_state.uploaded_video_num = 1
146
+
147
+ MAX_NUM_FRAMES = 64 # if cuda OOM set a smaller number
148
+
149
+
150
+ # Encodes a video by sampling frames at a fixed rate and converting them to image arrays.
151
+ def encode_video(video_path):
152
+ def uniform_sample(frame_indices, num_samples):
153
+ # Calculate sampling interval and uniformly sample frame indices
154
+ gap = len(frame_indices) / num_samples
155
+ sampled_idxs = np.linspace(gap / 2, len(frame_indices) - gap / 2, num_samples, dtype=int)
156
+ return [frame_indices[i] for i in sampled_idxs]
157
+
158
+ # Read the video and set the decoder's context to CPU
159
+ vr = VideoReader(video_path, ctx=cpu(0))
160
+
161
+ # Calculate the sampling interval to sample video frames at 1 FPS
162
+ sample_fps = round(vr.get_avg_fps() / 1) # Use integer FPS
163
+ frame_idx = list(range(0, len(vr), sample_fps))
164
+
165
+ # If the number of sampled frames exceeds the maximum limit, uniformly sample them
166
+ if len(frame_idx) > MAX_NUM_FRAMES:
167
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
168
+
169
+ # Retrieve the sampled frames and convert them to image arrays
170
+ frames = vr.get_batch(frame_idx).asnumpy()
171
+ frames = [Image.fromarray(frame.astype('uint8')) for frame in frames]
172
+
173
+ print('Number of frames:', len(frames))
174
+ return frames
175
+
176
+
177
+
178
+ # User input box
179
+ user_text = st.chat_input("Enter your question")
180
+ if user_text is not None:
181
+ if user_text.strip() is "":
182
+ st.warning('Input message could not be empty!', icon="⚠️")
183
+ else:
184
+ # Display user input and save it to session history
185
+ with st.chat_message(U_NAME, avatar="user"):
186
+ st.session_state.chat_history.append({
187
+ "role": "user",
188
+ "content": user_text,
189
+ "image": None,
190
+ "video": None
191
+ })
192
+ st.markdown(f"{U_NAME}: {user_text}")
193
+
194
+ # Generate responses using the model
195
+ model = st.session_state.model
196
+ tokenizer = st.session_state.tokenizer
197
+ content_list = [] # Store the content (text or image) that will be passed into the model
198
+ imageFile = None
199
+
200
+ with st.chat_message(A_NAME, avatar="assistant"):
201
+ # Handle different inputs depending on the mode selected by the user
202
+ if selected_mode == "Single Image":
203
+ # Single image mode: pass in the last uploaded image
204
+ print("Single Images mode in use")
205
+ if len(st.session_state.chat_history) > 1 and len(st.session_state.uploaded_image_list) >= 1:
206
+ uploaded_image = st.session_state.uploaded_image_list[-1]
207
+ if uploaded_image:
208
+ imageFile = Image.open(uploaded_image).convert('RGB')
209
+ content_list.append(imageFile)
210
+ else:
211
+ print("Single Images mode: No image found")
212
+
213
+ elif selected_mode == "Multiple Images":
214
+ # Multi-image mode: pass in all the images uploaded last time
215
+ print("Multiple Images mode in use")
216
+ if len(st.session_state.chat_history) > 1 and st.session_state.uploaded_image_num >= 1:
217
+ for uploaded_image in st.session_state.uploaded_image_list:
218
+ imageFile = Image.open(uploaded_image).convert('RGB')
219
+ content_list.append(imageFile)
220
+ else:
221
+ print("Multiple Images mode: No image found")
222
+
223
+ elif selected_mode == "Video":
224
+ # Video mode: pass in slice frames of uploaded video
225
+ print("Video mode in use")
226
+ if len(st.session_state.chat_history) > 1 and st.session_state.uploaded_video_num == 1:
227
+ uploaded_video_path = st.session_state.uploaded_video_list[-1]
228
+ if uploaded_video_path:
229
+ with st.spinner('Encoding your video, please wait...'):
230
+ frames = encode_video(uploaded_video_path)
231
+ else:
232
+ print("Video Mode: No video found")
233
+
234
+ # Defining model parameters
235
+ params = {
236
+ 'sampling': True,
237
+ 'top_p': top_p,
238
+ 'top_k': top_k,
239
+ 'temperature': temperature,
240
+ 'repetition_penalty': repetition_penalty,
241
+ "max_new_tokens": max_length,
242
+ "stream": True
243
+ }
244
+
245
+ # Set different input parameters depending on whether to upload a video
246
+ if st.session_state.uploaded_video_num == 1 and selected_mode == "Video":
247
+ msgs = [{"role": "user", "content": frames + [user_text]}]
248
+ # Set decode params for video
249
+ params["max_inp_length"] = 4352 # Set the maximum input length of the video mode
250
+ params["use_image_id"] = False # Do not use image_id
251
+ params["max_slice_nums"] = 1 # # use 1 if cuda OOM and video resolution > 448*448
252
+ else:
253
+ content_list.append(user_text)
254
+ msgs = [{"role": "user", "content": content_list}]
255
+
256
+ print("content_list:", content_list) # debug
257
+ print("params:", params) # debug
258
+
259
+ # Generate and display the model's responses
260
+ with st.spinner('AI is thinking...'):
261
+ response = model.chat(image=None, msgs=msgs, context=None, tokenizer=tokenizer, **params)
262
+ st.session_state.response = st.write_stream(response)
263
+ st.session_state.chat_history.append({
264
+ "role": "model",
265
+ "content": st.session_state.response,
266
+ "image": None,
267
+ "video": None
268
+ })
269
+
270
+ st.divider() # Add separators to the interface
271
+
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
+
6
+ # Model path
7
+ model_path = "openbmb/MiniCPM-V-2"
8
+
9
+ # User and assistant names
10
+ U_NAME = "User"
11
+ A_NAME = "Assistant"
12
+
13
+ # Set page configuration
14
+ st.set_page_config(
15
+ page_title="Minicpm-V-2 Streamlit",
16
+ page_icon=":robot:",
17
+ layout="wide"
18
+ )
19
+
20
+ # Load model and tokenizer
21
+ @st.cache_resource
22
+ def load_model_and_tokenizer():
23
+ print(f"load_model_and_tokenizer from {model_path}")
24
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(
25
+ device="cuda:0", dtype=torch.bfloat16)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
27
+ return model, tokenizer
28
+
29
+ # Initialize session state
30
+ if 'model' not in st.session_state:
31
+ st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
32
+ print("model and tokenizer had loaded completed!")
33
+
34
+ # Initialize session state
35
+ if 'chat_history' not in st.session_state:
36
+ st.session_state.chat_history = []
37
+
38
+ # Sidebar settings
39
+ sidebar_name = st.sidebar.title("Minicpm-V-2 Streamlit")
40
+ max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
41
+ top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
42
+ temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
43
+
44
+ # Clear chat history button
45
+ buttonClean = st.sidebar.button("Clear chat history", key="clean")
46
+ if buttonClean:
47
+ st.session_state.chat_history = []
48
+ st.session_state.response = ""
49
+ if torch.cuda.is_available():
50
+ torch.cuda.empty_cache()
51
+ st.rerun()
52
+
53
+ # Display chat history
54
+ for i, message in enumerate(st.session_state.chat_history):
55
+ if message["role"] == "user":
56
+ with st.chat_message(name="user", avatar="user"):
57
+ if message["image"] is not None:
58
+ st.image(message["image"], caption='User uploaded image', width=468, use_column_width=False)
59
+ continue
60
+ elif message["content"] is not None:
61
+ st.markdown(message["content"])
62
+ else:
63
+ with st.chat_message(name="model", avatar="assistant"):
64
+ st.markdown(message["content"])
65
+
66
+ # Select mode
67
+ selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
68
+ if selected_mode == "Image":
69
+ # Image mode
70
+ uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"], accept_multiple_files=False)
71
+ if uploaded_image is not None:
72
+ st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
73
+ # Add uploaded image to chat history
74
+ st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})
75
+
76
+ # User input box
77
+ user_text = st.chat_input("Enter your question")
78
+ if user_text:
79
+ with st.chat_message(U_NAME, avatar="user"):
80
+ st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
81
+ st.markdown(f"{U_NAME}: {user_text}")
82
+
83
+ # Generate reply using the model
84
+ model = st.session_state.model
85
+ tokenizer = st.session_state.tokenizer
86
+
87
+ with st.chat_message(A_NAME, avatar="assistant"):
88
+ # If the previous message contains an image, pass the image to the model
89
+ if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
90
+ uploaded_image = st.session_state.chat_history[-2]["image"]
91
+ imagefile = Image.open(uploaded_image).convert('RGB')
92
+
93
+ msgs = [{"role": "user", "content": user_text}]
94
+ res, context, _ = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
95
+ sampling=True,top_p=top_p,temperature=temperature)
96
+ st.markdown(f"{A_NAME}: {res}")
97
+ st.session_state.chat_history.append({"role": "model", "content": res, "image": None})
98
+
99
+ st.divider()