Add files using upload-large-folder tool
Browse files- r1-a/response_generation/minicpm/MiniCPM-o/assets/modelscope_logo.png +0 -0
- r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/__init__.py +1 -0
- r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/omnilmm.py +457 -0
- r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/resampler.py +171 -0
- r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/utils.py +555 -0
- r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/train/train_utils.py +153 -0
- r1-a/response_generation/minicpm/MiniCPM-o/quantize/bnb_quantize.py +81 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/chatbot_web_demo_o2.6.py +552 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/model_server.py +936 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/vad_utils.py +301 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.development +0 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.production +0 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc-auto-import.json +359 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc.cjs +26 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo.py +264 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.5.py +256 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.6.py +557 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-2_5.py +109 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-minicpmv2_6.py +271 -0
- r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit.py +99 -0
r1-a/response_generation/minicpm/MiniCPM-o/assets/modelscope_logo.png
ADDED
|
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .omnilmm import OmniLMMForCausalLM
|
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/omnilmm.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gc
|
| 3 |
+
import math
|
| 4 |
+
import timm
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn import CrossEntropyLoss
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 12 |
+
from transformers import MistralForCausalLM, MistralModel, MistralConfig
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 14 |
+
|
| 15 |
+
from omnilmm.model.utils import build_transform
|
| 16 |
+
from omnilmm.model.resampler import Resampler
|
| 17 |
+
|
| 18 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 19 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 20 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OmniLMMConfig(MistralConfig):
|
| 24 |
+
model_type = "omnilmm"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Identity(torch.nn.Identity):
|
| 28 |
+
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
| 29 |
+
return super().forward(input)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_vision_module(config):
|
| 33 |
+
vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
|
| 34 |
+
pretrained=False,
|
| 35 |
+
num_classes=0,
|
| 36 |
+
dynamic_img_size=True,
|
| 37 |
+
dynamic_img_pad=True)
|
| 38 |
+
|
| 39 |
+
if isinstance(vision_tower, timm.models.VisionTransformer):
|
| 40 |
+
if vision_tower.attn_pool is not None:
|
| 41 |
+
vision_tower.attn_pool = Identity()
|
| 42 |
+
|
| 43 |
+
# use 2nd last layer's output
|
| 44 |
+
vision_tower.blocks[-1] = Identity()
|
| 45 |
+
|
| 46 |
+
embed_dim = config.hidden_size
|
| 47 |
+
resampler = Resampler(
|
| 48 |
+
grid_size=int(math.sqrt(config.num_query)),
|
| 49 |
+
embed_dim=embed_dim,
|
| 50 |
+
num_heads=embed_dim // 128,
|
| 51 |
+
kv_dim=vision_tower.embed_dim,
|
| 52 |
+
)
|
| 53 |
+
return vision_tower, resampler
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class OmniLMMModel(MistralModel):
|
| 57 |
+
config_class = OmniLMMConfig
|
| 58 |
+
|
| 59 |
+
def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
|
| 60 |
+
super(OmniLMMModel, self).__init__(config)
|
| 61 |
+
|
| 62 |
+
if hasattr(config, "mm_vision_tower"):
|
| 63 |
+
vision_tower, resampler = create_vision_module(config)
|
| 64 |
+
|
| 65 |
+
# print(__file__, 'skip loading vision tower weights')
|
| 66 |
+
|
| 67 |
+
# HACK: for FSDP
|
| 68 |
+
self.vision_tower = [vision_tower]
|
| 69 |
+
self.resampler = resampler
|
| 70 |
+
if tune_clip:
|
| 71 |
+
self.vision_tower = self.vision_tower[0]
|
| 72 |
+
|
| 73 |
+
self.vision_config = lambda x: None
|
| 74 |
+
|
| 75 |
+
def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
|
| 76 |
+
self.config.mm_vision_tower = vision_tower
|
| 77 |
+
self.config.use_mm_proj = True
|
| 78 |
+
self.config.num_query = num_query
|
| 79 |
+
self.config.image_size = image_size
|
| 80 |
+
|
| 81 |
+
if not hasattr(self, 'vision_tower'):
|
| 82 |
+
vision_tower, resampler = create_vision_module(self.config)
|
| 83 |
+
state_dict = torch.load(
|
| 84 |
+
'/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
|
| 85 |
+
vision_tower.load_state_dict(state_dict, strict=False)
|
| 86 |
+
del state_dict
|
| 87 |
+
gc.collect()
|
| 88 |
+
else:
|
| 89 |
+
if isinstance(self.vision_tower, list):
|
| 90 |
+
vision_tower = self.vision_tower[0]
|
| 91 |
+
else:
|
| 92 |
+
vision_tower = self.vision_tower
|
| 93 |
+
resampler = self.resampler
|
| 94 |
+
self.vision_tower = vision_tower if tune_clip else [vision_tower]
|
| 95 |
+
self.resampler = resampler
|
| 96 |
+
|
| 97 |
+
train_img_transform = build_transform(
|
| 98 |
+
is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
| 99 |
+
eval_img_transform = build_transform(
|
| 100 |
+
is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
| 101 |
+
|
| 102 |
+
return dict(
|
| 103 |
+
image_processor=(train_img_transform, eval_img_transform),
|
| 104 |
+
image_token_len=num_query,
|
| 105 |
+
vision_config=self.vision_config
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def get_vision_embedding(self, pixel_values):
|
| 109 |
+
if isinstance(self.vision_tower, list):
|
| 110 |
+
vision_tower = self.vision_tower[0] # HACK: for FSDP
|
| 111 |
+
else:
|
| 112 |
+
vision_tower = self.vision_tower
|
| 113 |
+
|
| 114 |
+
dtype = vision_tower.pos_embed.data.dtype
|
| 115 |
+
vision_embedding = vision_tower.forward_features(
|
| 116 |
+
pixel_values.type(dtype))
|
| 117 |
+
if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
|
| 118 |
+
vision_embedding = vision_embedding[:,
|
| 119 |
+
vision_tower.num_prefix_tokens:]
|
| 120 |
+
res = self.resampler(vision_embedding)
|
| 121 |
+
return res
|
| 122 |
+
|
| 123 |
+
def get_vllm_embedding(self, data):
|
| 124 |
+
|
| 125 |
+
if 'vision_hidden_states' not in data:
|
| 126 |
+
pixel_values_list = data['pixel_values']
|
| 127 |
+
vision_hidden_states = []
|
| 128 |
+
for pixel_values in pixel_values_list:
|
| 129 |
+
if len(pixel_values) > 0:
|
| 130 |
+
vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
|
| 131 |
+
else:
|
| 132 |
+
vision_hidden_states.append([])
|
| 133 |
+
else:
|
| 134 |
+
vision_hidden_states = data['vision_hidden_states']
|
| 135 |
+
|
| 136 |
+
#vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
| 137 |
+
inputs_embeds = self.embed_tokens(data['input_ids'])
|
| 138 |
+
vision_hidden_states = [i.type(inputs_embeds.dtype)
|
| 139 |
+
if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# HACK: replace back original embeddings for LLaVA pretraining
|
| 144 |
+
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
| 145 |
+
|
| 146 |
+
new_input_embeds = []
|
| 147 |
+
cur_image_idx = 0
|
| 148 |
+
for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
|
| 149 |
+
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
| 150 |
+
# multimodal LLM, but the current sample is not multimodal
|
| 151 |
+
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
|
| 152 |
+
new_input_embeds.append(cur_input_embeds)
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
if self.vision_config.use_im_start_end:
|
| 156 |
+
cur_image_features = vision_hidden_states[cur_image_idx]
|
| 157 |
+
num_patches = cur_image_features.shape[0]
|
| 158 |
+
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"The number of image start tokens and image end tokens should be the same.")
|
| 161 |
+
image_start_tokens = torch.where(
|
| 162 |
+
cur_input_ids == self.vision_config.im_start_token)[0]
|
| 163 |
+
for image_start_token_pos in image_start_tokens:
|
| 164 |
+
cur_image_features = vision_hidden_states[cur_image_idx].to(
|
| 165 |
+
device=cur_input_embeds.device)
|
| 166 |
+
num_patches = cur_image_features.shape[0]
|
| 167 |
+
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
"The image end token should follow the image start token.")
|
| 170 |
+
if orig_embeds_params is not None:
|
| 171 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
| 172 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
| 173 |
+
else:
|
| 174 |
+
cur_new_input_embeds = torch.cat(
|
| 175 |
+
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
| 176 |
+
cur_image_idx += 1
|
| 177 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 178 |
+
else:
|
| 179 |
+
raise NotImplementedError
|
| 180 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
| 181 |
+
|
| 182 |
+
return inputs_embeds, vision_hidden_states
|
| 183 |
+
|
| 184 |
+
def forward(
|
| 185 |
+
self,
|
| 186 |
+
input_ids: torch.LongTensor = None,
|
| 187 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 188 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 189 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 190 |
+
use_cache: Optional[bool] = None,
|
| 191 |
+
output_attentions: Optional[bool] = None,
|
| 192 |
+
output_hidden_states: Optional[bool] = None,
|
| 193 |
+
images: Optional[torch.FloatTensor] = None,
|
| 194 |
+
return_dict: Optional[bool] = None,
|
| 195 |
+
**kwargs
|
| 196 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 197 |
+
|
| 198 |
+
# HACK: replace back original embeddings for LLaVA pretraining
|
| 199 |
+
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
| 200 |
+
|
| 201 |
+
if inputs_embeds is None and past_key_values is None:
|
| 202 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 203 |
+
|
| 204 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
| 205 |
+
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
|
| 206 |
+
|
| 207 |
+
if type(images) is list:
|
| 208 |
+
image_features = []
|
| 209 |
+
for image in images:
|
| 210 |
+
image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
|
| 211 |
+
0]
|
| 212 |
+
image_features.append(image_forward_out)
|
| 213 |
+
else:
|
| 214 |
+
image_features = self.get_vision_embedding(images)
|
| 215 |
+
|
| 216 |
+
dummy_image_features = torch.zeros(
|
| 217 |
+
self.config.num_query,
|
| 218 |
+
self.config.hidden_size,
|
| 219 |
+
device=inputs_embeds.device,
|
| 220 |
+
dtype=inputs_embeds.dtype)
|
| 221 |
+
|
| 222 |
+
new_input_embeds = []
|
| 223 |
+
cur_image_idx = 0
|
| 224 |
+
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
| 225 |
+
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
| 226 |
+
# multimodal LLM, but the current sample is not multimodal
|
| 227 |
+
cur_input_embeds = cur_input_embeds + \
|
| 228 |
+
(0. * dummy_image_features).sum()
|
| 229 |
+
new_input_embeds.append(cur_input_embeds)
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
if self.vision_config.use_im_start_end:
|
| 233 |
+
cur_image_features = image_features[cur_image_idx]
|
| 234 |
+
num_patches = cur_image_features.shape[0]
|
| 235 |
+
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
| 236 |
+
raise ValueError(
|
| 237 |
+
"The number of image start tokens and image end tokens should be the same.")
|
| 238 |
+
image_start_tokens = torch.where(
|
| 239 |
+
cur_input_ids == self.vision_config.im_start_token)[0]
|
| 240 |
+
for image_start_token_pos in image_start_tokens:
|
| 241 |
+
cur_image_features = image_features[cur_image_idx].to(
|
| 242 |
+
device=cur_input_embeds.device)
|
| 243 |
+
num_patches = cur_image_features.shape[0]
|
| 244 |
+
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
"The image end token should follow the image start token.")
|
| 247 |
+
if orig_embeds_params is not None:
|
| 248 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
| 249 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
| 250 |
+
else:
|
| 251 |
+
cur_new_input_embeds = torch.cat(
|
| 252 |
+
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
| 253 |
+
cur_image_idx += 1
|
| 254 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 255 |
+
else:
|
| 256 |
+
raise NotImplementedError
|
| 257 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
| 258 |
+
input_ids = None
|
| 259 |
+
|
| 260 |
+
return super(OmniLMMModel, self).forward(
|
| 261 |
+
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
|
| 262 |
+
inputs_embeds=inputs_embeds, use_cache=use_cache,
|
| 263 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
| 264 |
+
return_dict=return_dict,
|
| 265 |
+
**kwargs
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class OmniLMMForCausalLM(MistralForCausalLM):
|
| 270 |
+
config_class = OmniLMMConfig
|
| 271 |
+
|
| 272 |
+
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
|
| 273 |
+
super(MistralForCausalLM, self).__init__(config)
|
| 274 |
+
self.model = OmniLMMModel(
|
| 275 |
+
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
|
| 276 |
+
|
| 277 |
+
self.lm_head = nn.Linear(
|
| 278 |
+
config.hidden_size, config.vocab_size, bias=False)
|
| 279 |
+
|
| 280 |
+
# Initialize weights and apply final processing
|
| 281 |
+
self.post_init()
|
| 282 |
+
|
| 283 |
+
def forward(
|
| 284 |
+
self,
|
| 285 |
+
input_ids: torch.LongTensor = None,
|
| 286 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 287 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 288 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 289 |
+
labels: Optional[torch.LongTensor] = None,
|
| 290 |
+
use_cache: Optional[bool] = None,
|
| 291 |
+
output_attentions: Optional[bool] = None,
|
| 292 |
+
output_hidden_states: Optional[bool] = None,
|
| 293 |
+
images: Optional[torch.FloatTensor] = None,
|
| 294 |
+
return_dict: Optional[bool] = None,
|
| 295 |
+
**kwargs
|
| 296 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 297 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 298 |
+
output_hidden_states = (
|
| 299 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 300 |
+
)
|
| 301 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 302 |
+
|
| 303 |
+
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
|
| 304 |
+
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
|
| 305 |
+
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
|
| 306 |
+
|
| 307 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 308 |
+
outputs = self.model(
|
| 309 |
+
input_ids=input_ids,
|
| 310 |
+
attention_mask=attention_mask,
|
| 311 |
+
past_key_values=past_key_values,
|
| 312 |
+
inputs_embeds=inputs_embeds,
|
| 313 |
+
use_cache=use_cache,
|
| 314 |
+
output_attentions=output_attentions,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
return_dict=return_dict,
|
| 317 |
+
images=images,
|
| 318 |
+
**kwargs
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
hidden_states = outputs[0]
|
| 322 |
+
logits = self.lm_head(hidden_states)
|
| 323 |
+
|
| 324 |
+
loss = None
|
| 325 |
+
if labels is not None:
|
| 326 |
+
# Shift so that tokens < n predict n
|
| 327 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 328 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 329 |
+
# Flatten the tokens
|
| 330 |
+
loss_fct = CrossEntropyLoss()
|
| 331 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 332 |
+
shift_labels = shift_labels.view(-1)
|
| 333 |
+
# Enable model/pipeline parallelism
|
| 334 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 335 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 336 |
+
|
| 337 |
+
if not return_dict:
|
| 338 |
+
output = (logits,) + outputs[1:]
|
| 339 |
+
return (loss,) + output if loss is not None else output
|
| 340 |
+
|
| 341 |
+
return CausalLMOutputWithPast(
|
| 342 |
+
loss=loss,
|
| 343 |
+
logits=logits,
|
| 344 |
+
past_key_values=outputs.past_key_values,
|
| 345 |
+
hidden_states=outputs.hidden_states,
|
| 346 |
+
attentions=outputs.attentions,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# TODO could be removed for generate_vllm()
|
| 350 |
+
def prepare_inputs_for_generation(
|
| 351 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 352 |
+
):
|
| 353 |
+
if past_key_values:
|
| 354 |
+
input_ids = input_ids[:, -1:]
|
| 355 |
+
|
| 356 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 357 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 358 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 359 |
+
else:
|
| 360 |
+
model_inputs = {"input_ids": input_ids}
|
| 361 |
+
|
| 362 |
+
model_inputs.update(
|
| 363 |
+
{
|
| 364 |
+
"past_key_values": past_key_values,
|
| 365 |
+
"use_cache": kwargs.get("use_cache"),
|
| 366 |
+
"attention_mask": attention_mask,
|
| 367 |
+
"images": kwargs.get("images", None),
|
| 368 |
+
}
|
| 369 |
+
)
|
| 370 |
+
return model_inputs
|
| 371 |
+
|
| 372 |
+
def generate_vllm(
|
| 373 |
+
self,
|
| 374 |
+
input_ids: torch.LongTensor = None,
|
| 375 |
+
images: Optional[torch.FloatTensor] = None,
|
| 376 |
+
vision_hidden_states=None,
|
| 377 |
+
return_vision_hidden_states=False,
|
| 378 |
+
**kwargs
|
| 379 |
+
):
|
| 380 |
+
model_inputs = {'input_ids': input_ids}
|
| 381 |
+
if vision_hidden_states is None:
|
| 382 |
+
model_inputs['pixel_values'] = images
|
| 383 |
+
else:
|
| 384 |
+
model_inputs['vision_hidden_states'] = vision_hidden_states
|
| 385 |
+
|
| 386 |
+
with torch.inference_mode():
|
| 387 |
+
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
|
| 388 |
+
|
| 389 |
+
result = self.generate(
|
| 390 |
+
inputs_embeds=inputs_embeds,
|
| 391 |
+
**kwargs
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if return_vision_hidden_states:
|
| 395 |
+
return result, vision_hidden_states
|
| 396 |
+
|
| 397 |
+
return result
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
|
| 401 |
+
tune_mm_mlp_adapter=False):
|
| 402 |
+
self.model.vision_config.use_im_start_end = mm_use_im_start_end
|
| 403 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 404 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 405 |
+
|
| 406 |
+
if mm_use_im_start_end:
|
| 407 |
+
num_new_tokens = tokenizer.add_tokens(
|
| 408 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 409 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 410 |
+
self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
| 411 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
| 412 |
+
|
| 413 |
+
if num_new_tokens > 0:
|
| 414 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
| 415 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
| 416 |
+
|
| 417 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 418 |
+
dim=0, keepdim=True)
|
| 419 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 420 |
+
dim=0, keepdim=True)
|
| 421 |
+
|
| 422 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 423 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 424 |
+
|
| 425 |
+
# for new sft data
|
| 426 |
+
num_new_tokens = tokenizer.add_tokens(
|
| 427 |
+
['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
|
| 428 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 429 |
+
|
| 430 |
+
if num_new_tokens > 0:
|
| 431 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
| 432 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
| 433 |
+
|
| 434 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 435 |
+
dim=0, keepdim=True)
|
| 436 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 437 |
+
dim=0, keepdim=True)
|
| 438 |
+
|
| 439 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 440 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 441 |
+
|
| 442 |
+
if tune_mm_mlp_adapter:
|
| 443 |
+
self.model.orig_embeds_params = [
|
| 444 |
+
self.get_input_embeddings().weight.data.clone().to(device=device)]
|
| 445 |
+
for p in self.get_input_embeddings().parameters():
|
| 446 |
+
p.requires_grad = True
|
| 447 |
+
for p in self.get_output_embeddings().parameters():
|
| 448 |
+
p.requires_grad = False
|
| 449 |
+
|
| 450 |
+
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
| 451 |
+
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
| 452 |
+
print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
|
| 453 |
+
# exit()
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
AutoConfig.register("omnilmm", OmniLMMConfig)
|
| 457 |
+
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
|
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/resampler.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba Cloud.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import math
|
| 8 |
+
import requests
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from functools import partial
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from typing import Callable, Optional, Sequence, Tuple, List, Union
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
from torchvision.transforms import InterpolationMode
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_abs_pos(abs_pos, tgt_size):
|
| 24 |
+
# abs_pos: L, C
|
| 25 |
+
# tgt_size: M
|
| 26 |
+
# return: M, C
|
| 27 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
| 28 |
+
tgt_size = int(math.sqrt(tgt_size))
|
| 29 |
+
dtype = abs_pos.dtype
|
| 30 |
+
|
| 31 |
+
if src_size != tgt_size:
|
| 32 |
+
return F.interpolate(
|
| 33 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
| 34 |
+
size=(tgt_size, tgt_size),
|
| 35 |
+
mode="bicubic",
|
| 36 |
+
align_corners=False,
|
| 37 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
| 38 |
+
else:
|
| 39 |
+
return abs_pos
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
| 43 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 44 |
+
"""
|
| 45 |
+
grid_size: int of the grid height and width
|
| 46 |
+
return:
|
| 47 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 48 |
+
"""
|
| 49 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 50 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 51 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 52 |
+
grid = np.stack(grid, axis=0)
|
| 53 |
+
|
| 54 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 55 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 56 |
+
if cls_token:
|
| 57 |
+
pos_embed = np.concatenate(
|
| 58 |
+
[np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 59 |
+
return pos_embed
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 63 |
+
assert embed_dim % 2 == 0
|
| 64 |
+
|
| 65 |
+
# use half of dimensions to encode grid_h
|
| 66 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(
|
| 67 |
+
embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 68 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(
|
| 69 |
+
embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 70 |
+
|
| 71 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 72 |
+
return emb
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 76 |
+
"""
|
| 77 |
+
embed_dim: output dimension for each position
|
| 78 |
+
pos: a list of positions to be encoded: size (M,)
|
| 79 |
+
out: (M, D)
|
| 80 |
+
"""
|
| 81 |
+
assert embed_dim % 2 == 0
|
| 82 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 83 |
+
omega /= embed_dim / 2.
|
| 84 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
| 85 |
+
|
| 86 |
+
pos = pos.reshape(-1) # (M,)
|
| 87 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 88 |
+
|
| 89 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 90 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 91 |
+
|
| 92 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 93 |
+
return emb
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Resampler(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
A 2D perceiver-resampler network with one cross attention layers by
|
| 99 |
+
(grid_size**2) learnable queries and 2d sincos pos_emb
|
| 100 |
+
Outputs:
|
| 101 |
+
A tensor with the shape of (grid_size**2, embed_dim)
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
grid_size,
|
| 107 |
+
embed_dim,
|
| 108 |
+
num_heads,
|
| 109 |
+
kv_dim=None,
|
| 110 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
| 111 |
+
):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.num_queries = grid_size ** 2
|
| 114 |
+
self.embed_dim = embed_dim
|
| 115 |
+
self.num_heads = num_heads
|
| 116 |
+
|
| 117 |
+
self.pos_embed = nn.Parameter(
|
| 118 |
+
torch.from_numpy(get_2d_sincos_pos_embed(
|
| 119 |
+
embed_dim, grid_size)).float()
|
| 120 |
+
).requires_grad_(False)
|
| 121 |
+
|
| 122 |
+
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
| 123 |
+
trunc_normal_(self.query, std=.02)
|
| 124 |
+
|
| 125 |
+
if kv_dim is not None and kv_dim != embed_dim:
|
| 126 |
+
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
| 127 |
+
else:
|
| 128 |
+
self.kv_proj = nn.Identity()
|
| 129 |
+
|
| 130 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 131 |
+
self.ln_q = norm_layer(embed_dim)
|
| 132 |
+
self.ln_kv = norm_layer(embed_dim)
|
| 133 |
+
|
| 134 |
+
self.ln_post = norm_layer(embed_dim)
|
| 135 |
+
self.proj = nn.Parameter(
|
| 136 |
+
(embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
|
| 137 |
+
|
| 138 |
+
self.apply(self._init_weights)
|
| 139 |
+
|
| 140 |
+
def _init_weights(self, m):
|
| 141 |
+
if isinstance(m, nn.Linear):
|
| 142 |
+
trunc_normal_(m.weight, std=.02)
|
| 143 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 144 |
+
nn.init.constant_(m.bias, 0)
|
| 145 |
+
elif isinstance(m, nn.LayerNorm):
|
| 146 |
+
nn.init.constant_(m.bias, 0)
|
| 147 |
+
nn.init.constant_(m.weight, 1.0)
|
| 148 |
+
|
| 149 |
+
def forward(self, x, attn_mask=None):
|
| 150 |
+
|
| 151 |
+
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
| 152 |
+
|
| 153 |
+
x = self.kv_proj(x)
|
| 154 |
+
x = self.ln_kv(x).permute(1, 0, 2)
|
| 155 |
+
|
| 156 |
+
N = x.shape[1]
|
| 157 |
+
q = self.ln_q(self.query)
|
| 158 |
+
# print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
|
| 159 |
+
out = self.attn(
|
| 160 |
+
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
| 161 |
+
x + pos_embed.unsqueeze(1),
|
| 162 |
+
x,
|
| 163 |
+
attn_mask=attn_mask)[0]
|
| 164 |
+
x = out.permute(1, 0, 2)
|
| 165 |
+
|
| 166 |
+
x = self.ln_post(x)
|
| 167 |
+
x = x @ self.proj
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def _repeat(self, query, N: int):
|
| 171 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/model/utils.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
| 3 |
+
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 4 |
+
from transformers import AutoConfig
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pickle
|
| 10 |
+
import base64
|
| 11 |
+
import cv2
|
| 12 |
+
import os
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoConfig, StoppingCriteria
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
| 18 |
+
except ImportError:
|
| 19 |
+
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 20 |
+
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def auto_upgrade(config):
|
| 24 |
+
cfg = AutoConfig.from_pretrained(config)
|
| 25 |
+
if 'llava' in config and cfg.model_type != 'llava':
|
| 26 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
| 27 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
| 28 |
+
confirm = input(
|
| 29 |
+
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
| 30 |
+
if confirm.lower() in ["y", "yes"]:
|
| 31 |
+
print("Upgrading checkpoint...")
|
| 32 |
+
assert len(cfg.architectures) == 1
|
| 33 |
+
setattr(cfg.__class__, "model_type", "llava")
|
| 34 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
| 35 |
+
cfg.save_pretrained(config)
|
| 36 |
+
print("Checkpoint upgraded.")
|
| 37 |
+
else:
|
| 38 |
+
print("Checkpoint upgrade aborted.")
|
| 39 |
+
exit(1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 43 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 44 |
+
self.keywords = keywords
|
| 45 |
+
self.tokenizer = tokenizer
|
| 46 |
+
self.start_len = None
|
| 47 |
+
self.input_ids = input_ids
|
| 48 |
+
|
| 49 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 50 |
+
if self.start_len is None:
|
| 51 |
+
self.start_len = self.input_ids.shape[1]
|
| 52 |
+
else:
|
| 53 |
+
outputs = self.tokenizer.batch_decode(
|
| 54 |
+
output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
| 55 |
+
for keyword in self.keywords:
|
| 56 |
+
if keyword in outputs:
|
| 57 |
+
return True
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def auto_upgrade(config):
|
| 62 |
+
cfg = AutoConfig.from_pretrained(config)
|
| 63 |
+
if 'llava' in config and cfg.model_type != 'llava':
|
| 64 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
| 65 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
| 66 |
+
confirm = input(
|
| 67 |
+
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
| 68 |
+
if confirm.lower() in ["y", "yes"]:
|
| 69 |
+
print("Upgrading checkpoint...")
|
| 70 |
+
assert len(cfg.architectures) == 1
|
| 71 |
+
setattr(cfg.__class__, "model_type", "llava")
|
| 72 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
| 73 |
+
cfg.save_pretrained(config)
|
| 74 |
+
print("Checkpoint upgraded.")
|
| 75 |
+
else:
|
| 76 |
+
print("Checkpoint upgrade aborted.")
|
| 77 |
+
exit(1)
|
| 78 |
+
|
| 79 |
+
# aug functions
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def identity_func(img):
|
| 83 |
+
return img
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def autocontrast_func(img, cutoff=0):
|
| 87 |
+
'''
|
| 88 |
+
same output as PIL.ImageOps.autocontrast
|
| 89 |
+
'''
|
| 90 |
+
n_bins = 256
|
| 91 |
+
|
| 92 |
+
def tune_channel(ch):
|
| 93 |
+
n = ch.size
|
| 94 |
+
cut = cutoff * n // 100
|
| 95 |
+
if cut == 0:
|
| 96 |
+
high, low = ch.max(), ch.min()
|
| 97 |
+
else:
|
| 98 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
| 99 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
| 100 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
| 101 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
| 102 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
| 103 |
+
if high <= low:
|
| 104 |
+
table = np.arange(n_bins)
|
| 105 |
+
else:
|
| 106 |
+
scale = (n_bins - 1) / (high - low)
|
| 107 |
+
table = np.arange(n_bins) * scale - low * scale
|
| 108 |
+
table[table < 0] = 0
|
| 109 |
+
table[table > n_bins - 1] = n_bins - 1
|
| 110 |
+
table = table.clip(0, 255).astype(np.uint8)
|
| 111 |
+
return table[ch]
|
| 112 |
+
|
| 113 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
| 114 |
+
out = cv2.merge(channels)
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def equalize_func(img):
|
| 119 |
+
'''
|
| 120 |
+
same output as PIL.ImageOps.equalize
|
| 121 |
+
PIL's implementation is different from cv2.equalize
|
| 122 |
+
'''
|
| 123 |
+
n_bins = 256
|
| 124 |
+
|
| 125 |
+
def tune_channel(ch):
|
| 126 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
| 127 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
| 128 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
| 129 |
+
if step == 0:
|
| 130 |
+
return ch
|
| 131 |
+
n = np.empty_like(hist)
|
| 132 |
+
n[0] = step // 2
|
| 133 |
+
n[1:] = hist[:-1]
|
| 134 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
| 135 |
+
return table[ch]
|
| 136 |
+
|
| 137 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
| 138 |
+
out = cv2.merge(channels)
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
| 143 |
+
'''
|
| 144 |
+
like PIL, rotate by degree, not radians
|
| 145 |
+
'''
|
| 146 |
+
H, W = img.shape[0], img.shape[1]
|
| 147 |
+
center = W / 2, H / 2
|
| 148 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
| 149 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def solarize_func(img, thresh=128):
|
| 154 |
+
'''
|
| 155 |
+
same output as PIL.ImageOps.posterize
|
| 156 |
+
'''
|
| 157 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
| 158 |
+
table = table.clip(0, 255).astype(np.uint8)
|
| 159 |
+
out = table[img]
|
| 160 |
+
return out
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def color_func(img, factor):
|
| 164 |
+
'''
|
| 165 |
+
same output as PIL.ImageEnhance.Color
|
| 166 |
+
'''
|
| 167 |
+
# implementation according to PIL definition, quite slow
|
| 168 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
| 169 |
+
# out = blend(degenerate, img, factor)
|
| 170 |
+
# M = (
|
| 171 |
+
# np.eye(3) * factor
|
| 172 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
| 173 |
+
# )[np.newaxis, np.newaxis, :]
|
| 174 |
+
M = (
|
| 175 |
+
np.float32([
|
| 176 |
+
[0.886, -0.114, -0.114],
|
| 177 |
+
[-0.587, 0.413, -0.587],
|
| 178 |
+
[-0.299, -0.299, 0.701]]) * factor
|
| 179 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
| 180 |
+
)
|
| 181 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
| 182 |
+
return out
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def contrast_func(img, factor):
|
| 186 |
+
"""
|
| 187 |
+
same output as PIL.ImageEnhance.Contrast
|
| 188 |
+
"""
|
| 189 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
| 190 |
+
table = np.array([(
|
| 191 |
+
el - mean) * factor + mean
|
| 192 |
+
for el in range(256)
|
| 193 |
+
]).clip(0, 255).astype(np.uint8)
|
| 194 |
+
out = table[img]
|
| 195 |
+
return out
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def brightness_func(img, factor):
|
| 199 |
+
'''
|
| 200 |
+
same output as PIL.ImageEnhance.Contrast
|
| 201 |
+
'''
|
| 202 |
+
table = (np.arange(256, dtype=np.float32) *
|
| 203 |
+
factor).clip(0, 255).astype(np.uint8)
|
| 204 |
+
out = table[img]
|
| 205 |
+
return out
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def sharpness_func(img, factor):
|
| 209 |
+
'''
|
| 210 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
| 211 |
+
areas are same
|
| 212 |
+
'''
|
| 213 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
| 214 |
+
kernel[1][1] = 5
|
| 215 |
+
kernel /= 13
|
| 216 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
| 217 |
+
if factor == 0.0:
|
| 218 |
+
out = degenerate
|
| 219 |
+
elif factor == 1.0:
|
| 220 |
+
out = img
|
| 221 |
+
else:
|
| 222 |
+
out = img.astype(np.float32)
|
| 223 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
| 224 |
+
out[1:-1, 1:-1, :] = degenerate + factor * \
|
| 225 |
+
(out[1:-1, 1:-1, :] - degenerate)
|
| 226 |
+
out = out.astype(np.uint8)
|
| 227 |
+
return out
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
| 231 |
+
H, W = img.shape[0], img.shape[1]
|
| 232 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
| 233 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
| 234 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
| 235 |
+
return out
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
| 239 |
+
'''
|
| 240 |
+
same output as PIL.Image.transform
|
| 241 |
+
'''
|
| 242 |
+
H, W = img.shape[0], img.shape[1]
|
| 243 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
| 244 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
| 245 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
| 246 |
+
return out
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
| 250 |
+
'''
|
| 251 |
+
same output as PIL.Image.transform
|
| 252 |
+
'''
|
| 253 |
+
H, W = img.shape[0], img.shape[1]
|
| 254 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
| 255 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
| 256 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
| 257 |
+
return out
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def posterize_func(img, bits):
|
| 261 |
+
'''
|
| 262 |
+
same output as PIL.ImageOps.posterize
|
| 263 |
+
'''
|
| 264 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
| 265 |
+
return out
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
| 269 |
+
H, W = img.shape[0], img.shape[1]
|
| 270 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
| 271 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
| 272 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
| 273 |
+
return out
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
| 277 |
+
replace = np.array(replace, dtype=np.uint8)
|
| 278 |
+
H, W = img.shape[0], img.shape[1]
|
| 279 |
+
rh, rw = np.random.random(2)
|
| 280 |
+
pad_size = pad_size // 2
|
| 281 |
+
ch, cw = int(rh * H), int(rw * W)
|
| 282 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
| 283 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
| 284 |
+
out = img.copy()
|
| 285 |
+
out[x1:x2, y1:y2, :] = replace
|
| 286 |
+
return out
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# level to args
|
| 290 |
+
def enhance_level_to_args(MAX_LEVEL):
|
| 291 |
+
def level_to_args(level):
|
| 292 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
| 293 |
+
return level_to_args
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
| 297 |
+
def level_to_args(level):
|
| 298 |
+
level = (level / MAX_LEVEL) * 0.3
|
| 299 |
+
if np.random.random() > 0.5:
|
| 300 |
+
level = -level
|
| 301 |
+
return (level, replace_value)
|
| 302 |
+
|
| 303 |
+
return level_to_args
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
| 307 |
+
def level_to_args(level):
|
| 308 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
| 309 |
+
if np.random.random() > 0.5:
|
| 310 |
+
level = -level
|
| 311 |
+
return (level, replace_value)
|
| 312 |
+
|
| 313 |
+
return level_to_args
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
| 317 |
+
def level_to_args(level):
|
| 318 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
| 319 |
+
return (level, replace_value)
|
| 320 |
+
|
| 321 |
+
return level_to_args
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def solarize_level_to_args(MAX_LEVEL):
|
| 325 |
+
def level_to_args(level):
|
| 326 |
+
level = int((level / MAX_LEVEL) * 256)
|
| 327 |
+
return (level, )
|
| 328 |
+
return level_to_args
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def none_level_to_args(level):
|
| 332 |
+
return ()
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def posterize_level_to_args(MAX_LEVEL):
|
| 336 |
+
def level_to_args(level):
|
| 337 |
+
level = int((level / MAX_LEVEL) * 4)
|
| 338 |
+
return (level, )
|
| 339 |
+
return level_to_args
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
| 343 |
+
def level_to_args(level):
|
| 344 |
+
level = (level / MAX_LEVEL) * 30
|
| 345 |
+
if np.random.random() < 0.5:
|
| 346 |
+
level = -level
|
| 347 |
+
return (level, replace_value)
|
| 348 |
+
|
| 349 |
+
return level_to_args
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
func_dict = {
|
| 353 |
+
'Identity': identity_func,
|
| 354 |
+
'AutoContrast': autocontrast_func,
|
| 355 |
+
'Equalize': equalize_func,
|
| 356 |
+
'Rotate': rotate_func,
|
| 357 |
+
'Solarize': solarize_func,
|
| 358 |
+
'Color': color_func,
|
| 359 |
+
'Contrast': contrast_func,
|
| 360 |
+
'Brightness': brightness_func,
|
| 361 |
+
'Sharpness': sharpness_func,
|
| 362 |
+
'ShearX': shear_x_func,
|
| 363 |
+
'TranslateX': translate_x_func,
|
| 364 |
+
'TranslateY': translate_y_func,
|
| 365 |
+
'Posterize': posterize_func,
|
| 366 |
+
'ShearY': shear_y_func,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
translate_const = 10
|
| 370 |
+
MAX_LEVEL = 10
|
| 371 |
+
replace_value = (128, 128, 128)
|
| 372 |
+
arg_dict = {
|
| 373 |
+
'Identity': none_level_to_args,
|
| 374 |
+
'AutoContrast': none_level_to_args,
|
| 375 |
+
'Equalize': none_level_to_args,
|
| 376 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
| 377 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
| 378 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
| 379 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
| 380 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
| 381 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
| 382 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
| 383 |
+
'TranslateX': translate_level_to_args(
|
| 384 |
+
translate_const, MAX_LEVEL, replace_value
|
| 385 |
+
),
|
| 386 |
+
'TranslateY': translate_level_to_args(
|
| 387 |
+
translate_const, MAX_LEVEL, replace_value
|
| 388 |
+
),
|
| 389 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
| 390 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class RandomAugment(object):
|
| 395 |
+
|
| 396 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
| 397 |
+
self.N = N
|
| 398 |
+
self.M = M
|
| 399 |
+
self.isPIL = isPIL
|
| 400 |
+
if augs:
|
| 401 |
+
self.augs = augs
|
| 402 |
+
else:
|
| 403 |
+
self.augs = list(arg_dict.keys())
|
| 404 |
+
|
| 405 |
+
def get_random_ops(self):
|
| 406 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
| 407 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
| 408 |
+
|
| 409 |
+
def __call__(self, img):
|
| 410 |
+
if self.isPIL:
|
| 411 |
+
img = np.array(img)
|
| 412 |
+
ops = self.get_random_ops()
|
| 413 |
+
for name, prob, level in ops:
|
| 414 |
+
if np.random.random() > prob:
|
| 415 |
+
continue
|
| 416 |
+
args = arg_dict[name](level)
|
| 417 |
+
img = func_dict[name](img, *args)
|
| 418 |
+
return img
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
|
| 422 |
+
if std_mode == 'IMAGENET_INCEPTION':
|
| 423 |
+
mean = IMAGENET_INCEPTION_MEAN
|
| 424 |
+
std = IMAGENET_INCEPTION_STD
|
| 425 |
+
elif std_mode == 'OPENAI_CLIP':
|
| 426 |
+
mean = OPENAI_CLIP_MEAN
|
| 427 |
+
std = OPENAI_CLIP_STD
|
| 428 |
+
else:
|
| 429 |
+
raise NotImplementedError
|
| 430 |
+
|
| 431 |
+
if is_train:
|
| 432 |
+
crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
|
| 433 |
+
t = [
|
| 434 |
+
RandomResizedCropAndInterpolation(
|
| 435 |
+
input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
|
| 436 |
+
# transforms.RandomHorizontalFlip(),
|
| 437 |
+
]
|
| 438 |
+
if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
|
| 439 |
+
print(f'@@@@@ Do random aug during training', flush=True)
|
| 440 |
+
t.append(
|
| 441 |
+
RandomAugment(
|
| 442 |
+
2, 7, isPIL=True,
|
| 443 |
+
augs=[
|
| 444 |
+
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
| 445 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
| 446 |
+
]))
|
| 447 |
+
else:
|
| 448 |
+
print(f'@@@@@ Skip random aug during training', flush=True)
|
| 449 |
+
t += [
|
| 450 |
+
transforms.ToTensor(),
|
| 451 |
+
transforms.Normalize(mean=mean, std=std),
|
| 452 |
+
]
|
| 453 |
+
t = transforms.Compose(t)
|
| 454 |
+
else:
|
| 455 |
+
t = transforms.Compose([
|
| 456 |
+
transforms.Resize((input_size, input_size),
|
| 457 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
| 458 |
+
transforms.ToTensor(),
|
| 459 |
+
transforms.Normalize(mean=mean, std=std)
|
| 460 |
+
])
|
| 461 |
+
|
| 462 |
+
return t
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def img2b64(img_path):
|
| 466 |
+
img = Image.open(img_path) # path to file
|
| 467 |
+
img_buffer = BytesIO()
|
| 468 |
+
img.save(img_buffer, format=img.format)
|
| 469 |
+
byte_data = img_buffer.getvalue()
|
| 470 |
+
base64_str = base64.b64encode(byte_data) # bytes
|
| 471 |
+
base64_str = base64_str.decode("utf-8") # str
|
| 472 |
+
return base64_str
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def str2b64(str):
|
| 476 |
+
return base64.b64encode(str.encode('utf-8')).decode('utf-8')
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def b642str(b64):
|
| 480 |
+
return base64.b64decode(b64).decode('utf-8')
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def is_dist_avail_and_initialized():
|
| 484 |
+
if not dist.is_available():
|
| 485 |
+
return False
|
| 486 |
+
if not dist.is_initialized():
|
| 487 |
+
return False
|
| 488 |
+
return True
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def get_world_size():
|
| 492 |
+
if not is_dist_avail_and_initialized():
|
| 493 |
+
return 1
|
| 494 |
+
return dist.get_world_size()
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def get_rank():
|
| 498 |
+
if not is_dist_avail_and_initialized():
|
| 499 |
+
return 0
|
| 500 |
+
return dist.get_rank()
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def all_gather(data):
|
| 504 |
+
"""
|
| 505 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
| 506 |
+
Args:
|
| 507 |
+
data: any picklable object
|
| 508 |
+
Returns:
|
| 509 |
+
list[data]: list of data gathered from each rank
|
| 510 |
+
"""
|
| 511 |
+
world_size = get_world_size()
|
| 512 |
+
if world_size == 1:
|
| 513 |
+
return [data]
|
| 514 |
+
|
| 515 |
+
# serialized to a Tensor
|
| 516 |
+
buffer = pickle.dumps(data)
|
| 517 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
| 518 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
| 519 |
+
|
| 520 |
+
# obtain Tensor size of each rank
|
| 521 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
| 522 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
| 523 |
+
dist.all_gather(size_list, local_size)
|
| 524 |
+
size_list = [int(size.item()) for size in size_list]
|
| 525 |
+
max_size = max(size_list)
|
| 526 |
+
|
| 527 |
+
# receiving Tensor from all ranks
|
| 528 |
+
# we pad the tensor because torch all_gather does not support
|
| 529 |
+
# gathering tensors of different shapes
|
| 530 |
+
tensor_list = []
|
| 531 |
+
for _ in size_list:
|
| 532 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
| 533 |
+
if local_size != max_size:
|
| 534 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
| 535 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
| 536 |
+
dist.all_gather(tensor_list, tensor)
|
| 537 |
+
|
| 538 |
+
data_list = []
|
| 539 |
+
for size, tensor in zip(size_list, tensor_list):
|
| 540 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
| 541 |
+
data_list.append(pickle.loads(buffer))
|
| 542 |
+
|
| 543 |
+
return data_list
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def mean(lst):
|
| 547 |
+
return sum(lst) / len(lst)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def stop_gradient_by_name(name: str):
|
| 551 |
+
def apply_fn(module):
|
| 552 |
+
if hasattr(module, name):
|
| 553 |
+
getattr(module, name).requires_grad_(False)
|
| 554 |
+
|
| 555 |
+
return apply_fn
|
r1-a/response_generation/minicpm/MiniCPM-o/omnilmm/train/train_utils.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import copy
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import warnings
|
| 8 |
+
import transformers
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from typing import Dict, Optional, Sequence
|
| 13 |
+
from omnilmm import conversation as conversation_lib
|
| 14 |
+
|
| 15 |
+
IGNORE_INDEX = -100
|
| 16 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 17 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 18 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 19 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _tokenize_fn(strings: Sequence[str],
|
| 23 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 24 |
+
"""Tokenize a list of strings."""
|
| 25 |
+
tokenized_list = [
|
| 26 |
+
tokenizer(
|
| 27 |
+
text,
|
| 28 |
+
return_tensors="pt",
|
| 29 |
+
padding="longest",
|
| 30 |
+
max_length=tokenizer.model_max_length,
|
| 31 |
+
truncation=True,
|
| 32 |
+
) for text in strings
|
| 33 |
+
]
|
| 34 |
+
input_ids = labels = [
|
| 35 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
| 36 |
+
]
|
| 37 |
+
input_ids_lens = labels_lens = [
|
| 38 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
| 39 |
+
for tokenized in tokenized_list
|
| 40 |
+
]
|
| 41 |
+
return dict(
|
| 42 |
+
input_ids=input_ids,
|
| 43 |
+
labels=labels,
|
| 44 |
+
input_ids_lens=input_ids_lens,
|
| 45 |
+
labels_lens=labels_lens,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def omni_preprocess(sources,
|
| 51 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 52 |
+
generation=False):
|
| 53 |
+
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
|
| 54 |
+
ignore_index = -100
|
| 55 |
+
|
| 56 |
+
response_template = '\n<|assistant|>\n'
|
| 57 |
+
instruction_template = '\n<|user|>\n'
|
| 58 |
+
response_token_ids = tokenizer.encode(
|
| 59 |
+
response_template, add_special_tokens=False)
|
| 60 |
+
instruction_token_ids = tokenizer.encode(
|
| 61 |
+
instruction_template, add_special_tokens=False)
|
| 62 |
+
|
| 63 |
+
batch_input_ids = []
|
| 64 |
+
batch_labels = []
|
| 65 |
+
for i in range(len(sources)):
|
| 66 |
+
new_source = []
|
| 67 |
+
prev_role = 'unexpect'
|
| 68 |
+
for conv_turn in sources[i]:
|
| 69 |
+
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
|
| 70 |
+
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
|
| 71 |
+
|
| 72 |
+
role = 'user' if role == 'human' else role
|
| 73 |
+
role = 'assistant' if role == 'gpt' else role
|
| 74 |
+
|
| 75 |
+
assert role in ['user', 'assistant']
|
| 76 |
+
assert role != prev_role, f'role={role}, prev_role={prev_role}'
|
| 77 |
+
prev_role = role
|
| 78 |
+
|
| 79 |
+
new_turn = {
|
| 80 |
+
'role': role,
|
| 81 |
+
'content': content
|
| 82 |
+
}
|
| 83 |
+
new_source.append(new_turn)
|
| 84 |
+
if new_source[0]['role'] != 'system':
|
| 85 |
+
new_source.insert(0, {'role': 'system', 'content': system_content})
|
| 86 |
+
|
| 87 |
+
# TODO: this automatically add '\n' to the end
|
| 88 |
+
res_text = tokenizer.apply_chat_template(
|
| 89 |
+
new_source, tokenize=False, add_generation_prompt=generation)
|
| 90 |
+
if not generation:
|
| 91 |
+
res_text = res_text.strip()
|
| 92 |
+
|
| 93 |
+
conversations_tokenized = _tokenize_fn([res_text], tokenizer)
|
| 94 |
+
res_input_ids = conversations_tokenized["input_ids"][0]
|
| 95 |
+
|
| 96 |
+
# since labels and input_ids are reference towards the same object
|
| 97 |
+
res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
|
| 98 |
+
|
| 99 |
+
response_token_ids_idxs = []
|
| 100 |
+
human_token_ids_idxs = []
|
| 101 |
+
|
| 102 |
+
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
|
| 103 |
+
# find the indexes of the start of a response.
|
| 104 |
+
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
|
| 105 |
+
response_token_ids)].tolist()
|
| 106 |
+
):
|
| 107 |
+
response_token_ids_idxs.append(
|
| 108 |
+
assistant_idx + len(response_token_ids))
|
| 109 |
+
|
| 110 |
+
if len(response_token_ids_idxs) == 0:
|
| 111 |
+
warnings.warn(
|
| 112 |
+
f"Could not find response key `{response_template}` in the "
|
| 113 |
+
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
| 114 |
+
f'Raw text is @===>{res_text}<===@'
|
| 115 |
+
f'Raw source is @===>{new_source}<===@'
|
| 116 |
+
f"This instance will be ignored in loss calculation. "
|
| 117 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
| 118 |
+
)
|
| 119 |
+
res_labels[:] = ignore_index
|
| 120 |
+
|
| 121 |
+
human_token_ids = instruction_token_ids
|
| 122 |
+
for human_idx in np.where(res_labels == human_token_ids[0])[0]:
|
| 123 |
+
# find the indexes of the start of a human answer.
|
| 124 |
+
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
|
| 125 |
+
human_token_ids_idxs.append(human_idx)
|
| 126 |
+
|
| 127 |
+
if len(human_token_ids_idxs) == 0:
|
| 128 |
+
warnings.warn(
|
| 129 |
+
f"Could not find instruction key `{instruction_template}` in the "
|
| 130 |
+
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
| 131 |
+
f'Raw text is @===>{res_text}<===@'
|
| 132 |
+
f'Raw source is @===>{new_source}<===@'
|
| 133 |
+
f"This instance will be ignored in loss calculation. "
|
| 134 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
| 135 |
+
)
|
| 136 |
+
res_labels[:] = ignore_index
|
| 137 |
+
|
| 138 |
+
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
|
| 139 |
+
# Make pytorch loss function ignore all non response tokens
|
| 140 |
+
if idx != 0:
|
| 141 |
+
res_labels[start:end] = ignore_index
|
| 142 |
+
else:
|
| 143 |
+
res_labels[:end] = ignore_index
|
| 144 |
+
|
| 145 |
+
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
|
| 146 |
+
res_labels[human_token_ids_idxs[-1]:] = ignore_index
|
| 147 |
+
|
| 148 |
+
batch_input_ids.append(res_input_ids)
|
| 149 |
+
batch_labels.append(res_labels)
|
| 150 |
+
|
| 151 |
+
return dict(input_ids=batch_input_ids, labels=batch_labels)
|
| 152 |
+
|
| 153 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/quantize/bnb_quantize.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
the script will use bitandbytes to quantize the MiniCPM-Llama3-V-2_5 model.
|
| 3 |
+
the be quantized model can be finetuned by MiniCPM-Llama3-V-2_5 or not.
|
| 4 |
+
you only need to set the model_path 、save_path and run bash code
|
| 5 |
+
|
| 6 |
+
cd MiniCPM-V
|
| 7 |
+
python quantize/bnb_quantize.py
|
| 8 |
+
|
| 9 |
+
you will get the quantized model in save_path、quantized_model test time and gpu usage
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import time
|
| 17 |
+
import torch
|
| 18 |
+
import GPUtil
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
assert torch.cuda.is_available(),"CUDA is not available, but this code requires a GPU."
|
| 22 |
+
|
| 23 |
+
device = 'cuda' # Select GPU to use
|
| 24 |
+
model_path = '/root/ld/ld_model_pretrained/MiniCPM-Llama3-V-2_5' # Model download path
|
| 25 |
+
save_path = '/root/ld/ld_model_pretrain/MiniCPM-Llama3-V-2_5_int4' # Quantized model save path
|
| 26 |
+
image_path = './assets/airplane.jpeg'
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Create a configuration object to specify quantization parameters
|
| 30 |
+
quantization_config = BitsAndBytesConfig(
|
| 31 |
+
load_in_4bit=True, # Whether to perform 4-bit quantization
|
| 32 |
+
load_in_8bit=False, # Whether to perform 8-bit quantization
|
| 33 |
+
bnb_4bit_compute_dtype=torch.float16, # Computation precision setting
|
| 34 |
+
bnb_4bit_quant_storage=torch.uint8, # Storage format for quantized weights
|
| 35 |
+
bnb_4bit_quant_type="nf4", # Quantization format, here using normally distributed int4
|
| 36 |
+
bnb_4bit_use_double_quant=True, # Whether to use double quantization, i.e., quantizing zeropoint and scaling parameters
|
| 37 |
+
llm_int8_enable_fp32_cpu_offload=False, # Whether LLM uses int8, with fp32 parameters stored on the CPU
|
| 38 |
+
llm_int8_has_fp16_weight=False, # Whether mixed precision is enabled
|
| 39 |
+
llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # Modules not to be quantized
|
| 40 |
+
llm_int8_threshold=6.0 # Outlier value in the llm.int8() algorithm, distinguishing whether to perform quantization based on this value
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 44 |
+
model = AutoModel.from_pretrained(
|
| 45 |
+
model_path,
|
| 46 |
+
device_map=device, # Allocate model to device
|
| 47 |
+
quantization_config=quantization_config,
|
| 48 |
+
trust_remote_code=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
gpu_usage = GPUtil.getGPUs()[0].memoryUsed
|
| 52 |
+
start=time.time()
|
| 53 |
+
response = model.chat(
|
| 54 |
+
image=Image.open(image_path).convert("RGB"),
|
| 55 |
+
msgs=[
|
| 56 |
+
{
|
| 57 |
+
"role": "user",
|
| 58 |
+
"content": "What is in this picture?"
|
| 59 |
+
}
|
| 60 |
+
],
|
| 61 |
+
tokenizer=tokenizer
|
| 62 |
+
) # 模型推理
|
| 63 |
+
print('Output after quantization:',response)
|
| 64 |
+
print('Inference time after quantization:',time.time()-start)
|
| 65 |
+
print(f"GPU memory usage after quantization: {round(gpu_usage/1024,2)}GB")
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
Expected output:
|
| 69 |
+
|
| 70 |
+
Output after quantization: This picture contains specific parts of an airplane, including wings, engines, and tail sections. These components are key parts of large commercial aircraft.
|
| 71 |
+
The wings support lift during flight, while the engines provide thrust to move the plane forward. The tail section is typically used for stabilizing flight and plays a role in airline branding.
|
| 72 |
+
The design and color of the airplane indicate that it belongs to Air China, likely a passenger aircraft due to its large size and twin-engine configuration.
|
| 73 |
+
There are no markings or insignia on the airplane indicating the specific model or registration number; such information may require additional context or a clearer perspective to discern.
|
| 74 |
+
Inference time after quantization: 8.583992719650269 seconds
|
| 75 |
+
GPU memory usage after quantization: 6.41 GB
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
# Save the model and tokenizer
|
| 79 |
+
os.makedirs(save_path, exist_ok=True)
|
| 80 |
+
model.save_pretrained(save_path, safe_serialization=True)
|
| 81 |
+
tokenizer.save_pretrained(save_path)
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/chatbot_web_demo_o2.6.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from decord import VideoReader, cpu
|
| 9 |
+
import io
|
| 10 |
+
import os
|
| 11 |
+
import copy
|
| 12 |
+
import requests
|
| 13 |
+
import base64
|
| 14 |
+
import json
|
| 15 |
+
import traceback
|
| 16 |
+
import re
|
| 17 |
+
import modelscope_studio as mgr
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# README, How to run demo on different devices
|
| 21 |
+
|
| 22 |
+
# For Nvidia GPUs.
|
| 23 |
+
# python chatbot_web_demo_o2.6.py
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Argparser
|
| 27 |
+
parser = argparse.ArgumentParser(description='demo')
|
| 28 |
+
parser.add_argument('--model', type=str , default="openbmb/MiniCPM-o-2_6", help="huggingface model name or local path")
|
| 29 |
+
parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
device = "cuda"
|
| 32 |
+
model_name = 'MiniCPM-o 2.6'
|
| 33 |
+
|
| 34 |
+
# Load model
|
| 35 |
+
model_path = args.model
|
| 36 |
+
if args.multi_gpus:
|
| 37 |
+
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
|
| 38 |
+
with init_empty_weights():
|
| 39 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16,
|
| 40 |
+
init_audio=False, init_tts=False)
|
| 41 |
+
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
|
| 42 |
+
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
|
| 43 |
+
device_id = device_map["llm.model.embed_tokens"]
|
| 44 |
+
device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
|
| 45 |
+
device_map["vpm"] = device_id
|
| 46 |
+
device_map["resampler"] = device_id
|
| 47 |
+
device_id2 = device_map["llm.model.layers.26"]
|
| 48 |
+
device_map["llm.model.layers.8"] = device_id2
|
| 49 |
+
device_map["llm.model.layers.9"] = device_id2
|
| 50 |
+
device_map["llm.model.layers.10"] = device_id2
|
| 51 |
+
device_map["llm.model.layers.11"] = device_id2
|
| 52 |
+
device_map["llm.model.layers.12"] = device_id2
|
| 53 |
+
device_map["llm.model.layers.13"] = device_id2
|
| 54 |
+
device_map["llm.model.layers.14"] = device_id2
|
| 55 |
+
device_map["llm.model.layers.15"] = device_id2
|
| 56 |
+
device_map["llm.model.layers.16"] = device_id2
|
| 57 |
+
#print(device_map)
|
| 58 |
+
|
| 59 |
+
model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
|
| 60 |
+
else:
|
| 61 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, init_audio=False, init_tts=False)
|
| 62 |
+
model = model.to(device=device)
|
| 63 |
+
|
| 64 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 65 |
+
model.eval()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
ERROR_MSG = "Error, please retry"
|
| 71 |
+
MAX_NUM_FRAMES = 64
|
| 72 |
+
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
| 73 |
+
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
|
| 74 |
+
|
| 75 |
+
def get_file_extension(filename):
|
| 76 |
+
return os.path.splitext(filename)[1].lower()
|
| 77 |
+
|
| 78 |
+
def is_image(filename):
|
| 79 |
+
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
| 80 |
+
|
| 81 |
+
def is_video(filename):
|
| 82 |
+
return get_file_extension(filename) in VIDEO_EXTENSIONS
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
form_radio = {
|
| 86 |
+
'choices': ['Beam Search', 'Sampling'],
|
| 87 |
+
#'value': 'Beam Search',
|
| 88 |
+
'value': 'Sampling',
|
| 89 |
+
'interactive': True,
|
| 90 |
+
'label': 'Decode Type'
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def create_component(params, comp='Slider'):
|
| 95 |
+
if comp == 'Slider':
|
| 96 |
+
return gr.Slider(
|
| 97 |
+
minimum=params['minimum'],
|
| 98 |
+
maximum=params['maximum'],
|
| 99 |
+
value=params['value'],
|
| 100 |
+
step=params['step'],
|
| 101 |
+
interactive=params['interactive'],
|
| 102 |
+
label=params['label']
|
| 103 |
+
)
|
| 104 |
+
elif comp == 'Radio':
|
| 105 |
+
return gr.Radio(
|
| 106 |
+
choices=params['choices'],
|
| 107 |
+
value=params['value'],
|
| 108 |
+
interactive=params['interactive'],
|
| 109 |
+
label=params['label']
|
| 110 |
+
)
|
| 111 |
+
elif comp == 'Button':
|
| 112 |
+
return gr.Button(
|
| 113 |
+
value=params['value'],
|
| 114 |
+
interactive=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
|
| 119 |
+
return mgr.MultimodalInput(
|
| 120 |
+
upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
|
| 121 |
+
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
|
| 122 |
+
submit_button_props={'label': 'Submit'}
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
| 127 |
+
try:
|
| 128 |
+
print('msgs:', msgs)
|
| 129 |
+
answer = model.chat(
|
| 130 |
+
image=None,
|
| 131 |
+
msgs=msgs,
|
| 132 |
+
tokenizer=tokenizer,
|
| 133 |
+
**params
|
| 134 |
+
)
|
| 135 |
+
res = re.sub(r'(<box>.*</box>)', '', answer)
|
| 136 |
+
res = res.replace('<ref>', '')
|
| 137 |
+
res = res.replace('</ref>', '')
|
| 138 |
+
res = res.replace('<box>', '')
|
| 139 |
+
answer = res.replace('</box>', '')
|
| 140 |
+
print('answer:', answer)
|
| 141 |
+
return 0, answer, None, None
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(e)
|
| 144 |
+
traceback.print_exc()
|
| 145 |
+
return -1, ERROR_MSG, None, None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def encode_image(image):
|
| 149 |
+
if not isinstance(image, Image.Image):
|
| 150 |
+
if hasattr(image, 'path'):
|
| 151 |
+
image = Image.open(image.path).convert("RGB")
|
| 152 |
+
else:
|
| 153 |
+
image = Image.open(image.file.path).convert("RGB")
|
| 154 |
+
# resize to max_size
|
| 155 |
+
max_size = 448*16
|
| 156 |
+
if max(image.size) > max_size:
|
| 157 |
+
w,h = image.size
|
| 158 |
+
if w > h:
|
| 159 |
+
new_w = max_size
|
| 160 |
+
new_h = int(h * max_size / w)
|
| 161 |
+
else:
|
| 162 |
+
new_h = max_size
|
| 163 |
+
new_w = int(w * max_size / h)
|
| 164 |
+
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
|
| 165 |
+
return image
|
| 166 |
+
## save by BytesIO and convert to base64
|
| 167 |
+
#buffered = io.BytesIO()
|
| 168 |
+
#image.save(buffered, format="png")
|
| 169 |
+
#im_b64 = base64.b64encode(buffered.getvalue()).decode()
|
| 170 |
+
#return {"type": "image", "pairs": im_b64}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def encode_video(video):
|
| 174 |
+
def uniform_sample(l, n):
|
| 175 |
+
gap = len(l) / n
|
| 176 |
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
| 177 |
+
return [l[i] for i in idxs]
|
| 178 |
+
|
| 179 |
+
if hasattr(video, 'path'):
|
| 180 |
+
vr = VideoReader(video.path, ctx=cpu(0))
|
| 181 |
+
else:
|
| 182 |
+
vr = VideoReader(video.file.path, ctx=cpu(0))
|
| 183 |
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
| 184 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
| 185 |
+
if len(frame_idx)>MAX_NUM_FRAMES:
|
| 186 |
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
| 187 |
+
video = vr.get_batch(frame_idx).asnumpy()
|
| 188 |
+
video = [Image.fromarray(v.astype('uint8')) for v in video]
|
| 189 |
+
video = [encode_image(v) for v in video]
|
| 190 |
+
print('video frames:', len(video))
|
| 191 |
+
return video
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def check_mm_type(mm_file):
|
| 195 |
+
if hasattr(mm_file, 'path'):
|
| 196 |
+
path = mm_file.path
|
| 197 |
+
else:
|
| 198 |
+
path = mm_file.file.path
|
| 199 |
+
if is_image(path):
|
| 200 |
+
return "image"
|
| 201 |
+
if is_video(path):
|
| 202 |
+
return "video"
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def encode_mm_file(mm_file):
|
| 207 |
+
if check_mm_type(mm_file) == 'image':
|
| 208 |
+
return [encode_image(mm_file)]
|
| 209 |
+
if check_mm_type(mm_file) == 'video':
|
| 210 |
+
return encode_video(mm_file)
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
def make_text(text):
|
| 214 |
+
#return {"type": "text", "pairs": text} # # For remote call
|
| 215 |
+
return text
|
| 216 |
+
|
| 217 |
+
def encode_message(_question):
|
| 218 |
+
files = _question.files
|
| 219 |
+
question = _question.text
|
| 220 |
+
pattern = r"\[mm_media\]\d+\[/mm_media\]"
|
| 221 |
+
matches = re.split(pattern, question)
|
| 222 |
+
message = []
|
| 223 |
+
if len(matches) != len(files) + 1:
|
| 224 |
+
gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
|
| 225 |
+
assert len(matches) == len(files) + 1
|
| 226 |
+
|
| 227 |
+
text = matches[0].strip()
|
| 228 |
+
if text:
|
| 229 |
+
message.append(make_text(text))
|
| 230 |
+
for i in range(len(files)):
|
| 231 |
+
message += encode_mm_file(files[i])
|
| 232 |
+
text = matches[i + 1].strip()
|
| 233 |
+
if text:
|
| 234 |
+
message.append(make_text(text))
|
| 235 |
+
return message
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def check_has_videos(_question):
|
| 239 |
+
images_cnt = 0
|
| 240 |
+
videos_cnt = 0
|
| 241 |
+
for file in _question.files:
|
| 242 |
+
if check_mm_type(file) == "image":
|
| 243 |
+
images_cnt += 1
|
| 244 |
+
else:
|
| 245 |
+
videos_cnt += 1
|
| 246 |
+
return images_cnt, videos_cnt
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def count_video_frames(_context):
|
| 250 |
+
num_frames = 0
|
| 251 |
+
for message in _context:
|
| 252 |
+
for item in message["content"]:
|
| 253 |
+
#if item["type"] == "image": # For remote call
|
| 254 |
+
if isinstance(item, Image.Image):
|
| 255 |
+
num_frames += 1
|
| 256 |
+
return num_frames
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def respond(_question, _chat_bot, _app_cfg, params_form):
|
| 260 |
+
_context = _app_cfg['ctx'].copy()
|
| 261 |
+
_context.append({'role': 'user', 'content': encode_message(_question)})
|
| 262 |
+
|
| 263 |
+
images_cnt = _app_cfg['images_cnt']
|
| 264 |
+
videos_cnt = _app_cfg['videos_cnt']
|
| 265 |
+
files_cnts = check_has_videos(_question)
|
| 266 |
+
if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
|
| 267 |
+
gr.Warning("Only supports single video file input right now!")
|
| 268 |
+
return _question, _chat_bot, _app_cfg
|
| 269 |
+
|
| 270 |
+
if params_form == 'Beam Search':
|
| 271 |
+
params = {
|
| 272 |
+
'sampling': False,
|
| 273 |
+
'num_beams': 3,
|
| 274 |
+
'repetition_penalty': 1.2,
|
| 275 |
+
"max_new_tokens": 2048
|
| 276 |
+
}
|
| 277 |
+
else:
|
| 278 |
+
params = {
|
| 279 |
+
'sampling': True,
|
| 280 |
+
'top_p': 0.8,
|
| 281 |
+
'top_k': 100,
|
| 282 |
+
'temperature': 0.7,
|
| 283 |
+
'repetition_penalty': 1.05,
|
| 284 |
+
"max_new_tokens": 2048
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
if files_cnts[1] + videos_cnt > 0:
|
| 288 |
+
params["max_inp_length"] = 4352 # 4096+256
|
| 289 |
+
params["use_image_id"] = False
|
| 290 |
+
params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
|
| 291 |
+
|
| 292 |
+
code, _answer, _, sts = chat("", _context, None, params)
|
| 293 |
+
|
| 294 |
+
images_cnt += files_cnts[0]
|
| 295 |
+
videos_cnt += files_cnts[1]
|
| 296 |
+
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
| 297 |
+
_chat_bot.append((_question, _answer))
|
| 298 |
+
if code == 0:
|
| 299 |
+
_app_cfg['ctx']=_context
|
| 300 |
+
_app_cfg['sts']=sts
|
| 301 |
+
_app_cfg['images_cnt'] = images_cnt
|
| 302 |
+
_app_cfg['videos_cnt'] = videos_cnt
|
| 303 |
+
|
| 304 |
+
upload_image_disabled = videos_cnt > 0
|
| 305 |
+
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
| 306 |
+
return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
|
| 310 |
+
ctx = _app_cfg["ctx"]
|
| 311 |
+
message_item = []
|
| 312 |
+
if _image is not None:
|
| 313 |
+
image = Image.open(_image).convert("RGB")
|
| 314 |
+
ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
|
| 315 |
+
message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
|
| 316 |
+
else:
|
| 317 |
+
if _user_message:
|
| 318 |
+
ctx.append({"role": "user", "content": [make_text(_user_message)]})
|
| 319 |
+
message_item.append({"text": _user_message, "files": []})
|
| 320 |
+
else:
|
| 321 |
+
message_item.append(None)
|
| 322 |
+
if _assistant_message:
|
| 323 |
+
ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
|
| 324 |
+
message_item.append({"text": _assistant_message, "files": []})
|
| 325 |
+
else:
|
| 326 |
+
message_item.append(None)
|
| 327 |
+
|
| 328 |
+
_chat_bot.append(message_item)
|
| 329 |
+
return None, "", "", _chat_bot, _app_cfg
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
|
| 333 |
+
user_message_contents = []
|
| 334 |
+
_context = _app_cfg["ctx"].copy()
|
| 335 |
+
if _image:
|
| 336 |
+
image = Image.open(_image).convert("RGB")
|
| 337 |
+
user_message_contents += [encode_image(image)]
|
| 338 |
+
if _user_message:
|
| 339 |
+
user_message_contents += [make_text(_user_message)]
|
| 340 |
+
if user_message_contents:
|
| 341 |
+
_context.append({"role": "user", "content": user_message_contents})
|
| 342 |
+
|
| 343 |
+
if params_form == 'Beam Search':
|
| 344 |
+
params = {
|
| 345 |
+
'sampling': False,
|
| 346 |
+
'num_beams': 3,
|
| 347 |
+
'repetition_penalty': 1.2,
|
| 348 |
+
"max_new_tokens": 2048
|
| 349 |
+
}
|
| 350 |
+
else:
|
| 351 |
+
params = {
|
| 352 |
+
'sampling': True,
|
| 353 |
+
'top_p': 0.8,
|
| 354 |
+
'top_k': 100,
|
| 355 |
+
'temperature': 0.7,
|
| 356 |
+
'repetition_penalty': 1.05,
|
| 357 |
+
"max_new_tokens": 2048
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
code, _answer, _, sts = chat("", _context, None, params)
|
| 361 |
+
|
| 362 |
+
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
| 363 |
+
|
| 364 |
+
if _image:
|
| 365 |
+
_chat_bot.append([
|
| 366 |
+
{"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
|
| 367 |
+
{"text": _answer, "files": []}
|
| 368 |
+
])
|
| 369 |
+
else:
|
| 370 |
+
_chat_bot.append([
|
| 371 |
+
{"text": _user_message, "files": [_image]},
|
| 372 |
+
{"text": _answer, "files": []}
|
| 373 |
+
])
|
| 374 |
+
if code == 0:
|
| 375 |
+
_app_cfg['ctx']=_context
|
| 376 |
+
_app_cfg['sts']=sts
|
| 377 |
+
return None, '', '', _chat_bot, _app_cfg
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
|
| 381 |
+
if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
|
| 382 |
+
gr.Warning('No question for regeneration.')
|
| 383 |
+
return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
| 384 |
+
if _app_cfg["chat_type"] == "Chat":
|
| 385 |
+
images_cnt = _app_cfg['images_cnt']
|
| 386 |
+
videos_cnt = _app_cfg['videos_cnt']
|
| 387 |
+
_question = _chat_bot[-1][0]
|
| 388 |
+
_chat_bot = _chat_bot[:-1]
|
| 389 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
| 390 |
+
files_cnts = check_has_videos(_question)
|
| 391 |
+
images_cnt -= files_cnts[0]
|
| 392 |
+
videos_cnt -= files_cnts[1]
|
| 393 |
+
_app_cfg['images_cnt'] = images_cnt
|
| 394 |
+
_app_cfg['videos_cnt'] = videos_cnt
|
| 395 |
+
upload_image_disabled = videos_cnt > 0
|
| 396 |
+
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
| 397 |
+
_question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
|
| 398 |
+
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
| 399 |
+
else:
|
| 400 |
+
last_message = _chat_bot[-1][0]
|
| 401 |
+
last_image = None
|
| 402 |
+
last_user_message = ''
|
| 403 |
+
if last_message.text:
|
| 404 |
+
last_user_message = last_message.text
|
| 405 |
+
if last_message.files:
|
| 406 |
+
last_image = last_message.files[0].file.path
|
| 407 |
+
_chat_bot = _chat_bot[:-1]
|
| 408 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
| 409 |
+
_image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
|
| 410 |
+
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def flushed():
|
| 414 |
+
return gr.update(interactive=True)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def clear(txt_message, chat_bot, app_session):
|
| 418 |
+
txt_message.files.clear()
|
| 419 |
+
txt_message.text = ''
|
| 420 |
+
chat_bot = copy.deepcopy(init_conversation)
|
| 421 |
+
app_session['sts'] = None
|
| 422 |
+
app_session['ctx'] = []
|
| 423 |
+
app_session['images_cnt'] = 0
|
| 424 |
+
app_session['videos_cnt'] = 0
|
| 425 |
+
return create_multimodal_input(), chat_bot, app_session, None, '', ''
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def select_chat_type(_tab, _app_cfg):
|
| 429 |
+
_app_cfg["chat_type"] = _tab
|
| 430 |
+
return _app_cfg
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
init_conversation = [
|
| 434 |
+
[
|
| 435 |
+
None,
|
| 436 |
+
{
|
| 437 |
+
# The first message of bot closes the typewriter.
|
| 438 |
+
"text": "You can talk to me now",
|
| 439 |
+
"flushing": False
|
| 440 |
+
}
|
| 441 |
+
],
|
| 442 |
+
]
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
css = """
|
| 446 |
+
video { height: auto !important; }
|
| 447 |
+
.example label { font-size: 16px;}
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
introduction = """
|
| 451 |
+
|
| 452 |
+
## Features:
|
| 453 |
+
1. Chat with single image
|
| 454 |
+
2. Chat with multiple images
|
| 455 |
+
3. Chat with video
|
| 456 |
+
4. In-context few-shot learning
|
| 457 |
+
|
| 458 |
+
Click `How to use` tab to see examples.
|
| 459 |
+
"""
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
with gr.Blocks(css=css) as demo:
|
| 463 |
+
with gr.Tab(model_name):
|
| 464 |
+
with gr.Row():
|
| 465 |
+
with gr.Column(scale=1, min_width=300):
|
| 466 |
+
gr.Markdown(value=introduction)
|
| 467 |
+
params_form = create_component(form_radio, comp='Radio')
|
| 468 |
+
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
| 469 |
+
clear_button = create_component({'value': 'Clear History'}, comp='Button')
|
| 470 |
+
|
| 471 |
+
with gr.Column(scale=3, min_width=500):
|
| 472 |
+
app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
|
| 473 |
+
chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
|
| 474 |
+
|
| 475 |
+
with gr.Tab("Chat") as chat_tab:
|
| 476 |
+
txt_message = create_multimodal_input()
|
| 477 |
+
chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
|
| 478 |
+
|
| 479 |
+
txt_message.submit(
|
| 480 |
+
respond,
|
| 481 |
+
[txt_message, chat_bot, app_session, params_form],
|
| 482 |
+
[txt_message, chat_bot, app_session]
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
with gr.Tab("Few Shot") as fewshot_tab:
|
| 486 |
+
fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
|
| 487 |
+
with gr.Row():
|
| 488 |
+
with gr.Column(scale=1):
|
| 489 |
+
image_input = gr.Image(type="filepath", sources=["upload"])
|
| 490 |
+
with gr.Column(scale=3):
|
| 491 |
+
user_message = gr.Textbox(label="User")
|
| 492 |
+
assistant_message = gr.Textbox(label="Assistant")
|
| 493 |
+
with gr.Row():
|
| 494 |
+
add_demonstration_button = gr.Button("Add Example")
|
| 495 |
+
generate_button = gr.Button(value="Generate", variant="primary")
|
| 496 |
+
add_demonstration_button.click(
|
| 497 |
+
fewshot_add_demonstration,
|
| 498 |
+
[image_input, user_message, assistant_message, chat_bot, app_session],
|
| 499 |
+
[image_input, user_message, assistant_message, chat_bot, app_session]
|
| 500 |
+
)
|
| 501 |
+
generate_button.click(
|
| 502 |
+
fewshot_respond,
|
| 503 |
+
[image_input, user_message, chat_bot, app_session, params_form],
|
| 504 |
+
[image_input, user_message, assistant_message, chat_bot, app_session]
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
chat_tab.select(
|
| 508 |
+
select_chat_type,
|
| 509 |
+
[chat_tab_label, app_session],
|
| 510 |
+
[app_session]
|
| 511 |
+
)
|
| 512 |
+
chat_tab.select( # do clear
|
| 513 |
+
clear,
|
| 514 |
+
[txt_message, chat_bot, app_session],
|
| 515 |
+
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
| 516 |
+
)
|
| 517 |
+
fewshot_tab.select(
|
| 518 |
+
select_chat_type,
|
| 519 |
+
[fewshot_tab_label, app_session],
|
| 520 |
+
[app_session]
|
| 521 |
+
)
|
| 522 |
+
fewshot_tab.select( # do clear
|
| 523 |
+
clear,
|
| 524 |
+
[txt_message, chat_bot, app_session],
|
| 525 |
+
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
| 526 |
+
)
|
| 527 |
+
chat_bot.flushed(
|
| 528 |
+
flushed,
|
| 529 |
+
outputs=[txt_message]
|
| 530 |
+
)
|
| 531 |
+
regenerate.click(
|
| 532 |
+
regenerate_button_clicked,
|
| 533 |
+
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
|
| 534 |
+
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
|
| 535 |
+
)
|
| 536 |
+
clear_button.click(
|
| 537 |
+
clear,
|
| 538 |
+
[txt_message, chat_bot, app_session],
|
| 539 |
+
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
with gr.Tab("How to use"):
|
| 543 |
+
with gr.Column():
|
| 544 |
+
with gr.Row():
|
| 545 |
+
image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
|
| 546 |
+
example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
|
| 547 |
+
example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
# launch
|
| 551 |
+
demo.launch(share=False, debug=True, show_api=False, server_port=8000, server_name="0.0.0.0")
|
| 552 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/model_server.py
ADDED
|
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os, sys, io
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
import aiofiles
|
| 9 |
+
import librosa
|
| 10 |
+
import soundfile
|
| 11 |
+
import wave
|
| 12 |
+
from typing import Dict, List, Any, Optional
|
| 13 |
+
import argparse
|
| 14 |
+
import logging
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
| 18 |
+
import uvicorn
|
| 19 |
+
from fastapi import FastAPI, Header, Query, Request, HTTPException, WebSocket, WebSocketDisconnect
|
| 20 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 21 |
+
|
| 22 |
+
cur_path = os.path.split(os.path.realpath(__file__))[0]
|
| 23 |
+
sys.path.append(os.path.abspath(cur_path))
|
| 24 |
+
import vad_utils
|
| 25 |
+
|
| 26 |
+
def setup_logger():
|
| 27 |
+
logger = logging.getLogger("api_logger")
|
| 28 |
+
logger.setLevel(logging.DEBUG)
|
| 29 |
+
|
| 30 |
+
# Create formatter
|
| 31 |
+
formatter = logging.Formatter(
|
| 32 |
+
'%(asctime)s.%(msecs)03d-%(levelname)s-[%(filename)s:%(lineno)d] - %(message)s',
|
| 33 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Create handlers for stdout and stderr
|
| 37 |
+
stdout_handler = logging.StreamHandler(sys.stdout)
|
| 38 |
+
stdout_handler.setLevel(logging.INFO) # INFO and DEBUG go to stdout
|
| 39 |
+
stdout_handler.setFormatter(formatter)
|
| 40 |
+
stdout_handler.addFilter(lambda record: record.levelno <= logging.INFO)
|
| 41 |
+
|
| 42 |
+
stderr_handler = logging.StreamHandler(sys.stderr)
|
| 43 |
+
stderr_handler.setLevel(logging.WARNING) # WARNING, ERROR, CRITICAL go to stderr
|
| 44 |
+
stderr_handler.setFormatter(formatter)
|
| 45 |
+
|
| 46 |
+
# Add handlers to logger
|
| 47 |
+
logger.addHandler(stdout_handler)
|
| 48 |
+
logger.addHandler(stderr_handler)
|
| 49 |
+
|
| 50 |
+
return logger
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
app = FastAPI()
|
| 54 |
+
logger = setup_logger()
|
| 55 |
+
|
| 56 |
+
ap = argparse.ArgumentParser()
|
| 57 |
+
ap.add_argument('--port', type=int , default=32550)
|
| 58 |
+
ap.add_argument('--model', type=str , default="openbmb/MiniCPM-o-2_6", help="huggingface model name or local path")
|
| 59 |
+
args = ap.parse_args()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class StreamManager:
|
| 63 |
+
def __init__(self):
|
| 64 |
+
self.uid = None
|
| 65 |
+
|
| 66 |
+
self.is_streaming_complete = threading.Event()
|
| 67 |
+
self.conversation_started = threading.Event()
|
| 68 |
+
self.last_request_time = None
|
| 69 |
+
self.last_stream_time = None
|
| 70 |
+
self.timeout = 900 # seconds timeout
|
| 71 |
+
self.stream_timeout = 3 # seconds no stream
|
| 72 |
+
self.num_stream = 0
|
| 73 |
+
self.stream_started = False
|
| 74 |
+
self.stop_response = False
|
| 75 |
+
|
| 76 |
+
# VAD settings
|
| 77 |
+
self.vad_options = vad_utils.VadOptions()
|
| 78 |
+
self.vad_sequence_length = 5
|
| 79 |
+
self.vad_sequence = []
|
| 80 |
+
self.audio_prefill = []
|
| 81 |
+
self.audio_input = []
|
| 82 |
+
self.image_prefill = None
|
| 83 |
+
self.audio_chunk = 200
|
| 84 |
+
|
| 85 |
+
# customized options
|
| 86 |
+
self.customized_audio = None
|
| 87 |
+
self.customized_options = None
|
| 88 |
+
|
| 89 |
+
# Omni model
|
| 90 |
+
self.target_dtype = torch.bfloat16
|
| 91 |
+
self.device='cuda:0'
|
| 92 |
+
|
| 93 |
+
self.minicpmo_model_path = args.model #"openbmb/MiniCPM-o-2_6"
|
| 94 |
+
self.model_version = "2.6"
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
self.minicpmo_model = AutoModel.from_pretrained(self.minicpmo_model_path, trust_remote_code=True, torch_dtype=self.target_dtype, attn_implementation='sdpa')
|
| 97 |
+
self.minicpmo_tokenizer = AutoTokenizer.from_pretrained(self.minicpmo_model_path, trust_remote_code=True)
|
| 98 |
+
self.minicpmo_model.init_tts()
|
| 99 |
+
# self.minicpmo_model.tts.float()
|
| 100 |
+
self.minicpmo_model.to(self.device).eval()
|
| 101 |
+
|
| 102 |
+
self.ref_path_video_default = "assets/ref_audios/video_default.wav"
|
| 103 |
+
self.ref_path_default = "assets/ref_audios/default.wav"
|
| 104 |
+
self.ref_path_female = "assets/ref_audios/female_example.wav"
|
| 105 |
+
self.ref_path_male = "assets/ref_audios/male_example.wav"
|
| 106 |
+
|
| 107 |
+
self.input_audio_id = 0
|
| 108 |
+
self.input_audio_vad_id = 0
|
| 109 |
+
self.input_image_id = 0
|
| 110 |
+
self.output_audio_id = 0
|
| 111 |
+
self.flag_decode = False
|
| 112 |
+
self.cnts = None
|
| 113 |
+
|
| 114 |
+
self.all_start_time = time.time()
|
| 115 |
+
self.session_id = 233
|
| 116 |
+
self.sys_prompt_flag = False
|
| 117 |
+
self.vad_time = 0
|
| 118 |
+
self.ls_time = 0
|
| 119 |
+
self.msg_type = 1
|
| 120 |
+
|
| 121 |
+
self.speaking_time_stamp = 0
|
| 122 |
+
self.cycle_wait_time = 12800/24000 + 0.15
|
| 123 |
+
self.extra_wait_time = 2.5
|
| 124 |
+
self.server_wait = True
|
| 125 |
+
|
| 126 |
+
self.past_session_id = 0
|
| 127 |
+
self.sys_prompt_init(0)
|
| 128 |
+
self.session_id += 1
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def start_conversation(self):
|
| 132 |
+
logger.info(f"uid {self.uid}: new conversation started.")
|
| 133 |
+
self.conversation_started.set()
|
| 134 |
+
self.stop_response = False
|
| 135 |
+
|
| 136 |
+
def update_last_request_time(self):
|
| 137 |
+
self.last_request_time = time.time()
|
| 138 |
+
#logger.info(f"update last_request_time {self.last_request_time}")
|
| 139 |
+
|
| 140 |
+
def update_last_stream_time(self):
|
| 141 |
+
self.last_stream_time = time.time()
|
| 142 |
+
#logger.info(f"update last_stream_time {self.last_stream_time}")
|
| 143 |
+
|
| 144 |
+
def move_to_device(self, obj, device):
|
| 145 |
+
if isinstance(obj, torch.Tensor):
|
| 146 |
+
obj_ = obj.to(device)
|
| 147 |
+
if (obj_.dtype == torch.float) or (obj_.dtype == torch.half):
|
| 148 |
+
# cast to `torch.bfloat16`
|
| 149 |
+
obj_ = obj_.to(self.target_dtype)
|
| 150 |
+
return obj_
|
| 151 |
+
elif isinstance(obj, dict):
|
| 152 |
+
return {key: self.move_to_device(value, device) for key, value in obj.items()}
|
| 153 |
+
elif isinstance(obj, list):
|
| 154 |
+
return [self.move_to_device(item, device) for item in obj]
|
| 155 |
+
elif isinstance(obj, tuple):
|
| 156 |
+
return tuple(self.move_to_device(item, device) for item in obj)
|
| 157 |
+
elif isinstance(obj, set):
|
| 158 |
+
return {self.move_to_device(item, device) for item in obj}
|
| 159 |
+
else:
|
| 160 |
+
return obj
|
| 161 |
+
|
| 162 |
+
def reset(self):
|
| 163 |
+
logger.info("reset")
|
| 164 |
+
self.is_streaming_complete.clear()
|
| 165 |
+
self.conversation_started.clear()
|
| 166 |
+
self.last_request_time = None
|
| 167 |
+
self.last_stream_time = None
|
| 168 |
+
self.audio_buffer_raw = bytearray()
|
| 169 |
+
self.num_stream = 0
|
| 170 |
+
self.stream_started = False
|
| 171 |
+
self.stop_response = False
|
| 172 |
+
# self.customized_audio = None
|
| 173 |
+
# self.customized_options = None
|
| 174 |
+
# clear model
|
| 175 |
+
self.clear()
|
| 176 |
+
|
| 177 |
+
def merge_wav_files(self, input_bytes_list, output_file):
|
| 178 |
+
with wave.open(io.BytesIO(input_bytes_list[0]), 'rb') as wav:
|
| 179 |
+
params = wav.getparams()
|
| 180 |
+
n_channels, sampwidth, framerate, n_frames, comptype, compname = params
|
| 181 |
+
|
| 182 |
+
with wave.open(output_file, 'wb') as output_wav:
|
| 183 |
+
output_wav.setnchannels(n_channels)
|
| 184 |
+
output_wav.setsampwidth(sampwidth)
|
| 185 |
+
output_wav.setframerate(framerate)
|
| 186 |
+
output_wav.setcomptype(comptype, compname)
|
| 187 |
+
|
| 188 |
+
for wav_bytes in input_bytes_list:
|
| 189 |
+
with wave.open(io.BytesIO(wav_bytes), 'rb') as wav:
|
| 190 |
+
output_wav.writeframes(wav.readframes(wav.getnframes()))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def is_timed_out(self):
|
| 194 |
+
if self.last_request_time is not None:
|
| 195 |
+
return time.time() - self.last_request_time > self.timeout
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
def no_active_stream(self):
|
| 199 |
+
if self.last_stream_time is not None and self.stream_started:
|
| 200 |
+
no_stream_duration = time.time() - self.last_stream_time
|
| 201 |
+
if no_stream_duration > self.stream_timeout:
|
| 202 |
+
#logger.info(f"no active stream for {no_stream_duration} secs.")
|
| 203 |
+
return True
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
def sys_prompt_init(self, msg_type):
|
| 207 |
+
if self.past_session_id == self.session_id:
|
| 208 |
+
return
|
| 209 |
+
logger.info("### sys_prompt_init ###")
|
| 210 |
+
|
| 211 |
+
logger.info(f'msg_type is {msg_type}')
|
| 212 |
+
if msg_type <= 1: #audio
|
| 213 |
+
audio_voice_clone_prompt = "Use the voice in the audio prompt to synthesize new content."
|
| 214 |
+
audio_assistant_prompt = "You are a helpful assistant with the above voice style."
|
| 215 |
+
ref_path = self.ref_path_default
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if self.customized_options is not None:
|
| 219 |
+
audio_voice_clone_prompt = self.customized_options['voice_clone_prompt']
|
| 220 |
+
audio_assistant_prompt = self.customized_options['assistant_prompt']
|
| 221 |
+
if self.customized_options['use_audio_prompt'] == 1:
|
| 222 |
+
ref_path = self.ref_path_default
|
| 223 |
+
elif self.customized_options['use_audio_prompt'] == 2:
|
| 224 |
+
ref_path = self.ref_path_female
|
| 225 |
+
elif self.customized_options['use_audio_prompt'] == 3:
|
| 226 |
+
ref_path = self.ref_path_male
|
| 227 |
+
|
| 228 |
+
audio_prompt, sr = librosa.load(ref_path, sr=16000, mono=True)
|
| 229 |
+
sys_msg = {'role': 'user', 'content': [audio_voice_clone_prompt + "\n", audio_prompt, "\n" + audio_assistant_prompt]}
|
| 230 |
+
elif msg_type == 2: #video
|
| 231 |
+
voice_clone_prompt="你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。模仿输入音频中的声音特征。"
|
| 232 |
+
assistant_prompt="作为助手,你将使用这种声音风格说话。"
|
| 233 |
+
ref_path = self.ref_path_video_default
|
| 234 |
+
|
| 235 |
+
if self.customized_options is not None:
|
| 236 |
+
voice_clone_prompt = self.customized_options['voice_clone_prompt']
|
| 237 |
+
assistant_prompt = self.customized_options['assistant_prompt']
|
| 238 |
+
if self.customized_options['use_audio_prompt'] == 1:
|
| 239 |
+
ref_path = self.ref_path_default
|
| 240 |
+
elif self.customized_options['use_audio_prompt'] == 2:
|
| 241 |
+
ref_path = self.ref_path_female
|
| 242 |
+
elif self.customized_options['use_audio_prompt'] == 3:
|
| 243 |
+
ref_path = self.ref_path_male
|
| 244 |
+
|
| 245 |
+
audio_prompt, sr = librosa.load(ref_path, sr=16000, mono=True)
|
| 246 |
+
sys_msg = {'role': 'user', 'content': [voice_clone_prompt, audio_prompt, assistant_prompt]}
|
| 247 |
+
# elif msg_type == 3: #user start
|
| 248 |
+
# assistant_prompt="作为助手,你将使用这种声音风格说话。"
|
| 249 |
+
# if self.customized_options is not None:
|
| 250 |
+
# assistant_prompt = self.customized_options['assistant_prompt']
|
| 251 |
+
|
| 252 |
+
# sys_msg = {'role': 'user', 'content': [assistant_prompt]}
|
| 253 |
+
|
| 254 |
+
self.msg_type = msg_type
|
| 255 |
+
msgs = [sys_msg]
|
| 256 |
+
if self.customized_options is not None:
|
| 257 |
+
if self.customized_options['use_audio_prompt'] > 0:
|
| 258 |
+
self.minicpmo_model.streaming_prefill(
|
| 259 |
+
session_id=str(self.session_id),
|
| 260 |
+
msgs=msgs,
|
| 261 |
+
tokenizer=self.minicpmo_tokenizer,
|
| 262 |
+
)
|
| 263 |
+
if msg_type == 0:
|
| 264 |
+
self.minicpmo_model.streaming_prefill(
|
| 265 |
+
session_id=str(self.session_id),
|
| 266 |
+
msgs=msgs,
|
| 267 |
+
tokenizer=self.minicpmo_tokenizer,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.savedir = os.path.join(f"./log_data/{args.port}/", str(time.time()))
|
| 271 |
+
if not os.path.exists(self.savedir):
|
| 272 |
+
os.makedirs(self.savedir)
|
| 273 |
+
if not os.path.exists(self.savedir + "/input_audio_log"):
|
| 274 |
+
os.makedirs(self.savedir + "/input_audio_log")
|
| 275 |
+
if not os.path.exists(self.savedir + "/input_audio_vad_log"):
|
| 276 |
+
os.makedirs(self.savedir + "/input_audio_vad_log")
|
| 277 |
+
if not os.path.exists(self.savedir + "/input_image_log"):
|
| 278 |
+
os.makedirs(self.savedir + "/input_image_log")
|
| 279 |
+
if not os.path.exists(self.savedir + "/output_audio_log"):
|
| 280 |
+
os.makedirs(self.savedir + "/output_audio_log")
|
| 281 |
+
if not os.path.exists(self.savedir + "/feedback_log"):
|
| 282 |
+
os.makedirs(self.savedir + "/feedback_log")
|
| 283 |
+
if not os.path.exists(self.savedir + "/input_audio"):
|
| 284 |
+
os.makedirs(self.savedir + "/input_audio")
|
| 285 |
+
|
| 286 |
+
self.past_session_id = self.session_id
|
| 287 |
+
self.audio_prefill = []
|
| 288 |
+
self.audio_input = []
|
| 289 |
+
|
| 290 |
+
def clear(self):
|
| 291 |
+
try:
|
| 292 |
+
self.flag_decode = False
|
| 293 |
+
self.stream_started = False
|
| 294 |
+
self.cnts = None
|
| 295 |
+
self.vad_sequence = []
|
| 296 |
+
self.audio_prefill = []
|
| 297 |
+
self.audio_input = []
|
| 298 |
+
self.image_prefill = None
|
| 299 |
+
|
| 300 |
+
if self.minicpmo_model.llm_past_key_values[0][0].shape[2]>8192:
|
| 301 |
+
self.session_id += 1 # to clear all kv cache
|
| 302 |
+
self.sys_prompt_flag = False
|
| 303 |
+
|
| 304 |
+
self.vad_time = 0
|
| 305 |
+
self.ls_time = 0
|
| 306 |
+
self.msg_type = 1
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
raise ValueError(f"Clear error: {str(e)}")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def process_message(self, message: Dict[str, Any]):
|
| 313 |
+
try:
|
| 314 |
+
# Process content items
|
| 315 |
+
audio_data = None
|
| 316 |
+
image_data = None
|
| 317 |
+
for content_item in message["content"]:
|
| 318 |
+
if content_item["type"] == "stop_response":
|
| 319 |
+
logger.info("process_message: received request to stop_response")
|
| 320 |
+
self.stop_response = True
|
| 321 |
+
return "stop"
|
| 322 |
+
elif content_item["type"] == "input_audio":
|
| 323 |
+
audio_data = content_item["input_audio"]["data"]
|
| 324 |
+
audio_timestamp = content_item["input_audio"].get("timestamp", "")
|
| 325 |
+
elif content_item["type"] == "image_data":
|
| 326 |
+
image_data = content_item["image_data"]["data"]
|
| 327 |
+
if audio_data is None:
|
| 328 |
+
return "empty audio"
|
| 329 |
+
|
| 330 |
+
if self.conversation_started.is_set() and self.is_streaming_complete.is_set():
|
| 331 |
+
logger.info("conversation not started or still in generation, skip stream message.")
|
| 332 |
+
return "skip"
|
| 333 |
+
|
| 334 |
+
if self.flag_decode:
|
| 335 |
+
return "skip"
|
| 336 |
+
|
| 337 |
+
try:
|
| 338 |
+
audio_bytes = base64.b64decode(audio_data)
|
| 339 |
+
|
| 340 |
+
image = None
|
| 341 |
+
if image_data is not None:
|
| 342 |
+
if len(image_data) > 0:
|
| 343 |
+
image_bytes = base64.b64decode(image_data)
|
| 344 |
+
image_buffer = io.BytesIO(image_bytes)
|
| 345 |
+
image_buffer.seek(0)
|
| 346 |
+
image = Image.open(image_buffer)
|
| 347 |
+
# logger.info("read image")
|
| 348 |
+
|
| 349 |
+
if self.sys_prompt_flag is False:
|
| 350 |
+
self.all_start_time = time.time()
|
| 351 |
+
self.sys_prompt_flag = True
|
| 352 |
+
if image_data is not None:
|
| 353 |
+
self.sys_prompt_init(2)
|
| 354 |
+
else:
|
| 355 |
+
self.sys_prompt_init(1)
|
| 356 |
+
|
| 357 |
+
self.prefill(audio_bytes, image, False)
|
| 358 |
+
|
| 359 |
+
self.vad_sequence.append(audio_bytes)
|
| 360 |
+
if len(self.vad_sequence) < self.vad_sequence_length:
|
| 361 |
+
# logger.info('length of vad_sequence is {}, insufficient'.format(self.vad_sequence_length))
|
| 362 |
+
return "done"
|
| 363 |
+
elif len(self.vad_sequence) > self.vad_sequence_length:
|
| 364 |
+
# logger.info('length of vad_sequence exceeds {}'.format(self.vad_sequence_length))
|
| 365 |
+
self.vad_sequence.pop(0)
|
| 366 |
+
self.vad_check_audio_bytes(audio_bytes, image, 16000)
|
| 367 |
+
|
| 368 |
+
return "done"
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
raise ValueError(f"Audio processing error: {str(e)}")
|
| 372 |
+
|
| 373 |
+
except Exception as e:
|
| 374 |
+
raise ValueError(f"Message processing error: {str(e)}")
|
| 375 |
+
|
| 376 |
+
def resample_audio(self, input_path, src_sr, tar_sr, output_path):
|
| 377 |
+
audio_data, _ = librosa.load(input_path, sr=src_sr)
|
| 378 |
+
audio_new = librosa.resample(audio_data, orig_sr=src_sr, target_sr=tar_sr)
|
| 379 |
+
soundfile.write(output_path, audio_new, tar_sr)
|
| 380 |
+
|
| 381 |
+
def calculate_rms(self, input_path, sr):
|
| 382 |
+
audio_data, _ = librosa.load(input_path, sr=sr)
|
| 383 |
+
return (np.sqrt(np.mean(audio_data**2)) > 0.002)
|
| 384 |
+
|
| 385 |
+
def vad_check_audio_bytes(self, audio, image, sr):
|
| 386 |
+
try:
|
| 387 |
+
input_audio_vad_path = self.savedir + f"/input_audio_vad_log/vad_{self.input_audio_vad_id}.wav"
|
| 388 |
+
self.input_audio_vad_id += 1
|
| 389 |
+
self.merge_wav_files(self.vad_sequence, input_audio_vad_path)
|
| 390 |
+
|
| 391 |
+
with open(input_audio_vad_path,"rb") as f:
|
| 392 |
+
temp_audio = f.read()
|
| 393 |
+
dur_vad, vad_audio_bytes, time_vad = vad_utils.run_vad(temp_audio, sr, self.vad_options)
|
| 394 |
+
if self.customized_options is not None:
|
| 395 |
+
vad_threshold = 1 - self.customized_options['vad_threshold']
|
| 396 |
+
else:
|
| 397 |
+
vad_threshold = 0.2
|
| 398 |
+
|
| 399 |
+
if self.calculate_rms(input_audio_vad_path, sr) and dur_vad > 0.4:
|
| 400 |
+
if self.stream_started == False:
|
| 401 |
+
self.vad_time = time.time()
|
| 402 |
+
self.stream_started = True
|
| 403 |
+
elif dur_vad < vad_threshold:
|
| 404 |
+
if self.stream_started:
|
| 405 |
+
self.stream_started = False
|
| 406 |
+
if (time.time() - self.vad_time >= 0.6):
|
| 407 |
+
self.prefill(audio, image, True)
|
| 408 |
+
self.is_streaming_complete.set()
|
| 409 |
+
# self.ls_time = time.time()
|
| 410 |
+
|
| 411 |
+
except Exception as e:
|
| 412 |
+
logger.error(f"VAD error: {e}")
|
| 413 |
+
raise
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
def prefill(self, audio, image, is_end):
|
| 417 |
+
if self.server_wait:
|
| 418 |
+
now = time.time()
|
| 419 |
+
await_time = self.speaking_time_stamp - now + self.extra_wait_time
|
| 420 |
+
if await_time > 0:
|
| 421 |
+
return False
|
| 422 |
+
|
| 423 |
+
if self.flag_decode:
|
| 424 |
+
return False
|
| 425 |
+
|
| 426 |
+
if image is not None:
|
| 427 |
+
self.image_prefill = image
|
| 428 |
+
try:
|
| 429 |
+
if is_end == False:
|
| 430 |
+
self.audio_prefill.append(audio)
|
| 431 |
+
self.audio_input.append(audio)
|
| 432 |
+
slice_nums = 1
|
| 433 |
+
if is_end and self.customized_options is not None:
|
| 434 |
+
if self.customized_options['hd_video']:
|
| 435 |
+
slice_nums = 6
|
| 436 |
+
else:
|
| 437 |
+
return True
|
| 438 |
+
if (len(self.audio_prefill) == (1000/self.audio_chunk)) or (is_end and len(self.audio_prefill)>0):
|
| 439 |
+
time_prefill = time.time()
|
| 440 |
+
input_audio_path = self.savedir + f"/input_audio_log/input_audio_{self.input_audio_id}.wav"
|
| 441 |
+
self.merge_wav_files(self.audio_prefill, input_audio_path)
|
| 442 |
+
with open(input_audio_path,"rb") as wav_io:
|
| 443 |
+
signal, sr = soundfile.read(wav_io, dtype='float32')
|
| 444 |
+
soundfile.write(input_audio_path, signal, 16000)
|
| 445 |
+
audio_np, sr = librosa.load(input_audio_path, sr=16000, mono=True)
|
| 446 |
+
self.audio_prefill = []
|
| 447 |
+
|
| 448 |
+
if len(audio_np) > 16000:
|
| 449 |
+
audio_np = audio_np[:16000]
|
| 450 |
+
|
| 451 |
+
with torch.no_grad():
|
| 452 |
+
if self.image_prefill is not None:
|
| 453 |
+
input_image_path = self.savedir + f'/input_image_log/input_image_{self.input_audio_id}.png'
|
| 454 |
+
self.image_prefill.save(input_image_path, 'PNG')
|
| 455 |
+
self.image_prefill = self.image_prefill.convert("RGB")
|
| 456 |
+
|
| 457 |
+
cnts = None
|
| 458 |
+
if self.image_prefill is not None:
|
| 459 |
+
cnts = ["<unit>", self.image_prefill, audio_np]
|
| 460 |
+
else:
|
| 461 |
+
cnts = [audio_np]
|
| 462 |
+
|
| 463 |
+
if cnts is not None:
|
| 464 |
+
msg = {"role":"user", "content": cnts}
|
| 465 |
+
msgs = [msg]
|
| 466 |
+
res = self.minicpmo_model.streaming_prefill(
|
| 467 |
+
session_id=str(self.session_id),
|
| 468 |
+
msgs=msgs,
|
| 469 |
+
tokenizer=self.minicpmo_tokenizer,
|
| 470 |
+
max_slice_nums=slice_nums,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
self.input_audio_id += 1
|
| 474 |
+
return True
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
logger.error(f"prefill error: {e}")
|
| 478 |
+
import traceback
|
| 479 |
+
traceback.print_exc()
|
| 480 |
+
raise
|
| 481 |
+
|
| 482 |
+
def generate_end(self):
|
| 483 |
+
self.input_audio_id += 10
|
| 484 |
+
self.output_audio_id += 10
|
| 485 |
+
self.flag_decode = False
|
| 486 |
+
self.reset()
|
| 487 |
+
return
|
| 488 |
+
|
| 489 |
+
async def generate(self):
|
| 490 |
+
""" return audio bytes and response text (optional) """
|
| 491 |
+
if self.stop_response:
|
| 492 |
+
self.generate_end()
|
| 493 |
+
return
|
| 494 |
+
|
| 495 |
+
self.flag_decode = True
|
| 496 |
+
try:
|
| 497 |
+
with torch.no_grad():
|
| 498 |
+
logger.info("=== model gen start ===")
|
| 499 |
+
time_gen = time.time()
|
| 500 |
+
input_audio_path = self.savedir + f"/input_audio/all_input_audio_{self.input_audio_id}.wav"
|
| 501 |
+
self.merge_wav_files(self.audio_input, input_audio_path)
|
| 502 |
+
audio_stream = None
|
| 503 |
+
try:
|
| 504 |
+
with open(input_audio_path, 'rb') as wav_file:
|
| 505 |
+
audio_stream = wav_file.read()
|
| 506 |
+
except FileNotFoundError:
|
| 507 |
+
print(f"File {input_audio_path} not found.")
|
| 508 |
+
yield base64.b64encode(audio_stream).decode('utf-8'), "assistant:\n"
|
| 509 |
+
|
| 510 |
+
print('=== gen start: ', time.time() - time_gen)
|
| 511 |
+
first_time = True
|
| 512 |
+
temp_time = time.time()
|
| 513 |
+
temp_time1 = time.time()
|
| 514 |
+
with torch.inference_mode():
|
| 515 |
+
if self.stop_response:
|
| 516 |
+
self.generate_end()
|
| 517 |
+
return
|
| 518 |
+
self.minicpmo_model.config.stream_input=True
|
| 519 |
+
msg = {"role":"user", "content": self.cnts}
|
| 520 |
+
msgs = [msg]
|
| 521 |
+
text = ''
|
| 522 |
+
self.speaking_time_stamp = time.time()
|
| 523 |
+
try:
|
| 524 |
+
for r in self.minicpmo_model.streaming_generate(
|
| 525 |
+
session_id=str(self.session_id),
|
| 526 |
+
tokenizer=self.minicpmo_tokenizer,
|
| 527 |
+
generate_audio=True,
|
| 528 |
+
# enable_regenerate=True,
|
| 529 |
+
):
|
| 530 |
+
if self.stop_response:
|
| 531 |
+
self.generate_end()
|
| 532 |
+
return
|
| 533 |
+
audio_np, sr, text = r["audio_wav"], r["sampling_rate"], r["text"]
|
| 534 |
+
|
| 535 |
+
output_audio_path = self.savedir + f'/output_audio_log/output_audio_{self.output_audio_id}.wav'
|
| 536 |
+
self.output_audio_id += 1
|
| 537 |
+
soundfile.write(output_audio_path, audio_np, samplerate=sr)
|
| 538 |
+
audio_stream = None
|
| 539 |
+
try:
|
| 540 |
+
with open(output_audio_path, 'rb') as wav_file:
|
| 541 |
+
audio_stream = wav_file.read()
|
| 542 |
+
except FileNotFoundError:
|
| 543 |
+
print(f"File {output_audio_path} not found.")
|
| 544 |
+
temp_time1 = time.time()
|
| 545 |
+
print('text: ', text)
|
| 546 |
+
yield base64.b64encode(audio_stream).decode('utf-8'), text
|
| 547 |
+
self.speaking_time_stamp += self.cycle_wait_time
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logger.error(f"Error happened during generation: {str(e)}")
|
| 550 |
+
yield None, '\n<end>'
|
| 551 |
+
|
| 552 |
+
except Exception as e:
|
| 553 |
+
logger.error(f"发生异常:{e}")
|
| 554 |
+
import traceback
|
| 555 |
+
traceback.print_exc()
|
| 556 |
+
raise
|
| 557 |
+
|
| 558 |
+
finally:
|
| 559 |
+
logger.info(f"uid {self.uid}: generation finished!")
|
| 560 |
+
self.generate_end()
|
| 561 |
+
|
| 562 |
+
async def check_activity(self):
|
| 563 |
+
while True:
|
| 564 |
+
# Check for overall inactivity (30 minutes)
|
| 565 |
+
if self.is_timed_out():
|
| 566 |
+
self.reset()
|
| 567 |
+
if self.no_active_stream() and not self.is_streaming_complete.is_set():
|
| 568 |
+
self.is_streaming_complete.set()
|
| 569 |
+
|
| 570 |
+
await asyncio.sleep(1) # Check every second
|
| 571 |
+
|
| 572 |
+
def upload_customized_audio(self, audio_data, audio_fmt):
|
| 573 |
+
self.customized_audio = None
|
| 574 |
+
try:
|
| 575 |
+
if audio_data is not None and len(audio_data) > 0:
|
| 576 |
+
# if audio_fmt == "mp3" or audio_fmt == "wav":
|
| 577 |
+
audio_bytes = base64.b64decode(audio_data)
|
| 578 |
+
fio = io.BytesIO(audio_bytes)
|
| 579 |
+
fio.seek(0)
|
| 580 |
+
audio_np, sr = librosa.load(fio, sr=16000, mono=True)
|
| 581 |
+
if audio_np is not None and len(audio_np) > 1000:
|
| 582 |
+
output_audio_path = self.savedir + f'/customized_audio.wav'
|
| 583 |
+
soundfile.write(output_audio_path, audio_np, sr)
|
| 584 |
+
self.customized_audio = output_audio_path
|
| 585 |
+
logger.info(f"processed customized {audio_fmt} audio")
|
| 586 |
+
print(audio_np.shape, type(audio_np), sr)
|
| 587 |
+
else:
|
| 588 |
+
logger.info(f"empty customized audio, use default value instead.")
|
| 589 |
+
self.customized_audio = None
|
| 590 |
+
except Exception as e:
|
| 591 |
+
raise ValueError(f"Process customized audio error: {str(e)}")
|
| 592 |
+
|
| 593 |
+
def update_customized_options(self, uid, options):
|
| 594 |
+
self.customized_options = None
|
| 595 |
+
if options is None:
|
| 596 |
+
raise ValueError("Invalid None type for options, expected dict type")
|
| 597 |
+
self.customized_options = options
|
| 598 |
+
logger.info(f"uid: {uid} set customized_options to {options}")
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
stream_manager = StreamManager()
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
@app.on_event("startup")
|
| 605 |
+
async def startup_event():
|
| 606 |
+
logger.info("Starting application and activity checker")
|
| 607 |
+
asyncio.create_task(stream_manager.check_activity())
|
| 608 |
+
|
| 609 |
+
@app.on_event("shutdown")
|
| 610 |
+
async def shutdown_event():
|
| 611 |
+
logger.info("Shutting down application")
|
| 612 |
+
|
| 613 |
+
@app.post("/stream")
|
| 614 |
+
@app.post("/api/v1/stream")
|
| 615 |
+
async def stream(request: Request, uid: Optional[str] = Header(None)):
|
| 616 |
+
global stream_manager
|
| 617 |
+
|
| 618 |
+
stream_manager.update_last_request_time()
|
| 619 |
+
stream_manager.update_last_stream_time()
|
| 620 |
+
|
| 621 |
+
if not uid:
|
| 622 |
+
raise HTTPException(status_code=400, detail="Missing uid in headers")
|
| 623 |
+
if stream_manager.uid is not None and stream_manager.uid != uid:
|
| 624 |
+
logger.error(f"uid changed during steram: previous uid {stream_manager.uid}, new uid {uid}")
|
| 625 |
+
raise HTTPException(status_code=400, detail="uid changed in stream")
|
| 626 |
+
|
| 627 |
+
try:
|
| 628 |
+
# Parse JSON request
|
| 629 |
+
data = await request.json()
|
| 630 |
+
|
| 631 |
+
# Validate basic structure
|
| 632 |
+
if not isinstance(data, dict) or "messages" not in data:
|
| 633 |
+
raise HTTPException(status_code=400, detail="Invalid request format")
|
| 634 |
+
|
| 635 |
+
# Process messages
|
| 636 |
+
reason = ""
|
| 637 |
+
for message in data["messages"]:
|
| 638 |
+
if not isinstance(message, dict) or "role" not in message or "content" not in message:
|
| 639 |
+
raise HTTPException(status_code=400, detail="Invalid message format")
|
| 640 |
+
reason = stream_manager.process_message(message)
|
| 641 |
+
|
| 642 |
+
# Return response using uid from header
|
| 643 |
+
response = {
|
| 644 |
+
"id": uid,
|
| 645 |
+
"choices": {
|
| 646 |
+
"role": "assistant",
|
| 647 |
+
"content": "success",
|
| 648 |
+
"finish_reason": reason
|
| 649 |
+
}
|
| 650 |
+
}
|
| 651 |
+
return JSONResponse(content=response, status_code=200)
|
| 652 |
+
|
| 653 |
+
except json.JSONDecodeError:
|
| 654 |
+
raise HTTPException(status_code=400, detail="Invalid JSON")
|
| 655 |
+
except Exception as e:
|
| 656 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 657 |
+
|
| 658 |
+
@app.websocket("/ws/stream")
|
| 659 |
+
@app.websocket("/ws/api/v1/stream")
|
| 660 |
+
async def websocket_stream(websocket: WebSocket,
|
| 661 |
+
uid: Optional[str] = Query(None)):
|
| 662 |
+
global stream_manager
|
| 663 |
+
|
| 664 |
+
if not uid:
|
| 665 |
+
await websocket.close(code=400, reason="Missing uid in request")
|
| 666 |
+
return
|
| 667 |
+
|
| 668 |
+
# Accept the WebSocket connection
|
| 669 |
+
await websocket.accept()
|
| 670 |
+
|
| 671 |
+
#if stream_manager.uid is not None and stream_manager.uid != uid:
|
| 672 |
+
# logger.error(f"uid changed during steram: previous uid {stream_manager.uid}, new uid {uid}")
|
| 673 |
+
# await websocket.close(code=400, reason="Uid changed in stream.")
|
| 674 |
+
# return
|
| 675 |
+
|
| 676 |
+
try:
|
| 677 |
+
while True:
|
| 678 |
+
# Continuously listen for incoming messages from the client
|
| 679 |
+
data = await websocket.receive_text()
|
| 680 |
+
|
| 681 |
+
# Parse JSON request
|
| 682 |
+
try:
|
| 683 |
+
request_data = json.loads(data)
|
| 684 |
+
except json.JSONDecodeError:
|
| 685 |
+
await websocket.send_json({"error": "Invalid JSON"})
|
| 686 |
+
continue
|
| 687 |
+
|
| 688 |
+
stream_manager.update_last_request_time()
|
| 689 |
+
stream_manager.update_last_stream_time()
|
| 690 |
+
|
| 691 |
+
if stream_manager.uid is not None and stream_manager.uid != uid:
|
| 692 |
+
logger.error(f"uid changed during stream: previous uid {stream_manager.uid}, new uid {uid}")
|
| 693 |
+
await websocket.send_json({"error": "UID changed in stream"})
|
| 694 |
+
continue
|
| 695 |
+
|
| 696 |
+
# Validate basic structure
|
| 697 |
+
if not isinstance(request_data, dict) or "messages" not in request_data:
|
| 698 |
+
await websocket.send_json({"error": "Invalid request format"})
|
| 699 |
+
continue
|
| 700 |
+
|
| 701 |
+
# Process messages
|
| 702 |
+
try:
|
| 703 |
+
reason = ""
|
| 704 |
+
for message in request_data["messages"]:
|
| 705 |
+
if not isinstance(message, dict) or "role" not in message or "content" not in message:
|
| 706 |
+
await websocket.send_json({"error": "Invalid message format"})
|
| 707 |
+
continue
|
| 708 |
+
reason = stream_manager.process_message(message)
|
| 709 |
+
|
| 710 |
+
# Respond with success message
|
| 711 |
+
response = {
|
| 712 |
+
"id": uid,
|
| 713 |
+
"choices": {
|
| 714 |
+
"role": "assistant",
|
| 715 |
+
"content": "success",
|
| 716 |
+
"finish_reason": reason,
|
| 717 |
+
},
|
| 718 |
+
}
|
| 719 |
+
await websocket.send_json(response)
|
| 720 |
+
except WebSocketDisconnect:
|
| 721 |
+
# Handle WebSocket disconnection
|
| 722 |
+
break
|
| 723 |
+
except Exception as e:
|
| 724 |
+
logger.error(f"process message error: {str(e)}")
|
| 725 |
+
await websocket.close(code=1011, reason =f"Internal server error: {str(e)}")
|
| 726 |
+
|
| 727 |
+
except WebSocketDisconnect:
|
| 728 |
+
# Handle WebSocket disconnection
|
| 729 |
+
return
|
| 730 |
+
except Exception as e:
|
| 731 |
+
logger.error(f"ws_stream error: {str(e)}")
|
| 732 |
+
await websocket.close(code=1011, reason =f"Unexpected error: {str(e)}")
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
async def generate_sse_response(request: Request, uid: Optional[str] = Header(None)):
|
| 736 |
+
global stream_manager
|
| 737 |
+
print(f"uid: {uid}")
|
| 738 |
+
try:
|
| 739 |
+
# Wait for streaming to complete or timeout
|
| 740 |
+
while not stream_manager.is_streaming_complete.is_set():
|
| 741 |
+
# if stream_manager.is_timed_out():
|
| 742 |
+
# yield f"data: {json.dumps({'error': 'Stream timeout'})}\n\n"
|
| 743 |
+
# return
|
| 744 |
+
# print(f"{uid} whille not stream_manager.is_streaming_complete.is_set(), asyncio.sleep(0.1)")
|
| 745 |
+
await asyncio.sleep(0.1)
|
| 746 |
+
|
| 747 |
+
logger.info("streaming complete\n")
|
| 748 |
+
# Generate response
|
| 749 |
+
try:
|
| 750 |
+
yield f"event: message\n"
|
| 751 |
+
async for audio, text in stream_manager.generate():
|
| 752 |
+
if text == "stop":
|
| 753 |
+
break
|
| 754 |
+
res = {
|
| 755 |
+
"id": stream_manager.uid,
|
| 756 |
+
"response_id": stream_manager.output_audio_id,
|
| 757 |
+
"choices": [
|
| 758 |
+
{
|
| 759 |
+
"role": "assistant",
|
| 760 |
+
"audio": audio,
|
| 761 |
+
"text": text,
|
| 762 |
+
"finish_reason": "processing"
|
| 763 |
+
}
|
| 764 |
+
]
|
| 765 |
+
}
|
| 766 |
+
# logger.info("generate_sse_response yield response")
|
| 767 |
+
yield f"data: {json.dumps(res)}\n\n"
|
| 768 |
+
await asyncio.sleep(0)
|
| 769 |
+
|
| 770 |
+
except Exception as e:
|
| 771 |
+
logger.error(f"Error while generation: {str(e)}")
|
| 772 |
+
yield f'data:{{"error": "{str(exc)}"}}\n\n'
|
| 773 |
+
except Exception as e:
|
| 774 |
+
yield f'data:{{"error": "{str(e)}"}}\n\n'
|
| 775 |
+
|
| 776 |
+
@app.post("/completions")
|
| 777 |
+
@app.post("/api/v1/completions")
|
| 778 |
+
async def completions(request: Request, uid: Optional[str] = Header(None)):
|
| 779 |
+
global stream_manager
|
| 780 |
+
|
| 781 |
+
if not uid:
|
| 782 |
+
raise HTTPException(status_code=400, detail="Missing uid in headers")
|
| 783 |
+
|
| 784 |
+
try:
|
| 785 |
+
# if stream_manager.uid is not None and stream_manager.uid != uid:
|
| 786 |
+
if stream_manager.uid != uid:
|
| 787 |
+
# stream_manager.stop_response = True
|
| 788 |
+
# logger.info(f"uid changed, reset model: previous uid {stream_manager.uid}, new uid {uid}")
|
| 789 |
+
stream_manager.session_id += 1
|
| 790 |
+
stream_manager.sys_prompt_flag = False
|
| 791 |
+
stream_manager.reset()
|
| 792 |
+
|
| 793 |
+
# raise HTTPException(
|
| 794 |
+
# status_code=409,
|
| 795 |
+
# detail="User id changed, reset context."
|
| 796 |
+
# )
|
| 797 |
+
stream_manager.speaking_time_stamp = 0
|
| 798 |
+
stream_manager.update_last_request_time()
|
| 799 |
+
stream_manager.uid = uid
|
| 800 |
+
stream_manager.start_conversation()
|
| 801 |
+
|
| 802 |
+
data = await request.json()
|
| 803 |
+
|
| 804 |
+
return StreamingResponse(
|
| 805 |
+
generate_sse_response(request, uid),
|
| 806 |
+
media_type="text/event-stream",
|
| 807 |
+
headers={
|
| 808 |
+
"Cache-Control": "no-cache",
|
| 809 |
+
"Connection": "keep-alive",
|
| 810 |
+
"Transfer-Encoding": "chunked"
|
| 811 |
+
}
|
| 812 |
+
)
|
| 813 |
+
except asyncio.TimeoutError:
|
| 814 |
+
raise HTTPException(
|
| 815 |
+
status_code=503,
|
| 816 |
+
detail="Server busy, please try again later"
|
| 817 |
+
)
|
| 818 |
+
except Exception as e:
|
| 819 |
+
logger.error(f"Error processing request for user {uid}: {str(e)}")
|
| 820 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
@app.post("/stop")
|
| 824 |
+
@app.post("/api/v1/stop")
|
| 825 |
+
async def stop_response(request: Request, uid: Optional[str] = Header(None)):
|
| 826 |
+
if not uid:
|
| 827 |
+
raise HTTPException(status_code=400, detail="Missing uid in headers")
|
| 828 |
+
|
| 829 |
+
global stream_manager
|
| 830 |
+
# stream_manager.session_id += 1
|
| 831 |
+
logger.info(f"uid {uid}: received stop_response")
|
| 832 |
+
stream_manager.stop_response = True
|
| 833 |
+
response = {
|
| 834 |
+
"id": uid,
|
| 835 |
+
"choices": {
|
| 836 |
+
"role": "assistant",
|
| 837 |
+
"content": "success",
|
| 838 |
+
"finish_reason": "stop"
|
| 839 |
+
}
|
| 840 |
+
}
|
| 841 |
+
return JSONResponse(content=response, status_code=200)
|
| 842 |
+
|
| 843 |
+
@app.post("/feedback")
|
| 844 |
+
@app.post("/api/v1/feedback")
|
| 845 |
+
async def feedback(request: Request, uid: Optional[str] = Header(None)):
|
| 846 |
+
global stream_manager
|
| 847 |
+
|
| 848 |
+
# Validate the 'uid' header
|
| 849 |
+
if not uid:
|
| 850 |
+
raise HTTPException(status_code=400, detail="Missing 'uid' header")
|
| 851 |
+
|
| 852 |
+
try:
|
| 853 |
+
data = await request.json()
|
| 854 |
+
if "response_id" not in data or "rating" not in data:
|
| 855 |
+
raise HTTPException(status_code=400, detail="Invalid request: must have response_id and rating")
|
| 856 |
+
response_id = data.get("response_id", "")
|
| 857 |
+
rating = data.get("rating", "")
|
| 858 |
+
comment = data.get("comment", "")
|
| 859 |
+
# Validate the rating
|
| 860 |
+
if rating not in ["like", "dislike"]:
|
| 861 |
+
raise HTTPException(status_code=400, detail=f"Invalid rating value: {rating}")
|
| 862 |
+
|
| 863 |
+
# Define the log file path
|
| 864 |
+
log_file_path = f"{stream_manager.savedir}/feedback_log/{response_id}.{rating}"
|
| 865 |
+
# Write the feedback to the file asynchronously
|
| 866 |
+
async with aiofiles.open(log_file_path, mode="a") as file:
|
| 867 |
+
await file.write(f"model: {stream_manager.minicpmo_model_path}\nuid {uid}: {comment}\n")
|
| 868 |
+
response = {
|
| 869 |
+
"id": uid,
|
| 870 |
+
"choices": {
|
| 871 |
+
"role": "assistant",
|
| 872 |
+
"content": "success",
|
| 873 |
+
"finish_reason": "done"
|
| 874 |
+
}
|
| 875 |
+
}
|
| 876 |
+
return JSONResponse(content=response, status_code=200)
|
| 877 |
+
except Exception as e:
|
| 878 |
+
logger.error(f"Error processing feedback for user {uid}: {str(e)}")
|
| 879 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
@app.post("/init_options")
|
| 883 |
+
@app.post("/api/v1/init_options")
|
| 884 |
+
async def init_options(request: Request, uid: Optional[str] = Header(None)):
|
| 885 |
+
global stream_manager
|
| 886 |
+
|
| 887 |
+
stream_manager.update_last_request_time()
|
| 888 |
+
|
| 889 |
+
if not uid:
|
| 890 |
+
raise HTTPException(status_code=400, detail="Missing uid in headers")
|
| 891 |
+
try:
|
| 892 |
+
# Parse JSON request
|
| 893 |
+
data = await request.json()
|
| 894 |
+
|
| 895 |
+
# Validate basic structure
|
| 896 |
+
if not isinstance(data, dict) or "messages" not in data:
|
| 897 |
+
raise HTTPException(status_code=400, detail="Invalid request format")
|
| 898 |
+
|
| 899 |
+
messages = data.get("messages", [])
|
| 900 |
+
for message in messages:
|
| 901 |
+
if not isinstance(message, dict) or "role" not in message or "content" not in message:
|
| 902 |
+
raise HTTPException(status_code=400, detail="Invalid message format")
|
| 903 |
+
|
| 904 |
+
for content in message.get("content", []):
|
| 905 |
+
if content["type"] == "input_audio":
|
| 906 |
+
audio_data = content["input_audio"].get("data", "")
|
| 907 |
+
audio_fmt = content["input_audio"].get("format", "")
|
| 908 |
+
stream_manager.upload_customized_audio(audio_data, audio_fmt)
|
| 909 |
+
elif content["type"] == "options":
|
| 910 |
+
stream_manager.update_customized_options(uid, content["options"])
|
| 911 |
+
else:
|
| 912 |
+
ctype = content["type"]
|
| 913 |
+
raise HTTPException(status_code=400, detail=f"Invalid content type: {ctype}")
|
| 914 |
+
version = stream_manager.model_version
|
| 915 |
+
print(version)
|
| 916 |
+
response = {
|
| 917 |
+
"id": uid,
|
| 918 |
+
"choices": {
|
| 919 |
+
"role": "assistant",
|
| 920 |
+
"content": version,
|
| 921 |
+
"finish_reason": "done"
|
| 922 |
+
}
|
| 923 |
+
}
|
| 924 |
+
return JSONResponse(content=response, status_code=200)
|
| 925 |
+
except Exception as e:
|
| 926 |
+
raise HTTPException(status_code=400, detail=f"init options error: {str(e)}")
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
@app.get('/health')
|
| 930 |
+
@app.get('/api/v1/health')
|
| 931 |
+
async def health_check():
|
| 932 |
+
return {"status": "OK"}
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
if __name__ == "__main__":
|
| 936 |
+
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/vad_utils.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import traceback
|
| 7 |
+
|
| 8 |
+
from typing import List, NamedTuple, Optional
|
| 9 |
+
|
| 10 |
+
class VadOptions(NamedTuple):
|
| 11 |
+
"""VAD options.
|
| 12 |
+
|
| 13 |
+
Attributes:
|
| 14 |
+
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
| 15 |
+
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
| 16 |
+
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
| 17 |
+
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
| 18 |
+
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
| 19 |
+
than max_speech_duration_s will be split at the timestamp of the last silence that
|
| 20 |
+
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
| 21 |
+
split aggressively just before max_speech_duration_s.
|
| 22 |
+
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
| 23 |
+
before separating it
|
| 24 |
+
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
| 25 |
+
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
| 26 |
+
Values other than these may affect model performance!!
|
| 27 |
+
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# threshold: float = 0.3 # rep 0.5
|
| 31 |
+
# min_speech_duration_ms: int = 250
|
| 32 |
+
# max_speech_duration_s: float = float("inf")
|
| 33 |
+
# min_silence_duration_ms: int = 2000
|
| 34 |
+
# window_size_samples: int = 1024
|
| 35 |
+
# speech_pad_ms: int = 600 # rep 400
|
| 36 |
+
|
| 37 |
+
threshold: float = 0.7 # gw: 0.3 # rep 0.5
|
| 38 |
+
min_speech_duration_ms: int = 128 # original & gw: 250
|
| 39 |
+
max_speech_duration_s: float = float("inf")
|
| 40 |
+
min_silence_duration_ms: int = 500 # original & gw: 2000
|
| 41 |
+
window_size_samples: int = 1024
|
| 42 |
+
speech_pad_ms: int = 30 # gw: 600 # rep 400
|
| 43 |
+
|
| 44 |
+
class SileroVADModel:
|
| 45 |
+
def __init__(self, path):
|
| 46 |
+
try:
|
| 47 |
+
import onnxruntime
|
| 48 |
+
except ImportError as e:
|
| 49 |
+
raise RuntimeError(
|
| 50 |
+
"Applying the VAD filter requires the onnxruntime package"
|
| 51 |
+
) from e
|
| 52 |
+
|
| 53 |
+
opts = onnxruntime.SessionOptions()
|
| 54 |
+
opts.inter_op_num_threads = 1
|
| 55 |
+
opts.intra_op_num_threads = 1
|
| 56 |
+
opts.log_severity_level = 4
|
| 57 |
+
|
| 58 |
+
self.session = onnxruntime.InferenceSession(
|
| 59 |
+
path,
|
| 60 |
+
providers=["CPUExecutionProvider"],
|
| 61 |
+
sess_options=opts,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def get_initial_state(self, batch_size: int):
|
| 65 |
+
h = np.zeros((2, batch_size, 64), dtype=np.float32)
|
| 66 |
+
c = np.zeros((2, batch_size, 64), dtype=np.float32)
|
| 67 |
+
return h, c
|
| 68 |
+
|
| 69 |
+
def __call__(self, x, state, sr: int):
|
| 70 |
+
if len(x.shape) == 1:
|
| 71 |
+
x = np.expand_dims(x, 0)
|
| 72 |
+
if len(x.shape) > 2:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
f"Too many dimensions for input audio chunk {len(x.shape)}"
|
| 75 |
+
)
|
| 76 |
+
if sr / x.shape[1] > 31.25:
|
| 77 |
+
raise ValueError("Input audio chunk is too short")
|
| 78 |
+
|
| 79 |
+
h, c = state
|
| 80 |
+
|
| 81 |
+
ort_inputs = {
|
| 82 |
+
"input": x,
|
| 83 |
+
#"state": np.concatenate((h, c), axis=0),
|
| 84 |
+
"h": h,
|
| 85 |
+
"c": c,
|
| 86 |
+
"sr": np.array(sr, dtype="int64"),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
out, h, c = self.session.run(None, ort_inputs)
|
| 90 |
+
#out = self.session.run(None, ort_inputs)
|
| 91 |
+
state = (h, c)
|
| 92 |
+
return out, state
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@functools.lru_cache
|
| 96 |
+
def get_vad_model():
|
| 97 |
+
"""Returns the VAD model instance."""
|
| 98 |
+
path = os.path.join(os.path.dirname(__file__), "silero_vad.onnx")
|
| 99 |
+
return SileroVADModel(path)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_speech_timestamps(
|
| 103 |
+
audio: np.ndarray,
|
| 104 |
+
vad_options: Optional[VadOptions] = None,
|
| 105 |
+
**kwargs,
|
| 106 |
+
) -> List[dict]:
|
| 107 |
+
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
audio: One dimensional float array.
|
| 111 |
+
vad_options: Options for VAD processing.
|
| 112 |
+
kwargs: VAD options passed as keyword arguments for backward compatibility.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List of dicts containing begin and end samples of each speech chunk.
|
| 116 |
+
"""
|
| 117 |
+
if vad_options is None:
|
| 118 |
+
vad_options = VadOptions(**kwargs)
|
| 119 |
+
|
| 120 |
+
threshold = vad_options.threshold
|
| 121 |
+
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
| 122 |
+
max_speech_duration_s = vad_options.max_speech_duration_s
|
| 123 |
+
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
| 124 |
+
window_size_samples = vad_options.window_size_samples
|
| 125 |
+
speech_pad_ms = vad_options.speech_pad_ms
|
| 126 |
+
|
| 127 |
+
if window_size_samples not in [512, 1024, 1536]:
|
| 128 |
+
warnings.warn(
|
| 129 |
+
"Unusual window_size_samples! Supported window_size_samples:\n"
|
| 130 |
+
" - [512, 1024, 1536] for 16000 sampling_rate"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
sampling_rate = 16000
|
| 134 |
+
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 #如果间隔区间没这个长度就不会添加
|
| 135 |
+
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 136 |
+
max_speech_samples = (
|
| 137 |
+
sampling_rate * max_speech_duration_s
|
| 138 |
+
- window_size_samples
|
| 139 |
+
- 2 * speech_pad_samples
|
| 140 |
+
)
|
| 141 |
+
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 # 在每个silent需要等 min_silence_duration_ms 后才结束,
|
| 142 |
+
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 # 0.098s # need to adjust?
|
| 143 |
+
|
| 144 |
+
audio_length_samples = len(audio)
|
| 145 |
+
|
| 146 |
+
# import pdb
|
| 147 |
+
# pdb.set_trace()
|
| 148 |
+
|
| 149 |
+
model = get_vad_model()
|
| 150 |
+
state = model.get_initial_state(batch_size=1)
|
| 151 |
+
|
| 152 |
+
speech_probs = []
|
| 153 |
+
#print("audio_length_samples ", audio_length_samples, ", window_size_samples ", window_size_samples)
|
| 154 |
+
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
| 155 |
+
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
| 156 |
+
if len(chunk) < window_size_samples:
|
| 157 |
+
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
| 158 |
+
speech_prob, state = model(chunk, state, sampling_rate)
|
| 159 |
+
speech_probs.append(speech_prob)
|
| 160 |
+
|
| 161 |
+
triggered = False
|
| 162 |
+
speeches = []
|
| 163 |
+
current_speech = {}
|
| 164 |
+
neg_threshold = threshold - 0.15
|
| 165 |
+
|
| 166 |
+
# to save potential segment end (and tolerate some silence)
|
| 167 |
+
temp_end = 0
|
| 168 |
+
# to save potential segment limits in case of maximum segment size reached
|
| 169 |
+
prev_end = next_start = 0
|
| 170 |
+
|
| 171 |
+
# 大概是一段音频找出其中连续部分,如果遇到silent的话会先记录temp_end,然后如果没超过最小silent长度遇到active的情况下会重置temp_end。silent片段会分别记录silent的起终,在超过长度的时候切开(不完全确定,但是inf的最大长也遇不到)
|
| 172 |
+
|
| 173 |
+
for i, speech_prob in enumerate(speech_probs):
|
| 174 |
+
if (speech_prob >= threshold) and temp_end:
|
| 175 |
+
temp_end = 0
|
| 176 |
+
if next_start < prev_end:
|
| 177 |
+
next_start = window_size_samples * i
|
| 178 |
+
|
| 179 |
+
if (speech_prob >= threshold) and not triggered:
|
| 180 |
+
triggered = True
|
| 181 |
+
current_speech["start"] = window_size_samples * i
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
if (
|
| 185 |
+
triggered
|
| 186 |
+
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
|
| 187 |
+
):
|
| 188 |
+
if prev_end:
|
| 189 |
+
current_speech["end"] = prev_end
|
| 190 |
+
speeches.append(current_speech)
|
| 191 |
+
current_speech = {}
|
| 192 |
+
# previously reached silence (< neg_thres) and is still not speech (< thres)
|
| 193 |
+
if next_start < prev_end:
|
| 194 |
+
triggered = False
|
| 195 |
+
else:
|
| 196 |
+
current_speech["start"] = next_start
|
| 197 |
+
prev_end = next_start = temp_end = 0
|
| 198 |
+
else:
|
| 199 |
+
current_speech["end"] = window_size_samples * i
|
| 200 |
+
speeches.append(current_speech)
|
| 201 |
+
current_speech = {}
|
| 202 |
+
prev_end = next_start = temp_end = 0
|
| 203 |
+
triggered = False
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
if (speech_prob < neg_threshold) and triggered:
|
| 207 |
+
if not temp_end:
|
| 208 |
+
temp_end = window_size_samples * i
|
| 209 |
+
# condition to avoid cutting in very short silence
|
| 210 |
+
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
|
| 211 |
+
prev_end = temp_end
|
| 212 |
+
if (window_size_samples * i) - temp_end < min_silence_samples:
|
| 213 |
+
continue
|
| 214 |
+
else:
|
| 215 |
+
current_speech["end"] = temp_end
|
| 216 |
+
if (
|
| 217 |
+
current_speech["end"] - current_speech["start"]
|
| 218 |
+
) > min_speech_samples:
|
| 219 |
+
speeches.append(current_speech)
|
| 220 |
+
current_speech = {}
|
| 221 |
+
prev_end = next_start = temp_end = 0
|
| 222 |
+
triggered = False
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if (
|
| 227 |
+
current_speech
|
| 228 |
+
and (audio_length_samples - current_speech["start"]) > min_speech_samples
|
| 229 |
+
):
|
| 230 |
+
current_speech["end"] = audio_length_samples
|
| 231 |
+
speeches.append(current_speech)
|
| 232 |
+
|
| 233 |
+
# pad 多少ms,每个中间都会不足平分
|
| 234 |
+
for i, speech in enumerate(speeches):
|
| 235 |
+
if i == 0:
|
| 236 |
+
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
|
| 237 |
+
if i != len(speeches) - 1:
|
| 238 |
+
silence_duration = speeches[i + 1]["start"] - speech["end"]
|
| 239 |
+
if silence_duration < 2 * speech_pad_samples:
|
| 240 |
+
speech["end"] += int(silence_duration // 2)
|
| 241 |
+
speeches[i + 1]["start"] = int(
|
| 242 |
+
max(0, speeches[i + 1]["start"] - silence_duration // 2)
|
| 243 |
+
)
|
| 244 |
+
else:
|
| 245 |
+
speech["end"] = int(
|
| 246 |
+
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 247 |
+
)
|
| 248 |
+
speeches[i + 1]["start"] = int(
|
| 249 |
+
max(0, speeches[i + 1]["start"] - speech_pad_samples)
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
speech["end"] = int(
|
| 253 |
+
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 254 |
+
)
|
| 255 |
+
return speeches
|
| 256 |
+
|
| 257 |
+
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
| 258 |
+
"""Collects and concatenates audio chunks."""
|
| 259 |
+
if not chunks:
|
| 260 |
+
return np.array([], dtype=np.float32)
|
| 261 |
+
|
| 262 |
+
return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def run_vad(ori_audio, sr, vad_options=None):
|
| 266 |
+
_st = time.time()
|
| 267 |
+
try:
|
| 268 |
+
audio = np.frombuffer(ori_audio, dtype=np.int16)
|
| 269 |
+
audio = audio.astype(np.float32) / 32768.0
|
| 270 |
+
sampling_rate = 16000
|
| 271 |
+
if sr != sampling_rate:
|
| 272 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
|
| 273 |
+
# print('audio.encode.shape: {}'.format(audio.shape))
|
| 274 |
+
if vad_options is None:
|
| 275 |
+
vad_options = VadOptions()
|
| 276 |
+
|
| 277 |
+
# 确保传递给 get_speech_timestamps 的是 VadOptions 实例
|
| 278 |
+
speech_chunks = get_speech_timestamps(audio, vad_options=vad_options)
|
| 279 |
+
# print(speech_chunks)
|
| 280 |
+
audio = collect_chunks(audio, speech_chunks)
|
| 281 |
+
# print(audio.shape)
|
| 282 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
| 283 |
+
|
| 284 |
+
# print('audio.decode.shape: {}'.format(audio.shape))
|
| 285 |
+
if sr != sampling_rate:
|
| 286 |
+
# resample to original sampling rate
|
| 287 |
+
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
|
| 288 |
+
else:
|
| 289 |
+
vad_audio = audio
|
| 290 |
+
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
|
| 291 |
+
|
| 292 |
+
# 这个round会有一定的误差
|
| 293 |
+
|
| 294 |
+
vad_audio_bytes = vad_audio.tobytes()
|
| 295 |
+
|
| 296 |
+
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
|
| 299 |
+
print(msg)
|
| 300 |
+
return -1, ori_audio, round(time.time() - _st, 4)
|
| 301 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.development
ADDED
|
File without changes
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.env.production
ADDED
|
File without changes
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc-auto-import.json
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"globals": {
|
| 3 |
+
"Component": true,
|
| 4 |
+
"ComponentPublicInstance": true,
|
| 5 |
+
"ComputedRef": true,
|
| 6 |
+
"EffectScope": true,
|
| 7 |
+
"ExtractDefaultPropTypes": true,
|
| 8 |
+
"ExtractPropTypes": true,
|
| 9 |
+
"ExtractPublicPropTypes": true,
|
| 10 |
+
"InjectionKey": true,
|
| 11 |
+
"LegalTypeEnum": true,
|
| 12 |
+
"LoginTypeEnum": true,
|
| 13 |
+
"PropType": true,
|
| 14 |
+
"Ref": true,
|
| 15 |
+
"VNode": true,
|
| 16 |
+
"WritableComputedRef": true,
|
| 17 |
+
"acceptHMRUpdate": true,
|
| 18 |
+
"ajaxHeader": true,
|
| 19 |
+
"asyncComputed": true,
|
| 20 |
+
"authLogin": true,
|
| 21 |
+
"autoResetRef": true,
|
| 22 |
+
"computed": true,
|
| 23 |
+
"computedAsync": true,
|
| 24 |
+
"computedEager": true,
|
| 25 |
+
"computedInject": true,
|
| 26 |
+
"computedWithControl": true,
|
| 27 |
+
"controlledComputed": true,
|
| 28 |
+
"controlledRef": true,
|
| 29 |
+
"createApp": true,
|
| 30 |
+
"createEventHook": true,
|
| 31 |
+
"createGlobalState": true,
|
| 32 |
+
"createInjectionState": true,
|
| 33 |
+
"createPinia": true,
|
| 34 |
+
"createReactiveFn": true,
|
| 35 |
+
"createReusableTemplate": true,
|
| 36 |
+
"createSharedComposable": true,
|
| 37 |
+
"createTemplatePromise": true,
|
| 38 |
+
"createUnrefFn": true,
|
| 39 |
+
"customRef": true,
|
| 40 |
+
"debouncedRef": true,
|
| 41 |
+
"debouncedWatch": true,
|
| 42 |
+
"defineAsyncComponent": true,
|
| 43 |
+
"defineComponent": true,
|
| 44 |
+
"defineStore": true,
|
| 45 |
+
"eagerComputed": true,
|
| 46 |
+
"effectScope": true,
|
| 47 |
+
"extendRef": true,
|
| 48 |
+
"fetchSmsVerifyCode": true,
|
| 49 |
+
"getActivePinia": true,
|
| 50 |
+
"getCurrentInstance": true,
|
| 51 |
+
"getCurrentScope": true,
|
| 52 |
+
"getHomeInfo": true,
|
| 53 |
+
"h": true,
|
| 54 |
+
"ignorableWatch": true,
|
| 55 |
+
"inject": true,
|
| 56 |
+
"injectLocal": true,
|
| 57 |
+
"isDefined": true,
|
| 58 |
+
"isProxy": true,
|
| 59 |
+
"isReactive": true,
|
| 60 |
+
"isReadonly": true,
|
| 61 |
+
"isRef": true,
|
| 62 |
+
"loginSuccess": true,
|
| 63 |
+
"makeDestructurable": true,
|
| 64 |
+
"mapActions": true,
|
| 65 |
+
"mapGetters": true,
|
| 66 |
+
"mapState": true,
|
| 67 |
+
"mapStores": true,
|
| 68 |
+
"mapWritableState": true,
|
| 69 |
+
"markRaw": true,
|
| 70 |
+
"nextTick": true,
|
| 71 |
+
"onActivated": true,
|
| 72 |
+
"onBeforeMount": true,
|
| 73 |
+
"onBeforeRouteLeave": true,
|
| 74 |
+
"onBeforeRouteUpdate": true,
|
| 75 |
+
"onBeforeUnmount": true,
|
| 76 |
+
"onBeforeUpdate": true,
|
| 77 |
+
"onClickOutside": true,
|
| 78 |
+
"onDeactivated": true,
|
| 79 |
+
"onErrorCaptured": true,
|
| 80 |
+
"onKeyStroke": true,
|
| 81 |
+
"onLongPress": true,
|
| 82 |
+
"onMounted": true,
|
| 83 |
+
"onRenderTracked": true,
|
| 84 |
+
"onRenderTriggered": true,
|
| 85 |
+
"onScopeDispose": true,
|
| 86 |
+
"onServerPrefetch": true,
|
| 87 |
+
"onStartTyping": true,
|
| 88 |
+
"onUnmounted": true,
|
| 89 |
+
"onUpdated": true,
|
| 90 |
+
"pausableWatch": true,
|
| 91 |
+
"provide": true,
|
| 92 |
+
"provideLocal": true,
|
| 93 |
+
"reactify": true,
|
| 94 |
+
"reactifyObject": true,
|
| 95 |
+
"reactive": true,
|
| 96 |
+
"reactiveComputed": true,
|
| 97 |
+
"reactiveOmit": true,
|
| 98 |
+
"reactivePick": true,
|
| 99 |
+
"readonly": true,
|
| 100 |
+
"ref": true,
|
| 101 |
+
"refAutoReset": true,
|
| 102 |
+
"refDebounced": true,
|
| 103 |
+
"refDefault": true,
|
| 104 |
+
"refThrottled": true,
|
| 105 |
+
"refWithControl": true,
|
| 106 |
+
"resolveComponent": true,
|
| 107 |
+
"resolveRef": true,
|
| 108 |
+
"resolveUnref": true,
|
| 109 |
+
"setActivePinia": true,
|
| 110 |
+
"setMapStoreSuffix": true,
|
| 111 |
+
"setupStore": true,
|
| 112 |
+
"shallowReactive": true,
|
| 113 |
+
"shallowReadonly": true,
|
| 114 |
+
"shallowRef": true,
|
| 115 |
+
"store": true,
|
| 116 |
+
"storeToRefs": true,
|
| 117 |
+
"submitFeedback": true,
|
| 118 |
+
"syncRef": true,
|
| 119 |
+
"syncRefs": true,
|
| 120 |
+
"templateRef": true,
|
| 121 |
+
"throttledRef": true,
|
| 122 |
+
"throttledWatch": true,
|
| 123 |
+
"toRaw": true,
|
| 124 |
+
"toReactive": true,
|
| 125 |
+
"toRef": true,
|
| 126 |
+
"toRefs": true,
|
| 127 |
+
"toValue": true,
|
| 128 |
+
"triggerRef": true,
|
| 129 |
+
"tryOnBeforeMount": true,
|
| 130 |
+
"tryOnBeforeUnmount": true,
|
| 131 |
+
"tryOnMounted": true,
|
| 132 |
+
"tryOnScopeDispose": true,
|
| 133 |
+
"tryOnUnmounted": true,
|
| 134 |
+
"unref": true,
|
| 135 |
+
"unrefElement": true,
|
| 136 |
+
"until": true,
|
| 137 |
+
"useActiveElement": true,
|
| 138 |
+
"useAnimate": true,
|
| 139 |
+
"useArrayDifference": true,
|
| 140 |
+
"useArrayEvery": true,
|
| 141 |
+
"useArrayFilter": true,
|
| 142 |
+
"useArrayFind": true,
|
| 143 |
+
"useArrayFindIndex": true,
|
| 144 |
+
"useArrayFindLast": true,
|
| 145 |
+
"useArrayIncludes": true,
|
| 146 |
+
"useArrayJoin": true,
|
| 147 |
+
"useArrayMap": true,
|
| 148 |
+
"useArrayReduce": true,
|
| 149 |
+
"useArraySome": true,
|
| 150 |
+
"useArrayUnique": true,
|
| 151 |
+
"useAsyncQueue": true,
|
| 152 |
+
"useAsyncState": true,
|
| 153 |
+
"useAttrs": true,
|
| 154 |
+
"useBase64": true,
|
| 155 |
+
"useBattery": true,
|
| 156 |
+
"useBluetooth": true,
|
| 157 |
+
"useBreakpoints": true,
|
| 158 |
+
"useBroadcastChannel": true,
|
| 159 |
+
"useBrowserLocation": true,
|
| 160 |
+
"useCached": true,
|
| 161 |
+
"useClearLocalCache": true,
|
| 162 |
+
"useClipboard": true,
|
| 163 |
+
"useClipboardItems": true,
|
| 164 |
+
"useCloned": true,
|
| 165 |
+
"useColorMode": true,
|
| 166 |
+
"useConfirmDialog": true,
|
| 167 |
+
"useCounter": true,
|
| 168 |
+
"useCssModule": true,
|
| 169 |
+
"useCssVar": true,
|
| 170 |
+
"useCssVars": true,
|
| 171 |
+
"useCurrentElement": true,
|
| 172 |
+
"useCycleList": true,
|
| 173 |
+
"useDark": true,
|
| 174 |
+
"useDateFormat": true,
|
| 175 |
+
"useDebounce": true,
|
| 176 |
+
"useDebounceFn": true,
|
| 177 |
+
"useDebouncedRefHistory": true,
|
| 178 |
+
"useDeviceMotion": true,
|
| 179 |
+
"useDeviceOrientation": true,
|
| 180 |
+
"useDevicePixelRatio": true,
|
| 181 |
+
"useDevicesList": true,
|
| 182 |
+
"useDisplayMedia": true,
|
| 183 |
+
"useDocumentVisibility": true,
|
| 184 |
+
"useDraggable": true,
|
| 185 |
+
"useDropZone": true,
|
| 186 |
+
"useElementBounding": true,
|
| 187 |
+
"useElementByPoint": true,
|
| 188 |
+
"useElementHover": true,
|
| 189 |
+
"useElementSize": true,
|
| 190 |
+
"useElementVisibility": true,
|
| 191 |
+
"useEventBus": true,
|
| 192 |
+
"useEventListener": true,
|
| 193 |
+
"useEventSource": true,
|
| 194 |
+
"useEyeDropper": true,
|
| 195 |
+
"useFavicon": true,
|
| 196 |
+
"useFetch": true,
|
| 197 |
+
"useFetchLogin": true,
|
| 198 |
+
"useFileDialog": true,
|
| 199 |
+
"useFileSystemAccess": true,
|
| 200 |
+
"useFocus": true,
|
| 201 |
+
"useFocusWithin": true,
|
| 202 |
+
"useFps": true,
|
| 203 |
+
"useFullscreen": true,
|
| 204 |
+
"useGamepad": true,
|
| 205 |
+
"useGeolocation": true,
|
| 206 |
+
"useGetLocalCache": true,
|
| 207 |
+
"useHttp": true,
|
| 208 |
+
"useIdle": true,
|
| 209 |
+
"useImage": true,
|
| 210 |
+
"useInfiniteScroll": true,
|
| 211 |
+
"useIntersectionObserver": true,
|
| 212 |
+
"useInterval": true,
|
| 213 |
+
"useIntervalFn": true,
|
| 214 |
+
"useKeyModifier": true,
|
| 215 |
+
"useLastChanged": true,
|
| 216 |
+
"useLegal": true,
|
| 217 |
+
"useLink": true,
|
| 218 |
+
"useLocalStorage": true,
|
| 219 |
+
"useLogin": true,
|
| 220 |
+
"useMagicKeys": true,
|
| 221 |
+
"useManualRefHistory": true,
|
| 222 |
+
"useMediaControls": true,
|
| 223 |
+
"useMediaQuery": true,
|
| 224 |
+
"useMemoize": true,
|
| 225 |
+
"useMemory": true,
|
| 226 |
+
"useMounted": true,
|
| 227 |
+
"useMouse": true,
|
| 228 |
+
"useMouseInElement": true,
|
| 229 |
+
"useMousePressed": true,
|
| 230 |
+
"useMutationObserver": true,
|
| 231 |
+
"useNavigatorLanguage": true,
|
| 232 |
+
"useNetwork": true,
|
| 233 |
+
"useNow": true,
|
| 234 |
+
"useObjectUrl": true,
|
| 235 |
+
"useOffsetPagination": true,
|
| 236 |
+
"useOnline": true,
|
| 237 |
+
"usePageLeave": true,
|
| 238 |
+
"useParallax": true,
|
| 239 |
+
"useParentElement": true,
|
| 240 |
+
"usePerformanceObserver": true,
|
| 241 |
+
"usePermission": true,
|
| 242 |
+
"usePointer": true,
|
| 243 |
+
"usePointerLock": true,
|
| 244 |
+
"usePointerSwipe": true,
|
| 245 |
+
"usePreferredColorScheme": true,
|
| 246 |
+
"usePreferredContrast": true,
|
| 247 |
+
"usePreferredDark": true,
|
| 248 |
+
"usePreferredLanguages": true,
|
| 249 |
+
"usePreferredReducedMotion": true,
|
| 250 |
+
"usePrevious": true,
|
| 251 |
+
"useRafFn": true,
|
| 252 |
+
"useRefHistory": true,
|
| 253 |
+
"useResizeObserver": true,
|
| 254 |
+
"useRoute": true,
|
| 255 |
+
"useRouter": true,
|
| 256 |
+
"useScreenOrientation": true,
|
| 257 |
+
"useScreenSafeArea": true,
|
| 258 |
+
"useScriptTag": true,
|
| 259 |
+
"useScroll": true,
|
| 260 |
+
"useScrollLock": true,
|
| 261 |
+
"useSessionStorage": true,
|
| 262 |
+
"useSetLocalCache": true,
|
| 263 |
+
"useShare": true,
|
| 264 |
+
"useSlots": true,
|
| 265 |
+
"useSorted": true,
|
| 266 |
+
"useSpeechRecognition": true,
|
| 267 |
+
"useSpeechSynthesis": true,
|
| 268 |
+
"useStepper": true,
|
| 269 |
+
"useStorage": true,
|
| 270 |
+
"useStorageAsync": true,
|
| 271 |
+
"useStyleTag": true,
|
| 272 |
+
"useSupported": true,
|
| 273 |
+
"useSwipe": true,
|
| 274 |
+
"useTemplateRefsList": true,
|
| 275 |
+
"useTextDirection": true,
|
| 276 |
+
"useTextSelection": true,
|
| 277 |
+
"useTextareaAutosize": true,
|
| 278 |
+
"useThrottle": true,
|
| 279 |
+
"useThrottleFn": true,
|
| 280 |
+
"useThrottledRefHistory": true,
|
| 281 |
+
"useTimeAgo": true,
|
| 282 |
+
"useTimeout": true,
|
| 283 |
+
"useTimeoutFn": true,
|
| 284 |
+
"useTimeoutPoll": true,
|
| 285 |
+
"useTimestamp": true,
|
| 286 |
+
"useTitle": true,
|
| 287 |
+
"useToNumber": true,
|
| 288 |
+
"useToString": true,
|
| 289 |
+
"useToggle": true,
|
| 290 |
+
"useTransition": true,
|
| 291 |
+
"useUrlSearchParams": true,
|
| 292 |
+
"useUserMedia": true,
|
| 293 |
+
"useUserStore": true,
|
| 294 |
+
"useUserStoreWithOut": true,
|
| 295 |
+
"useVModel": true,
|
| 296 |
+
"useVModels": true,
|
| 297 |
+
"useVibrate": true,
|
| 298 |
+
"useVirtualList": true,
|
| 299 |
+
"useWakeLock": true,
|
| 300 |
+
"useWebNotification": true,
|
| 301 |
+
"useWebSocket": true,
|
| 302 |
+
"useWebWorker": true,
|
| 303 |
+
"useWebWorkerFn": true,
|
| 304 |
+
"useWindowFocus": true,
|
| 305 |
+
"useWindowScroll": true,
|
| 306 |
+
"useWindowSize": true,
|
| 307 |
+
"watch": true,
|
| 308 |
+
"watchArray": true,
|
| 309 |
+
"watchAtMost": true,
|
| 310 |
+
"watchDebounced": true,
|
| 311 |
+
"watchDeep": true,
|
| 312 |
+
"watchEffect": true,
|
| 313 |
+
"watchIgnorable": true,
|
| 314 |
+
"watchImmediate": true,
|
| 315 |
+
"watchOnce": true,
|
| 316 |
+
"watchPausable": true,
|
| 317 |
+
"watchPostEffect": true,
|
| 318 |
+
"watchSyncEffect": true,
|
| 319 |
+
"watchThrottled": true,
|
| 320 |
+
"watchTriggerable": true,
|
| 321 |
+
"watchWithFilter": true,
|
| 322 |
+
"whenever": true,
|
| 323 |
+
"ElMessage": true,
|
| 324 |
+
"ElLoading": true,
|
| 325 |
+
"deleteHistoryBatch": true,
|
| 326 |
+
"deleteHistoryItem": true,
|
| 327 |
+
"getHistory": true,
|
| 328 |
+
"createConv": true,
|
| 329 |
+
"fetchHistoryList": true,
|
| 330 |
+
"stopChat": true,
|
| 331 |
+
"useChatStore": true,
|
| 332 |
+
"useChatStoreWithOut": true,
|
| 333 |
+
"useChatExchangeStore": true,
|
| 334 |
+
"useChatExchangeStoreWithOut": true,
|
| 335 |
+
"useExchangeStore": true,
|
| 336 |
+
"useExchangeStoreWithOut": true,
|
| 337 |
+
"delMessage": true,
|
| 338 |
+
"sendRating": true,
|
| 339 |
+
"getInitialActions": true,
|
| 340 |
+
"sendFeedback": true,
|
| 341 |
+
"md": true,
|
| 342 |
+
"useMarkdown": true,
|
| 343 |
+
"connectService": true,
|
| 344 |
+
"sendMessage": true,
|
| 345 |
+
"Audio": true,
|
| 346 |
+
"SoundRecording": true,
|
| 347 |
+
"getVolume": true,
|
| 348 |
+
"ElMessageBox": true,
|
| 349 |
+
"encodeWav": true,
|
| 350 |
+
"encodeWAV": true,
|
| 351 |
+
"stopMessage": true,
|
| 352 |
+
"TaskQueue": true,
|
| 353 |
+
"getNewUserId": true,
|
| 354 |
+
"setNewUserId": true,
|
| 355 |
+
"uploadFile": true,
|
| 356 |
+
"feedback": true,
|
| 357 |
+
"uploadConfig": true
|
| 358 |
+
}
|
| 359 |
+
}
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/minicpm-o_2.6/web_server/.eslintrc.cjs
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* eslint-env node */
|
| 2 |
+
require('@rushstack/eslint-patch/modern-module-resolution');
|
| 3 |
+
|
| 4 |
+
module.exports = {
|
| 5 |
+
root: true,
|
| 6 |
+
extends: [
|
| 7 |
+
'plugin:vue/vue3-essential',
|
| 8 |
+
'eslint:recommended',
|
| 9 |
+
'@vue/eslint-config-prettier/skip-formatting',
|
| 10 |
+
'./.eslintrc-auto-import.json',
|
| 11 |
+
],
|
| 12 |
+
parserOptions: {
|
| 13 |
+
ecmaVersion: 'latest',
|
| 14 |
+
},
|
| 15 |
+
rules: {
|
| 16 |
+
'no-console': process.env.NODE_ENV === 'production' ? 'off' : 'warn',
|
| 17 |
+
'no-debugger': process.env.NODE_ENV === 'production' ? 'error' : 'warn',
|
| 18 |
+
'no-var': process.env.NODE_ENV === 'production' ? 'off' : 'warn',
|
| 19 |
+
'no-undef': process.env.NODE_ENV === 'production' ? 'error' : 'warn',
|
| 20 |
+
'vue/multi-word-component-names': 'off', // 不校验组件名
|
| 21 |
+
'no-empty': 0, // 允许代码块为空
|
| 22 |
+
'vue/no-unused-components': 'warn',
|
| 23 |
+
'no-unused-vars': 'warn',
|
| 24 |
+
'prettier/prettier': 'off', // 不符合prettier格式规范的编码eslint直接自动报错
|
| 25 |
+
},
|
| 26 |
+
};
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import traceback
|
| 6 |
+
import re
|
| 7 |
+
import torch
|
| 8 |
+
import argparse
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
# README, How to run demo on different devices
|
| 12 |
+
# For Nvidia GPUs support BF16 (like A100, H100, RTX3090)
|
| 13 |
+
# python web_demo.py --device cuda --dtype bf16
|
| 14 |
+
|
| 15 |
+
# For Nvidia GPUs do NOT support BF16 (like V100, T4, RTX2080)
|
| 16 |
+
# python web_demo.py --device cuda --dtype fp16
|
| 17 |
+
|
| 18 |
+
# For Mac with MPS (Apple silicon or AMD GPUs).
|
| 19 |
+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo.py --device mps --dtype fp16
|
| 20 |
+
|
| 21 |
+
# Argparser
|
| 22 |
+
parser = argparse.ArgumentParser(description='demo')
|
| 23 |
+
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
| 24 |
+
parser.add_argument('--dtype', type=str, default='bf16', help='bf16 or fp16')
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
device = args.device
|
| 27 |
+
assert device in ['cuda', 'mps']
|
| 28 |
+
if args.dtype == 'bf16':
|
| 29 |
+
if device == 'mps':
|
| 30 |
+
print('Warning: MPS does not support bf16, will use fp16 instead')
|
| 31 |
+
dtype = torch.float16
|
| 32 |
+
else:
|
| 33 |
+
dtype = torch.bfloat16
|
| 34 |
+
else:
|
| 35 |
+
dtype = torch.float16
|
| 36 |
+
|
| 37 |
+
# Load model
|
| 38 |
+
model_path = 'openbmb/MiniCPM-V-2'
|
| 39 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 41 |
+
|
| 42 |
+
model = model.to(device=device, dtype=dtype)
|
| 43 |
+
model.eval()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
ERROR_MSG = "Error, please retry"
|
| 48 |
+
model_name = 'MiniCPM-V 2.0'
|
| 49 |
+
|
| 50 |
+
form_radio = {
|
| 51 |
+
'choices': ['Beam Search', 'Sampling'],
|
| 52 |
+
#'value': 'Beam Search',
|
| 53 |
+
'value': 'Sampling',
|
| 54 |
+
'interactive': True,
|
| 55 |
+
'label': 'Decode Type'
|
| 56 |
+
}
|
| 57 |
+
# Beam Form
|
| 58 |
+
num_beams_slider = {
|
| 59 |
+
'minimum': 0,
|
| 60 |
+
'maximum': 5,
|
| 61 |
+
'value': 3,
|
| 62 |
+
'step': 1,
|
| 63 |
+
'interactive': True,
|
| 64 |
+
'label': 'Num Beams'
|
| 65 |
+
}
|
| 66 |
+
repetition_penalty_slider = {
|
| 67 |
+
'minimum': 0,
|
| 68 |
+
'maximum': 3,
|
| 69 |
+
'value': 1.2,
|
| 70 |
+
'step': 0.01,
|
| 71 |
+
'interactive': True,
|
| 72 |
+
'label': 'Repetition Penalty'
|
| 73 |
+
}
|
| 74 |
+
repetition_penalty_slider2 = {
|
| 75 |
+
'minimum': 0,
|
| 76 |
+
'maximum': 3,
|
| 77 |
+
'value': 1.05,
|
| 78 |
+
'step': 0.01,
|
| 79 |
+
'interactive': True,
|
| 80 |
+
'label': 'Repetition Penalty'
|
| 81 |
+
}
|
| 82 |
+
max_new_tokens_slider = {
|
| 83 |
+
'minimum': 1,
|
| 84 |
+
'maximum': 4096,
|
| 85 |
+
'value': 1024,
|
| 86 |
+
'step': 1,
|
| 87 |
+
'interactive': True,
|
| 88 |
+
'label': 'Max New Tokens'
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
top_p_slider = {
|
| 92 |
+
'minimum': 0,
|
| 93 |
+
'maximum': 1,
|
| 94 |
+
'value': 0.8,
|
| 95 |
+
'step': 0.05,
|
| 96 |
+
'interactive': True,
|
| 97 |
+
'label': 'Top P'
|
| 98 |
+
}
|
| 99 |
+
top_k_slider = {
|
| 100 |
+
'minimum': 0,
|
| 101 |
+
'maximum': 200,
|
| 102 |
+
'value': 100,
|
| 103 |
+
'step': 1,
|
| 104 |
+
'interactive': True,
|
| 105 |
+
'label': 'Top K'
|
| 106 |
+
}
|
| 107 |
+
temperature_slider = {
|
| 108 |
+
'minimum': 0,
|
| 109 |
+
'maximum': 2,
|
| 110 |
+
'value': 0.7,
|
| 111 |
+
'step': 0.05,
|
| 112 |
+
'interactive': True,
|
| 113 |
+
'label': 'Temperature'
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def create_component(params, comp='Slider'):
|
| 118 |
+
if comp == 'Slider':
|
| 119 |
+
return gr.Slider(
|
| 120 |
+
minimum=params['minimum'],
|
| 121 |
+
maximum=params['maximum'],
|
| 122 |
+
value=params['value'],
|
| 123 |
+
step=params['step'],
|
| 124 |
+
interactive=params['interactive'],
|
| 125 |
+
label=params['label']
|
| 126 |
+
)
|
| 127 |
+
elif comp == 'Radio':
|
| 128 |
+
return gr.Radio(
|
| 129 |
+
choices=params['choices'],
|
| 130 |
+
value=params['value'],
|
| 131 |
+
interactive=params['interactive'],
|
| 132 |
+
label=params['label']
|
| 133 |
+
)
|
| 134 |
+
elif comp == 'Button':
|
| 135 |
+
return gr.Button(
|
| 136 |
+
value=params['value'],
|
| 137 |
+
interactive=True
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
| 142 |
+
default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
|
| 143 |
+
if params is None:
|
| 144 |
+
params = default_params
|
| 145 |
+
if img is None:
|
| 146 |
+
return -1, "Error, invalid image, please upload a new image", None, None
|
| 147 |
+
try:
|
| 148 |
+
image = img.convert('RGB')
|
| 149 |
+
answer, context, _ = model.chat(
|
| 150 |
+
image=image,
|
| 151 |
+
msgs=msgs,
|
| 152 |
+
context=None,
|
| 153 |
+
tokenizer=tokenizer,
|
| 154 |
+
**params
|
| 155 |
+
)
|
| 156 |
+
res = re.sub(r'(<box>.*</box>)', '', answer)
|
| 157 |
+
res = res.replace('<ref>', '')
|
| 158 |
+
res = res.replace('</ref>', '')
|
| 159 |
+
res = res.replace('<box>', '')
|
| 160 |
+
answer = res.replace('</box>', '')
|
| 161 |
+
return 0, answer, None, None
|
| 162 |
+
except Exception as err:
|
| 163 |
+
print(err)
|
| 164 |
+
traceback.print_exc()
|
| 165 |
+
return -1, ERROR_MSG, None, None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def upload_img(image, _chatbot, _app_session):
|
| 169 |
+
image = Image.fromarray(image)
|
| 170 |
+
|
| 171 |
+
_app_session['sts']=None
|
| 172 |
+
_app_session['ctx']=[]
|
| 173 |
+
_app_session['img']=image
|
| 174 |
+
_chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
|
| 175 |
+
return _chatbot, _app_session
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
|
| 179 |
+
if _app_cfg.get('ctx', None) is None:
|
| 180 |
+
_chat_bot.append((_question, 'Please upload an image to start'))
|
| 181 |
+
return '', _chat_bot, _app_cfg
|
| 182 |
+
|
| 183 |
+
_context = _app_cfg['ctx'].copy()
|
| 184 |
+
if _context:
|
| 185 |
+
_context.append({"role": "user", "content": _question})
|
| 186 |
+
else:
|
| 187 |
+
_context = [{"role": "user", "content": _question}]
|
| 188 |
+
print('<User>:', _question)
|
| 189 |
+
|
| 190 |
+
if params_form == 'Beam Search':
|
| 191 |
+
params = {
|
| 192 |
+
'sampling': False,
|
| 193 |
+
'num_beams': num_beams,
|
| 194 |
+
'repetition_penalty': repetition_penalty,
|
| 195 |
+
"max_new_tokens": 896
|
| 196 |
+
}
|
| 197 |
+
else:
|
| 198 |
+
params = {
|
| 199 |
+
'sampling': True,
|
| 200 |
+
'top_p': top_p,
|
| 201 |
+
'top_k': top_k,
|
| 202 |
+
'temperature': temperature,
|
| 203 |
+
'repetition_penalty': repetition_penalty_2,
|
| 204 |
+
"max_new_tokens": 896
|
| 205 |
+
}
|
| 206 |
+
code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
|
| 207 |
+
print('<Assistant>:', _answer)
|
| 208 |
+
|
| 209 |
+
_context.append({"role": "assistant", "content": _answer})
|
| 210 |
+
_chat_bot.append((_question, _answer))
|
| 211 |
+
if code == 0:
|
| 212 |
+
_app_cfg['ctx']=_context
|
| 213 |
+
_app_cfg['sts']=sts
|
| 214 |
+
return '', _chat_bot, _app_cfg
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
|
| 218 |
+
if len(_chat_bot) <= 1:
|
| 219 |
+
_chat_bot.append(('Regenerate', 'No question for regeneration.'))
|
| 220 |
+
return '', _chat_bot, _app_cfg
|
| 221 |
+
elif _chat_bot[-1][0] == 'Regenerate':
|
| 222 |
+
return '', _chat_bot, _app_cfg
|
| 223 |
+
else:
|
| 224 |
+
_question = _chat_bot[-1][0]
|
| 225 |
+
_chat_bot = _chat_bot[:-1]
|
| 226 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
| 227 |
+
return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
with gr.Blocks() as demo:
|
| 232 |
+
with gr.Row():
|
| 233 |
+
with gr.Column(scale=1, min_width=300):
|
| 234 |
+
params_form = create_component(form_radio, comp='Radio')
|
| 235 |
+
with gr.Accordion("Beam Search") as beams_according:
|
| 236 |
+
num_beams = create_component(num_beams_slider)
|
| 237 |
+
repetition_penalty = create_component(repetition_penalty_slider)
|
| 238 |
+
with gr.Accordion("Sampling") as sampling_according:
|
| 239 |
+
top_p = create_component(top_p_slider)
|
| 240 |
+
top_k = create_component(top_k_slider)
|
| 241 |
+
temperature = create_component(temperature_slider)
|
| 242 |
+
repetition_penalty_2 = create_component(repetition_penalty_slider2)
|
| 243 |
+
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
| 244 |
+
with gr.Column(scale=3, min_width=500):
|
| 245 |
+
app_session = gr.State({'sts':None,'ctx':None,'img':None})
|
| 246 |
+
bt_pic = gr.Image(label="Upload an image to start")
|
| 247 |
+
chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
|
| 248 |
+
txt_message = gr.Textbox(label="Input text")
|
| 249 |
+
|
| 250 |
+
regenerate.click(
|
| 251 |
+
regenerate_button_clicked,
|
| 252 |
+
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
|
| 253 |
+
[txt_message, chat_bot, app_session]
|
| 254 |
+
)
|
| 255 |
+
txt_message.submit(
|
| 256 |
+
respond,
|
| 257 |
+
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
|
| 258 |
+
[txt_message, chat_bot, app_session]
|
| 259 |
+
)
|
| 260 |
+
bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
|
| 261 |
+
|
| 262 |
+
# launch
|
| 263 |
+
demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
|
| 264 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.5.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import traceback
|
| 6 |
+
import re
|
| 7 |
+
import torch
|
| 8 |
+
import argparse
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
# README, How to run demo on different devices
|
| 12 |
+
|
| 13 |
+
# For Nvidia GPUs.
|
| 14 |
+
# python web_demo_2.5.py --device cuda
|
| 15 |
+
|
| 16 |
+
# For Mac with MPS (Apple silicon or AMD GPUs).
|
| 17 |
+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
|
| 18 |
+
|
| 19 |
+
# Argparser
|
| 20 |
+
parser = argparse.ArgumentParser(description='demo')
|
| 21 |
+
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
device = args.device
|
| 24 |
+
assert device in ['cuda', 'mps']
|
| 25 |
+
|
| 26 |
+
# Load model
|
| 27 |
+
model_path = 'openbmb/MiniCPM-Llama3-V-2_5'
|
| 28 |
+
if 'int4' in model_path:
|
| 29 |
+
if device == 'mps':
|
| 30 |
+
print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
|
| 31 |
+
exit()
|
| 32 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
| 33 |
+
else:
|
| 34 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
|
| 35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
ERROR_MSG = "Error, please retry"
|
| 41 |
+
model_name = 'MiniCPM-V 2.5'
|
| 42 |
+
|
| 43 |
+
form_radio = {
|
| 44 |
+
'choices': ['Beam Search', 'Sampling'],
|
| 45 |
+
#'value': 'Beam Search',
|
| 46 |
+
'value': 'Sampling',
|
| 47 |
+
'interactive': True,
|
| 48 |
+
'label': 'Decode Type'
|
| 49 |
+
}
|
| 50 |
+
# Beam Form
|
| 51 |
+
num_beams_slider = {
|
| 52 |
+
'minimum': 0,
|
| 53 |
+
'maximum': 5,
|
| 54 |
+
'value': 3,
|
| 55 |
+
'step': 1,
|
| 56 |
+
'interactive': True,
|
| 57 |
+
'label': 'Num Beams'
|
| 58 |
+
}
|
| 59 |
+
repetition_penalty_slider = {
|
| 60 |
+
'minimum': 0,
|
| 61 |
+
'maximum': 3,
|
| 62 |
+
'value': 1.2,
|
| 63 |
+
'step': 0.01,
|
| 64 |
+
'interactive': True,
|
| 65 |
+
'label': 'Repetition Penalty'
|
| 66 |
+
}
|
| 67 |
+
repetition_penalty_slider2 = {
|
| 68 |
+
'minimum': 0,
|
| 69 |
+
'maximum': 3,
|
| 70 |
+
'value': 1.05,
|
| 71 |
+
'step': 0.01,
|
| 72 |
+
'interactive': True,
|
| 73 |
+
'label': 'Repetition Penalty'
|
| 74 |
+
}
|
| 75 |
+
max_new_tokens_slider = {
|
| 76 |
+
'minimum': 1,
|
| 77 |
+
'maximum': 4096,
|
| 78 |
+
'value': 1024,
|
| 79 |
+
'step': 1,
|
| 80 |
+
'interactive': True,
|
| 81 |
+
'label': 'Max New Tokens'
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
top_p_slider = {
|
| 85 |
+
'minimum': 0,
|
| 86 |
+
'maximum': 1,
|
| 87 |
+
'value': 0.8,
|
| 88 |
+
'step': 0.05,
|
| 89 |
+
'interactive': True,
|
| 90 |
+
'label': 'Top P'
|
| 91 |
+
}
|
| 92 |
+
top_k_slider = {
|
| 93 |
+
'minimum': 0,
|
| 94 |
+
'maximum': 200,
|
| 95 |
+
'value': 100,
|
| 96 |
+
'step': 1,
|
| 97 |
+
'interactive': True,
|
| 98 |
+
'label': 'Top K'
|
| 99 |
+
}
|
| 100 |
+
temperature_slider = {
|
| 101 |
+
'minimum': 0,
|
| 102 |
+
'maximum': 2,
|
| 103 |
+
'value': 0.7,
|
| 104 |
+
'step': 0.05,
|
| 105 |
+
'interactive': True,
|
| 106 |
+
'label': 'Temperature'
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def create_component(params, comp='Slider'):
|
| 111 |
+
if comp == 'Slider':
|
| 112 |
+
return gr.Slider(
|
| 113 |
+
minimum=params['minimum'],
|
| 114 |
+
maximum=params['maximum'],
|
| 115 |
+
value=params['value'],
|
| 116 |
+
step=params['step'],
|
| 117 |
+
interactive=params['interactive'],
|
| 118 |
+
label=params['label']
|
| 119 |
+
)
|
| 120 |
+
elif comp == 'Radio':
|
| 121 |
+
return gr.Radio(
|
| 122 |
+
choices=params['choices'],
|
| 123 |
+
value=params['value'],
|
| 124 |
+
interactive=params['interactive'],
|
| 125 |
+
label=params['label']
|
| 126 |
+
)
|
| 127 |
+
elif comp == 'Button':
|
| 128 |
+
return gr.Button(
|
| 129 |
+
value=params['value'],
|
| 130 |
+
interactive=True
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
| 135 |
+
default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
|
| 136 |
+
if params is None:
|
| 137 |
+
params = default_params
|
| 138 |
+
if img is None:
|
| 139 |
+
return -1, "Error, invalid image, please upload a new image", None, None
|
| 140 |
+
try:
|
| 141 |
+
image = img.convert('RGB')
|
| 142 |
+
answer = model.chat(
|
| 143 |
+
image=image,
|
| 144 |
+
msgs=msgs,
|
| 145 |
+
tokenizer=tokenizer,
|
| 146 |
+
**params
|
| 147 |
+
)
|
| 148 |
+
res = re.sub(r'(<box>.*</box>)', '', answer)
|
| 149 |
+
res = res.replace('<ref>', '')
|
| 150 |
+
res = res.replace('</ref>', '')
|
| 151 |
+
res = res.replace('<box>', '')
|
| 152 |
+
answer = res.replace('</box>', '')
|
| 153 |
+
return 0, answer, None, None
|
| 154 |
+
except Exception as err:
|
| 155 |
+
print(err)
|
| 156 |
+
traceback.print_exc()
|
| 157 |
+
return -1, ERROR_MSG, None, None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def upload_img(image, _chatbot, _app_session):
|
| 161 |
+
image = Image.fromarray(image)
|
| 162 |
+
|
| 163 |
+
_app_session['sts']=None
|
| 164 |
+
_app_session['ctx']=[]
|
| 165 |
+
_app_session['img']=image
|
| 166 |
+
_chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
|
| 167 |
+
return _chatbot, _app_session
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
|
| 171 |
+
if _app_cfg.get('ctx', None) is None:
|
| 172 |
+
_chat_bot.append((_question, 'Please upload an image to start'))
|
| 173 |
+
return '', _chat_bot, _app_cfg
|
| 174 |
+
|
| 175 |
+
_context = _app_cfg['ctx'].copy()
|
| 176 |
+
if _context:
|
| 177 |
+
_context.append({"role": "user", "content": _question})
|
| 178 |
+
else:
|
| 179 |
+
_context = [{"role": "user", "content": _question}]
|
| 180 |
+
print('<User>:', _question)
|
| 181 |
+
|
| 182 |
+
if params_form == 'Beam Search':
|
| 183 |
+
params = {
|
| 184 |
+
'sampling': False,
|
| 185 |
+
'num_beams': num_beams,
|
| 186 |
+
'repetition_penalty': repetition_penalty,
|
| 187 |
+
"max_new_tokens": 896
|
| 188 |
+
}
|
| 189 |
+
else:
|
| 190 |
+
params = {
|
| 191 |
+
'sampling': True,
|
| 192 |
+
'top_p': top_p,
|
| 193 |
+
'top_k': top_k,
|
| 194 |
+
'temperature': temperature,
|
| 195 |
+
'repetition_penalty': repetition_penalty_2,
|
| 196 |
+
"max_new_tokens": 896
|
| 197 |
+
}
|
| 198 |
+
code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
|
| 199 |
+
print('<Assistant>:', _answer)
|
| 200 |
+
|
| 201 |
+
_context.append({"role": "assistant", "content": _answer})
|
| 202 |
+
_chat_bot.append((_question, _answer))
|
| 203 |
+
if code == 0:
|
| 204 |
+
_app_cfg['ctx']=_context
|
| 205 |
+
_app_cfg['sts']=sts
|
| 206 |
+
return '', _chat_bot, _app_cfg
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
|
| 210 |
+
if len(_chat_bot) <= 1:
|
| 211 |
+
_chat_bot.append(('Regenerate', 'No question for regeneration.'))
|
| 212 |
+
return '', _chat_bot, _app_cfg
|
| 213 |
+
elif _chat_bot[-1][0] == 'Regenerate':
|
| 214 |
+
return '', _chat_bot, _app_cfg
|
| 215 |
+
else:
|
| 216 |
+
_question = _chat_bot[-1][0]
|
| 217 |
+
_chat_bot = _chat_bot[:-1]
|
| 218 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
| 219 |
+
return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
with gr.Blocks() as demo:
|
| 224 |
+
with gr.Row():
|
| 225 |
+
with gr.Column(scale=1, min_width=300):
|
| 226 |
+
params_form = create_component(form_radio, comp='Radio')
|
| 227 |
+
with gr.Accordion("Beam Search") as beams_according:
|
| 228 |
+
num_beams = create_component(num_beams_slider)
|
| 229 |
+
repetition_penalty = create_component(repetition_penalty_slider)
|
| 230 |
+
with gr.Accordion("Sampling") as sampling_according:
|
| 231 |
+
top_p = create_component(top_p_slider)
|
| 232 |
+
top_k = create_component(top_k_slider)
|
| 233 |
+
temperature = create_component(temperature_slider)
|
| 234 |
+
repetition_penalty_2 = create_component(repetition_penalty_slider2)
|
| 235 |
+
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
| 236 |
+
with gr.Column(scale=3, min_width=500):
|
| 237 |
+
app_session = gr.State({'sts':None,'ctx':None,'img':None})
|
| 238 |
+
bt_pic = gr.Image(label="Upload an image to start")
|
| 239 |
+
chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
|
| 240 |
+
txt_message = gr.Textbox(label="Input text")
|
| 241 |
+
|
| 242 |
+
regenerate.click(
|
| 243 |
+
regenerate_button_clicked,
|
| 244 |
+
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
|
| 245 |
+
[txt_message, chat_bot, app_session]
|
| 246 |
+
)
|
| 247 |
+
txt_message.submit(
|
| 248 |
+
respond,
|
| 249 |
+
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
|
| 250 |
+
[txt_message, chat_bot, app_session]
|
| 251 |
+
)
|
| 252 |
+
bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
|
| 253 |
+
|
| 254 |
+
# launch
|
| 255 |
+
demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
|
| 256 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_2.6.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from decord import VideoReader, cpu
|
| 9 |
+
import io
|
| 10 |
+
import os
|
| 11 |
+
import copy
|
| 12 |
+
import requests
|
| 13 |
+
import base64
|
| 14 |
+
import json
|
| 15 |
+
import traceback
|
| 16 |
+
import re
|
| 17 |
+
import modelscope_studio as mgr
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# README, How to run demo on different devices
|
| 21 |
+
|
| 22 |
+
# For Nvidia GPUs.
|
| 23 |
+
# python web_demo_2.6.py --device cuda
|
| 24 |
+
|
| 25 |
+
# For Mac with MPS (Apple silicon or AMD GPUs).
|
| 26 |
+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.6.py --device mps
|
| 27 |
+
|
| 28 |
+
# Argparser
|
| 29 |
+
parser = argparse.ArgumentParser(description='demo')
|
| 30 |
+
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
| 31 |
+
parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
device = args.device
|
| 34 |
+
assert device in ['cuda', 'mps']
|
| 35 |
+
|
| 36 |
+
# Load model
|
| 37 |
+
model_path = 'openbmb/MiniCPM-V-2_6'
|
| 38 |
+
if 'int4' in model_path:
|
| 39 |
+
if device == 'mps':
|
| 40 |
+
print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
|
| 41 |
+
exit()
|
| 42 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
| 43 |
+
else:
|
| 44 |
+
if args.multi_gpus:
|
| 45 |
+
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
|
| 46 |
+
with init_empty_weights():
|
| 47 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
| 48 |
+
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
|
| 49 |
+
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
|
| 50 |
+
device_id = device_map["llm.model.embed_tokens"]
|
| 51 |
+
device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
|
| 52 |
+
device_map["vpm"] = device_id
|
| 53 |
+
device_map["resampler"] = device_id
|
| 54 |
+
device_id2 = device_map["llm.model.layers.26"]
|
| 55 |
+
device_map["llm.model.layers.8"] = device_id2
|
| 56 |
+
device_map["llm.model.layers.9"] = device_id2
|
| 57 |
+
device_map["llm.model.layers.10"] = device_id2
|
| 58 |
+
device_map["llm.model.layers.11"] = device_id2
|
| 59 |
+
device_map["llm.model.layers.12"] = device_id2
|
| 60 |
+
device_map["llm.model.layers.13"] = device_id2
|
| 61 |
+
device_map["llm.model.layers.14"] = device_id2
|
| 62 |
+
device_map["llm.model.layers.15"] = device_id2
|
| 63 |
+
device_map["llm.model.layers.16"] = device_id2
|
| 64 |
+
#print(device_map)
|
| 65 |
+
|
| 66 |
+
model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
|
| 67 |
+
else:
|
| 68 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
|
| 69 |
+
model = model.to(device=device)
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 71 |
+
model.eval()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
ERROR_MSG = "Error, please retry"
|
| 77 |
+
model_name = 'MiniCPM-V 2.6'
|
| 78 |
+
MAX_NUM_FRAMES = 64
|
| 79 |
+
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
| 80 |
+
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
|
| 81 |
+
|
| 82 |
+
def get_file_extension(filename):
|
| 83 |
+
return os.path.splitext(filename)[1].lower()
|
| 84 |
+
|
| 85 |
+
def is_image(filename):
|
| 86 |
+
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
| 87 |
+
|
| 88 |
+
def is_video(filename):
|
| 89 |
+
return get_file_extension(filename) in VIDEO_EXTENSIONS
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
form_radio = {
|
| 93 |
+
'choices': ['Beam Search', 'Sampling'],
|
| 94 |
+
#'value': 'Beam Search',
|
| 95 |
+
'value': 'Sampling',
|
| 96 |
+
'interactive': True,
|
| 97 |
+
'label': 'Decode Type'
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def create_component(params, comp='Slider'):
|
| 102 |
+
if comp == 'Slider':
|
| 103 |
+
return gr.Slider(
|
| 104 |
+
minimum=params['minimum'],
|
| 105 |
+
maximum=params['maximum'],
|
| 106 |
+
value=params['value'],
|
| 107 |
+
step=params['step'],
|
| 108 |
+
interactive=params['interactive'],
|
| 109 |
+
label=params['label']
|
| 110 |
+
)
|
| 111 |
+
elif comp == 'Radio':
|
| 112 |
+
return gr.Radio(
|
| 113 |
+
choices=params['choices'],
|
| 114 |
+
value=params['value'],
|
| 115 |
+
interactive=params['interactive'],
|
| 116 |
+
label=params['label']
|
| 117 |
+
)
|
| 118 |
+
elif comp == 'Button':
|
| 119 |
+
return gr.Button(
|
| 120 |
+
value=params['value'],
|
| 121 |
+
interactive=True
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
|
| 126 |
+
return mgr.MultimodalInput(upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
|
| 127 |
+
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
|
| 128 |
+
submit_button_props={'label': 'Submit'})
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
| 132 |
+
try:
|
| 133 |
+
print('msgs:', msgs)
|
| 134 |
+
answer = model.chat(
|
| 135 |
+
image=None,
|
| 136 |
+
msgs=msgs,
|
| 137 |
+
tokenizer=tokenizer,
|
| 138 |
+
**params
|
| 139 |
+
)
|
| 140 |
+
res = re.sub(r'(<box>.*</box>)', '', answer)
|
| 141 |
+
res = res.replace('<ref>', '')
|
| 142 |
+
res = res.replace('</ref>', '')
|
| 143 |
+
res = res.replace('<box>', '')
|
| 144 |
+
answer = res.replace('</box>', '')
|
| 145 |
+
print('answer:', answer)
|
| 146 |
+
return 0, answer, None, None
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(e)
|
| 149 |
+
traceback.print_exc()
|
| 150 |
+
return -1, ERROR_MSG, None, None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def encode_image(image):
|
| 154 |
+
if not isinstance(image, Image.Image):
|
| 155 |
+
if hasattr(image, 'path'):
|
| 156 |
+
image = Image.open(image.path).convert("RGB")
|
| 157 |
+
else:
|
| 158 |
+
image = Image.open(image.file.path).convert("RGB")
|
| 159 |
+
# resize to max_size
|
| 160 |
+
max_size = 448*16
|
| 161 |
+
if max(image.size) > max_size:
|
| 162 |
+
w,h = image.size
|
| 163 |
+
if w > h:
|
| 164 |
+
new_w = max_size
|
| 165 |
+
new_h = int(h * max_size / w)
|
| 166 |
+
else:
|
| 167 |
+
new_h = max_size
|
| 168 |
+
new_w = int(w * max_size / h)
|
| 169 |
+
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
|
| 170 |
+
return image
|
| 171 |
+
## save by BytesIO and convert to base64
|
| 172 |
+
#buffered = io.BytesIO()
|
| 173 |
+
#image.save(buffered, format="png")
|
| 174 |
+
#im_b64 = base64.b64encode(buffered.getvalue()).decode()
|
| 175 |
+
#return {"type": "image", "pairs": im_b64}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def encode_video(video):
|
| 179 |
+
def uniform_sample(l, n):
|
| 180 |
+
gap = len(l) / n
|
| 181 |
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
| 182 |
+
return [l[i] for i in idxs]
|
| 183 |
+
|
| 184 |
+
if hasattr(video, 'path'):
|
| 185 |
+
vr = VideoReader(video.path, ctx=cpu(0))
|
| 186 |
+
else:
|
| 187 |
+
vr = VideoReader(video.file.path, ctx=cpu(0))
|
| 188 |
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
| 189 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
| 190 |
+
if len(frame_idx)>MAX_NUM_FRAMES:
|
| 191 |
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
| 192 |
+
video = vr.get_batch(frame_idx).asnumpy()
|
| 193 |
+
video = [Image.fromarray(v.astype('uint8')) for v in video]
|
| 194 |
+
video = [encode_image(v) for v in video]
|
| 195 |
+
print('video frames:', len(video))
|
| 196 |
+
return video
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def check_mm_type(mm_file):
|
| 200 |
+
if hasattr(mm_file, 'path'):
|
| 201 |
+
path = mm_file.path
|
| 202 |
+
else:
|
| 203 |
+
path = mm_file.file.path
|
| 204 |
+
if is_image(path):
|
| 205 |
+
return "image"
|
| 206 |
+
if is_video(path):
|
| 207 |
+
return "video"
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def encode_mm_file(mm_file):
|
| 212 |
+
if check_mm_type(mm_file) == 'image':
|
| 213 |
+
return [encode_image(mm_file)]
|
| 214 |
+
if check_mm_type(mm_file) == 'video':
|
| 215 |
+
return encode_video(mm_file)
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
def make_text(text):
|
| 219 |
+
#return {"type": "text", "pairs": text} # # For remote call
|
| 220 |
+
return text
|
| 221 |
+
|
| 222 |
+
def encode_message(_question):
|
| 223 |
+
files = _question.files
|
| 224 |
+
question = _question.text
|
| 225 |
+
pattern = r"\[mm_media\]\d+\[/mm_media\]"
|
| 226 |
+
matches = re.split(pattern, question)
|
| 227 |
+
message = []
|
| 228 |
+
if len(matches) != len(files) + 1:
|
| 229 |
+
gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
|
| 230 |
+
assert len(matches) == len(files) + 1
|
| 231 |
+
|
| 232 |
+
text = matches[0].strip()
|
| 233 |
+
if text:
|
| 234 |
+
message.append(make_text(text))
|
| 235 |
+
for i in range(len(files)):
|
| 236 |
+
message += encode_mm_file(files[i])
|
| 237 |
+
text = matches[i + 1].strip()
|
| 238 |
+
if text:
|
| 239 |
+
message.append(make_text(text))
|
| 240 |
+
return message
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def check_has_videos(_question):
|
| 244 |
+
images_cnt = 0
|
| 245 |
+
videos_cnt = 0
|
| 246 |
+
for file in _question.files:
|
| 247 |
+
if check_mm_type(file) == "image":
|
| 248 |
+
images_cnt += 1
|
| 249 |
+
else:
|
| 250 |
+
videos_cnt += 1
|
| 251 |
+
return images_cnt, videos_cnt
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def count_video_frames(_context):
|
| 255 |
+
num_frames = 0
|
| 256 |
+
for message in _context:
|
| 257 |
+
for item in message["content"]:
|
| 258 |
+
#if item["type"] == "image": # For remote call
|
| 259 |
+
if isinstance(item, Image.Image):
|
| 260 |
+
num_frames += 1
|
| 261 |
+
return num_frames
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def respond(_question, _chat_bot, _app_cfg, params_form):
|
| 265 |
+
_context = _app_cfg['ctx'].copy()
|
| 266 |
+
_context.append({'role': 'user', 'content': encode_message(_question)})
|
| 267 |
+
|
| 268 |
+
images_cnt = _app_cfg['images_cnt']
|
| 269 |
+
videos_cnt = _app_cfg['videos_cnt']
|
| 270 |
+
files_cnts = check_has_videos(_question)
|
| 271 |
+
if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
|
| 272 |
+
gr.Warning("Only supports single video file input right now!")
|
| 273 |
+
return _question, _chat_bot, _app_cfg
|
| 274 |
+
|
| 275 |
+
if params_form == 'Beam Search':
|
| 276 |
+
params = {
|
| 277 |
+
'sampling': False,
|
| 278 |
+
'num_beams': 3,
|
| 279 |
+
'repetition_penalty': 1.2,
|
| 280 |
+
"max_new_tokens": 2048
|
| 281 |
+
}
|
| 282 |
+
else:
|
| 283 |
+
params = {
|
| 284 |
+
'sampling': True,
|
| 285 |
+
'top_p': 0.8,
|
| 286 |
+
'top_k': 100,
|
| 287 |
+
'temperature': 0.7,
|
| 288 |
+
'repetition_penalty': 1.05,
|
| 289 |
+
"max_new_tokens": 2048
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if files_cnts[1] + videos_cnt > 0:
|
| 293 |
+
params["max_inp_length"] = 4352 # 4096+256
|
| 294 |
+
params["use_image_id"] = False
|
| 295 |
+
params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
|
| 296 |
+
|
| 297 |
+
code, _answer, _, sts = chat("", _context, None, params)
|
| 298 |
+
|
| 299 |
+
images_cnt += files_cnts[0]
|
| 300 |
+
videos_cnt += files_cnts[1]
|
| 301 |
+
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
| 302 |
+
_chat_bot.append((_question, _answer))
|
| 303 |
+
if code == 0:
|
| 304 |
+
_app_cfg['ctx']=_context
|
| 305 |
+
_app_cfg['sts']=sts
|
| 306 |
+
_app_cfg['images_cnt'] = images_cnt
|
| 307 |
+
_app_cfg['videos_cnt'] = videos_cnt
|
| 308 |
+
|
| 309 |
+
upload_image_disabled = videos_cnt > 0
|
| 310 |
+
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
| 311 |
+
return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
|
| 315 |
+
ctx = _app_cfg["ctx"]
|
| 316 |
+
message_item = []
|
| 317 |
+
if _image is not None:
|
| 318 |
+
image = Image.open(_image).convert("RGB")
|
| 319 |
+
ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
|
| 320 |
+
message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
|
| 321 |
+
else:
|
| 322 |
+
if _user_message:
|
| 323 |
+
ctx.append({"role": "user", "content": [make_text(_user_message)]})
|
| 324 |
+
message_item.append({"text": _user_message, "files": []})
|
| 325 |
+
else:
|
| 326 |
+
message_item.append(None)
|
| 327 |
+
if _assistant_message:
|
| 328 |
+
ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
|
| 329 |
+
message_item.append({"text": _assistant_message, "files": []})
|
| 330 |
+
else:
|
| 331 |
+
message_item.append(None)
|
| 332 |
+
|
| 333 |
+
_chat_bot.append(message_item)
|
| 334 |
+
return None, "", "", _chat_bot, _app_cfg
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
|
| 338 |
+
user_message_contents = []
|
| 339 |
+
_context = _app_cfg["ctx"].copy()
|
| 340 |
+
if _image:
|
| 341 |
+
image = Image.open(_image).convert("RGB")
|
| 342 |
+
user_message_contents += [encode_image(image)]
|
| 343 |
+
if _user_message:
|
| 344 |
+
user_message_contents += [make_text(_user_message)]
|
| 345 |
+
if user_message_contents:
|
| 346 |
+
_context.append({"role": "user", "content": user_message_contents})
|
| 347 |
+
|
| 348 |
+
if params_form == 'Beam Search':
|
| 349 |
+
params = {
|
| 350 |
+
'sampling': False,
|
| 351 |
+
'num_beams': 3,
|
| 352 |
+
'repetition_penalty': 1.2,
|
| 353 |
+
"max_new_tokens": 2048
|
| 354 |
+
}
|
| 355 |
+
else:
|
| 356 |
+
params = {
|
| 357 |
+
'sampling': True,
|
| 358 |
+
'top_p': 0.8,
|
| 359 |
+
'top_k': 100,
|
| 360 |
+
'temperature': 0.7,
|
| 361 |
+
'repetition_penalty': 1.05,
|
| 362 |
+
"max_new_tokens": 2048
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
code, _answer, _, sts = chat("", _context, None, params)
|
| 366 |
+
|
| 367 |
+
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
| 368 |
+
|
| 369 |
+
if _image:
|
| 370 |
+
_chat_bot.append([
|
| 371 |
+
{"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
|
| 372 |
+
{"text": _answer, "files": []}
|
| 373 |
+
])
|
| 374 |
+
else:
|
| 375 |
+
_chat_bot.append([
|
| 376 |
+
{"text": _user_message, "files": [_image]},
|
| 377 |
+
{"text": _answer, "files": []}
|
| 378 |
+
])
|
| 379 |
+
if code == 0:
|
| 380 |
+
_app_cfg['ctx']=_context
|
| 381 |
+
_app_cfg['sts']=sts
|
| 382 |
+
return None, '', '', _chat_bot, _app_cfg
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
|
| 386 |
+
if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
|
| 387 |
+
gr.Warning('No question for regeneration.')
|
| 388 |
+
return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
| 389 |
+
if _app_cfg["chat_type"] == "Chat":
|
| 390 |
+
images_cnt = _app_cfg['images_cnt']
|
| 391 |
+
videos_cnt = _app_cfg['videos_cnt']
|
| 392 |
+
_question = _chat_bot[-1][0]
|
| 393 |
+
_chat_bot = _chat_bot[:-1]
|
| 394 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
| 395 |
+
files_cnts = check_has_videos(_question)
|
| 396 |
+
images_cnt -= files_cnts[0]
|
| 397 |
+
videos_cnt -= files_cnts[1]
|
| 398 |
+
_app_cfg['images_cnt'] = images_cnt
|
| 399 |
+
_app_cfg['videos_cnt'] = videos_cnt
|
| 400 |
+
upload_image_disabled = videos_cnt > 0
|
| 401 |
+
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
| 402 |
+
_question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
|
| 403 |
+
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
| 404 |
+
else:
|
| 405 |
+
last_message = _chat_bot[-1][0]
|
| 406 |
+
last_image = None
|
| 407 |
+
last_user_message = ''
|
| 408 |
+
if last_message.text:
|
| 409 |
+
last_user_message = last_message.text
|
| 410 |
+
if last_message.files:
|
| 411 |
+
last_image = last_message.files[0].file.path
|
| 412 |
+
_chat_bot = _chat_bot[:-1]
|
| 413 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
| 414 |
+
_image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
|
| 415 |
+
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def flushed():
|
| 419 |
+
return gr.update(interactive=True)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def clear(txt_message, chat_bot, app_session):
|
| 423 |
+
txt_message.files.clear()
|
| 424 |
+
txt_message.text = ''
|
| 425 |
+
chat_bot = copy.deepcopy(init_conversation)
|
| 426 |
+
app_session['sts'] = None
|
| 427 |
+
app_session['ctx'] = []
|
| 428 |
+
app_session['images_cnt'] = 0
|
| 429 |
+
app_session['videos_cnt'] = 0
|
| 430 |
+
return create_multimodal_input(), chat_bot, app_session, None, '', ''
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def select_chat_type(_tab, _app_cfg):
|
| 434 |
+
_app_cfg["chat_type"] = _tab
|
| 435 |
+
return _app_cfg
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
init_conversation = [
|
| 439 |
+
[
|
| 440 |
+
None,
|
| 441 |
+
{
|
| 442 |
+
# The first message of bot closes the typewriter.
|
| 443 |
+
"text": "You can talk to me now",
|
| 444 |
+
"flushing": False
|
| 445 |
+
}
|
| 446 |
+
],
|
| 447 |
+
]
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
css = """
|
| 451 |
+
video { height: auto !important; }
|
| 452 |
+
.example label { font-size: 16px;}
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
introduction = """
|
| 456 |
+
|
| 457 |
+
## Features:
|
| 458 |
+
1. Chat with single image
|
| 459 |
+
2. Chat with multiple images
|
| 460 |
+
3. Chat with video
|
| 461 |
+
4. In-context few-shot learning
|
| 462 |
+
|
| 463 |
+
Click `How to use` tab to see examples.
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
with gr.Blocks(css=css) as demo:
|
| 468 |
+
with gr.Tab(model_name):
|
| 469 |
+
with gr.Row():
|
| 470 |
+
with gr.Column(scale=1, min_width=300):
|
| 471 |
+
gr.Markdown(value=introduction)
|
| 472 |
+
params_form = create_component(form_radio, comp='Radio')
|
| 473 |
+
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
| 474 |
+
clear_button = create_component({'value': 'Clear History'}, comp='Button')
|
| 475 |
+
|
| 476 |
+
with gr.Column(scale=3, min_width=500):
|
| 477 |
+
app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
|
| 478 |
+
chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
|
| 479 |
+
|
| 480 |
+
with gr.Tab("Chat") as chat_tab:
|
| 481 |
+
txt_message = create_multimodal_input()
|
| 482 |
+
chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
|
| 483 |
+
|
| 484 |
+
txt_message.submit(
|
| 485 |
+
respond,
|
| 486 |
+
[txt_message, chat_bot, app_session, params_form],
|
| 487 |
+
[txt_message, chat_bot, app_session]
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
with gr.Tab("Few Shot") as fewshot_tab:
|
| 491 |
+
fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
|
| 492 |
+
with gr.Row():
|
| 493 |
+
with gr.Column(scale=1):
|
| 494 |
+
image_input = gr.Image(type="filepath", sources=["upload"])
|
| 495 |
+
with gr.Column(scale=3):
|
| 496 |
+
user_message = gr.Textbox(label="User")
|
| 497 |
+
assistant_message = gr.Textbox(label="Assistant")
|
| 498 |
+
with gr.Row():
|
| 499 |
+
add_demonstration_button = gr.Button("Add Example")
|
| 500 |
+
generate_button = gr.Button(value="Generate", variant="primary")
|
| 501 |
+
add_demonstration_button.click(
|
| 502 |
+
fewshot_add_demonstration,
|
| 503 |
+
[image_input, user_message, assistant_message, chat_bot, app_session],
|
| 504 |
+
[image_input, user_message, assistant_message, chat_bot, app_session]
|
| 505 |
+
)
|
| 506 |
+
generate_button.click(
|
| 507 |
+
fewshot_respond,
|
| 508 |
+
[image_input, user_message, chat_bot, app_session, params_form],
|
| 509 |
+
[image_input, user_message, assistant_message, chat_bot, app_session]
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
chat_tab.select(
|
| 513 |
+
select_chat_type,
|
| 514 |
+
[chat_tab_label, app_session],
|
| 515 |
+
[app_session]
|
| 516 |
+
)
|
| 517 |
+
chat_tab.select( # do clear
|
| 518 |
+
clear,
|
| 519 |
+
[txt_message, chat_bot, app_session],
|
| 520 |
+
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
| 521 |
+
)
|
| 522 |
+
fewshot_tab.select(
|
| 523 |
+
select_chat_type,
|
| 524 |
+
[fewshot_tab_label, app_session],
|
| 525 |
+
[app_session]
|
| 526 |
+
)
|
| 527 |
+
fewshot_tab.select( # do clear
|
| 528 |
+
clear,
|
| 529 |
+
[txt_message, chat_bot, app_session],
|
| 530 |
+
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
| 531 |
+
)
|
| 532 |
+
chat_bot.flushed(
|
| 533 |
+
flushed,
|
| 534 |
+
outputs=[txt_message]
|
| 535 |
+
)
|
| 536 |
+
regenerate.click(
|
| 537 |
+
regenerate_button_clicked,
|
| 538 |
+
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
|
| 539 |
+
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
|
| 540 |
+
)
|
| 541 |
+
clear_button.click(
|
| 542 |
+
clear,
|
| 543 |
+
[txt_message, chat_bot, app_session],
|
| 544 |
+
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
with gr.Tab("How to use"):
|
| 548 |
+
with gr.Column():
|
| 549 |
+
with gr.Row():
|
| 550 |
+
image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
|
| 551 |
+
example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
|
| 552 |
+
example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# launch
|
| 556 |
+
demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")
|
| 557 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-2_5.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoModel, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
# Model path
|
| 7 |
+
model_path = "openbmb/MiniCPM-Llama3-V-2_5"
|
| 8 |
+
|
| 9 |
+
# User and assistant names
|
| 10 |
+
U_NAME = "User"
|
| 11 |
+
A_NAME = "Assistant"
|
| 12 |
+
|
| 13 |
+
# Set page configuration
|
| 14 |
+
st.set_page_config(
|
| 15 |
+
page_title="MiniCPM-Llama3-V-2_5 Streamlit",
|
| 16 |
+
page_icon=":robot:",
|
| 17 |
+
layout="wide"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Load model and tokenizer
|
| 22 |
+
@st.cache_resource
|
| 23 |
+
def load_model_and_tokenizer():
|
| 24 |
+
print(f"load_model_and_tokenizer from {model_path}")
|
| 25 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device="cuda")
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 27 |
+
return model, tokenizer
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Initialize session state
|
| 31 |
+
if 'model' not in st.session_state:
|
| 32 |
+
st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
|
| 33 |
+
st.session_state.model.eval()
|
| 34 |
+
print("model and tokenizer had loaded completed!")
|
| 35 |
+
|
| 36 |
+
# Initialize session state
|
| 37 |
+
if 'chat_history' not in st.session_state:
|
| 38 |
+
st.session_state.chat_history = []
|
| 39 |
+
|
| 40 |
+
# Sidebar settings
|
| 41 |
+
sidebar_name = st.sidebar.title("MiniCPM-Llama3-V-2_5 Streamlit")
|
| 42 |
+
max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
|
| 43 |
+
repetition_penalty = st.sidebar.slider("repetition_penalty", 0.0, 2.0, 1.05, step=0.01)
|
| 44 |
+
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
|
| 45 |
+
top_k = st.sidebar.slider("top_k", 0, 100, 100, step=1)
|
| 46 |
+
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
|
| 47 |
+
|
| 48 |
+
# Clear chat history button
|
| 49 |
+
buttonClean = st.sidebar.button("Clear chat history", key="clean")
|
| 50 |
+
if buttonClean:
|
| 51 |
+
st.session_state.chat_history = []
|
| 52 |
+
st.session_state.response = ""
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
torch.cuda.empty_cache()
|
| 55 |
+
st.rerun()
|
| 56 |
+
|
| 57 |
+
# Display chat history
|
| 58 |
+
for i, message in enumerate(st.session_state.chat_history):
|
| 59 |
+
if message["role"] == "user":
|
| 60 |
+
with st.chat_message(name="user", avatar="user"):
|
| 61 |
+
if message["image"] is not None:
|
| 62 |
+
st.image(message["image"], caption='User uploaded image', width=448, use_column_width=False)
|
| 63 |
+
continue
|
| 64 |
+
elif message["content"] is not None:
|
| 65 |
+
st.markdown(message["content"])
|
| 66 |
+
else:
|
| 67 |
+
with st.chat_message(name="model", avatar="assistant"):
|
| 68 |
+
st.markdown(message["content"])
|
| 69 |
+
|
| 70 |
+
# Select mode
|
| 71 |
+
selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
|
| 72 |
+
if selected_mode == "Image":
|
| 73 |
+
# Image mode
|
| 74 |
+
uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"],
|
| 75 |
+
accept_multiple_files=False)
|
| 76 |
+
if uploaded_image is not None:
|
| 77 |
+
st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
|
| 78 |
+
# Add uploaded image to chat history
|
| 79 |
+
st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})
|
| 80 |
+
|
| 81 |
+
# User input box
|
| 82 |
+
user_text = st.chat_input("Enter your question")
|
| 83 |
+
if user_text:
|
| 84 |
+
with st.chat_message(U_NAME, avatar="user"):
|
| 85 |
+
st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
|
| 86 |
+
st.markdown(f"{U_NAME}: {user_text}")
|
| 87 |
+
|
| 88 |
+
# Generate reply using the model
|
| 89 |
+
model = st.session_state.model
|
| 90 |
+
tokenizer = st.session_state.tokenizer
|
| 91 |
+
imagefile = None
|
| 92 |
+
|
| 93 |
+
with st.chat_message(A_NAME, avatar="assistant"):
|
| 94 |
+
# If the previous message contains an image, pass the image to the model
|
| 95 |
+
if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
|
| 96 |
+
uploaded_image = st.session_state.chat_history[-2]["image"]
|
| 97 |
+
imagefile = Image.open(uploaded_image).convert('RGB')
|
| 98 |
+
|
| 99 |
+
msgs = [{"role": "user", "content": user_text}]
|
| 100 |
+
res = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
|
| 101 |
+
sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
|
| 102 |
+
temperature=temperature, stream=True)
|
| 103 |
+
|
| 104 |
+
# Collect the generated_text str
|
| 105 |
+
generated_text = st.write_stream(res)
|
| 106 |
+
|
| 107 |
+
st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None})
|
| 108 |
+
|
| 109 |
+
st.divider()
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit-minicpmv2_6.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from decord import VideoReader, cpu
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import AutoModel, AutoTokenizer
|
| 9 |
+
|
| 10 |
+
# Model path
|
| 11 |
+
model_path = "openbmb/MiniCPM-V-2_6"
|
| 12 |
+
upload_path = ".\\uploads"
|
| 13 |
+
|
| 14 |
+
# User and assistant names
|
| 15 |
+
U_NAME = "User"
|
| 16 |
+
A_NAME = "Assistant"
|
| 17 |
+
|
| 18 |
+
# Set page configuration
|
| 19 |
+
st.set_page_config(
|
| 20 |
+
page_title="MiniCPM-V-2_6 Streamlit",
|
| 21 |
+
page_icon=":robot:",
|
| 22 |
+
layout="wide"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Load model and tokenizer
|
| 27 |
+
@st.cache_resource
|
| 28 |
+
def load_model_and_tokenizer():
|
| 29 |
+
print(f"load_model_and_tokenizer from {model_path}")
|
| 30 |
+
model = (AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa').
|
| 31 |
+
to(dtype=torch.bfloat16))
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 33 |
+
return model, tokenizer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Initialize session state
|
| 37 |
+
if 'model' not in st.session_state:
|
| 38 |
+
st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
|
| 39 |
+
st.session_state.model.eval().cuda()
|
| 40 |
+
print("model and tokenizer had loaded completed!")
|
| 41 |
+
|
| 42 |
+
# Initialize session state
|
| 43 |
+
if 'chat_history' not in st.session_state:
|
| 44 |
+
st.session_state.chat_history = []
|
| 45 |
+
st.session_state.uploaded_image_list = []
|
| 46 |
+
st.session_state.uploaded_image_num = 0
|
| 47 |
+
st.session_state.uploaded_video_list = []
|
| 48 |
+
st.session_state.uploaded_video_num = 0
|
| 49 |
+
st.session_state.response = ""
|
| 50 |
+
|
| 51 |
+
# Sidebar settings
|
| 52 |
+
sidebar_name = st.sidebar.title("MiniCPM-V-2_6 Streamlit")
|
| 53 |
+
max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
|
| 54 |
+
repetition_penalty = st.sidebar.slider("repetition_penalty", 0.0, 2.0, 1.05, step=0.01)
|
| 55 |
+
top_k = st.sidebar.slider("top_k", 0, 100, 100, step=1)
|
| 56 |
+
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
|
| 57 |
+
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
|
| 58 |
+
|
| 59 |
+
# Button to clear session history
|
| 60 |
+
buttonClean = st.sidebar.button("Clearing session history", key="clean")
|
| 61 |
+
if buttonClean:
|
| 62 |
+
# Reset the session state history and uploaded file lists
|
| 63 |
+
st.session_state.chat_history = []
|
| 64 |
+
st.session_state.uploaded_image_list = []
|
| 65 |
+
st.session_state.uploaded_image_num = 0
|
| 66 |
+
st.session_state.uploaded_video_list = []
|
| 67 |
+
st.session_state.uploaded_video_num = 0
|
| 68 |
+
st.session_state.response = ""
|
| 69 |
+
|
| 70 |
+
# If using GPU, clear the CUDA cache to free up memory
|
| 71 |
+
if torch.cuda.is_available():
|
| 72 |
+
torch.cuda.empty_cache()
|
| 73 |
+
|
| 74 |
+
# Rerun to refresh the interface
|
| 75 |
+
st.rerun()
|
| 76 |
+
|
| 77 |
+
# Display chat history
|
| 78 |
+
for i, message in enumerate(st.session_state.chat_history):
|
| 79 |
+
if message["role"] == "user":
|
| 80 |
+
with st.chat_message(name="user", avatar="user"):
|
| 81 |
+
if message["image"] is not None:
|
| 82 |
+
st.image(message["image"], caption='User uploaded images', width=512, use_column_width=False)
|
| 83 |
+
continue
|
| 84 |
+
elif message["video"] is not None:
|
| 85 |
+
st.video(message["video"], format="video/mp4", loop=False, autoplay=False, muted=True)
|
| 86 |
+
continue
|
| 87 |
+
elif message["content"] is not None:
|
| 88 |
+
st.markdown(message["content"])
|
| 89 |
+
else:
|
| 90 |
+
with st.chat_message(name="model", avatar="assistant"):
|
| 91 |
+
st.markdown(message["content"])
|
| 92 |
+
|
| 93 |
+
# Select mode
|
| 94 |
+
selected_mode = st.sidebar.selectbox("Select Mode", ["Text", "Single Image", "Multiple Images", "Video"])
|
| 95 |
+
|
| 96 |
+
# Supported image file extensions
|
| 97 |
+
image_type = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
|
| 98 |
+
|
| 99 |
+
if selected_mode == "Single Image":
|
| 100 |
+
# Single Image Mode
|
| 101 |
+
uploaded_image = st.sidebar.file_uploader("Upload a Single Image", key=1, type=image_type,
|
| 102 |
+
accept_multiple_files=False)
|
| 103 |
+
if uploaded_image is not None:
|
| 104 |
+
st.image(uploaded_image, caption='User Uploaded Image', width=512, use_column_width=False)
|
| 105 |
+
# Add the uploaded image to the chat history
|
| 106 |
+
st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image, "video": None})
|
| 107 |
+
st.session_state.uploaded_image_list = [uploaded_image]
|
| 108 |
+
st.session_state.uploaded_image_num = 1
|
| 109 |
+
|
| 110 |
+
if selected_mode == "Multiple Images":
|
| 111 |
+
# Multiple Images Mode
|
| 112 |
+
uploaded_image_list = st.sidebar.file_uploader("Upload Multiple Images", key=2, type=image_type,
|
| 113 |
+
accept_multiple_files=True)
|
| 114 |
+
uploaded_image_num = len(uploaded_image_list)
|
| 115 |
+
|
| 116 |
+
if uploaded_image_list is not None and uploaded_image_num > 0:
|
| 117 |
+
for img in uploaded_image_list:
|
| 118 |
+
st.image(img, caption='User Uploaded Image', width=512, use_column_width=False)
|
| 119 |
+
# Add the uploaded images to the chat history
|
| 120 |
+
st.session_state.chat_history.append({"role": "user", "content": None, "image": img, "video": None})
|
| 121 |
+
# Update the uploaded image list and count in st.session_state
|
| 122 |
+
st.session_state.uploaded_image_list = uploaded_image_list
|
| 123 |
+
st.session_state.uploaded_image_num = uploaded_image_num
|
| 124 |
+
|
| 125 |
+
# Supported video format suffixes
|
| 126 |
+
video_type = ['.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v']
|
| 127 |
+
|
| 128 |
+
# Tip: You can use the command `streamlit run ./web_demo_streamlit-minicpmv2_6.py --server.maxUploadSize 1024`
|
| 129 |
+
# to adjust the maximum upload size to 1024MB or larger files.
|
| 130 |
+
# The default 200MB limit of Streamlit's file_uploader component might be insufficient for video-based interactions.
|
| 131 |
+
# Adjust the size based on your GPU memory usage.
|
| 132 |
+
|
| 133 |
+
if selected_mode == "Video":
|
| 134 |
+
# 单个视频模态
|
| 135 |
+
uploaded_video = st.sidebar.file_uploader("Upload a single video file", key=3, type=video_type,
|
| 136 |
+
accept_multiple_files=False)
|
| 137 |
+
if uploaded_video is not None:
|
| 138 |
+
st.video(uploaded_video, format="video/mp4", loop=False, autoplay=False, muted=True)
|
| 139 |
+
st.session_state.chat_history.append({"role": "user", "content": None, "image": None, "video": uploaded_video})
|
| 140 |
+
|
| 141 |
+
uploaded_video_path = os.path.join(upload_path, uploaded_video.name)
|
| 142 |
+
with open(uploaded_video_path, "wb") as vf:
|
| 143 |
+
vf.write(uploaded_video.getvalue())
|
| 144 |
+
st.session_state.uploaded_video_list = [uploaded_video_path]
|
| 145 |
+
st.session_state.uploaded_video_num = 1
|
| 146 |
+
|
| 147 |
+
MAX_NUM_FRAMES = 64 # if cuda OOM set a smaller number
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Encodes a video by sampling frames at a fixed rate and converting them to image arrays.
|
| 151 |
+
def encode_video(video_path):
|
| 152 |
+
def uniform_sample(frame_indices, num_samples):
|
| 153 |
+
# Calculate sampling interval and uniformly sample frame indices
|
| 154 |
+
gap = len(frame_indices) / num_samples
|
| 155 |
+
sampled_idxs = np.linspace(gap / 2, len(frame_indices) - gap / 2, num_samples, dtype=int)
|
| 156 |
+
return [frame_indices[i] for i in sampled_idxs]
|
| 157 |
+
|
| 158 |
+
# Read the video and set the decoder's context to CPU
|
| 159 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 160 |
+
|
| 161 |
+
# Calculate the sampling interval to sample video frames at 1 FPS
|
| 162 |
+
sample_fps = round(vr.get_avg_fps() / 1) # Use integer FPS
|
| 163 |
+
frame_idx = list(range(0, len(vr), sample_fps))
|
| 164 |
+
|
| 165 |
+
# If the number of sampled frames exceeds the maximum limit, uniformly sample them
|
| 166 |
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
| 167 |
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
| 168 |
+
|
| 169 |
+
# Retrieve the sampled frames and convert them to image arrays
|
| 170 |
+
frames = vr.get_batch(frame_idx).asnumpy()
|
| 171 |
+
frames = [Image.fromarray(frame.astype('uint8')) for frame in frames]
|
| 172 |
+
|
| 173 |
+
print('Number of frames:', len(frames))
|
| 174 |
+
return frames
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# User input box
|
| 179 |
+
user_text = st.chat_input("Enter your question")
|
| 180 |
+
if user_text is not None:
|
| 181 |
+
if user_text.strip() is "":
|
| 182 |
+
st.warning('Input message could not be empty!', icon="⚠️")
|
| 183 |
+
else:
|
| 184 |
+
# Display user input and save it to session history
|
| 185 |
+
with st.chat_message(U_NAME, avatar="user"):
|
| 186 |
+
st.session_state.chat_history.append({
|
| 187 |
+
"role": "user",
|
| 188 |
+
"content": user_text,
|
| 189 |
+
"image": None,
|
| 190 |
+
"video": None
|
| 191 |
+
})
|
| 192 |
+
st.markdown(f"{U_NAME}: {user_text}")
|
| 193 |
+
|
| 194 |
+
# Generate responses using the model
|
| 195 |
+
model = st.session_state.model
|
| 196 |
+
tokenizer = st.session_state.tokenizer
|
| 197 |
+
content_list = [] # Store the content (text or image) that will be passed into the model
|
| 198 |
+
imageFile = None
|
| 199 |
+
|
| 200 |
+
with st.chat_message(A_NAME, avatar="assistant"):
|
| 201 |
+
# Handle different inputs depending on the mode selected by the user
|
| 202 |
+
if selected_mode == "Single Image":
|
| 203 |
+
# Single image mode: pass in the last uploaded image
|
| 204 |
+
print("Single Images mode in use")
|
| 205 |
+
if len(st.session_state.chat_history) > 1 and len(st.session_state.uploaded_image_list) >= 1:
|
| 206 |
+
uploaded_image = st.session_state.uploaded_image_list[-1]
|
| 207 |
+
if uploaded_image:
|
| 208 |
+
imageFile = Image.open(uploaded_image).convert('RGB')
|
| 209 |
+
content_list.append(imageFile)
|
| 210 |
+
else:
|
| 211 |
+
print("Single Images mode: No image found")
|
| 212 |
+
|
| 213 |
+
elif selected_mode == "Multiple Images":
|
| 214 |
+
# Multi-image mode: pass in all the images uploaded last time
|
| 215 |
+
print("Multiple Images mode in use")
|
| 216 |
+
if len(st.session_state.chat_history) > 1 and st.session_state.uploaded_image_num >= 1:
|
| 217 |
+
for uploaded_image in st.session_state.uploaded_image_list:
|
| 218 |
+
imageFile = Image.open(uploaded_image).convert('RGB')
|
| 219 |
+
content_list.append(imageFile)
|
| 220 |
+
else:
|
| 221 |
+
print("Multiple Images mode: No image found")
|
| 222 |
+
|
| 223 |
+
elif selected_mode == "Video":
|
| 224 |
+
# Video mode: pass in slice frames of uploaded video
|
| 225 |
+
print("Video mode in use")
|
| 226 |
+
if len(st.session_state.chat_history) > 1 and st.session_state.uploaded_video_num == 1:
|
| 227 |
+
uploaded_video_path = st.session_state.uploaded_video_list[-1]
|
| 228 |
+
if uploaded_video_path:
|
| 229 |
+
with st.spinner('Encoding your video, please wait...'):
|
| 230 |
+
frames = encode_video(uploaded_video_path)
|
| 231 |
+
else:
|
| 232 |
+
print("Video Mode: No video found")
|
| 233 |
+
|
| 234 |
+
# Defining model parameters
|
| 235 |
+
params = {
|
| 236 |
+
'sampling': True,
|
| 237 |
+
'top_p': top_p,
|
| 238 |
+
'top_k': top_k,
|
| 239 |
+
'temperature': temperature,
|
| 240 |
+
'repetition_penalty': repetition_penalty,
|
| 241 |
+
"max_new_tokens": max_length,
|
| 242 |
+
"stream": True
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
# Set different input parameters depending on whether to upload a video
|
| 246 |
+
if st.session_state.uploaded_video_num == 1 and selected_mode == "Video":
|
| 247 |
+
msgs = [{"role": "user", "content": frames + [user_text]}]
|
| 248 |
+
# Set decode params for video
|
| 249 |
+
params["max_inp_length"] = 4352 # Set the maximum input length of the video mode
|
| 250 |
+
params["use_image_id"] = False # Do not use image_id
|
| 251 |
+
params["max_slice_nums"] = 1 # # use 1 if cuda OOM and video resolution > 448*448
|
| 252 |
+
else:
|
| 253 |
+
content_list.append(user_text)
|
| 254 |
+
msgs = [{"role": "user", "content": content_list}]
|
| 255 |
+
|
| 256 |
+
print("content_list:", content_list) # debug
|
| 257 |
+
print("params:", params) # debug
|
| 258 |
+
|
| 259 |
+
# Generate and display the model's responses
|
| 260 |
+
with st.spinner('AI is thinking...'):
|
| 261 |
+
response = model.chat(image=None, msgs=msgs, context=None, tokenizer=tokenizer, **params)
|
| 262 |
+
st.session_state.response = st.write_stream(response)
|
| 263 |
+
st.session_state.chat_history.append({
|
| 264 |
+
"role": "model",
|
| 265 |
+
"content": st.session_state.response,
|
| 266 |
+
"image": None,
|
| 267 |
+
"video": None
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
st.divider() # Add separators to the interface
|
| 271 |
+
|
r1-a/response_generation/minicpm/MiniCPM-o/web_demos/web_demo_streamlit.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoModel, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
# Model path
|
| 7 |
+
model_path = "openbmb/MiniCPM-V-2"
|
| 8 |
+
|
| 9 |
+
# User and assistant names
|
| 10 |
+
U_NAME = "User"
|
| 11 |
+
A_NAME = "Assistant"
|
| 12 |
+
|
| 13 |
+
# Set page configuration
|
| 14 |
+
st.set_page_config(
|
| 15 |
+
page_title="Minicpm-V-2 Streamlit",
|
| 16 |
+
page_icon=":robot:",
|
| 17 |
+
layout="wide"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Load model and tokenizer
|
| 21 |
+
@st.cache_resource
|
| 22 |
+
def load_model_and_tokenizer():
|
| 23 |
+
print(f"load_model_and_tokenizer from {model_path}")
|
| 24 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(
|
| 25 |
+
device="cuda:0", dtype=torch.bfloat16)
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 27 |
+
return model, tokenizer
|
| 28 |
+
|
| 29 |
+
# Initialize session state
|
| 30 |
+
if 'model' not in st.session_state:
|
| 31 |
+
st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
|
| 32 |
+
print("model and tokenizer had loaded completed!")
|
| 33 |
+
|
| 34 |
+
# Initialize session state
|
| 35 |
+
if 'chat_history' not in st.session_state:
|
| 36 |
+
st.session_state.chat_history = []
|
| 37 |
+
|
| 38 |
+
# Sidebar settings
|
| 39 |
+
sidebar_name = st.sidebar.title("Minicpm-V-2 Streamlit")
|
| 40 |
+
max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
|
| 41 |
+
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
|
| 42 |
+
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
|
| 43 |
+
|
| 44 |
+
# Clear chat history button
|
| 45 |
+
buttonClean = st.sidebar.button("Clear chat history", key="clean")
|
| 46 |
+
if buttonClean:
|
| 47 |
+
st.session_state.chat_history = []
|
| 48 |
+
st.session_state.response = ""
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
torch.cuda.empty_cache()
|
| 51 |
+
st.rerun()
|
| 52 |
+
|
| 53 |
+
# Display chat history
|
| 54 |
+
for i, message in enumerate(st.session_state.chat_history):
|
| 55 |
+
if message["role"] == "user":
|
| 56 |
+
with st.chat_message(name="user", avatar="user"):
|
| 57 |
+
if message["image"] is not None:
|
| 58 |
+
st.image(message["image"], caption='User uploaded image', width=468, use_column_width=False)
|
| 59 |
+
continue
|
| 60 |
+
elif message["content"] is not None:
|
| 61 |
+
st.markdown(message["content"])
|
| 62 |
+
else:
|
| 63 |
+
with st.chat_message(name="model", avatar="assistant"):
|
| 64 |
+
st.markdown(message["content"])
|
| 65 |
+
|
| 66 |
+
# Select mode
|
| 67 |
+
selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
|
| 68 |
+
if selected_mode == "Image":
|
| 69 |
+
# Image mode
|
| 70 |
+
uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"], accept_multiple_files=False)
|
| 71 |
+
if uploaded_image is not None:
|
| 72 |
+
st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
|
| 73 |
+
# Add uploaded image to chat history
|
| 74 |
+
st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})
|
| 75 |
+
|
| 76 |
+
# User input box
|
| 77 |
+
user_text = st.chat_input("Enter your question")
|
| 78 |
+
if user_text:
|
| 79 |
+
with st.chat_message(U_NAME, avatar="user"):
|
| 80 |
+
st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
|
| 81 |
+
st.markdown(f"{U_NAME}: {user_text}")
|
| 82 |
+
|
| 83 |
+
# Generate reply using the model
|
| 84 |
+
model = st.session_state.model
|
| 85 |
+
tokenizer = st.session_state.tokenizer
|
| 86 |
+
|
| 87 |
+
with st.chat_message(A_NAME, avatar="assistant"):
|
| 88 |
+
# If the previous message contains an image, pass the image to the model
|
| 89 |
+
if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
|
| 90 |
+
uploaded_image = st.session_state.chat_history[-2]["image"]
|
| 91 |
+
imagefile = Image.open(uploaded_image).convert('RGB')
|
| 92 |
+
|
| 93 |
+
msgs = [{"role": "user", "content": user_text}]
|
| 94 |
+
res, context, _ = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
|
| 95 |
+
sampling=True,top_p=top_p,temperature=temperature)
|
| 96 |
+
st.markdown(f"{A_NAME}: {res}")
|
| 97 |
+
st.session_state.chat_history.append({"role": "model", "content": res, "image": None})
|
| 98 |
+
|
| 99 |
+
st.divider()
|