BoruiXu commited on
Commit
55a580f
·
verified ·
1 Parent(s): b1e3290
Files changed (5) hide show
  1. model_utils.py +144 -0
  2. modeling_phi3.py +1622 -0
  3. phi-3-chat-w4-g128_awq.pt +3 -0
  4. run_awq.py +195 -0
  5. save_weights.py +102 -0
model_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright © 2023 Advanced Micro Devices, Inc. All rights reserved.
3
+ #
4
+
5
+ import torch
6
+ import logging
7
+ import time
8
+ import random
9
+ import numpy as np
10
+
11
+ prompts = [ "What is the meaning of life?",
12
+ "Tell me something you don't know.",
13
+ "What does Xilinx do?",
14
+ "What is the mass of earth?",
15
+ "What is a poem?",
16
+ "What is recursion?",
17
+ "Tell me a one line joke.",
18
+ "Who is Gilgamesh?",
19
+ "Tell me something about cryptocurrency.",
20
+ "How did it all begin?"
21
+ ]
22
+
23
+ def warmup(model, tokenizer, max_new_tokens=30):
24
+ print(f"Warming up ... ")
25
+ for prompt in prompts[0:1]:
26
+ inputs = tokenizer(prompt, return_tensors="pt")
27
+ generate_ids = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_new_tokens)
28
+ _ = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
29
+ print(f"Warm up DONE!! ")
30
+
31
+
32
+ def decode_prompt(model, tokenizer, prompt, input_ids=None, max_new_tokens=30):
33
+ if input_ids is None:
34
+ print(f"prompt: {prompt}")
35
+ start = time.time()
36
+ inputs = tokenizer(prompt, return_tensors="pt")
37
+ end = time.time()
38
+ logging.critical(f"[PROFILE][WARMUP] tokenizer: {end-start}")
39
+ else:
40
+ logging.critical(f"[PROFILE][WARMUP] tokenizer: na") # for logging consistency
41
+
42
+ start, end = 0, 0
43
+ prompt_tokens = 0
44
+ input_ids_ = input_ids if prompt is None else inputs.input_ids
45
+ # attention_mask = torch.ones((1, input_ids.numel())) if prompt is None else inputs.attention_mask
46
+ start = time.time()
47
+ generate_ids = model.generate(input_ids_, max_new_tokens=max_new_tokens,eos_token_id=None)
48
+ # generate_ids = model.generate(input_ids_, attention_mask=attention_mask, max_new_tokens=max_new_tokens)
49
+ end = time.time()
50
+ prompt_tokens = input_ids_.shape[1]
51
+ num_tokens_out = generate_ids.shape[1]
52
+ new_tokens_generated = num_tokens_out - prompt_tokens
53
+ generate_time = (end - start)
54
+ # print(generate_time)
55
+ time_per_token = (generate_time/new_tokens_generated)*1e3
56
+ # print(time_per_token)
57
+ logging.critical(f"[PROFILE][AIE] generate: {generate_time} for {num_tokens_out} tokens; prompt-tokens: {prompt_tokens}; time per generated token: {time_per_token}")
58
+
59
+ start = time.time()
60
+ response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
61
+ end = time.time()
62
+ logging.critical(f"[PROFILE][WARMUP] tokenizer decode: {end-start}")
63
+
64
+ print(f"response: {response}")
65
+ logging.critical(f"response: {response}")
66
+
67
+
68
+ def decode_prompts(model, tokenizer, max_new_tokens=30):
69
+ for prompt in prompts:
70
+ logging.critical("*"*40)
71
+ print("*"*40)
72
+ decode_prompt(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
73
+
74
+
75
+ def get_wikitext2(tokenizer, dataset="non-raw", nsamples=128, seqlen=2048):
76
+ """ gptq """
77
+ from datasets import load_dataset
78
+ if dataset == "non-raw":
79
+ traindata = load_dataset('wikitext', 'wikitext-2-v1', split='train')
80
+ testdata = load_dataset('wikitext', 'wikitext-2-v1', split='test')
81
+ elif dataset == "raw":
82
+ traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
83
+ testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
84
+ else:
85
+ raise ValueError(
86
+ "You are using an unsupported dataset, only support wikitext2-raw-v1 and wikitext2-v1."
87
+ "Using wikitext2-raw-v1 with --dataset=raw and wikitext2-v1 with --dataset=non-raw."
88
+ )
89
+
90
+ trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
91
+ testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
92
+ dataloader = []
93
+ for _ in range(nsamples):
94
+ i = random.randint(0, testenc.input_ids.shape[1] - seqlen - 1)
95
+ j = i + seqlen
96
+ inp = testenc.input_ids[:, i:j]
97
+ tar = inp.clone()
98
+ tar[:, :-1] = -100
99
+ dataloader.append((inp, tar))
100
+ return dataloader, testenc
101
+
102
+
103
+ def perplexity(model, tokenizer, dataset, framework="pytorch"):
104
+ random.seed(0)
105
+ np.random.seed(0)
106
+ torch.random.manual_seed(0)
107
+ print(f"Calculating Perplexity on wikitext2 test set ...")
108
+ model = model#.cuda()
109
+ dataloader, testenc = get_wikitext2(tokenizer, dataset=dataset)
110
+
111
+ model.seqlen = 2048
112
+ test_enc = testenc.input_ids
113
+ nsamples = 2 #test_enc.numel() // model.seqlen
114
+ if framework == "pytorch":
115
+ dtype = next(iter(model.parameters())).dtype
116
+
117
+ loss = torch.nn.CrossEntropyLoss()
118
+ nlls = []
119
+
120
+ with torch.no_grad():
121
+ attention_mask = torch.ones((1, test_enc.numel()))#.cuda()
122
+ for i in range(nsamples):
123
+ batch = test_enc[:, (i * model.seqlen):((i + 1) * model.seqlen)]#.cuda()
124
+ if framework == "pytorch":
125
+ out = model(
126
+ batch,
127
+ attention_mask=attention_mask[:, (i * model.seqlen):((i + 1) * model.seqlen)].reshape((1, -1))
128
+ )
129
+ else :
130
+ out = model(
131
+ batch,
132
+ attention_mask=batch.new_ones(batch.shape)
133
+ )
134
+ shift_labels = test_enc[
135
+ :, (i * model.seqlen):((i + 1) * model.seqlen)
136
+ ][:, 1:]#.cuda()
137
+ loss_fct = torch.nn.CrossEntropyLoss()
138
+ loss = loss_fct(out.logits[0][:-1, :], shift_labels.view(-1))
139
+ neg_log_likelihood = loss.float() * model.seqlen
140
+ nlls.append(neg_log_likelihood)
141
+
142
+ ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
143
+ print('Perplexity:', ppl.item())
144
+
modeling_phi3.py ADDED
@@ -0,0 +1,1622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi-3 model."""
17
+
18
+ import inspect
19
+ import math
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
+ is_flash_attn_greater_or_equal_2_10,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from .configuration_phi3 import Phi3Config
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
54
+ # if is_flash_attn_2_available():
55
+ _flash_supports_window_size = False
56
+ try:
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
+
60
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
61
+ except ImportError as error:
62
+ logger.warning(
63
+ f"`flash-attention` package not found, consider installing for better performance: {error}."
64
+ )
65
+ if not _flash_supports_window_size:
66
+ logger.warning(
67
+ "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
68
+ )
69
+
70
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
71
+ _CONFIG_FOR_DOC = "Phi3Config"
72
+
73
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
74
+ "microsoft/Phi-3-mini-4k-instruct",
75
+ "microsoft/Phi-3-mini-128k-instruct",
76
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
77
+ ]
78
+
79
+ import time
80
+ import logging
81
+
82
+
83
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
84
+ class Phi3RMSNorm(nn.Module):
85
+ def __init__(self, hidden_size, eps=1e-6):
86
+ """
87
+ Phi3RMSNorm is equivalent to T5LayerNorm
88
+ """
89
+ super().__init__()
90
+ self.weight = nn.Parameter(torch.ones(hidden_size))
91
+ self.variance_epsilon = eps
92
+
93
+ def forward(self, hidden_states):
94
+ input_dtype = hidden_states.dtype
95
+ hidden_states = hidden_states.to(torch.float32)
96
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
97
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
98
+ return self.weight * hidden_states.to(input_dtype)
99
+
100
+
101
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
102
+ def _get_unpad_data(attention_mask):
103
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
104
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
105
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
106
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
107
+ return (
108
+ indices,
109
+ cu_seqlens,
110
+ max_seqlen_in_batch,
111
+ )
112
+
113
+
114
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
115
+ class Phi3RotaryEmbedding(nn.Module):
116
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
+ super().__init__()
118
+
119
+ self.dim = dim
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.base = base
122
+ self.register_buffer("inv_freq", None, persistent=False)
123
+
124
+ @torch.no_grad()
125
+ def forward(self, x, position_ids, seq_len=None):
126
+ # x: [bs, num_attention_heads, seq_len, head_size]
127
+ if self.inv_freq is None:
128
+ self.inv_freq = 1.0 / (
129
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
130
+ )
131
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
132
+ position_ids_expanded = position_ids[:, None, :].float()
133
+ # Force float32 since bfloat16 loses precision on long contexts
134
+ # See https://github.com/huggingface/transformers/pull/29285
135
+ device_type = x.device.type
136
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
137
+ with torch.autocast(device_type=device_type, enabled=False):
138
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
139
+ emb = torch.cat((freqs, freqs), dim=-1)
140
+ cos = emb.cos()
141
+ sin = emb.sin()
142
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
143
+
144
+
145
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
146
+ def __init__(self, dim, config, device=None):
147
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
148
+
149
+ self.short_factor = config.rope_scaling["short_factor"]
150
+ self.long_factor = config.rope_scaling["long_factor"]
151
+ self.original_max_position_embeddings = config.original_max_position_embeddings
152
+
153
+ @torch.no_grad()
154
+ def forward(self, x, position_ids, seq_len=None):
155
+ seq_len = torch.max(position_ids) + 1
156
+ if seq_len > self.original_max_position_embeddings:
157
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
158
+ else:
159
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
160
+
161
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
162
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
163
+
164
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
165
+ position_ids_expanded = position_ids[:, None, :].float()
166
+
167
+ # Force float32 since bfloat16 loses precision on long contexts
168
+ # See https://github.com/huggingface/transformers/pull/29285
169
+ device_type = x.device.type
170
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
171
+ with torch.autocast(device_type=device_type, enabled=False):
172
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
173
+ emb = torch.cat((freqs, freqs), dim=-1)
174
+
175
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
176
+ if scale <= 1.0:
177
+ scaling_factor = 1.0
178
+ else:
179
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
180
+
181
+ cos = emb.cos() * scaling_factor
182
+ sin = emb.sin() * scaling_factor
183
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
184
+
185
+
186
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
187
+ def __init__(self, dim, config, device=None):
188
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
189
+
190
+ self.short_factor = config.rope_scaling["short_factor"]
191
+ self.long_factor = config.rope_scaling["long_factor"]
192
+ self.original_max_position_embeddings = config.original_max_position_embeddings
193
+
194
+ @torch.no_grad()
195
+ def forward(self, x, position_ids, seq_len=None):
196
+ seq_len = torch.max(position_ids) + 1
197
+ if seq_len > self.original_max_position_embeddings:
198
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
199
+ else:
200
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
201
+
202
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
203
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
204
+
205
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
206
+ position_ids_expanded = position_ids[:, None, :].float()
207
+
208
+ # Force float32 since bfloat16 loses precision on long contexts
209
+ # See https://github.com/huggingface/transformers/pull/29285
210
+ device_type = x.device.type
211
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
212
+ with torch.autocast(device_type=device_type, enabled=False):
213
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
214
+ emb = torch.cat((freqs, freqs), dim=-1)
215
+
216
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
217
+ if scale <= 1.0:
218
+ scaling_factor = 1.0
219
+ else:
220
+ scaling_factor = 0.1 * math.log(scale) + 1.0
221
+
222
+ cos = emb.cos() * scaling_factor
223
+ sin = emb.sin() * scaling_factor
224
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
225
+
226
+
227
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
228
+ def rotate_half(x):
229
+ """Rotates half the hidden dims of the input."""
230
+ x1 = x[..., : x.shape[-1] // 2]
231
+ x2 = x[..., x.shape[-1] // 2 :]
232
+ return torch.cat((-x2, x1), dim=-1)
233
+
234
+
235
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
236
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
237
+ """Applies Rotary Position Embedding to the query and key tensors.
238
+
239
+ Args:
240
+ q (`torch.Tensor`): The query tensor.
241
+ k (`torch.Tensor`): The key tensor.
242
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
243
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
244
+ position_ids (`torch.Tensor`, *optional*):
245
+ Deprecated and unused.
246
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
247
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
248
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
249
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
250
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
251
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
252
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
253
+ Returns:
254
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
255
+ """
256
+ cos = cos.unsqueeze(unsqueeze_dim)
257
+ sin = sin.unsqueeze(unsqueeze_dim)
258
+ q_embed = (q * cos) + (rotate_half(q) * sin)
259
+ k_embed = (k * cos) + (rotate_half(k) * sin)
260
+ return q_embed, k_embed
261
+
262
+
263
+ class Phi3MLP(nn.Module):
264
+ def __init__(self, config):
265
+ super().__init__()
266
+
267
+ self.config = config
268
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
269
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
270
+
271
+ self.activation_fn = ACT2FN[config.hidden_act]
272
+
273
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
274
+ up_states = self.gate_up_proj(hidden_states)
275
+
276
+ gate, up_states = up_states.chunk(2, dim=-1)
277
+ up_states = up_states * self.activation_fn(gate)
278
+
279
+ return self.down_proj(up_states)
280
+
281
+
282
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
283
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
284
+ """
285
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
286
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
287
+ """
288
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
289
+ if n_rep == 1:
290
+ return hidden_states
291
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
292
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
293
+
294
+
295
+ class Phi3Attention(nn.Module):
296
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
297
+
298
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
299
+ super().__init__()
300
+ self.config = config
301
+ self.layer_idx = layer_idx
302
+ if layer_idx is None:
303
+ logger.warning_once(
304
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
305
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
306
+ "when creating this class."
307
+ )
308
+
309
+ self.attention_dropout = config.attention_dropout
310
+ self.hidden_size = config.hidden_size
311
+ self.num_heads = config.num_attention_heads
312
+ self.head_dim = self.hidden_size // self.num_heads
313
+ self.num_key_value_heads = config.num_key_value_heads
314
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
315
+ self.max_position_embeddings = config.max_position_embeddings
316
+ self.original_max_position_embeddings = config.original_max_position_embeddings
317
+ self.rope_theta = config.rope_theta
318
+ self.rope_scaling = config.rope_scaling
319
+ self.is_causal = True
320
+
321
+ if (self.head_dim * self.num_heads) != self.hidden_size:
322
+ raise ValueError(
323
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
324
+ f" and `num_heads`: {self.num_heads})."
325
+ )
326
+
327
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
328
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
329
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
330
+ self._init_rope()
331
+
332
+ def _init_rope(self):
333
+ if self.rope_scaling is None:
334
+ self.rotary_emb = Phi3RotaryEmbedding(
335
+ self.head_dim,
336
+ max_position_embeddings=self.max_position_embeddings,
337
+ base=self.rope_theta,
338
+ )
339
+ else:
340
+ scaling_type = self.config.rope_scaling["type"]
341
+ if scaling_type == "su":
342
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
343
+ elif scaling_type == "yarn":
344
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
345
+ else:
346
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
347
+
348
+ def forward(
349
+ self,
350
+ hidden_states: torch.Tensor,
351
+ attention_mask: Optional[torch.Tensor] = None,
352
+ position_ids: Optional[torch.LongTensor] = None,
353
+ past_key_value: Optional[Cache] = None,
354
+ output_attentions: bool = False,
355
+ use_cache: bool = False,
356
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
357
+ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
358
+
359
+ bsz, q_len, _ = hidden_states.size()
360
+
361
+ qkv = self.qkv_proj(hidden_states)
362
+ query_pos = self.num_heads * self.head_dim
363
+ query_states = qkv[..., :query_pos]
364
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
365
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
366
+
367
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
368
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
369
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
370
+
371
+ kv_seq_len = key_states.shape[-2]
372
+ if past_key_value is not None:
373
+ if self.layer_idx is None:
374
+ raise ValueError(
375
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
376
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
377
+ "with a layer index."
378
+ )
379
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
380
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
381
+
382
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
383
+
384
+ if past_key_value is not None:
385
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
386
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
387
+
388
+ # repeat k/v heads if n_kv_heads < n_heads
389
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
390
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
391
+
392
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
393
+
394
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
395
+ raise ValueError(
396
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
397
+ f" {attn_weights.size()}"
398
+ )
399
+
400
+ if attention_mask is not None:
401
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
402
+ raise ValueError(
403
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
404
+ )
405
+ attn_weights = attn_weights + attention_mask
406
+
407
+ # upcast attention to fp32
408
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
409
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
410
+
411
+ attn_output = torch.matmul(attn_weights, value_states)
412
+
413
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
414
+ raise ValueError(
415
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
416
+ f" {attn_output.size()}"
417
+ )
418
+
419
+ attn_output = attn_output.transpose(1, 2).contiguous()
420
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
421
+
422
+ attn_output = self.o_proj(attn_output)
423
+
424
+ if not output_attentions:
425
+ attn_weights = None
426
+
427
+ return attn_output, attn_weights, past_key_value
428
+
429
+
430
+ class Phi3FlashAttention2(Phi3Attention):
431
+ """
432
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
433
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
434
+ flash attention and deal with padding tokens in case the input contains any of them.
435
+ """
436
+
437
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
438
+ def __init__(self, *args, **kwargs):
439
+ super().__init__(*args, **kwargs)
440
+
441
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
442
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
443
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
444
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
445
+
446
+ def forward(
447
+ self,
448
+ hidden_states: torch.Tensor,
449
+ attention_mask: Optional[torch.LongTensor] = None,
450
+ position_ids: Optional[torch.LongTensor] = None,
451
+ past_key_value: Optional[Cache] = None,
452
+ output_attentions: bool = False,
453
+ use_cache: bool = False,
454
+ **kwargs,
455
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
456
+ # Phi3FlashAttention2 attention does not support output_attentions
457
+
458
+ if not _flash_supports_window_size:
459
+ logger.warning_once(
460
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
461
+ )
462
+ raise ValueError("The current flash attention version does not support sliding window attention.")
463
+
464
+ output_attentions = False
465
+
466
+ if "padding_mask" in kwargs:
467
+ warnings.warn(
468
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
469
+ )
470
+
471
+ # overwrite attention_mask with padding_mask
472
+ attention_mask = kwargs.pop("padding_mask")
473
+
474
+ bsz, q_len, _ = hidden_states.size()
475
+
476
+ qkv = self.qkv_proj(hidden_states)
477
+ query_pos = self.num_heads * self.head_dim
478
+ query_states = qkv[..., :query_pos]
479
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
480
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
481
+
482
+ # Flash attention requires the input to have the shape
483
+ # batch_size x seq_length x head_dim x hidden_dim
484
+ # therefore we just need to keep the original shape
485
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
486
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
487
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
488
+
489
+ kv_seq_len = key_states.shape[-2]
490
+ if past_key_value is not None:
491
+ if self.layer_idx is None:
492
+ raise ValueError(
493
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
494
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
495
+ "with a layer index."
496
+ )
497
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
498
+
499
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
500
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
501
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
502
+
503
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
504
+
505
+ use_sliding_windows = (
506
+ _flash_supports_window_size
507
+ and getattr(self.config, "sliding_window", None) is not None
508
+ and kv_seq_len > self.config.sliding_window
509
+ )
510
+
511
+ if past_key_value is not None:
512
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
513
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
514
+ if (
515
+ getattr(self.config, "sliding_window", None) is not None
516
+ and kv_seq_len > self.config.sliding_window
517
+ and cache_has_contents
518
+ ):
519
+ slicing_tokens = 1 - self.config.sliding_window
520
+
521
+ past_key = past_key_value[self.layer_idx][0]
522
+ past_value = past_key_value[self.layer_idx][1]
523
+
524
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
525
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
526
+
527
+ if past_key.shape[-2] != self.config.sliding_window - 1:
528
+ raise ValueError(
529
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
530
+ f" {past_key.shape}"
531
+ )
532
+
533
+ if attention_mask is not None:
534
+ attention_mask = attention_mask[:, slicing_tokens:]
535
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
536
+
537
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
538
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
539
+
540
+ # repeat k/v heads if n_kv_heads < n_heads
541
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
542
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
543
+
544
+ attn_dropout = self.attention_dropout if self.training else 0.0
545
+
546
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
547
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
548
+ # cast them back in the correct dtype just to be sure everything works as expected.
549
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
550
+ # in fp32.
551
+
552
+ if query_states.dtype == torch.float32:
553
+ if torch.is_autocast_enabled():
554
+ target_dtype = torch.get_autocast_gpu_dtype()
555
+ # Handle the case where the model is quantized
556
+ elif hasattr(self.config, "_pre_quantization_dtype"):
557
+ target_dtype = self.config._pre_quantization_dtype
558
+ else:
559
+ target_dtype = self.qkv_proj.weight.dtype
560
+
561
+ logger.warning_once(
562
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
563
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
564
+ f" {target_dtype}."
565
+ )
566
+
567
+ query_states = query_states.to(target_dtype)
568
+ key_states = key_states.to(target_dtype)
569
+ value_states = value_states.to(target_dtype)
570
+
571
+ # Reashape to the expected shape for Flash Attention
572
+ query_states = query_states.transpose(1, 2)
573
+ key_states = key_states.transpose(1, 2)
574
+ value_states = value_states.transpose(1, 2)
575
+
576
+ attn_output = self._flash_attention_forward(
577
+ query_states,
578
+ key_states,
579
+ value_states,
580
+ attention_mask,
581
+ q_len,
582
+ dropout=attn_dropout,
583
+ use_sliding_windows=use_sliding_windows,
584
+ )
585
+
586
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
587
+ attn_output = self.o_proj(attn_output)
588
+
589
+ if not output_attentions:
590
+ attn_weights = None
591
+
592
+ return attn_output, attn_weights, past_key_value
593
+
594
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
595
+ def _flash_attention_forward(
596
+ self,
597
+ query_states,
598
+ key_states,
599
+ value_states,
600
+ attention_mask,
601
+ query_length,
602
+ dropout=0.0,
603
+ softmax_scale=None,
604
+ use_sliding_windows=False,
605
+ ):
606
+ """
607
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
608
+ first unpad the input, then computes the attention scores and pad the final attention scores.
609
+
610
+ Args:
611
+ query_states (`torch.Tensor`):
612
+ Input query states to be passed to Flash Attention API
613
+ key_states (`torch.Tensor`):
614
+ Input key states to be passed to Flash Attention API
615
+ value_states (`torch.Tensor`):
616
+ Input value states to be passed to Flash Attention API
617
+ attention_mask (`torch.Tensor`):
618
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
619
+ position of padding tokens and 1 for the position of non-padding tokens.
620
+ dropout (`float`):
621
+ Attention dropout
622
+ softmax_scale (`float`, *optional*):
623
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
624
+ use_sliding_windows (`bool`, *optional*):
625
+ Whether to activate sliding window attention.
626
+ """
627
+ if not self._flash_attn_uses_top_left_mask:
628
+ causal = self.is_causal
629
+ else:
630
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
631
+ causal = self.is_causal and query_length != 1
632
+
633
+ # Contains at least one padding token in the sequence
634
+ if attention_mask is not None:
635
+ batch_size = query_states.shape[0]
636
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
637
+ query_states, key_states, value_states, attention_mask, query_length
638
+ )
639
+
640
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
641
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
642
+
643
+ if not use_sliding_windows:
644
+ attn_output_unpad = flash_attn_varlen_func(
645
+ query_states,
646
+ key_states,
647
+ value_states,
648
+ cu_seqlens_q=cu_seqlens_q,
649
+ cu_seqlens_k=cu_seqlens_k,
650
+ max_seqlen_q=max_seqlen_in_batch_q,
651
+ max_seqlen_k=max_seqlen_in_batch_k,
652
+ dropout_p=dropout,
653
+ softmax_scale=softmax_scale,
654
+ causal=causal,
655
+ )
656
+ else:
657
+ attn_output_unpad = flash_attn_varlen_func(
658
+ query_states,
659
+ key_states,
660
+ value_states,
661
+ cu_seqlens_q=cu_seqlens_q,
662
+ cu_seqlens_k=cu_seqlens_k,
663
+ max_seqlen_q=max_seqlen_in_batch_q,
664
+ max_seqlen_k=max_seqlen_in_batch_k,
665
+ dropout_p=dropout,
666
+ softmax_scale=softmax_scale,
667
+ causal=causal,
668
+ window_size=(self.config.sliding_window, self.config.sliding_window),
669
+ )
670
+
671
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
672
+ else:
673
+ if not use_sliding_windows:
674
+ attn_output = flash_attn_func(
675
+ query_states,
676
+ key_states,
677
+ value_states,
678
+ dropout,
679
+ softmax_scale=softmax_scale,
680
+ causal=causal,
681
+ )
682
+ else:
683
+ attn_output = flash_attn_func(
684
+ query_states,
685
+ key_states,
686
+ value_states,
687
+ dropout,
688
+ softmax_scale=softmax_scale,
689
+ causal=causal,
690
+ window_size=(self.config.sliding_window, self.config.sliding_window),
691
+ )
692
+
693
+ return attn_output
694
+
695
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
696
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
697
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
698
+
699
+ # On the first iteration we need to properly re-create the padding mask
700
+ # by slicing it on the proper place
701
+ if kv_seq_len != attention_mask.shape[-1]:
702
+ attention_mask_num_tokens = attention_mask.shape[-1]
703
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
704
+
705
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
706
+
707
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
708
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
709
+
710
+ if query_length == kv_seq_len:
711
+ query_layer = index_first_axis(
712
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
713
+ )
714
+ cu_seqlens_q = cu_seqlens_k
715
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
716
+ indices_q = indices_k
717
+ elif query_length == 1:
718
+ max_seqlen_in_batch_q = 1
719
+ cu_seqlens_q = torch.arange(
720
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
721
+ ) # There is a memcpy here, that is very bad.
722
+ indices_q = cu_seqlens_q[:-1]
723
+ query_layer = query_layer.squeeze(1)
724
+ else:
725
+ # The -q_len: slice assumes left padding.
726
+ attention_mask = attention_mask[:, -query_length:]
727
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
728
+
729
+ return (
730
+ query_layer,
731
+ key_layer,
732
+ value_layer,
733
+ indices_q,
734
+ (cu_seqlens_q, cu_seqlens_k),
735
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
736
+ )
737
+
738
+
739
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
740
+ # TODO @Arthur no longer copied from LLama after static cache
741
+ class Phi3SdpaAttention(Phi3Attention):
742
+ """
743
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
744
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
745
+ SDPA API.
746
+ """
747
+
748
+ # Adapted from Phi3Attention.forward
749
+ def forward(
750
+ self,
751
+ hidden_states: torch.Tensor,
752
+ attention_mask: Optional[torch.Tensor] = None,
753
+ position_ids: Optional[torch.LongTensor] = None,
754
+ past_key_value: Optional[Cache] = None,
755
+ output_attentions: bool = False,
756
+ use_cache: bool = False,
757
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
758
+ if output_attentions:
759
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
760
+ logger.warning_once(
761
+ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
762
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
763
+ )
764
+ return super().forward(
765
+ hidden_states=hidden_states,
766
+ attention_mask=attention_mask,
767
+ position_ids=position_ids,
768
+ past_key_value=past_key_value,
769
+ output_attentions=output_attentions,
770
+ use_cache=use_cache,
771
+ )
772
+
773
+ bsz, q_len, _ = hidden_states.size()
774
+
775
+ qkv = self.qkv_proj(hidden_states)
776
+ query_pos = self.num_heads * self.head_dim
777
+ query_states = qkv[..., :query_pos]
778
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
779
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
780
+
781
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
782
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
783
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
784
+
785
+ kv_seq_len = key_states.shape[-2]
786
+ if past_key_value is not None:
787
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
788
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
789
+
790
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
791
+
792
+ if past_key_value is not None:
793
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
794
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
795
+
796
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
797
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
798
+
799
+ if attention_mask is not None:
800
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
801
+ raise ValueError(
802
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
803
+ )
804
+
805
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
806
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
807
+ if query_states.device.type == "cuda" and attention_mask is not None:
808
+ query_states = query_states.contiguous()
809
+ key_states = key_states.contiguous()
810
+ value_states = value_states.contiguous()
811
+
812
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
813
+ query_states,
814
+ key_states,
815
+ value_states,
816
+ attn_mask=attention_mask,
817
+ dropout_p=self.attention_dropout if self.training else 0.0,
818
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
819
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
820
+ )
821
+
822
+ attn_output = attn_output.transpose(1, 2).contiguous()
823
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
824
+
825
+ attn_output = self.o_proj(attn_output)
826
+
827
+ return attn_output, None, past_key_value
828
+
829
+
830
+ PHI3_ATTENTION_CLASSES = {
831
+ "eager": Phi3Attention,
832
+ "flash_attention_2": Phi3FlashAttention2,
833
+ "sdpa": Phi3SdpaAttention,
834
+ }
835
+
836
+
837
+ class Phi3DecoderLayer(nn.Module):
838
+ def __init__(self, config: Phi3Config, layer_idx: int):
839
+ super().__init__()
840
+
841
+ self.config = config
842
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
843
+
844
+ self.mlp = Phi3MLP(config)
845
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
846
+
847
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
848
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
849
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
850
+
851
+ def forward(
852
+ self,
853
+ hidden_states: torch.Tensor,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_ids: Optional[torch.LongTensor] = None,
856
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
857
+ output_attentions: Optional[bool] = False,
858
+ use_cache: Optional[bool] = False,
859
+ **kwargs,
860
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
861
+ if "padding_mask" in kwargs:
862
+ warnings.warn(
863
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
864
+ )
865
+ """
866
+ Args:
867
+ hidden_states (`torch.FloatTensor`):
868
+ input to the layer of shape `(batch, seq_len, embed_dim)`
869
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
870
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
871
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
872
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
873
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
874
+ output_attentions (`bool`, *optional*):
875
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
876
+ returned tensors for more detail.
877
+ use_cache (`bool`, *optional*):
878
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
879
+ (see `past_key_values`).
880
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
881
+ """
882
+
883
+ residual = hidden_states
884
+
885
+ hidden_states = self.input_layernorm(hidden_states)
886
+
887
+ # Self Attention
888
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
889
+ hidden_states=hidden_states,
890
+ attention_mask=attention_mask,
891
+ position_ids=position_ids,
892
+ past_key_value=past_key_value,
893
+ output_attentions=output_attentions,
894
+ use_cache=use_cache,
895
+ )
896
+
897
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
898
+
899
+ residual = hidden_states
900
+ hidden_states = self.post_attention_layernorm(hidden_states)
901
+ hidden_states = self.mlp(hidden_states)
902
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
903
+
904
+ outputs = (hidden_states,)
905
+
906
+ if output_attentions:
907
+ outputs += (self_attn_weights,)
908
+
909
+ if use_cache:
910
+ outputs += (present_key_value,)
911
+
912
+ return outputs
913
+
914
+
915
+ PHI3_START_DOCSTRING = r"""
916
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
917
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
918
+ etc.)
919
+
920
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
921
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
922
+ and behavior.
923
+
924
+ Parameters:
925
+ config ([`Phi3Config`]):
926
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
927
+ load the weights associated with the model, only the configuration. Check out the
928
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
929
+ """
930
+
931
+
932
+ @add_start_docstrings(
933
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
934
+ PHI3_START_DOCSTRING,
935
+ )
936
+ class Phi3PreTrainedModel(PreTrainedModel):
937
+ config_class = Phi3Config
938
+ base_model_prefix = "model"
939
+ supports_gradient_checkpointing = True
940
+ _no_split_modules = ["Phi3DecoderLayer"]
941
+ _skip_keys_device_placement = "past_key_values"
942
+ _supports_flash_attn_2 = True
943
+ _supports_sdpa = False
944
+ _supports_cache_class = True
945
+
946
+ _version = "0.0.5"
947
+
948
+ def _init_weights(self, module):
949
+ std = self.config.initializer_range
950
+ if isinstance(module, nn.Linear):
951
+ module.weight.data.normal_(mean=0.0, std=std)
952
+ if module.bias is not None:
953
+ module.bias.data.zero_()
954
+ elif isinstance(module, nn.Embedding):
955
+ module.weight.data.normal_(mean=0.0, std=std)
956
+ if module.padding_idx is not None:
957
+ module.weight.data[module.padding_idx].zero_()
958
+
959
+
960
+ PHI3_INPUTS_DOCSTRING = r"""
961
+ Args:
962
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
963
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
964
+ it.
965
+
966
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
967
+ [`PreTrainedTokenizer.__call__`] for details.
968
+
969
+ [What are input IDs?](../glossary#input-ids)
970
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
971
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
972
+
973
+ - 1 for tokens that are **not masked**,
974
+ - 0 for tokens that are **masked**.
975
+
976
+ [What are attention masks?](../glossary#attention-mask)
977
+
978
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
979
+ [`PreTrainedTokenizer.__call__`] for details.
980
+
981
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
982
+ `past_key_values`).
983
+
984
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
985
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
986
+ information on the default strategy.
987
+
988
+ - 1 indicates the head is **not masked**,
989
+ - 0 indicates the head is **masked**.
990
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
991
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
992
+ config.n_positions - 1]`.
993
+
994
+ [What are position IDs?](../glossary#position-ids)
995
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
996
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
997
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
998
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
999
+
1000
+ Two formats are allowed:
1001
+ - a [`~cache_utils.Cache`] instance;
1002
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1003
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1004
+ cache format.
1005
+
1006
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1007
+ legacy cache format will be returned.
1008
+
1009
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1010
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1011
+ of shape `(batch_size, sequence_length)`.
1012
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1013
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1014
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1015
+ model's internal embedding lookup matrix.
1016
+ use_cache (`bool`, *optional*):
1017
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1018
+ `past_key_values`).
1019
+ output_attentions (`bool`, *optional*):
1020
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1021
+ tensors for more detail.
1022
+ output_hidden_states (`bool`, *optional*):
1023
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1024
+ more detail.
1025
+ return_dict (`bool`, *optional*):
1026
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1027
+ """
1028
+
1029
+
1030
+ @add_start_docstrings(
1031
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
1032
+ PHI3_START_DOCSTRING,
1033
+ )
1034
+ class Phi3Model(Phi3PreTrainedModel):
1035
+ """
1036
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1037
+
1038
+ Args:
1039
+ config: Phi3Config
1040
+ """
1041
+
1042
+ def __init__(self, config: Phi3Config):
1043
+ super().__init__(config)
1044
+ self.padding_idx = config.pad_token_id
1045
+ self.vocab_size = config.vocab_size
1046
+
1047
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1048
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1049
+ self.layers = nn.ModuleList(
1050
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1051
+ )
1052
+ self._attn_implementation = config._attn_implementation
1053
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1054
+
1055
+ self.gradient_checkpointing = False
1056
+ # Initialize weights and apply final processing
1057
+ self.post_init()
1058
+
1059
+ def get_input_embeddings(self):
1060
+ return self.embed_tokens
1061
+
1062
+ def set_input_embeddings(self, value):
1063
+ self.embed_tokens = value
1064
+
1065
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1066
+ def forward(
1067
+ self,
1068
+ input_ids: torch.LongTensor = None,
1069
+ attention_mask: Optional[torch.Tensor] = None,
1070
+ position_ids: Optional[torch.LongTensor] = None,
1071
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1072
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1073
+ use_cache: Optional[bool] = None,
1074
+ output_attentions: Optional[bool] = None,
1075
+ output_hidden_states: Optional[bool] = None,
1076
+ return_dict: Optional[bool] = None,
1077
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1078
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1079
+ output_hidden_states = (
1080
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1081
+ )
1082
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1083
+
1084
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1085
+
1086
+ # retrieve input_ids and inputs_embeds
1087
+ if input_ids is not None and inputs_embeds is not None:
1088
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1089
+ elif input_ids is not None:
1090
+ batch_size, seq_length = input_ids.shape[:2]
1091
+ elif inputs_embeds is not None:
1092
+ batch_size, seq_length = inputs_embeds.shape[:2]
1093
+ else:
1094
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1095
+
1096
+ past_key_values_length = 0
1097
+
1098
+ if self.gradient_checkpointing and self.training:
1099
+ if use_cache:
1100
+ logger.warning_once(
1101
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1102
+ )
1103
+ use_cache = False
1104
+
1105
+ if use_cache:
1106
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1107
+ if use_legacy_cache:
1108
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1109
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1110
+
1111
+ if position_ids is None:
1112
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1113
+ position_ids = torch.arange(
1114
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1115
+ )
1116
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1117
+ else:
1118
+ position_ids = position_ids.view(-1, seq_length).long()
1119
+
1120
+ if inputs_embeds is None:
1121
+ inputs_embeds = self.embed_tokens(input_ids)
1122
+
1123
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1124
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1125
+ if is_padding_right:
1126
+ raise ValueError(
1127
+ "You are attempting to perform batched generation with padding_side='right'"
1128
+ " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
1129
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1130
+ )
1131
+
1132
+ if self._attn_implementation == "flash_attention_2":
1133
+ # 2d mask is passed through the layers
1134
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1135
+ else:
1136
+ # 4d mask is passed through the layers
1137
+ attention_mask = _prepare_4d_causal_attention_mask(
1138
+ attention_mask,
1139
+ (batch_size, seq_length),
1140
+ inputs_embeds,
1141
+ past_key_values_length,
1142
+ sliding_window=self.config.sliding_window,
1143
+ )
1144
+
1145
+ hidden_states = inputs_embeds
1146
+
1147
+ # decoder layers
1148
+ all_hidden_states = () if output_hidden_states else None
1149
+ all_self_attns = () if output_attentions else None
1150
+ next_decoder_cache = None
1151
+
1152
+ for decoder_layer in self.layers:
1153
+ if output_hidden_states:
1154
+ all_hidden_states += (hidden_states,)
1155
+
1156
+ if self.gradient_checkpointing and self.training:
1157
+ layer_outputs = self._gradient_checkpointing_func(
1158
+ decoder_layer.__call__,
1159
+ hidden_states,
1160
+ attention_mask,
1161
+ position_ids,
1162
+ past_key_values,
1163
+ output_attentions,
1164
+ use_cache,
1165
+ )
1166
+ else:
1167
+ layer_outputs = decoder_layer(
1168
+ hidden_states,
1169
+ attention_mask=attention_mask,
1170
+ position_ids=position_ids,
1171
+ past_key_value=past_key_values,
1172
+ output_attentions=output_attentions,
1173
+ use_cache=use_cache,
1174
+ )
1175
+
1176
+ hidden_states = layer_outputs[0]
1177
+
1178
+ if use_cache:
1179
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1180
+
1181
+ if output_attentions:
1182
+ all_self_attns += (layer_outputs[1],)
1183
+
1184
+ hidden_states = self.norm(hidden_states)
1185
+
1186
+ # add hidden states from the last decoder layer
1187
+ if output_hidden_states:
1188
+ all_hidden_states += (hidden_states,)
1189
+
1190
+ next_cache = None
1191
+ if use_cache:
1192
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1193
+ if not return_dict:
1194
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1195
+ return BaseModelOutputWithPast(
1196
+ last_hidden_state=hidden_states,
1197
+ past_key_values=next_cache,
1198
+ hidden_states=all_hidden_states,
1199
+ attentions=all_self_attns,
1200
+ )
1201
+
1202
+
1203
+ class Phi3ForCausalLM(Phi3PreTrainedModel):
1204
+ _tied_weights_keys = ["lm_head.weight"]
1205
+
1206
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1207
+ def __init__(self, config):
1208
+ super().__init__(config)
1209
+ self.model = Phi3Model(config)
1210
+ self.vocab_size = config.vocab_size
1211
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1212
+
1213
+ # Initialize weights and apply final processing
1214
+ self.post_init()
1215
+
1216
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1217
+ def get_input_embeddings(self):
1218
+ return self.model.embed_tokens
1219
+
1220
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1221
+ def set_input_embeddings(self, value):
1222
+ self.model.embed_tokens = value
1223
+
1224
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1225
+ def get_output_embeddings(self):
1226
+ return self.lm_head
1227
+
1228
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1229
+ def set_output_embeddings(self, new_embeddings):
1230
+ self.lm_head = new_embeddings
1231
+
1232
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1233
+ def set_decoder(self, decoder):
1234
+ self.model = decoder
1235
+
1236
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1237
+ def get_decoder(self):
1238
+ return self.model
1239
+
1240
+ # Ignore copy
1241
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1242
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1243
+ def forward(
1244
+ self,
1245
+ input_ids: torch.LongTensor = None,
1246
+ attention_mask: Optional[torch.Tensor] = None,
1247
+ position_ids: Optional[torch.LongTensor] = None,
1248
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1249
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1250
+ labels: Optional[torch.LongTensor] = None,
1251
+ use_cache: Optional[bool] = None,
1252
+ output_attentions: Optional[bool] = None,
1253
+ output_hidden_states: Optional[bool] = None,
1254
+ return_dict: Optional[bool] = None,
1255
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1256
+ r"""
1257
+ Args:
1258
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1259
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1260
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1261
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1262
+
1263
+ Returns:
1264
+
1265
+ Example:
1266
+
1267
+ ```python
1268
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1269
+
1270
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1271
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1272
+
1273
+ >>> prompt = "This is an example script ."
1274
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1275
+
1276
+ >>> # Generate
1277
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1278
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1279
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1280
+ ```"""
1281
+
1282
+ start = time.time_ns()
1283
+
1284
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1285
+ output_hidden_states = (
1286
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1287
+ )
1288
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1289
+
1290
+ end = time.time_ns()
1291
+ preprocessing_time = end - start
1292
+
1293
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1294
+ start = time.time_ns()
1295
+ outputs = self.model(
1296
+ input_ids=input_ids,
1297
+ attention_mask=attention_mask,
1298
+ position_ids=position_ids,
1299
+ past_key_values=past_key_values,
1300
+ inputs_embeds=inputs_embeds,
1301
+ use_cache=use_cache,
1302
+ output_attentions=output_attentions,
1303
+ output_hidden_states=output_hidden_states,
1304
+ return_dict=return_dict,
1305
+ )
1306
+
1307
+ hidden_states = outputs[0]
1308
+ logits = self.lm_head(hidden_states)
1309
+ logits = logits.float()
1310
+ end = time.time_ns()
1311
+ decoder_time = end - start
1312
+
1313
+ start = time.time_ns()
1314
+ loss = None
1315
+ if labels is not None:
1316
+ # Shift so that tokens < n predict n
1317
+ shift_logits = logits[..., :-1, :].contiguous()
1318
+ shift_labels = labels[..., 1:].contiguous()
1319
+ # Flatten the tokens
1320
+ loss_fct = CrossEntropyLoss()
1321
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1322
+ shift_labels = shift_labels.view(-1)
1323
+ # Enable model parallelism
1324
+ shift_labels = shift_labels.to(shift_logits.device)
1325
+ loss = loss_fct(shift_logits, shift_labels)
1326
+
1327
+ if not return_dict:
1328
+ output = (logits,) + outputs[1:]
1329
+ return (loss,) + output if loss is not None else output
1330
+
1331
+
1332
+ end = time.time_ns()
1333
+ postprocessing_time = end - start
1334
+ logging.critical(f"[PROFILE][Phi3] model_decoder_forward {decoder_time} {preprocessing_time} {postprocessing_time}")
1335
+ return CausalLMOutputWithPast(
1336
+ loss=loss,
1337
+ logits=logits,
1338
+ past_key_values=outputs.past_key_values,
1339
+ hidden_states=outputs.hidden_states,
1340
+ attentions=outputs.attentions,
1341
+ )
1342
+
1343
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1344
+ def prepare_inputs_for_generation(
1345
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1346
+ ):
1347
+ if past_key_values is not None:
1348
+ if isinstance(past_key_values, Cache):
1349
+ cache_length = past_key_values.get_seq_length()
1350
+ past_length = past_key_values.seen_tokens
1351
+ max_cache_length = past_key_values.get_max_length()
1352
+ else:
1353
+ cache_length = past_length = past_key_values[0][0].shape[2]
1354
+ max_cache_length = None
1355
+
1356
+ # Keep only the unprocessed tokens:
1357
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1358
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1359
+ # input)
1360
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1361
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1362
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1363
+ # input_ids based on the past_length.
1364
+ elif past_length < input_ids.shape[1]:
1365
+ input_ids = input_ids[:, past_length:]
1366
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1367
+
1368
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1369
+ if (
1370
+ max_cache_length is not None
1371
+ and attention_mask is not None
1372
+ and cache_length + input_ids.shape[1] > max_cache_length
1373
+ ):
1374
+ attention_mask = attention_mask[:, -max_cache_length:]
1375
+
1376
+ position_ids = kwargs.get("position_ids", None)
1377
+ if attention_mask is not None and position_ids is None:
1378
+ # create position_ids on the fly for batch generation
1379
+ position_ids = attention_mask.long().cumsum(-1) - 1
1380
+ position_ids.masked_fill_(attention_mask == 0, 1)
1381
+ if past_key_values:
1382
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1383
+
1384
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1385
+ if inputs_embeds is not None and past_key_values is None:
1386
+ model_inputs = {"inputs_embeds": inputs_embeds}
1387
+ else:
1388
+ model_inputs = {"input_ids": input_ids}
1389
+
1390
+ model_inputs.update(
1391
+ {
1392
+ "position_ids": position_ids,
1393
+ "past_key_values": past_key_values,
1394
+ "use_cache": kwargs.get("use_cache"),
1395
+ "attention_mask": attention_mask,
1396
+ }
1397
+ )
1398
+ return model_inputs
1399
+
1400
+ @staticmethod
1401
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1402
+ def _reorder_cache(past_key_values, beam_idx):
1403
+ reordered_past = ()
1404
+ for layer_past in past_key_values:
1405
+ reordered_past += (
1406
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1407
+ )
1408
+ return reordered_past
1409
+
1410
+
1411
+ @add_start_docstrings(
1412
+ """
1413
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1414
+
1415
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1416
+ (e.g. GPT-2) do.
1417
+
1418
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1419
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1420
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1421
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1422
+ each row of the batch).
1423
+ """,
1424
+ PHI3_START_DOCSTRING,
1425
+ )
1426
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1427
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1428
+ def __init__(self, config):
1429
+ super().__init__(config)
1430
+ self.num_labels = config.num_labels
1431
+ self.model = Phi3Model(config)
1432
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1433
+
1434
+ # Initialize weights and apply final processing
1435
+ self.post_init()
1436
+
1437
+ def get_input_embeddings(self):
1438
+ return self.model.embed_tokens
1439
+
1440
+ def set_input_embeddings(self, value):
1441
+ self.model.embed_tokens = value
1442
+
1443
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1444
+ def forward(
1445
+ self,
1446
+ input_ids: torch.LongTensor = None,
1447
+ attention_mask: Optional[torch.Tensor] = None,
1448
+ position_ids: Optional[torch.LongTensor] = None,
1449
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1450
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1451
+ labels: Optional[torch.LongTensor] = None,
1452
+ use_cache: Optional[bool] = None,
1453
+ output_attentions: Optional[bool] = None,
1454
+ output_hidden_states: Optional[bool] = None,
1455
+ return_dict: Optional[bool] = None,
1456
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1457
+ r"""
1458
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1459
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1460
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1461
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1462
+ """
1463
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1464
+
1465
+ model_outputs = self.model(
1466
+ input_ids,
1467
+ attention_mask=attention_mask,
1468
+ position_ids=position_ids,
1469
+ past_key_values=past_key_values,
1470
+ inputs_embeds=inputs_embeds,
1471
+ use_cache=use_cache,
1472
+ output_attentions=output_attentions,
1473
+ output_hidden_states=output_hidden_states,
1474
+ return_dict=return_dict,
1475
+ )
1476
+ hidden_states = model_outputs[0]
1477
+ logits = self.score(hidden_states)
1478
+
1479
+ if input_ids is not None:
1480
+ batch_size = input_ids.shape[0]
1481
+ else:
1482
+ batch_size = inputs_embeds.shape[0]
1483
+
1484
+ if self.config.pad_token_id is None and batch_size != 1:
1485
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1486
+ if self.config.pad_token_id is None:
1487
+ sequence_lengths = -1
1488
+ else:
1489
+ if input_ids is not None:
1490
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1491
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1492
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1493
+ sequence_lengths = sequence_lengths.to(logits.device)
1494
+ else:
1495
+ sequence_lengths = -1
1496
+
1497
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1498
+
1499
+ loss = None
1500
+ if labels is not None:
1501
+ labels = labels.to(logits.device)
1502
+ if self.config.problem_type is None:
1503
+ if self.num_labels == 1:
1504
+ self.config.problem_type = "regression"
1505
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1506
+ self.config.problem_type = "single_label_classification"
1507
+ else:
1508
+ self.config.problem_type = "multi_label_classification"
1509
+
1510
+ if self.config.problem_type == "regression":
1511
+ loss_fct = MSELoss()
1512
+ if self.num_labels == 1:
1513
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1514
+ else:
1515
+ loss = loss_fct(pooled_logits, labels)
1516
+ elif self.config.problem_type == "single_label_classification":
1517
+ loss_fct = CrossEntropyLoss()
1518
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1519
+ elif self.config.problem_type == "multi_label_classification":
1520
+ loss_fct = BCEWithLogitsLoss()
1521
+ loss = loss_fct(pooled_logits, labels)
1522
+ if not return_dict:
1523
+ output = (pooled_logits,) + model_outputs[1:]
1524
+ return ((loss,) + output) if loss is not None else output
1525
+
1526
+ return SequenceClassifierOutputWithPast(
1527
+ loss=loss,
1528
+ logits=pooled_logits,
1529
+ past_key_values=model_outputs.past_key_values,
1530
+ hidden_states=model_outputs.hidden_states,
1531
+ attentions=model_outputs.attentions,
1532
+ )
1533
+
1534
+
1535
+ @add_start_docstrings(
1536
+ """
1537
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1538
+ Named-Entity-Recognition (NER) tasks.
1539
+ """,
1540
+ PHI3_START_DOCSTRING,
1541
+ )
1542
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1543
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1544
+ def __init__(self, config: Phi3Config):
1545
+ super().__init__(config)
1546
+ self.num_labels = config.num_labels
1547
+
1548
+ self.model = Phi3Model(config)
1549
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1550
+ classifier_dropout = config.classifier_dropout
1551
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1552
+ classifier_dropout = config.hidden_dropout
1553
+ else:
1554
+ classifier_dropout = 0.1
1555
+ self.dropout = nn.Dropout(classifier_dropout)
1556
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1557
+
1558
+ # Initialize weights and apply final processing
1559
+ self.post_init()
1560
+
1561
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1562
+ @add_code_sample_docstrings(
1563
+ checkpoint=_CHECKPOINT_FOR_DOC,
1564
+ output_type=TokenClassifierOutput,
1565
+ config_class=_CONFIG_FOR_DOC,
1566
+ )
1567
+ def forward(
1568
+ self,
1569
+ input_ids: Optional[torch.LongTensor] = None,
1570
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1571
+ attention_mask: Optional[torch.Tensor] = None,
1572
+ inputs_embeds: Optional[torch.Tensor] = None,
1573
+ labels: Optional[torch.Tensor] = None,
1574
+ use_cache: Optional[bool] = None,
1575
+ output_attentions: Optional[bool] = None,
1576
+ output_hidden_states: Optional[bool] = None,
1577
+ return_dict: Optional[bool] = None,
1578
+ **deprecated_arguments,
1579
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1580
+ r"""
1581
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1582
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1583
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1584
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1585
+ """
1586
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1587
+
1588
+ model_outputs = self.model(
1589
+ input_ids,
1590
+ past_key_values=past_key_values,
1591
+ attention_mask=attention_mask,
1592
+ inputs_embeds=inputs_embeds,
1593
+ use_cache=use_cache,
1594
+ output_attentions=output_attentions,
1595
+ output_hidden_states=output_hidden_states,
1596
+ return_dict=return_dict,
1597
+ )
1598
+
1599
+ hidden_states = model_outputs[0]
1600
+ hidden_states = self.dropout(hidden_states)
1601
+ logits = self.classifier(hidden_states)
1602
+
1603
+ loss = None
1604
+ if labels is not None:
1605
+ # move labels to correct device to enable model parallelism
1606
+ labels = labels.to(logits.device)
1607
+ batch_size, seq_length = labels.shape
1608
+ loss_fct = CrossEntropyLoss()
1609
+ loss = loss_fct(
1610
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1611
+ )
1612
+
1613
+ if not return_dict:
1614
+ output = (logits,) + model_outputs[2:]
1615
+ return ((loss,) + output) if loss is not None else output
1616
+
1617
+ return TokenClassifierOutput(
1618
+ loss=loss,
1619
+ logits=logits,
1620
+ hidden_states=model_outputs.hidden_states,
1621
+ attentions=model_outputs.attentions,
1622
+ )
phi-3-chat-w4-g128_awq.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a5bd72f11e8bf5a739f8718456887195237b16ce45dadfc32cff8cc3a1c1113
3
+ size 57621774
run_awq.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright © 2023 Advanced Micro Devices, Inc. All rights reserved.
3
+ #
4
+
5
+ import torch
6
+ import logging
7
+ import time
8
+ import argparse
9
+ import os
10
+ import psutil
11
+ from transformers import set_seed
12
+ from transformers import LlamaTokenizer,AutoTokenizer
13
+
14
+ import qlinear
15
+ from utils import Utils
16
+ from model_utils import (
17
+ warmup,
18
+ decode_prompt,
19
+ decode_prompts,
20
+ get_wikitext2,
21
+ perplexity,
22
+ )
23
+ from profiler import ProfileAIE
24
+ import gc
25
+
26
+
27
+ from phi3_mini.modeling_phi3 import Phi3ForCausalLM
28
+
29
+ from pre_quant import run_awq, apply_awq
30
+ from quantizer import real_quantize_model_weight
31
+ from qmodule import WQLinear
32
+
33
+ set_seed(123)
34
+
35
+
36
+ def load_model(args):
37
+
38
+ # tokenizer = LlamaTokenizer.from_pretrained("./Phi-3-mini-4k-instruct-AWQ")
39
+ tokenizer = AutoTokenizer.from_pretrained("./phi3_mini")
40
+ if args.awq == "none":
41
+ model = Phi3ForCausalLM.from_pretrained("./phi3_mini", torch_dtype=torch.bfloat16)
42
+
43
+ else:
44
+ # ckpt = "pytorch_phi3_mini_w_bit_{}_awq{}_{}amd.pt".format(args.w_bit, "_fa" if args.flash_attention else "", "lm_" if args.lm_head else "")
45
+ ckpt = "./phi3_mini_awq_4bit_no_flash_attention.pt"
46
+ if args.task == "quantize":
47
+ model = Phi3ForCausalLM.from_pretrained("./phi3_mini", torch_dtype=torch.bfloat16)
48
+ print(model)
49
+
50
+ Utils.print_model_size(model)
51
+
52
+ q_config = {
53
+ "zero_point": True,
54
+ "q_group_size": 128, } # whether to use group quantization
55
+
56
+ if args.awq == 'load':
57
+ print("Loading pre-computed AWQ results from", os.getenv("AWQ_CACHE"))
58
+ awq_results = torch.load( "./phi-3-chat-w4-g128_awq.pt", map_location="cpu")
59
+ apply_awq(model, awq_results)
60
+ print("Quantization config:", q_config)
61
+ real_quantize_model_weight(
62
+ model, w_bit=args.w_bit, q_config=q_config
63
+ )
64
+
65
+ Utils.print_model_size(model)
66
+
67
+ #for n, m in model.named_modules():
68
+ # if isinstance(m, WQLinear):
69
+ # print(f"AWQ Model load : {n} : {m.qweight.data.min()} {m.qweight.data.max()} {m.qweight.data.shape} {m.scales.shape} qzeros: {m.qzeros.shape} {m.qzeros.min()} {m.qzeros.max()}")
70
+
71
+ elif args.awq == 'run':
72
+ awq_results = run_awq(
73
+ model, tokenizer,
74
+ w_bit=args.w_bit, q_config=q_config,
75
+ n_samples=128, seqlen=512,
76
+ )
77
+ torch.save(awq_results, "./phi3-mini-w%d-g128-generated.pt"%args.w_bit)
78
+ print(model)
79
+ print("Saved AWQ results in ./phi3-mini-w%d-g128-generated.pt"%args.w_bit)
80
+ raise SystemExit
81
+
82
+
83
+ Utils.replace_node( model,
84
+ WQLinear,
85
+ qlinear.QLinearPerGrp,
86
+ (), {'device':'cpu', 'w_bit':args.w_bit, 'group_size':128} )
87
+ print(model)
88
+ gc.collect()
89
+
90
+ Utils.print_model_size(model)
91
+ if args.lm_head: # Quantize lm_head
92
+ Utils.replace_node( model,
93
+ torch.nn.Linear,
94
+ qlinear.QLinearPerGrp,
95
+ (), {'device':'cpu', 'w_bit':args.w_bit, 'group_size':32} )
96
+ print(model)
97
+ gc.collect()
98
+
99
+ torch.save(model, ckpt)
100
+ print(f"Quantized and saved model: {ckpt}")
101
+ raise SystemExit
102
+ else:
103
+ print(f"Loading from ckpt: {ckpt}")
104
+ if not os.path.exists(ckpt):
105
+ print(f"\n\n ***** Run --task quantize (with/without lm_head) first to save quantized model ...!!! \n\n")
106
+ raise SystemExit
107
+ model = torch.load(ckpt)
108
+
109
+ Utils.print_model_size(model)
110
+ _ = gc.collect()
111
+ model.eval()
112
+ model = model.to(torch.bfloat16)
113
+ print(model)
114
+ return model, tokenizer
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument('--dataset', help="Dataset - wikitext2-raw-v1, wikitext2-v1", type=str, default="raw", choices=["non-raw", "raw"])
120
+ parser.add_argument('--w_bit', help="weight bit size", type=int, default=3, choices=[3, 4])
121
+ parser.add_argument('--awq', help="load awq scales, clips from pt or run awq", type=str, default="load", choices=["load", "run", "none"])
122
+ parser.add_argument("--target", help="cpu, aie, aie_emu", type=str, default="cpu", choices=["cpu", "aie_emu", "aie"])
123
+ parser.add_argument('--task', help="quantize: Apply AWQ and save ckpt; perplexity: Measure perplexity on wikitext2 dataset; benchmark: Benchmark latency w.r.t prompt length; benchmark_long: Benchmark long sequences (compare with flash attn); decode: Decode set of prompts;", type=str, default="decode", choices=["quantize", "decode", "benchmark", "benchmark_long", "perplexity"] )
124
+ parser.add_argument('--flash_attention', help="Enable flash attention", action='store_true')
125
+ parser.add_argument('--lm_head', help="Enable PerGrp quantization of lm_head layer", action='store_true')
126
+ parser.add_argument('--num_torch_threads', help="Number of torch threads", type=int, default=8, choices=[1, 2, 3, 4, 5, 6, 7, 8])
127
+ args = parser.parse_args()
128
+ print(f"{args}")
129
+ dev = os.getenv("DEVICE")
130
+ print(f'DEVICE varibale is {dev}')
131
+
132
+ if dev == "stx":
133
+ p = psutil.Process()
134
+ p.cpu_affinity([0, 1, 2, 3])
135
+ torch.set_num_threads(args.num_torch_threads)
136
+
137
+ log_dir = "./logs_awq_phi3_chat"
138
+ if not os.path.exists(log_dir):
139
+ os.makedirs(log_dir)
140
+ log_file = log_dir + "/log_awq.log"
141
+
142
+ logging.basicConfig(filename=log_file,
143
+ filemode='w',
144
+ format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
145
+ datefmt='%H:%M:%S',
146
+ level=logging.CRITICAL)
147
+
148
+ model, tokenizer = load_model(args)
149
+
150
+ if args.awq != "none":
151
+ for n, m in model.named_modules():
152
+ print(n)
153
+ if isinstance(m, qlinear.QLinearPerGrp):
154
+ print(f"Preparing weights of layer : {n}")
155
+ m.device = "aie"
156
+ m.quantize_weights()
157
+
158
+ print(model)
159
+ Utils.print_model_size(model)
160
+
161
+ warmup(model, tokenizer)
162
+
163
+ if (args.task == "decode"):
164
+ decode_prompts(model, tokenizer, max_new_tokens=11)
165
+ logging.shutdown()
166
+ out_file = log_file.replace(".log", "_profile.csv")
167
+ out_file = open(out_file, "w")
168
+ ProfileAIE.analyze_profiling(False, True, log_file, out_file)
169
+ out_file.close()
170
+
171
+ elif (args.task == "benchmark") or (args.task == "benchmark_long"):
172
+ #print(model.config.max_position_embeddings) # 2048
173
+ trainloader, testenc = get_wikitext2(tokenizer, nsamples=2, seqlen=4096)
174
+ if (args.task == "benchmark"):
175
+ seqlens = [1,2,3,4,5,6,7, 8,9,10,60,61,62,63,64,65,510,512,513,514,515]
176
+ else:
177
+ seqlens = [512, 1024, 1536]
178
+ input_ids = next(iter(trainloader))[0][:, :4096]
179
+ for seqlen in seqlens:
180
+ logging.critical("*"*40)
181
+ print("*"*40)
182
+ print(f"Benchmarking for {seqlen} tokens ...")
183
+ input_ids_test = input_ids[:, :seqlen]
184
+ decode_prompt(model, tokenizer, prompt=None, input_ids = input_ids_test, max_new_tokens=11)
185
+
186
+ logging.shutdown()
187
+ out_file = log_file.replace(".log", "_profile.csv")
188
+ out_file = open(out_file, "w")
189
+ ProfileAIE.analyze_profiling(False, True, log_file, out_file)
190
+ out_file.close()
191
+
192
+ elif (args.task == "perplexity"):
193
+ start = time.time()
194
+ perplexity(model, tokenizer, dataset=args.dataset)
195
+ print(f"Time taken to measure ppl on RyzenAI: {time.time() - start}s")
save_weights.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright © 2023 Advanced Micro Devices, Inc. All rights reserved.
3
+ #
4
+
5
+ import torch
6
+ import argparse
7
+ from transformers import pipeline, set_seed
8
+
9
+ from modeling_llama_amd import LlamaForCausalLM
10
+ from transformers import LlamaTokenizer
11
+ import os
12
+
13
+ import gc
14
+ import smooth
15
+
16
+ import numpy as np
17
+
18
+ set_seed(123)
19
+
20
+ def save_weights(weights_dir):
21
+ model = LlamaForCausalLM.from_pretrained("./llama-2-wts-hf/%s"%args.model_name) #, torch_dtype=torch.bfloat16)
22
+
23
+ if args.quant_mode == "smooth":
24
+ act_scales = torch.load(os.getenv("PYTORCH_AIE_PATH") + "/ext/smoothquant/act_scales/" + "llama2-7b-gateproj.pt")
25
+ smooth.smooth_lm(model, act_scales, 0.5)
26
+ print(f"SmoothQuant enabled ...")
27
+
28
+ torch.ao.quantization.quantize_dynamic(
29
+ model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True )
30
+ torch.save(model, "./quantized_llama2_%s.pth"%args.model_name)
31
+ count = 0
32
+
33
+ # Save weights for onnx
34
+ for name, module in model.named_modules():
35
+ if isinstance(module, torch.ao.nn.quantized.dynamic.modules.linear.Linear):
36
+ weight_bias = module._packed_params._weight_bias()
37
+ weight_q = torch.int_repr(
38
+ weight_bias[0]).numpy().astype( np.int8)
39
+ weight_scale = weight_bias[0].q_scale()
40
+
41
+ fname = weights_dir + "/" + name
42
+
43
+ if weight_bias[1] is not None:
44
+ bias = weight_bias[1].detach().numpy()
45
+ print(f"{name} {module._get_name()} {weight_q.shape} {bias.shape} ")
46
+ count += bias.shape[0]
47
+ np.savez(fname, weight_q=weight_q, weight_scale=weight_scale, bias=bias)
48
+ else:
49
+ print(f"{name} {module._get_name()} {weight_q.shape} None ")
50
+ bias = None
51
+ np.savez(fname, weight_q=weight_q, weight_scale=weight_scale)
52
+
53
+ count += weight_q.shape[0] * weight_q.shape[1]
54
+ print(f"Num of params: {count/(1024*1024)}MB")
55
+
56
+
57
+ def read_weights(weights_dir):
58
+ for path, directories, files in os.walk(weights_dir):
59
+ for i, file_name in enumerate(files):
60
+ file_name = path + "/" + file_name
61
+ npzfile = np.load(file_name)
62
+ weight_q = npzfile['weight_q']
63
+ weight_scale = npzfile['weight_scale']
64
+
65
+ if 'bias' in npzfile.files:
66
+ bias = npzfile['bias']
67
+ print(f"{file_name} {weight_q.shape} {bias.shape} {weight_q.min()} {weight_q.max()}")
68
+ else:
69
+ bias = None
70
+ print(f"{file_name} {weight_q.shape} None ")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ """
75
+ Description:
76
+ 1. Load Llama2 model
77
+ 2. Perform Smooth quant
78
+ 3. Perform PTDQ
79
+ 4. Save pytorch model
80
+ 5. Create weights directory
81
+ 6. Dump all integer weights, floating point scale and floating point bias to npz file
82
+ 7. Each npz file is the hierarchical name of the layer
83
+ """
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("--model_name", help="Different Llama model variants", type=str, default="7B_chat", choices=["7B", "7B_chat"])
86
+ parser.add_argument('--quant_mode', help="Quantization mode - smoothquant or pytorch dynamic-quant", type=str, default="smooth", choices=["dyn", "smooth"])
87
+ parser.add_argument('--action', help="save to npz or read from npz", type=str, default="save", choices=["save", "read"])
88
+ args = parser.parse_args()
89
+ print(f"{args}")
90
+
91
+ weights_dir = "./weights_%s"%args.model_name
92
+ if not os.path.exists(weights_dir):
93
+ os.makedirs(weights_dir)
94
+
95
+ if args.action == "save":
96
+ save_weights(weights_dir)
97
+ else:
98
+ read_weights(weights_dir)
99
+
100
+
101
+
102
+