tsrigo commited on
Commit
c88fc87
·
verified ·
1 Parent(s): d9fdad5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -1
README.md CHANGED
@@ -3,4 +3,65 @@ license: unknown
3
  ---
4
  Checkpoints of [facebookresearch/coconut](https://github.com/facebookresearch/coconut) runing on a A100 40GB.
5
 
6
- Logs are available at [wandb](https://wandb.ai/weikaihuang-xidian-university/coconut).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
  Checkpoints of [facebookresearch/coconut](https://github.com/facebookresearch/coconut) runing on a A100 40GB.
5
 
6
+ Logs are available at [wandb](https://wandb.ai/weikaihuang-xidian-university/coconut).
7
+
8
+ Quickstart:
9
+
10
+ ```python
11
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
12
+ # All rights reserved.
13
+
14
+ import torch
15
+ import torch.optim as optim
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+ from coconut import Coconut
18
+
19
+ def main():
20
+ load_model_path = "save_models/gsm-coconut/checkpoint_22"
21
+ model_id = "openai-community/gpt2"
22
+ # load the configuration file
23
+ print(f"Loading from {load_model_path}.")
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(model_id)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+ tokenizer.add_tokens("<|start-latent|>")
29
+ tokenizer.add_tokens("<|end-latent|>")
30
+ tokenizer.add_tokens("<|latent|>")
31
+ latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
32
+ start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
33
+ end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")
34
+
35
+ saved_weights = torch.load(
36
+ load_model_path, map_location=torch.device("cuda")
37
+ )
38
+
39
+ model.resize_token_embeddings(len(tokenizer))
40
+ embeddings = model.get_input_embeddings()
41
+ target_id = tokenizer.convert_tokens_to_ids("<<")
42
+ # initialize the new token embeddings with a known token
43
+ # it helps stablize the training
44
+ for token_id in [latent_id, start_id, end_id]:
45
+ target_embedding = embeddings.weight.data[token_id]
46
+ embeddings.weight.data[token_id] = target_embedding
47
+ # The input embeddings and lm heads are tied in GPT2. So the code below is not necessary
48
+ lm_head = model.lm_head
49
+ lm_head.weight.data[token_id] = lm_head.weight.data[target_id]
50
+
51
+ model = Coconut(model, latent_id, start_id, end_id, tokenizer.eos_token_id)
52
+ print(model.load_state_dict(saved_weights, strict=False))
53
+ model = model.to("cuda")
54
+
55
+ prompt = "Sally received the following scores on her math quizzes: 50, 80, 80. Find her mean score."
56
+ prompt = tokenizer(prompt, return_tensors="pt").to("cuda")
57
+ output = model.generate(
58
+ **prompt,
59
+ max_new_tokens=20
60
+ )
61
+ for i, o in enumerate(output):
62
+ print(f"Output {i}: {tokenizer.decode(o, skip_special_tokens=True)}")
63
+
64
+ if __name__ == "__main__":
65
+ main()
66
+
67
+ ```