GeoUni-Instruct / modeling_geouni.py
JO-KU's picture
first
b6ec373 verified
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM
from typing import List, Optional, Tuple, Union
from .modeling_qwen2 import Qwen2Model, Qwen2ForCausalLM
from .configuration_qwen2 import Qwen2Config
# Custom config for GeoUni
class GeoUniConfig(Qwen2Config):
model_type = "geo-uni"
def __init__(self, vocab_size=159864, num_vq_tokens=256, num_new_special_tokens=7, llm_vocab_size=151665, codebook_size=8192, **kwargs):
super().__init__(**kwargs) # Call parent class constructor
self.vocab_size = vocab_size
self.num_vq_tokens = num_vq_tokens
self.num_new_special_tokens = num_new_special_tokens
self.llm_vocab_size = llm_vocab_size
self.codebook_size = codebook_size
class GeoUniModel(Qwen2Model):
config_class = GeoUniConfig
def __init__(self, config: Qwen2Config):
super(GeoUniModel, self).__init__(config)
class GeoUniForCausalLM(Qwen2ForCausalLM):
config_class = GeoUniConfig
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = GeoUniModel(config)
self.vocab_size = config.vocab_size
self.num_vq_tokens = config.num_vq_tokens
self.num_new_special_tokens = config.num_new_special_tokens
self.llm_vocab_size = config.llm_vocab_size
self.codebook_size = config.codebook_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
batch_size_t2i=0,
batch_size_reasoning=0,
batch_size_mixing=0,
):
outputs = super().forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
if labels is not None:
logits = outputs.logits
loss_t2i = F.cross_entropy(
logits[:batch_size_t2i, :-1].contiguous().view(-1, self.vocab_size),
labels[:batch_size_t2i, 1:].contiguous().view(-1), ignore_index=-100,
)
loss_reasoning = F.cross_entropy(
logits[batch_size_t2i:batch_size_t2i+batch_size_reasoning, :-1].contiguous().view(-1, self.vocab_size),
labels[batch_size_t2i:batch_size_t2i+batch_size_reasoning, 1:].contiguous().view(-1), ignore_index=-100,
)
loss_mixing = F.cross_entropy(
logits[-batch_size_mixing:, :-1].contiguous().view(-1, self.vocab_size),
labels[-batch_size_mixing:, 1:].contiguous().view(-1), ignore_index=-100,
)
return logits, loss_t2i, loss_reasoning, loss_mixing
return outputs
@torch.no_grad()
def t2i_generate(
self,
input_ids: torch.LongTensor,
pad_token_id=151665,
temperature=1.0,
attention_masks=None,
):
# 生成 num_vq_tokens 个新 token
generated_tokens = self.generate(input_ids=input_ids,
max_new_tokens=self.num_vq_tokens,
attention_mask=attention_masks,
pad_token_id=pad_token_id,
eos_token_id=None,
temperature=temperature,
do_sample=False,
top_p=None,
use_cache=True,
)
# 转换为 VQ-GAN 可接收的 token
new_tokens = generated_tokens[:, -self.num_vq_tokens:] - (self.llm_vocab_size + self.num_new_special_tokens)
gen_token_ids = torch.clamp(new_tokens, max=self.codebook_size - 1, min=0)
return gen_token_ids
@torch.no_grad()
def mix_generate(self,
input_ids,
max_new_tokens: int,
temperature: float,
pad_token_id: int,
eos_token_id: int,
soi_token_id: int,
eoi_token_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# 生成完整序列
output_ids = self.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
do_sample=False,
top_p=None,
use_cache=True
)
output_ids = output_ids[:, input_ids.size(1):] # 移除输入部分
batch_size = output_ids.size(0)
assert batch_size == 1
image_tokens = output_ids[:, 1:1+self.num_vq_tokens]
image_tokens = image_tokens - (self.llm_vocab_size + self.num_new_special_tokens)
pad_length = self.num_vq_tokens - image_tokens.shape[1]
# 如果不足,后面用0补齐
if pad_length > 0:
padding = torch.zeros((image_tokens.shape[0], pad_length), dtype=image_tokens.dtype, device=image_tokens.device)
image_tokens = torch.cat([image_tokens, padding], dim=1)
image_tokens = torch.clamp(image_tokens, max=self.codebook_size - 1, min=0)
text_tokens = output_ids[:, 2+self.num_vq_tokens:]
return image_tokens, text_tokens
AutoConfig.register("geo-uni", GeoUniConfig)
AutoModelForCausalLM.register(GeoUniConfig, GeoUniForCausalLM)