File size: 6,299 Bytes
b6ec373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM
from typing import List, Optional, Tuple, Union
from .modeling_qwen2 import Qwen2Model, Qwen2ForCausalLM
from .configuration_qwen2 import Qwen2Config


# Custom config for GeoUni
class GeoUniConfig(Qwen2Config):
    model_type = "geo-uni"

    def __init__(self, vocab_size=159864, num_vq_tokens=256, num_new_special_tokens=7, llm_vocab_size=151665, codebook_size=8192, **kwargs):
        super().__init__(**kwargs)  # Call parent class constructor
        self.vocab_size = vocab_size
        self.num_vq_tokens = num_vq_tokens
        self.num_new_special_tokens = num_new_special_tokens
        self.llm_vocab_size = llm_vocab_size
        self.codebook_size = codebook_size


class GeoUniModel(Qwen2Model):
    config_class = GeoUniConfig

    def __init__(self, config: Qwen2Config):
        super(GeoUniModel, self).__init__(config)


class GeoUniForCausalLM(Qwen2ForCausalLM):
    config_class = GeoUniConfig

    def __init__(self, config):
        super(Qwen2ForCausalLM, self).__init__(config)
        self.model = GeoUniModel(config)
        self.vocab_size = config.vocab_size
        self.num_vq_tokens = config.num_vq_tokens
        self.num_new_special_tokens = config.num_new_special_tokens
        self.llm_vocab_size = config.llm_vocab_size
        self.codebook_size = config.codebook_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            batch_size_t2i=0,
            batch_size_reasoning=0,
            batch_size_mixing=0,
    ):
        outputs = super().forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
        if labels is not None:
            logits = outputs.logits
            loss_t2i = F.cross_entropy(
                logits[:batch_size_t2i, :-1].contiguous().view(-1, self.vocab_size),
                labels[:batch_size_t2i, 1:].contiguous().view(-1), ignore_index=-100,
            )
            loss_reasoning = F.cross_entropy(
                logits[batch_size_t2i:batch_size_t2i+batch_size_reasoning, :-1].contiguous().view(-1, self.vocab_size),
                labels[batch_size_t2i:batch_size_t2i+batch_size_reasoning, 1:].contiguous().view(-1), ignore_index=-100,
            )
            loss_mixing = F.cross_entropy(
                logits[-batch_size_mixing:, :-1].contiguous().view(-1, self.vocab_size),
                labels[-batch_size_mixing:, 1:].contiguous().view(-1), ignore_index=-100,
            )

            return logits, loss_t2i, loss_reasoning, loss_mixing
        
        return outputs

    @torch.no_grad()
    def t2i_generate(
            self,
            input_ids: torch.LongTensor,
            pad_token_id=151665,
            temperature=1.0,
            attention_masks=None,
    ):
          
        
        # 生成 num_vq_tokens 个新 token
        generated_tokens = self.generate(input_ids=input_ids,
                                            max_new_tokens=self.num_vq_tokens,
                                            attention_mask=attention_masks,
                                            pad_token_id=pad_token_id,
                                            eos_token_id=None,
                                            temperature=temperature,
                                            do_sample=False,
                                            top_p=None,
                                            use_cache=True,
                                        )

        # 转换为 VQ-GAN 可接收的 token
        new_tokens = generated_tokens[:, -self.num_vq_tokens:] - (self.llm_vocab_size + self.num_new_special_tokens)
        gen_token_ids = torch.clamp(new_tokens, max=self.codebook_size - 1, min=0)

        return gen_token_ids
    
    @torch.no_grad()
    def mix_generate(self, 
                    input_ids,
                    max_new_tokens: int,
                    temperature: float,
                    pad_token_id: int,
                    eos_token_id: int,
                    soi_token_id: int,
                    eoi_token_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        
        # 生成完整序列
        output_ids = self.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            do_sample=False,
            top_p=None,
            use_cache=True
        )
        output_ids = output_ids[:, input_ids.size(1):]  # 移除输入部分

        batch_size = output_ids.size(0)
        assert batch_size == 1
        image_tokens = output_ids[:, 1:1+self.num_vq_tokens]
        image_tokens = image_tokens - (self.llm_vocab_size + self.num_new_special_tokens)
        pad_length = self.num_vq_tokens - image_tokens.shape[1]
        # 如果不足,后面用0补齐
        if pad_length > 0:
            padding = torch.zeros((image_tokens.shape[0], pad_length), dtype=image_tokens.dtype, device=image_tokens.device)
            image_tokens = torch.cat([image_tokens, padding], dim=1)
        image_tokens = torch.clamp(image_tokens, max=self.codebook_size - 1, min=0)
        text_tokens = output_ids[:, 2+self.num_vq_tokens:]
        return image_tokens, text_tokens
    
AutoConfig.register("geo-uni", GeoUniConfig)
AutoModelForCausalLM.register(GeoUniConfig, GeoUniForCausalLM)