from transformers import AutoConfig, AutoModelForCausalLM, \ Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \ CLIPVisionModel, CLIPImageProcessor from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from typing import List, Optional, Tuple, Union from transformers.cache_utils import Cache, DynamicCache import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss import os import dataclasses from enum import auto, Enum from typing import List, Tuple from transformers import StoppingCriteria from transformers import TextStreamer class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "<|im_end|>" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep + '\n' for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret if self.sep_style == SeparatorStyle.MPT: if self.system: ret = self.system + self.sep else: ret = '' for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) # result.paste(pil_img, (0, (width - height) // 2)) result.paste(pil_img) return result else: result = Image.new(pil_img.mode, (height, height), background_color) # result.paste(pil_img, ((height - width) // 2, 0)) result.paste(pil_img) return result image = expand2square(image) elif image_process_mode == "Crop": max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) elif image_process_mode == "Resize": image = image.resize((224, 224)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") if return_pil: images.append(image) else: buffered = BytesIO() image.convert('RGB').save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) # image = image.resize((224, 224)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' msg = msg.replace('', img_str) ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_mpt = Conversation( system="""<|im_start|>system You should follow the instructions carefully and explain your answers in detail.""", # system = None, roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_templates = { "mpt": conv_mpt, } class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: for keyword_id in self.keyword_ids: if output_ids[0, -1] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' class WeaverConfig(Qwen2Config): model_type = "Weaver" # class MemoryCompressor(nn.Module): # """ # 使用 Cross-Attention 压缩 N 个 token 为 1 个 memory token # 可学习的 query 动态聚合最重要的信息 # """ # def __init__(self, hidden_size, num_latent_tokens=32, num_heads=8, dropout=0.1): # super().__init__() # self.hidden_size = hidden_size # self.num_latent_tokens = num_latent_tokens # # 可学习的压缩 query(1 个 token 来"总结"所有输入) # self.compress_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) # # Cross-Attention # self.cross_attn = nn.MultiheadAttention( # embed_dim=hidden_size, # num_heads=num_heads, # dropout=dropout, # batch_first=True # ) # # 信息瓶颈:进一步压缩 # bottleneck_dim = hidden_size // 2 # self.bottleneck = nn.Sequential( # nn.Linear(hidden_size, bottleneck_dim), # nn.SiLU(), # nn.Linear(bottleneck_dim, hidden_size), # ) # # 层归一化 # self.norm1 = nn.LayerNorm(hidden_size) # self.norm2 = nn.LayerNorm(hidden_size) # self._init_weights() # def _init_weights(self): # """初始化为接近 mean pooling 的效果""" # # Cross-attention 初始化为均匀注意力 # nn.init.xavier_uniform_(self.cross_attn.in_proj_weight) # nn.init.zeros_(self.cross_attn.in_proj_bias) # nn.init.xavier_uniform_(self.cross_attn.out_proj.weight) # nn.init.zeros_(self.cross_attn.out_proj.bias) # # Bottleneck 初始化为近似恒等映射 # nn.init.xavier_uniform_(self.bottleneck[0].weight, gain=0.1) # nn.init.zeros_(self.bottleneck[0].bias) # nn.init.xavier_uniform_(self.bottleneck[2].weight, gain=0.1) # nn.init.zeros_(self.bottleneck[2].bias) # def forward(self, x, attention_mask=None): # """ # Args: # x: (B, N, H) - N 个 latent tokens # attention_mask: (B, N) - 可选的 mask # Returns: # (B, 1, H) - 压缩后的单个 token # """ # B = x.shape[0] # # 扩展 query 到 batch # query = self.compress_query.expand(B, -1, -1) # (B, 1, H) # # Cross-attention # if attention_mask is not None: # key_padding_mask = (attention_mask == 0) # else: # key_padding_mask = None # attended, attn_weights = self.cross_attn( # query=query, # key=x, # value=x, # key_padding_mask=key_padding_mask # ) # (B, 1, H) # # 残差 + 归一化 # attended = self.norm1(attended + query) # # 信息瓶颈 # bottlenecked = self.bottleneck(attended) # output = self.norm2(bottlenecked + attended) # return output # 返回 attention weights 用于可视化 # class MemoryDecompressor(nn.Module): # """ # 使用位置编码 + 自注意力解压 1 个 token 为 N 个 token # """ # def __init__(self, hidden_size, num_latent_tokens=32, num_heads=8, num_layers=2): # super().__init__() # self.hidden_size = hidden_size # self.num_latent_tokens = num_latent_tokens # # 位置编码 # self.position_embed = nn.Embedding(num_latent_tokens, hidden_size) # # 初始展开投影 # self.expand_proj = nn.Linear(hidden_size, hidden_size) # # 自注意力精炼层 # self.refine_layers = nn.ModuleList([ # nn.TransformerEncoderLayer( # d_model=hidden_size, # nhead=num_heads, # dim_feedforward=hidden_size * 4, # dropout=0.1, # activation='gelu', # batch_first=True, # norm_first=True # Pre-LN 更稳定 # ) # for _ in range(num_layers) # ]) # self.final_norm = nn.LayerNorm(hidden_size) # self._init_weights() # def _init_weights(self): # """初始化为接近广播的效果""" # nn.init.eye_(self.expand_proj.weight) # nn.init.zeros_(self.expand_proj.bias) # nn.init.normal_(self.position_embed.weight, std=0.02) # def forward(self, x): # """ # Args: # x: (B, 1, H) - 压缩后的单个 token # Returns: # (B, N, H) - 解压后的 N 个 tokens # """ # B = x.shape[0] # device = x.device # # 广播到 N 个位置 # x_expanded = x.expand(-1, self.num_latent_tokens, -1) # (B, N, H) # # 投影 + 位置编码 # positions = torch.arange(self.num_latent_tokens, device=device) # pos_embed = self.position_embed(positions) # (N, H) # hidden = self.expand_proj(x_expanded) + pos_embed # (B, N, H) # # 自注意力精炼 # for layer in self.refine_layers: # hidden = layer(hidden) # return self.final_norm(hidden) class MemoryCompressor(nn.Module): """ Weaver端:压缩 N 个 token 为 1 个 memory token (B, N, H) -> (B, 1, H) """ def __init__(self, hidden_size, num_latent_tokens=32): super().__init__() self.num_tokens = num_latent_tokens self.hidden_size = hidden_size self.input_dim = num_latent_tokens * hidden_size # Compress: N*H -> H self.compressor = nn.Linear(self.input_dim, hidden_size, bias=False) def forward(self, x): # x: (Batch, N, H) B, N, H = x.shape x_flat = x.view(B, -1) # (B, N*H) compressed = self.compressor(x_flat) # (B, H) return compressed.unsqueeze(1) # (B, 1, H) class MemoryDecompressor(nn.Module): """ Reasoner端:解压 1 个 memory token 为 N 个 token (B, 1, H) -> (B, N, H) """ def __init__(self, hidden_size, num_latent_tokens=32): super().__init__() self.num_tokens = num_latent_tokens self.hidden_size = hidden_size self.output_dim = num_latent_tokens * hidden_size # Decompress (SwiGLU style): H -> N*H self.up_gate = nn.Linear(hidden_size, self.output_dim, bias=False) self.up_val = nn.Linear(hidden_size, self.output_dim, bias=False) self.act_fn = nn.SiLU() def forward(self, x): # x: (Batch, 1, H) B = x.shape[0] x_squeezed = x.squeeze(1) # (B, H) gate = self.act_fn(self.up_gate(x_squeezed)) val = self.up_val(x_squeezed) out_flat = gate * val # (B, N*H) return out_flat.view(B, self.num_tokens, self.hidden_size) # (B, N, H) class WeaverQwenModel(Qwen2Model): config_class = WeaverConfig def __init__(self, config: Qwen2Config): super(WeaverQwenModel, self).__init__(config) self.Q = nn.Embedding(config.latent_token_len , config.contexts_compression_llm_hidden_size) self.mm_projector = nn.Linear(config.contexts_compression_llm_hidden_size, config.hidden_size) self.weaver = None self.config.use_im_start_end = True self.memory_compressor = MemoryCompressor( hidden_size=config.hidden_size, num_latent_tokens=config.latent_token_len ) self.memory_decompressor = MemoryDecompressor( hidden_size=config.hidden_size, num_latent_tokens=config.latent_token_len ) def forward( self, input_ids: torch.LongTensor = None, context_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, context_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: # HACK: replace back original embeddings for LLaVA pretraining orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: with torch.no_grad(): self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) context_embeds = self.weaver.model.embed_tokens(context_ids) #######encoder####### if input_ids.shape[1] != 1 or self.training: use_im_start_end = getattr(self.config, "use_im_start_end", -1) im_patch_token = getattr(self.config, "im_patch_token", -1) im_start_token = getattr(self.config, "im_start_token", -1) im_end_token = getattr(self.config, "im_end_token", -1) context_features = [] for i in range(context_embeds.shape[0]): context_features.append([self.Q.weight]) use_im_start_end = True new_context_embeds = [] image_start_tokens_list = [] for cur_context_ids, cur_context_embeds, cur_context_features in zip(context_ids, context_embeds, context_features): if use_im_start_end: image_start_tokens = torch.where(cur_context_ids == im_start_token)[0] image_start_tokens_list.append(image_start_tokens) for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_context_features): per_cur_image_features = per_cur_image_features.to(device=cur_context_embeds.device) num_patches = per_cur_image_features.shape[0] if cur_context_ids[image_start_token_pos + num_patches + 1] != im_end_token: raise ValueError("The image end token should follow the image start token.") cur_context_embeds = torch.cat( ( cur_context_embeds[:image_start_token_pos+1], per_cur_image_features, cur_context_embeds[image_start_token_pos + num_patches + 1:] ), dim=0 ) new_context_embeds.append(cur_context_embeds) else: raise NotImplementedError image_start_tokens_list = torch.tensor(image_start_tokens_list) context_embeds = torch.stack(new_context_embeds, dim=0) weaver_hidden_states = self.weaver.forward( input_ids=None, attention_mask=context_attention_mask, past_key_values=None, inputs_embeds=context_embeds, use_cache=None, position_ids = None, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict )['hidden_states'][-1] latent_contexts = [] for i, weaver_hidden_state in enumerate(weaver_hidden_states): image_start_token_pos = image_start_tokens_list[i] weaver_hidden_state = weaver_hidden_state[image_start_token_pos+1:image_start_token_pos + num_patches+1] latent_contexts.append(weaver_hidden_state) ########decoder######## latent_features = [] for latent_context in latent_contexts: aligned_context = self.mm_projector(latent_context) # Apply Bottleneck (Memory Formation) # Input: (1, 32, H) -> Output: (1, 32, H) compressed = self.memory_compressor(aligned_context.unsqueeze(0)) decompressed = self.memory_decompressor(compressed) latent_features.append([decompressed.squeeze(0)]) new_input_embeds = [] for cur_input_ids, cur_input_embeds, cur_latent_features in zip(input_ids, inputs_embeds, latent_features): if use_im_start_end: if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): raise ValueError("The number of image start tokens and image end tokens should be the same.") image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] for image_start_token_pos, per_cur_latent_features in zip(image_start_tokens, cur_latent_features): per_cur_latent_features = per_cur_latent_features.to(device=cur_input_embeds.device) num_patches = per_cur_latent_features.shape[0] if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: raise ValueError("The image end token should follow the image start token.") cur_input_embeds = torch.cat( ( cur_input_embeds[:image_start_token_pos+1], per_cur_latent_features, cur_input_embeds[image_start_token_pos + num_patches + 1:] ), dim=0 ) new_input_embeds.append(cur_input_embeds) else: raise NotImplementedError inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(WeaverQwenModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class WeaverQwenForCausalLM(Qwen2ForCausalLM): config_class = WeaverConfig # supports_gradient_checkpointing = True def __init__(self, config): super(Qwen2ForCausalLM, self).__init__(config) self.model = WeaverQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, context_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, context_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, context_ids=context_ids, past_key_values=past_key_values, attention_mask=attention_mask, context_attention_mask=context_attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() # logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.get_seq_length() #max_cache_length = past_key_values.get_max_length() max_cache_length = None else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, #"images": kwargs.get("images", None), "context_ids": kwargs.get("context_ids", None), } ) return model_inputs @classmethod def from_pretrained( cls, pretrained_model_name_or_path, *model_args, **kwargs, ): model = super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) if os.path.exists(pretrained_model_name_or_path): weaver_path = os.path.join(pretrained_model_name_or_path, "weaver") print(f"Loading weaver from path: {weaver_path}") dtype = kwargs.get("torch_dtype", torch.float16) device = kwargs.get("device_map", "auto") weaver = Qwen2ForCausalLM.from_pretrained( weaver_path, use_safetensors=kwargs.get("use_safetensors", True), torch_dtype=dtype, device_map=device, ) else: print(f"Loading weaver from HF") dtype = kwargs.get("torch_dtype", torch.float16) device = kwargs.get("device_map", "auto") weaver = Qwen2ForCausalLM.from_pretrained( pretrained_model_name_or_path, subfolder="weaver", use_safetensors=kwargs.get("use_safetensors", True), torch_dtype=dtype, device_map=device, ) model.model.weaver = weaver print("Successfully loaded and attached weaver.") return model def initialize_special_tokenizer( self, tokenizer, device="cuda" ): config = self.get_model().config self.resize_token_embeddings(len(tokenizer)) config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] config.use_im_start_end = True if config.use_im_start_end: self.resize_token_embeddings(len(tokenizer)) config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) def chat(self, tokenizer, context, prompt): self.initialize_special_tokenizer(tokenizer) qs = prompt qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*self.get_model().config.latent_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs context = context + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*self.get_model().config.latent_token_len + DEFAULT_IM_END_TOKEN conv_mode = "mpt" conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) inputs_context = tokenizer([context]) input_ids = torch.as_tensor(inputs.input_ids).cuda() inputs_context_ids = torch.as_tensor(inputs_context.input_ids).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.autocast("cuda", dtype=torch.bfloat16): output_ids = self.generate( input_ids, context_ids=inputs_context_ids, do_sample=False, num_beams = 1, no_repeat_ngram_size = 20, streamer=streamer, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() return outputs AutoConfig.register("Weaver", WeaverConfig) AutoModelForCausalLM.register(WeaverConfig, WeaverQwenForCausalLM)