File size: 6,299 Bytes
b6ec373 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM
from typing import List, Optional, Tuple, Union
from .modeling_qwen2 import Qwen2Model, Qwen2ForCausalLM
from .configuration_qwen2 import Qwen2Config
# Custom config for GeoUni
class GeoUniConfig(Qwen2Config):
model_type = "geo-uni"
def __init__(self, vocab_size=159864, num_vq_tokens=256, num_new_special_tokens=7, llm_vocab_size=151665, codebook_size=8192, **kwargs):
super().__init__(**kwargs) # Call parent class constructor
self.vocab_size = vocab_size
self.num_vq_tokens = num_vq_tokens
self.num_new_special_tokens = num_new_special_tokens
self.llm_vocab_size = llm_vocab_size
self.codebook_size = codebook_size
class GeoUniModel(Qwen2Model):
config_class = GeoUniConfig
def __init__(self, config: Qwen2Config):
super(GeoUniModel, self).__init__(config)
class GeoUniForCausalLM(Qwen2ForCausalLM):
config_class = GeoUniConfig
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = GeoUniModel(config)
self.vocab_size = config.vocab_size
self.num_vq_tokens = config.num_vq_tokens
self.num_new_special_tokens = config.num_new_special_tokens
self.llm_vocab_size = config.llm_vocab_size
self.codebook_size = config.codebook_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
batch_size_t2i=0,
batch_size_reasoning=0,
batch_size_mixing=0,
):
outputs = super().forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
if labels is not None:
logits = outputs.logits
loss_t2i = F.cross_entropy(
logits[:batch_size_t2i, :-1].contiguous().view(-1, self.vocab_size),
labels[:batch_size_t2i, 1:].contiguous().view(-1), ignore_index=-100,
)
loss_reasoning = F.cross_entropy(
logits[batch_size_t2i:batch_size_t2i+batch_size_reasoning, :-1].contiguous().view(-1, self.vocab_size),
labels[batch_size_t2i:batch_size_t2i+batch_size_reasoning, 1:].contiguous().view(-1), ignore_index=-100,
)
loss_mixing = F.cross_entropy(
logits[-batch_size_mixing:, :-1].contiguous().view(-1, self.vocab_size),
labels[-batch_size_mixing:, 1:].contiguous().view(-1), ignore_index=-100,
)
return logits, loss_t2i, loss_reasoning, loss_mixing
return outputs
@torch.no_grad()
def t2i_generate(
self,
input_ids: torch.LongTensor,
pad_token_id=151665,
temperature=1.0,
attention_masks=None,
):
# 生成 num_vq_tokens 个新 token
generated_tokens = self.generate(input_ids=input_ids,
max_new_tokens=self.num_vq_tokens,
attention_mask=attention_masks,
pad_token_id=pad_token_id,
eos_token_id=None,
temperature=temperature,
do_sample=False,
top_p=None,
use_cache=True,
)
# 转换为 VQ-GAN 可接收的 token
new_tokens = generated_tokens[:, -self.num_vq_tokens:] - (self.llm_vocab_size + self.num_new_special_tokens)
gen_token_ids = torch.clamp(new_tokens, max=self.codebook_size - 1, min=0)
return gen_token_ids
@torch.no_grad()
def mix_generate(self,
input_ids,
max_new_tokens: int,
temperature: float,
pad_token_id: int,
eos_token_id: int,
soi_token_id: int,
eoi_token_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# 生成完整序列
output_ids = self.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
do_sample=False,
top_p=None,
use_cache=True
)
output_ids = output_ids[:, input_ids.size(1):] # 移除输入部分
batch_size = output_ids.size(0)
assert batch_size == 1
image_tokens = output_ids[:, 1:1+self.num_vq_tokens]
image_tokens = image_tokens - (self.llm_vocab_size + self.num_new_special_tokens)
pad_length = self.num_vq_tokens - image_tokens.shape[1]
# 如果不足,后面用0补齐
if pad_length > 0:
padding = torch.zeros((image_tokens.shape[0], pad_length), dtype=image_tokens.dtype, device=image_tokens.device)
image_tokens = torch.cat([image_tokens, padding], dim=1)
image_tokens = torch.clamp(image_tokens, max=self.codebook_size - 1, min=0)
text_tokens = output_ids[:, 2+self.num_vq_tokens:]
return image_tokens, text_tokens
AutoConfig.register("geo-uni", GeoUniConfig)
AutoModelForCausalLM.register(GeoUniConfig, GeoUniForCausalLM) |