| import torch |
| from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM, BitsAndBytesConfig |
| from torch import nn |
| import os |
| from typing import Optional, List |
| import os |
|
|
| def kmp_preprocess(pattern): |
| pattern_len = len(pattern) |
| prefix_suffix = [0] * pattern_len |
| j = 0 |
|
|
| for i in range(1, pattern_len): |
| while j > 0 and pattern[i] != pattern[j]: |
| j = prefix_suffix[j - 1] |
|
|
| if pattern[i] == pattern[j]: |
| j += 1 |
|
|
| prefix_suffix[i] = j |
|
|
| return prefix_suffix |
|
|
| def kmp_search(text, pattern): |
| text_len = len(text) |
| pattern_len = len(pattern) |
| prefix_suffix = kmp_preprocess(pattern) |
| matches = [] |
|
|
| j = 0 |
| for i in range(text_len): |
| while j > 0 and text[i] != pattern[j]: |
| j = prefix_suffix[j - 1] |
|
|
| if text[i] == pattern[j]: |
| j += 1 |
|
|
| if j == pattern_len: |
| matches.append(i - j + 1) |
| j = prefix_suffix[j - 1] |
|
|
| return matches |
|
|
| class ModelWrapper: |
| def __init__(self, model): |
| self.model = model |
|
|
| def __getattr__(self, name): |
| return getattr(self.model, name) |
|
|
| @torch.no_grad() |
| def __call__(self, pixel_values): |
| return self.model(pixel_values) |
| |
| def eval(self): |
| pass |
|
|
| def train(self): |
| pass |
|
|
| |
| def parameters(self): |
| return self.model.parameters() |
|
|
|
|
| class CrelloModelConfig(PretrainedConfig): |
| def __init__( |
| self, |
| old_vocab_size: int = 32000, |
| vocab_size: int = 32000, |
| pad_token_id: int = 2, |
| ignore_ids: List[int] = [], |
| |
| freeze_lm: bool = True, |
| opt_version: str = 'facebook/opt-6.7b', |
| |
| task: str = 'captioning', |
| |
| use_lora: bool = False, |
| lora_alpha: int = 32, |
| lora_r: int = 8, |
| lora_dropout: float = 0.05, |
| lora_target_modules: str = r'.*\.(q_proj|v_proj)', |
| |
| hidden_size: int = -1, |
| load_in_4bit: Optional[bool] = False, |
| |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| assert old_vocab_size > 0, 'old_vocab_size must be positive' |
| assert vocab_size > 0, 'vocab_size must be positive' |
|
|
| self.old_vocab_size = old_vocab_size |
| self.vocab_size = vocab_size |
| self.pad_token_id = pad_token_id |
| self.freeze_lm = freeze_lm |
| self.opt_version = opt_version |
| self.task = task |
| self.use_lora = use_lora |
| self.lora_alpha = lora_alpha |
| self.lora_r = lora_r |
| self.lora_dropout = lora_dropout |
| self.lora_target_modules = lora_target_modules |
| self.hidden_size = hidden_size |
| self.load_in_4bit = load_in_4bit |
| self.ignore_ids = ignore_ids |
|
|
|
|
| class CrelloModel(PreTrainedModel): |
| config_class = CrelloModelConfig |
| supports_gradient_checkpointing = True |
|
|
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| self.lm.gradient_checkpointing_enable() |
|
|
| def __init__(self, config: CrelloModelConfig): |
| super().__init__(config) |
| use_auth_token = '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' |
|
|
| self.pad_token_id = config.pad_token_id |
|
|
| self.args = config |
|
|
| opt_version = config.opt_version |
|
|
| print(f"Using {opt_version} for the language model.") |
|
|
| if 'facebook/opt' in opt_version: |
| self.lm = OPTForCausalLM.from_pretrained(opt_version) |
| word_embed_proj_dim = self.lm.config.word_embed_proj_dim |
| else: |
| if config.load_in_4bit: |
| print("\n would load_in_4bit") |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=config.load_in_4bit |
| ) |
| |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| device_map = {"": local_rank} |
| torch_dtype = torch.bfloat16 |
| else: |
| print("\n wouldn't load_in_4bit") |
| quantization_config = None |
| device_map = None |
| torch_dtype = None |
|
|
| self.lm = AutoModelForCausalLM.from_pretrained( |
| "WYBar/LLM_For_Layout_Planning", |
| subfolder="Meta-Llama-3-8B", |
| |
| |
| |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| word_embed_proj_dim = self.lm.config.hidden_size |
| self.config.hidden_size = self.lm.config.hidden_size |
| self.opt_version = opt_version |
|
|
| if self.args.freeze_lm: |
| self.lm.eval() |
| print("Freezing the LM.") |
| for param in self.lm.parameters(): |
| param.requires_grad = False |
| else: |
| print("\n no freeze lm, so to train lm") |
| self.lm.train() |
| self.lm.config.gradient_checkpointing = True |
|
|
| print('resize token embeddings to match the tokenizer', config.vocab_size) |
| self.lm.resize_token_embeddings(config.vocab_size) |
| self.input_embeddings = self.lm.get_input_embeddings() |
| print('after token embeddings to match the tokenizer', config.vocab_size) |
| |
| def train(self, mode=True): |
| super().train(mode=mode) |
| |
| if self.args.freeze_lm: |
| self.lm.eval() |
|
|
| def forward( |
| self, |
| labels: torch.LongTensor, |
| ): |
| print("inside Crello") |
| batch_size = labels.shape[0] |
| full_labels = labels.detach().clone() |
|
|
| input_embs = self.input_embeddings(labels) |
| input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean() |
|
|
| for ignore_id in self.config.ignore_ids: |
| full_labels[full_labels == ignore_id] = -100 |
|
|
| pad_idx = [] |
| |
| |
| for label in full_labels: |
| for k, token in enumerate(label): |
| |
| if token in [self.pad_token_id]: |
| label[k:] = -100 |
| pad_idx.append(k) |
| break |
| if k == len(label) - 1: |
| pad_idx.append(k + 1) |
| assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) |
| |
| print("inside Crello, lm1") |
| output = self.lm( inputs_embeds=input_embs, |
| |
| labels=full_labels, |
| output_hidden_states=True) |
| print("inside Crello, lm2") |
| |
| return output, full_labels, input_embs_norm |
|
|
| if __name__=="__main__": |
| config = CrelloModelConfig( |
| vocab_size=50265, |
| image_reg_token=50264, |
| image_gt_token=50263, |
| ) |
| print("config: ",config) |
| model1 = CrelloModel(config) |
| print("\nmodel1: ",model1) |
| model1.save_pretrained('test') |
| model2 = CrelloModel.from_pretrained('test') |
| print("\nmodel2: ",model2) |
| |
|
|
| state_dict1 = model1.state_dict() |
| state_dict2 = model2.state_dict() |
| assert set(state_dict1.keys()) == set(state_dict2.keys()) |
| for k in state_dict1.keys(): |
| assert torch.equal(state_dict1[k], state_dict2[k]) |
| print('all parameters are equal') |
| |