ryota39 commited on
Commit
3afe408
·
verified ·
1 Parent(s): df99996

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -0
README.md CHANGED
@@ -12,6 +12,80 @@ license: cc-by-nc-4.0
12
  Tora-7B-v0.2 = NTQAI/chatntq-ja-7b-v1.0 + (NousResearch/Hermes-2-Pro-Mistral-7B - mistralai/Mistral-7B-v0.1)
13
  ```
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  ## Benchmark (Japanese MT bench)
16
 
17
  |model|category|score|ver|
 
12
  Tora-7B-v0.2 = NTQAI/chatntq-ja-7b-v1.0 + (NousResearch/Hermes-2-Pro-Mistral-7B - mistralai/Mistral-7B-v0.1)
13
  ```
14
 
15
+ ## 実装
16
+
17
+ @jovyan様の実装を参考に下記のコードでモデルを作成しました。
18
+
19
+ ```python
20
+ import torch
21
+ from transformers import AutoModelForCausalLM
22
+
23
+
24
+ def build_chat_vector_model(
25
+ base_model_name,
26
+ inst_model_name,
27
+ target_model_name,
28
+ skip_layers,
29
+ ):
30
+
31
+ base_model = AutoModelForCausalLM.from_pretrained(
32
+ base_model_name,
33
+ torch_dtype=torch.bfloat16,
34
+ device_map="cpu",
35
+ )
36
+ inst_model = AutoModelForCausalLM.from_pretrained(
37
+ inst_model_name,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="cpu",
40
+ )
41
+
42
+ target_model = AutoModelForCausalLM.from_pretrained(
43
+ target_model_name,
44
+ torch_dtype=torch.bfloat16,
45
+ device_map="cuda",
46
+ )
47
+
48
+ # 英語ベースモデル
49
+ for k, v in base_model.state_dict().items():
50
+ print(k, v.shape)
51
+
52
+ # 日本語継続事前学習モデル
53
+ for k, v in target_model.state_dict().items():
54
+ print(k, v.shape)
55
+
56
+ # 除外対象
57
+ skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
58
+
59
+ for k, v in target_model.state_dict().items():
60
+ # layernormも除外
61
+ if (k in skip_layers) or ("layernorm" in k):
62
+ continue
63
+ chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
64
+ new_v = v + chat_vector.to(v.device)
65
+ v.copy_(new_v)
66
+
67
+ target_model.save_pretrained("./chat_model")
68
+
69
+ return
70
+
71
+
72
+ if __name__ == '__main__':
73
+
74
+ base_model_name = "mistralai/Mistral-7B-v0.1"
75
+ inst_model_name = "NousResearch/Hermes-2-Pro-Mistral-7B"
76
+ target_model_name = "NTQAI/chatntq-ja-7b-v1.0"
77
+
78
+ skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
79
+
80
+ build_chat_vector_model(
81
+ base_model_name=base_model_name,
82
+ inst_model_name=inst_model_name,
83
+ target_model_name=target_model_name,
84
+ skip_layers=skip_layers
85
+ )
86
+
87
+ ```
88
+
89
  ## Benchmark (Japanese MT bench)
90
 
91
  |model|category|score|ver|