|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from math import e |
|
|
import torch |
|
|
from attrdict import AttrDict |
|
|
from einops import rearrange |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModelForCausalLM, |
|
|
LlamaConfig, |
|
|
LlamaForCausalLM, |
|
|
PreTrainedModel, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
from janus.models.clip_encoder import CLIPVisionTower |
|
|
from janus.models.projector import MlpProjector |
|
|
|
|
|
|
|
|
|
|
|
class vision_head(torch.nn.Module): |
|
|
def __init__(self, params): |
|
|
super().__init__() |
|
|
self.output_mlp_projector = torch.nn.Linear( |
|
|
params.n_embed, params.image_token_embed |
|
|
) |
|
|
self.vision_activation = torch.nn.GELU() |
|
|
self.vision_head = torch.nn.Linear( |
|
|
params.image_token_embed, params.image_token_size |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.output_mlp_projector(x) |
|
|
x = self.vision_activation(x) |
|
|
x = self.vision_head(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def model_name_to_cls(cls_name): |
|
|
if "MlpProjector" in cls_name: |
|
|
cls = MlpProjector |
|
|
|
|
|
elif "CLIPVisionTower" in cls_name: |
|
|
cls = CLIPVisionTower |
|
|
|
|
|
elif "VQ" in cls_name: |
|
|
from janus.models.vq_model import VQ_models |
|
|
|
|
|
cls = VQ_models[cls_name] |
|
|
elif "vision_head" in cls_name: |
|
|
cls = vision_head |
|
|
else: |
|
|
raise ValueError(f"class_name {cls_name} is invalid.") |
|
|
|
|
|
return cls |
|
|
|
|
|
|
|
|
class VisionConfig(PretrainedConfig): |
|
|
model_type = "vision" |
|
|
cls: str = "" |
|
|
params: AttrDict = {} |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.cls = kwargs.get("cls", "") |
|
|
if not isinstance(self.cls, str): |
|
|
self.cls = self.cls.__name__ |
|
|
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
|
|
|
class AlignerConfig(PretrainedConfig): |
|
|
model_type = "aligner" |
|
|
cls: str = "" |
|
|
params: AttrDict = {} |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.cls = kwargs.get("cls", "") |
|
|
if not isinstance(self.cls, str): |
|
|
self.cls = self.cls.__name__ |
|
|
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
|
|
|
class GenVisionConfig(PretrainedConfig): |
|
|
model_type = "gen_vision" |
|
|
cls: str = "" |
|
|
params: AttrDict = {} |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.cls = kwargs.get("cls", "") |
|
|
if not isinstance(self.cls, str): |
|
|
self.cls = self.cls.__name__ |
|
|
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
|
|
|
class GenAlignerConfig(PretrainedConfig): |
|
|
model_type = "gen_aligner" |
|
|
cls: str = "" |
|
|
params: AttrDict = {} |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.cls = kwargs.get("cls", "") |
|
|
if not isinstance(self.cls, str): |
|
|
self.cls = self.cls.__name__ |
|
|
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
|
|
|
class GenHeadConfig(PretrainedConfig): |
|
|
model_type = "gen_head" |
|
|
cls: str = "" |
|
|
params: AttrDict = {} |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.cls = kwargs.get("cls", "") |
|
|
if not isinstance(self.cls, str): |
|
|
self.cls = self.cls.__name__ |
|
|
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
from dataclasses import dataclass |
|
|
@dataclass |
|
|
class VLChatProcessorOutput(): |
|
|
sft_format: str |
|
|
input_ids: torch.Tensor |
|
|
pixel_values: torch.Tensor |
|
|
num_image_tokens: torch.IntTensor |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.input_ids) |
|
|
|
|
|
class MultiModalityConfig(PretrainedConfig): |
|
|
model_type = "multi_modality" |
|
|
vision_config: VisionConfig |
|
|
aligner_config: AlignerConfig |
|
|
|
|
|
gen_vision_config: GenVisionConfig |
|
|
gen_aligner_config: GenAlignerConfig |
|
|
gen_head_config: GenHeadConfig |
|
|
|
|
|
language_config: LlamaConfig |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
vision_config = kwargs.get("vision_config", {}) |
|
|
self.vision_config = VisionConfig(**vision_config) |
|
|
|
|
|
aligner_config = kwargs.get("aligner_config", {}) |
|
|
self.aligner_config = AlignerConfig(**aligner_config) |
|
|
|
|
|
gen_vision_config = kwargs.get("gen_vision_config", {}) |
|
|
self.gen_vision_config = GenVisionConfig(**gen_vision_config) |
|
|
|
|
|
gen_aligner_config = kwargs.get("gen_aligner_config", {}) |
|
|
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) |
|
|
|
|
|
gen_head_config = kwargs.get("gen_head_config", {}) |
|
|
self.gen_head_config = GenHeadConfig(**gen_head_config) |
|
|
|
|
|
language_config = kwargs.get("language_config", {}) |
|
|
if isinstance(language_config, LlamaConfig): |
|
|
self.language_config = language_config |
|
|
else: |
|
|
self.language_config = LlamaConfig(**language_config) |
|
|
|
|
|
|
|
|
class MultiModalityPreTrainedModel(PreTrainedModel): |
|
|
config_class = MultiModalityConfig |
|
|
base_model_prefix = "multi_modality" |
|
|
_no_split_modules = [] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
|
|
|
|
|
class MultiModalityCausalLM(MultiModalityPreTrainedModel): |
|
|
def __init__(self, config: MultiModalityConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
vision_config = config.vision_config |
|
|
vision_cls = model_name_to_cls(vision_config.cls) |
|
|
self.vision_model = vision_cls(**vision_config.params) |
|
|
|
|
|
aligner_config = config.aligner_config |
|
|
aligner_cls = model_name_to_cls(aligner_config.cls) |
|
|
self.aligner = aligner_cls(aligner_config.params) |
|
|
|
|
|
gen_vision_config = config.gen_vision_config |
|
|
gen_vision_cls = model_name_to_cls(gen_vision_config.cls) |
|
|
self.gen_vision_model = gen_vision_cls() |
|
|
|
|
|
gen_aligner_config = config.gen_aligner_config |
|
|
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) |
|
|
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) |
|
|
|
|
|
gen_head_config = config.gen_head_config |
|
|
gen_head_cls = model_name_to_cls(gen_head_config.cls) |
|
|
self.gen_head = gen_head_cls(gen_head_config.params) |
|
|
|
|
|
self.gen_embed = torch.nn.Embedding( |
|
|
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed |
|
|
) |
|
|
|
|
|
language_config = config.language_config |
|
|
self.language_model = LlamaForCausalLM(language_config) |
|
|
|
|
|
def prepare_inputs_embeds( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
pixel_values: torch.FloatTensor, |
|
|
images_seq_mask: torch.LongTensor=None, |
|
|
images_emb_mask: torch.LongTensor=None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
input_ids (torch.LongTensor): [b, T] |
|
|
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] |
|
|
images_seq_mask (torch.BoolTensor): [b, T] |
|
|
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] |
|
|
|
|
|
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) |
|
|
|
|
|
Returns: |
|
|
input_embeds (torch.Tensor): [b, T, D] |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bs, n = pixel_values.shape[0:2] |
|
|
print('px.shape', pixel_values.shape) |
|
|
images = rearrange(pixel_values, "b n c h w -> (b n) c h w") |
|
|
|
|
|
images_embeds = self.aligner(self.vision_model(images)) |
|
|
|
|
|
|
|
|
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) |
|
|
|
|
|
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") |
|
|
|
|
|
|
|
|
input_ids[input_ids < 0] = 0 |
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
|
print('input_ids' ,input_ids.shape) |
|
|
print('images_seq_mask ',images_seq_mask.shape) |
|
|
print('inputs_embeds ',inputs_embeds.shape) |
|
|
print('images_embeds ',images_embeds.shape) |
|
|
print('images_emb_mask ',images_emb_mask.shape) |
|
|
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] |
|
|
|
|
|
return inputs_embeds |
|
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): |
|
|
return self.gen_aligner(self.gen_embed(image_ids)) |
|
|
|
|
|
def forward(self,vl_chat_processor, |
|
|
input_ids, labels=None, task="understanding", return_dict=True, pixel_values=None, images_seq_mask=None, images_emb_mask=None, **kwargs): |
|
|
if task == "understanding": |
|
|
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask) |
|
|
return self.language_model.forward( |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
elif task == "generation": |
|
|
print('LLLLLLLLLLL ',pixel_values) |
|
|
print(kwargs) |
|
|
image_token_num_per_image = 576 |
|
|
cfg_weight = 5 |
|
|
temperature = 1 |
|
|
|
|
|
tokens = torch.zeros((2*input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda() |
|
|
for i in range(2): |
|
|
tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), :] = input_ids |
|
|
if i % 2 != 0: |
|
|
tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), 1:-1] = 100015 |
|
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings()(tokens) |
|
|
|
|
|
generated_tokens = torch.zeros((2*input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda() |
|
|
|
|
|
outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None, labels=labels) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
logits = self.gen_head(hidden_states) |
|
|
|
|
|
|
|
|
logits_cond = logits[0::2, :] |
|
|
logits_uncond = logits[1::2, :] |
|
|
|
|
|
all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond) |
|
|
|
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = all_logits[..., :-1, :].contiguous() |
|
|
shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) |
|
|
|
|
|
if labels is not None: |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
shift_labels = shift_labels.view(-1) |
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
else: |
|
|
loss = None |
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
elif task == "generation_direct": |
|
|
outputs = self.language_model.model(input_ids=input_ids, **kwargs) |
|
|
hidden_states = outputs[0] |
|
|
logits = self.gen_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
|
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) |
|
|
|
|
|
if labels is not None: |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
else: |
|
|
loss = None |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
elif task == "image_editing": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_token_num_per_image = 576 |
|
|
img_size = 384 |
|
|
patch_size = 16 |
|
|
cfg_weight = 5 |
|
|
temperature = 1 |
|
|
|
|
|
tokens = torch.zeros((3 * input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda() |
|
|
pre_data = [] |
|
|
img_len = len(kwargs['source_image']) |
|
|
|
|
|
print(kwargs['source_image']) |
|
|
print(len(kwargs['source_image'][0])) |
|
|
import PIL.Image |
|
|
images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']] |
|
|
|
|
|
print('len_images : ',len(images)) |
|
|
encoder_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'] |
|
|
print(encoder_pixel_values.shape) |
|
|
print(encoder_pixel_values[0].shape) |
|
|
print(encoder_pixel_values[0][0][0][:2]) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(3 * input_ids.size(0)): |
|
|
print(input_ids.shape) |
|
|
print(input_ids.size(0)) |
|
|
tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0),:] = input_ids[i // 3,:] |
|
|
if i % 3 == 2: |
|
|
tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0), 1:-1] = 100002 |
|
|
print(encoder_pixel_values[i//3,:].shape) |
|
|
print(len(kwargs['sft_format'][i//3])) |
|
|
print(tokens[i].shape) |
|
|
pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:], |
|
|
input_ids=tokens[i - 2], |
|
|
num_image_tokens=[vl_chat_processor.num_image_tokens] * 1)) |
|
|
pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:], |
|
|
input_ids=tokens[i - 1], |
|
|
num_image_tokens=[vl_chat_processor.num_image_tokens] * 1)) |
|
|
pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=None, input_ids=tokens[i], |
|
|
num_image_tokens=[])) |
|
|
|
|
|
|
|
|
ppp = (tokens == 100580).nonzero() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_inputs = vl_chat_processor.batchify(pre_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds = self.prepare_inputs_embeds( |
|
|
input_ids=tokens.cuda(), |
|
|
pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(), |
|
|
images_emb_mask=prepare_inputs['images_emb_mask'].cuda(), |
|
|
images_seq_mask=prepare_inputs['images_seq_mask'].cuda() |
|
|
) |
|
|
|
|
|
|
|
|
input_image_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'].to(torch.bfloat16).cuda() |
|
|
quant_input, emb_loss_input, info_input = self.gen_vision_model.encode(input_image_pixel_values) |
|
|
image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1) |
|
|
image_embeds_input = self.prepare_gen_img_embeds(image_tokens_input) |
|
|
|
|
|
|
|
|
for ii, ind in enumerate(ppp): |
|
|
|
|
|
if ii % 4 == 0: |
|
|
offset = ind[1] + 2 |
|
|
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[ii // 4] |
|
|
|
|
|
generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda() |
|
|
|
|
|
outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None, |
|
|
labels=labels) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
print('HHHH',hidden_states.shape) |
|
|
|
|
|
|
|
|
logits = self.gen_head(hidden_states) |
|
|
print('logits.shape', logits.shape) |
|
|
print(labels.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logit_cond_full = logits[0::3, :] |
|
|
logit_cond_part = logits[1::3, :] |
|
|
logit_uncond = logits[2::3, :] |
|
|
|
|
|
cfg_weight2 = 5 |
|
|
logit_cond = (logit_cond_full + cfg_weight2 * (logit_cond_part)) / (1 + cfg_weight2) |
|
|
all_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) |
|
|
|
|
|
|
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = all_logits[..., :-1, :].contiguous() |
|
|
shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) |
|
|
|
|
|
if labels is not None: |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
shift_labels = shift_labels.view(-1) |
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
print(shift_logits.shape, shift_labels.shape) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
else: |
|
|
loss = None |
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
AutoConfig.register("vision", VisionConfig) |
|
|
AutoConfig.register("aligner", AlignerConfig) |
|
|
AutoConfig.register("gen_vision", GenVisionConfig) |
|
|
AutoConfig.register("gen_aligner", GenAlignerConfig) |
|
|
AutoConfig.register("gen_head", GenHeadConfig) |
|
|
AutoConfig.register("multi_modality", MultiModalityConfig) |
|
|
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) |
|
|
|