| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | from typing import List, Optional, Tuple, Union |
| | from .modeling_qwen2 import Qwen2Model, Qwen2ForCausalLM |
| | from .configuration_qwen2 import Qwen2Config |
| |
|
| |
|
| | |
| | class GeoUniConfig(Qwen2Config): |
| | model_type = "geo-uni" |
| |
|
| | def __init__(self, vocab_size=159864, num_vq_tokens=256, num_new_special_tokens=7, llm_vocab_size=151665, codebook_size=8192, **kwargs): |
| | super().__init__(**kwargs) |
| | self.vocab_size = vocab_size |
| | self.num_vq_tokens = num_vq_tokens |
| | self.num_new_special_tokens = num_new_special_tokens |
| | self.llm_vocab_size = llm_vocab_size |
| | self.codebook_size = codebook_size |
| |
|
| |
|
| | class GeoUniModel(Qwen2Model): |
| | config_class = GeoUniConfig |
| |
|
| | def __init__(self, config: Qwen2Config): |
| | super(GeoUniModel, self).__init__(config) |
| |
|
| |
|
| | class GeoUniForCausalLM(Qwen2ForCausalLM): |
| | config_class = GeoUniConfig |
| |
|
| | def __init__(self, config): |
| | super(Qwen2ForCausalLM, self).__init__(config) |
| | self.model = GeoUniModel(config) |
| | self.vocab_size = config.vocab_size |
| | self.num_vq_tokens = config.num_vq_tokens |
| | self.num_new_special_tokens = config.num_new_special_tokens |
| | self.llm_vocab_size = config.llm_vocab_size |
| | self.codebook_size = config.codebook_size |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_model(self): |
| | return self.model |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | batch_size_t2i=0, |
| | batch_size_reasoning=0, |
| | batch_size_mixing=0, |
| | ): |
| | outputs = super().forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) |
| | if labels is not None: |
| | logits = outputs.logits |
| | loss_t2i = F.cross_entropy( |
| | logits[:batch_size_t2i, :-1].contiguous().view(-1, self.vocab_size), |
| | labels[:batch_size_t2i, 1:].contiguous().view(-1), ignore_index=-100, |
| | ) |
| | loss_reasoning = F.cross_entropy( |
| | logits[batch_size_t2i:batch_size_t2i+batch_size_reasoning, :-1].contiguous().view(-1, self.vocab_size), |
| | labels[batch_size_t2i:batch_size_t2i+batch_size_reasoning, 1:].contiguous().view(-1), ignore_index=-100, |
| | ) |
| | loss_mixing = F.cross_entropy( |
| | logits[-batch_size_mixing:, :-1].contiguous().view(-1, self.vocab_size), |
| | labels[-batch_size_mixing:, 1:].contiguous().view(-1), ignore_index=-100, |
| | ) |
| |
|
| | return logits, loss_t2i, loss_reasoning, loss_mixing |
| | |
| | return outputs |
| |
|
| | @torch.no_grad() |
| | def t2i_generate( |
| | self, |
| | input_ids: torch.LongTensor, |
| | pad_token_id=151665, |
| | temperature=1.0, |
| | attention_masks=None, |
| | ): |
| | |
| | |
| | |
| | generated_tokens = self.generate(input_ids=input_ids, |
| | max_new_tokens=self.num_vq_tokens, |
| | attention_mask=attention_masks, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=None, |
| | temperature=temperature, |
| | do_sample=False, |
| | top_p=None, |
| | use_cache=True, |
| | ) |
| |
|
| | |
| | new_tokens = generated_tokens[:, -self.num_vq_tokens:] - (self.llm_vocab_size + self.num_new_special_tokens) |
| | gen_token_ids = torch.clamp(new_tokens, max=self.codebook_size - 1, min=0) |
| |
|
| | return gen_token_ids |
| | |
| | @torch.no_grad() |
| | def mix_generate(self, |
| | input_ids, |
| | max_new_tokens: int, |
| | temperature: float, |
| | pad_token_id: int, |
| | eos_token_id: int, |
| | soi_token_id: int, |
| | eoi_token_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
| | |
| | |
| | output_ids = self.generate( |
| | input_ids=input_ids, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | do_sample=False, |
| | top_p=None, |
| | use_cache=True |
| | ) |
| | output_ids = output_ids[:, input_ids.size(1):] |
| |
|
| | batch_size = output_ids.size(0) |
| | assert batch_size == 1 |
| | image_tokens = output_ids[:, 1:1+self.num_vq_tokens] |
| | image_tokens = image_tokens - (self.llm_vocab_size + self.num_new_special_tokens) |
| | pad_length = self.num_vq_tokens - image_tokens.shape[1] |
| | |
| | if pad_length > 0: |
| | padding = torch.zeros((image_tokens.shape[0], pad_length), dtype=image_tokens.dtype, device=image_tokens.device) |
| | image_tokens = torch.cat([image_tokens, padding], dim=1) |
| | image_tokens = torch.clamp(image_tokens, max=self.codebook_size - 1, min=0) |
| | text_tokens = output_ids[:, 2+self.num_vq_tokens:] |
| | return image_tokens, text_tokens |
| | |
| | AutoConfig.register("geo-uni", GeoUniConfig) |
| | AutoModelForCausalLM.register(GeoUniConfig, GeoUniForCausalLM) |