Commit
·
78f6f3b
1
Parent(s):
16cc769
Renaming state dict keys from Phi2
Browse files- phi2_model.py +8 -8
- streaming_inference.py +23 -33
phi2_model.py
CHANGED
|
@@ -91,7 +91,7 @@ class Embedding(nn.Module):
|
|
| 91 |
class Phi2Model(Phi2PreTrainedModel):
|
| 92 |
def __init__(self, config: Phi2Config) -> None:
|
| 93 |
super().__init__(config)
|
| 94 |
-
self.
|
| 95 |
vocab_size=config.vocab_size,
|
| 96 |
d_embedding=config.d_embedding,
|
| 97 |
embd_pdrop=config.embd_pdrop,
|
|
@@ -117,10 +117,10 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
| 117 |
|
| 118 |
"""
|
| 119 |
def get_input_embeddings(self) -> nn.Embedding:
|
| 120 |
-
return self.
|
| 121 |
|
| 122 |
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
| 123 |
-
self.
|
| 124 |
"""
|
| 125 |
|
| 126 |
def forward(
|
|
@@ -129,7 +129,7 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
| 129 |
kv_cache: KVCache | None = None,
|
| 130 |
key_padding_mask: torch.BoolTensor | None = None,
|
| 131 |
) -> torch.FloatTensor:
|
| 132 |
-
x = self.
|
| 133 |
for block in self.parallel_blocks:
|
| 134 |
x = block(
|
| 135 |
x,
|
|
@@ -143,8 +143,8 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
| 143 |
def __init__(self, config: Phi2Config) -> None:
|
| 144 |
super().__init__(config)
|
| 145 |
self.pretrained_model = Phi2Model(config)
|
| 146 |
-
self.
|
| 147 |
-
self.
|
| 148 |
self.loss_fn = nn.CrossEntropyLoss()
|
| 149 |
self.post_init() # calls self._init_weights() for all modules
|
| 150 |
|
|
@@ -156,8 +156,8 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
| 156 |
labels: torch.LongTensor | None = None,
|
| 157 |
) -> CausalLMOutputWithPast:
|
| 158 |
x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
| 159 |
-
x = self.
|
| 160 |
-
logits = self.
|
| 161 |
loss = (
|
| 162 |
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 163 |
if labels is not None
|
|
|
|
| 91 |
class Phi2Model(Phi2PreTrainedModel):
|
| 92 |
def __init__(self, config: Phi2Config) -> None:
|
| 93 |
super().__init__(config)
|
| 94 |
+
self.rotary_embedding = Embedding(
|
| 95 |
vocab_size=config.vocab_size,
|
| 96 |
d_embedding=config.d_embedding,
|
| 97 |
embd_pdrop=config.embd_pdrop,
|
|
|
|
| 117 |
|
| 118 |
"""
|
| 119 |
def get_input_embeddings(self) -> nn.Embedding:
|
| 120 |
+
return self.rotary_embedding.embeddings
|
| 121 |
|
| 122 |
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
| 123 |
+
self.rotary_embedding.embeddings = new_embeddings
|
| 124 |
"""
|
| 125 |
|
| 126 |
def forward(
|
|
|
|
| 129 |
kv_cache: KVCache | None = None,
|
| 130 |
key_padding_mask: torch.BoolTensor | None = None,
|
| 131 |
) -> torch.FloatTensor:
|
| 132 |
+
x = self.rotary_embedding(input_ids)
|
| 133 |
for block in self.parallel_blocks:
|
| 134 |
x = block(
|
| 135 |
x,
|
|
|
|
| 143 |
def __init__(self, config: Phi2Config) -> None:
|
| 144 |
super().__init__(config)
|
| 145 |
self.pretrained_model = Phi2Model(config)
|
| 146 |
+
self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
|
| 147 |
+
self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
|
| 148 |
self.loss_fn = nn.CrossEntropyLoss()
|
| 149 |
self.post_init() # calls self._init_weights() for all modules
|
| 150 |
|
|
|
|
| 156 |
labels: torch.LongTensor | None = None,
|
| 157 |
) -> CausalLMOutputWithPast:
|
| 158 |
x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
| 159 |
+
x = self.lm_head_layer_norm(x)
|
| 160 |
+
logits = self.lm_head_linear(x).to(torch.float32)
|
| 161 |
loss = (
|
| 162 |
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 163 |
if labels is not None
|
streaming_inference.py
CHANGED
|
@@ -1,43 +1,11 @@
|
|
| 1 |
import json
|
| 2 |
from threading import Thread
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 4 |
-
import torch
|
| 5 |
|
| 6 |
from .phi2_configuration import Phi2Config
|
| 7 |
from .phi2_model import Phi2ModelForCausalLM
|
| 8 |
|
| 9 |
|
| 10 |
-
# This works, but is not streaming
|
| 11 |
-
"""
|
| 12 |
-
if __name__ == "__main__":
|
| 13 |
-
device = "cuda"
|
| 14 |
-
|
| 15 |
-
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
|
| 16 |
-
model = Phi2ModelForCausalLM(model_config).to(device)
|
| 17 |
-
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
| 18 |
-
model.load_state_dict(phi_model.state_dict())
|
| 19 |
-
|
| 20 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
| 21 |
-
|
| 22 |
-
text = "Write an essay on sea monkeys: "
|
| 23 |
-
tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False).to(device)
|
| 24 |
-
outputs = model.generate(**tokens, max_length=200)
|
| 25 |
-
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 26 |
-
print(text)
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# This is streaming, but does not work because you can't set trust_remote_code=True
|
| 31 |
-
"""
|
| 32 |
-
if __name__ == "__main__":
|
| 33 |
-
client = InferenceClient(model="microsoft/phi-2")
|
| 34 |
-
text = "How do you make cheese?"
|
| 35 |
-
for token in client.text_generation(text, max_new_tokens=500, stream=True):
|
| 36 |
-
print(token, end="")
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# This is trying the TextIteratorStreamer class
|
| 41 |
if __name__ == "__main__":
|
| 42 |
# make and load tokenizer, use tokenizer to initialize token_streamer
|
| 43 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
|
@@ -48,7 +16,29 @@ if __name__ == "__main__":
|
|
| 48 |
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
|
| 49 |
model = Phi2ModelForCausalLM(model_config).to(device)
|
| 50 |
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
thread = Thread(
|
| 53 |
target=model.generate,
|
| 54 |
kwargs=dict(
|
|
|
|
| 1 |
import json
|
| 2 |
from threading import Thread
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
| 4 |
|
| 5 |
from .phi2_configuration import Phi2Config
|
| 6 |
from .phi2_model import Phi2ModelForCausalLM
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
if __name__ == "__main__":
|
| 10 |
# make and load tokenizer, use tokenizer to initialize token_streamer
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
|
|
|
| 16 |
model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
|
| 17 |
model = Phi2ModelForCausalLM(model_config).to(device)
|
| 18 |
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
| 19 |
+
|
| 20 |
+
phi_model_state_dict = phi_model.state_dict()
|
| 21 |
+
model_state_dict = {}
|
| 22 |
+
for key, value in phi_model_state_dict.items():
|
| 23 |
+
# transformer.embd.wte.weight -> model.rotary_embedding.embeddings.weight
|
| 24 |
+
# transformer.h.0.mlp.fc1.weight -> pretrained_model.parallel_blocks.0.mlp.fc1.weight
|
| 25 |
+
# transformer.h.0.ln.weight -> pretrained_model.parallel_blocks.0.layer_norm.weight
|
| 26 |
+
# transformer.h.0.mixer.Wqkv.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.Wqkv.weight
|
| 27 |
+
# transformer.h.0.mixer.out_proj.weight -> pretrained_model.parallel_blocks.0.multi_head_attention.fc_out.weight
|
| 28 |
+
# lm_head.ln.weight -> lm_head_layer_norm.weight
|
| 29 |
+
# lm_head.linear.weight -> lm_head_linear.weight
|
| 30 |
+
if key.startswith("transformer"):
|
| 31 |
+
key.replace("transformer.", "model.")
|
| 32 |
+
key.replace(".embd.wte.", ".rotary_embedding.embeddings.")
|
| 33 |
+
key.replace(".h.", ".parallel_blocks")
|
| 34 |
+
key.replace(".ln.", ".layer_norm.")
|
| 35 |
+
key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.")
|
| 36 |
+
key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.")
|
| 37 |
+
key.replace(".lm_head.ln.", ".lm_head_layer_norm.")
|
| 38 |
+
key.replace(".lm_head.linear.", ".lm_head_linear.")
|
| 39 |
+
model_state_dict[key] = value
|
| 40 |
+
model.load_state_dict(model_state_dict)
|
| 41 |
+
|
| 42 |
thread = Thread(
|
| 43 |
target=model.generate,
|
| 44 |
kwargs=dict(
|