| import torch |
| def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32): |
| assert head_dim % 2 == 0, "Embedding dimension must be even" |
|
|
| |
| inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)) |
|
|
| |
| if freq_config is not None: |
| low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"] |
| high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"] |
|
|
| wavelen = 2 * torch.pi / inv_freq |
|
|
| inv_freq_llama = torch.where( |
| wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq |
| ) |
|
|
| smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / ( |
| freq_config["high_freq_factor"] - freq_config["low_freq_factor"] |
| ) |
|
|
| smoothed_inv_freq = ( |
| (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq |
| ) |
|
|
| is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen) |
| inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) |
| inv_freq = inv_freq_llama |
|
|
| |
| positions = torch.arange(context_length, dtype=dtype) |
|
|
| |
| angles = positions[:, None] * inv_freq[None, :] |
|
|
| |
| angles = torch.cat([angles, angles], dim=1) |
|
|
| |
| cos = torch.cos(angles) |
| sin = torch.sin(angles) |
|
|
| return cos, sin |
|
|
|
|
| def apply_rope(x, cos, sin): |
| |
| batch_size, num_heads, seq_len, head_dim = x.shape |
| assert head_dim % 2 == 0, "Head dimension must be even" |
|
|
| |
| x1 = x[..., : head_dim // 2] |
| x2 = x[..., head_dim // 2 :] |
|
|
| |
| cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) |
| sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) |
|
|
| |
| rotated = torch.cat((-x2, x1), dim=-1) |
| x_rotated = (x * cos) + (rotated * sin) |
|
|
| |
| return x_rotated.to(dtype=x.dtype) |
|
|
| def model_memory_size(model, input_dtype=torch.float32): |
| total_params = 0 |
| total_grads = 0 |
| for param in model.parameters(): |
| |
| param_size = param.numel() |
| total_params += param_size |
| |
| if param.requires_grad: |
| total_grads += param_size |
|
|
| |
| total_buffers = sum(buf.numel() for buf in model.buffers()) |
|
|
| |
| |
| element_size = torch.tensor(0, dtype=input_dtype).element_size() |
| total_memory_bytes = (total_params + total_grads + total_buffers) * element_size |
|
|
| |
| total_memory_gb = total_memory_bytes / (1024**3) |
|
|
| return total_memory_gb |
|
|
| import os |
| from pathlib import Path |
|
|
| import tiktoken |
| from tiktoken.load import load_tiktoken_bpe |
|
|
|
|
|
|
| class Tokenizer: |
| """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" |
| def __init__(self, model_path): |
| if not os.path.isfile(model_path): |
| raise FileNotFoundError(model_path) |
|
|
| mergeable = load_tiktoken_bpe(model_path) |
|
|
| |
| self.special = { |
| "<|begin_of_text|>": 128000, |
| "<|end_of_text|>": 128001, |
| "<|start_header_id|>": 128006, |
| "<|end_header_id|>": 128007, |
| "<|eot_id|>": 128009, |
| } |
| self.special.update({f"<|reserved_{i}|>": 128002 + i |
| for i in range(256) |
| if 128002 + i not in self.special.values()}) |
|
|
| self.model = tiktoken.Encoding( |
| name=Path(model_path).name, |
| pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" |
| r"|[^\r\n\p{L}\p{N}]?\p{L}+" |
| r"|\p{N}{1,3}" |
| r"| ?[^\s\p{L}\p{N}]+[\r\n]*" |
| r"|\s*[\r\n]+" |
| r"|\s+(?!\S)" |
| r"|\s+", |
| mergeable_ranks=mergeable, |
| special_tokens=self.special, |
| ) |
|
|
| def encode(self, text, bos=False, eos=False): |
| ids = ([self.special["<|begin_of_text|>"]] if bos else []) \ |
| + self.model.encode(text) |
| if eos: |
| ids.append(self.special["<|end_of_text|>"]) |
| return ids |
|
|
| def decode(self, ids): |
| return self.model.decode(ids) |
|
|
|
|
| class ChatFormat: |
|
|
| def __init__(self, tokenizer: Tokenizer, *, |
| default_system="You are a helpful assistant."): |
| self.tok = tokenizer |
| self.default_system = default_system |
|
|
| def _header(self, role): |
| """Encode <|start_header_id|>role<|end_header_id|>\n\n""" |
| return ( |
| [self.tok.special["<|start_header_id|>"]] |
| + self.tok.encode(role) |
| + [self.tok.special["<|end_header_id|>"]] |
| + self.tok.encode("\n\n") |
| ) |
|
|
| def encode(self, user_message, system_message=None): |
| sys_msg = system_message if system_message is not None else self.default_system |
|
|
| ids = [self.tok.special["<|begin_of_text|>"]] |
|
|
| |
| ids += self._header("system") |
| ids += self.tok.encode(sys_msg) |
| ids += [self.tok.special["<|eot_id|>"]] |
|
|
| |
| ids += self._header("user") |
| ids += self.tok.encode(user_message) |
| ids += [self.tok.special["<|eot_id|>"]] |
|
|
| |
| ids += self._header("assistant") |
|
|
| return ids |
|
|
| def assign(left, right, tensor_name="unknown"): |
| if left.shape != right.shape: |
| raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") |
|
|
| if isinstance(right, torch.Tensor): |
| return torch.nn.Parameter(right.clone().detach()) |
| else: |
| return torch.nn.Parameter(torch.tensor(right)) |
|
|
|
|
| def load_weights_into_llama(model, param_config, params): |
| model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") |
|
|
| for l in range(param_config["n_layers"]): |
|
|
| |
| model.trf_blocks[l].att.W_query.weight = assign( |
| model.trf_blocks[l].att.W_query.weight, |
| params[f"model.layers.{l}.self_attn.q_proj.weight"], |
| f"model.layers.{l}.self_attn.q_proj.weight" |
| ) |
| model.trf_blocks[l].att.W_key.weight = assign( |
| model.trf_blocks[l].att.W_key.weight, |
| params[f"model.layers.{l}.self_attn.k_proj.weight"], |
| f"model.layers.{l}.self_attn.k_proj.weight" |
| ) |
| model.trf_blocks[l].att.W_value.weight = assign( |
| model.trf_blocks[l].att.W_value.weight, |
| params[f"model.layers.{l}.self_attn.v_proj.weight"], |
| f"model.layers.{l}.self_attn.v_proj.weight" |
| ) |
| model.trf_blocks[l].att.out_proj.weight = assign( |
| model.trf_blocks[l].att.out_proj.weight, |
| params[f"model.layers.{l}.self_attn.o_proj.weight"], |
| f"model.layers.{l}.self_attn.o_proj.weight" |
| ) |
| model.trf_blocks[l].norm1.weight = assign( |
| model.trf_blocks[l].norm1.weight, |
| params[f"model.layers.{l}.input_layernorm.weight"], |
| f"model.layers.{l}.input_layernorm.weight" |
| ) |
|
|
| |
| model.trf_blocks[l].ff.fc1.weight = assign( |
| model.trf_blocks[l].ff.fc1.weight, |
| params[f"model.layers.{l}.mlp.gate_proj.weight"], |
| f"model.layers.{l}.mlp.gate_proj.weight" |
| ) |
| model.trf_blocks[l].ff.fc2.weight = assign( |
| model.trf_blocks[l].ff.fc2.weight, |
| params[f"model.layers.{l}.mlp.up_proj.weight"], |
| f"model.layers.{l}.mlp.up_proj.weight" |
| ) |
| model.trf_blocks[l].ff.fc3.weight = assign( |
| model.trf_blocks[l].ff.fc3.weight, |
| params[f"model.layers.{l}.mlp.down_proj.weight"], |
| f"model.layers.{l}.mlp.down_proj.weight" |
| ) |
| model.trf_blocks[l].norm2.weight = assign( |
| model.trf_blocks[l].norm2.weight, |
| params[f"model.layers.{l}.post_attention_layernorm.weight"], |
| f"model.layers.{l}.post_attention_layernorm.weight" |
| ) |
|
|
| |
| model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight") |
|
|
| if "lm_head.weight" in params.keys(): |
| model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") |
| else: |
| model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") |
| print("Model uses weight tying.") |
|
|
| def text_to_token_ids(text, tokenizer): |
| encoded = tokenizer.encode(text) |
| encoded_tensor = torch.tensor(encoded).unsqueeze(0) |
| return encoded_tensor |
|
|
|
|
| def token_ids_to_text(token_ids, tokenizer): |
| flat = token_ids.squeeze(0) |
| return tokenizer.decode(flat.tolist()) |
|
|
|
|
| def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): |
|
|
| |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -context_size:] |
| with torch.no_grad(): |
| logits = model(idx_cond) |
| logits = logits[:, -1, :] |
|
|
| |
| if top_k is not None: |
| |
| top_logits, _ = torch.topk(logits, top_k) |
| min_val = top_logits[:, -1] |
| logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) |
|
|
| |
| if temperature > 0.0: |
| logits = logits / temperature |
|
|
| |
| probs = torch.softmax(logits, dim=-1) |
|
|
| |
| idx_next = torch.multinomial(probs, num_samples=1) |
|
|
| |
| else: |
| idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| if idx_next == eos_id: |
| break |
|
|
| |
| |
| idx = torch.cat((idx, idx_next), dim=1) |
|
|
| return idx |