--- license: unknown --- Checkpoints of [facebookresearch/coconut](https://github.com/facebookresearch/coconut) runing on a A100 40GB. Logs are available at [wandb](https://wandb.ai/weikaihuang-xidian-university/coconut). Quickstart: ```python # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. import torch import torch.optim as optim from transformers import AutoModelForCausalLM, AutoTokenizer from coconut import Coconut def main(): load_model_path = "save_models/gsm-coconut/checkpoint_22" model_id = "openai-community/gpt2" # load the configuration file print(f"Loading from {load_model_path}.") model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token tokenizer.add_tokens("<|start-latent|>") tokenizer.add_tokens("<|end-latent|>") tokenizer.add_tokens("<|latent|>") latent_id = tokenizer.convert_tokens_to_ids("<|latent|>") start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>") end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>") saved_weights = torch.load( load_model_path, map_location=torch.device("cuda") ) model.resize_token_embeddings(len(tokenizer)) embeddings = model.get_input_embeddings() target_id = tokenizer.convert_tokens_to_ids("<<") # initialize the new token embeddings with a known token # it helps stablize the training for token_id in [latent_id, start_id, end_id]: target_embedding = embeddings.weight.data[token_id] embeddings.weight.data[token_id] = target_embedding # The input embeddings and lm heads are tied in GPT2. So the code below is not necessary lm_head = model.lm_head lm_head.weight.data[token_id] = lm_head.weight.data[target_id] model = Coconut(model, latent_id, start_id, end_id, tokenizer.eos_token_id) print(model.load_state_dict(saved_weights, strict=False)) model = model.to("cuda") prompt = "Sally received the following scores on her math quizzes: 50, 80, 80. Find her mean score." prompt = tokenizer(prompt, return_tensors="pt").to("cuda") output = model.generate( **prompt, max_new_tokens=20 ) for i, o in enumerate(output): print(f"Output {i}: {tokenizer.decode(o, skip_special_tokens=True)}") if __name__ == "__main__": main() ```