| """ | |
| # RetNPhi: Byte-Level Hybrid of Phi-3.5 and RetNet | |
| RetNPhi is an experimental architecture that transforms Phi-3.5 into a byte-level language model, incorporating RetNet-inspired mechanisms. This innovative approach enables the model to process raw byte sequences, allowing for universal file type handling. | |
| ## Key Features: | |
| 1. **Byte-Level Processing**: Operates directly on raw byte sequences, enabling universal application to any file type. | |
| 2. **RetNet Integration**: Incorporates RetNet's multi-scale exponential decay and group normalization for efficient long-range dependency modeling. | |
| 3. **Dual-mode Processing**: Supports parallel mode for efficient training and recurrent mode for inference. | |
| 4. **Selective Fine-tuning**: Trains only specific layers (e.g., token embedding, post-attention layer normalizations) while keeping most of the original Phi-3.5 weights frozen. | |
| 5. **Weight-Decomposed Low-Rank Adaptation (DoRA)**: Applies DoRA to self-attention output projections for efficient adaptation while preserving pretrained knowledge. | |
| ## Implementation Strategy: | |
| - **Weight Reuse**: Utilizes frozen weights from the original Phi-3.5 model for most layers. | |
| - **Flexible DoRA Application**: Allows configuration of which layers and targets to apply DoRA. | |
| - **Configurable Architecture**: Supports both retention-based and original attention mechanisms. | |
| - **Untied Embeddings Option**: Provides the ability to use separate input and output embeddings. | |
| ## Training and Inference: | |
| - Implements efficient training loops with customizable learning rate schedules. | |
| - Supports both training from scratch and fine-tuning from a checkpoint. | |
| - Provides a generation function for text completion tasks. | |
| ## Goals: | |
| - Explore the potential of retention-like mechanisms in a byte-level Phi architecture. | |
| - Leverage dual-mode processing for efficient training and inference. | |
| - Develop a universal model capable of processing any file type. | |
| Note: This is a highly experimental implementation, designed for research and exploration rather than production use. It demonstrates the potential of combining pretrained models with novel architectures and efficient fine-tuning techniques. | |
| Author: Josef Albers | |
| Date: Aug 28, 2024 | |
| """ | |
| import glob | |
| import json | |
| import math | |
| import time | |
| from datetime import datetime | |
| from types import SimpleNamespace | |
| import fire | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| import numpy as np | |
| from huggingface_hub import snapshot_download | |
| from mlx.utils import tree_flatten, tree_unflatten | |
| from datasets import load_dataset | |
| class Tokenizer: | |
| def __init__(self, file_path=None): | |
| if file_path is None: | |
| self.vocab = list(range(256)) | |
| else: | |
| with open(file_path, 'r') as f: | |
| content = f.read().lower().encode('utf-8') | |
| self.vocab = sorted(set(content)) | |
| self.vocab_size = len(self.vocab) | |
| self.byte_to_index = {byte: index for index, byte in enumerate(self.vocab)} | |
| self.index_to_byte = {index: byte for index, byte in enumerate(self.vocab)} | |
| def encode(self, text): | |
| byte_seq = text.encode('utf-8') | |
| return [self.byte_to_index[byte] for byte in byte_seq] | |
| def decode(self, indices): | |
| byte_seq = bytes(self.index_to_byte[index] for index in indices) | |
| return byte_seq.decode('utf-8', errors='ignore') | |
| class SuRoPE(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dim = config.hidden_size // config.num_attention_heads | |
| self.original_max_position_embeddings = config.original_max_position_embeddings | |
| self.rope_theta = config.rope_theta | |
| self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings)) | |
| self._long_factor = mx.array(config.rope_scaling["long_factor"], dtype=mx.float32) | |
| self._short_factor = mx.array(config.rope_scaling["short_factor"], dtype=mx.float32) | |
| def __call__(self, q, k, position_ids): | |
| cos, sin = self._get_cos_sin(position_ids) | |
| q = (q * cos) + (self._rotate_half(q) * sin) | |
| k = (k * cos) + (self._rotate_half(k) * sin) | |
| return q, k | |
| def _get_cos_sin(self, position_ids): | |
| su_factor = self._short_factor | |
| position_ids_expanded = position_ids[:, None, :] | |
| inv_freq = 1.0 / (su_factor * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)) | |
| inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0) | |
| freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) | |
| emb = mx.concatenate([freqs, freqs], axis=-1) | |
| cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1) | |
| sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1) | |
| return cos, sin | |
| def _rotate_half(self, x): | |
| midpoint = x.shape[-1] // 2 | |
| x1, x2 = x[..., :midpoint], x[..., midpoint:] | |
| return mx.concatenate([-x2, x1], axis=-1) | |
| class Phi3Attention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| dim = config.hidden_size | |
| self.n_heads = n_heads = config.num_attention_heads | |
| self.n_kv_heads = n_kv_heads = config.num_key_value_heads | |
| self.num_hidden_layers = config.num_hidden_layers | |
| self.head_dim = head_dim = config.hidden_size // n_heads | |
| self.scale = head_dim**-0.5 | |
| chop_1 = self.n_heads * self.head_dim | |
| chop_2 = chop_1 + self.n_kv_heads * self.head_dim | |
| self.chop = [chop_1, chop_2] | |
| op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) | |
| self.qkv_proj = nn.Linear(dim, op_size, bias=False) | |
| self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | |
| self.rope = SuRoPE(config) | |
| def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode): | |
| B, L, _ = x.shape | |
| qkv = self.qkv_proj(x) | |
| q, k, v = mx.split(qkv, self.chop, axis=-1) | |
| q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | |
| k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| if cache is None: | |
| position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids | |
| q, k = self.rope(q,k,position_ids) | |
| mask = mx.triu(mx.full((v.shape[2], v.shape[2]), -mx.inf), k=1) | |
| if attention_mask is not None: | |
| mask += mx.where(attention_mask[:, :, None]*attention_mask[:, None, :]==1, 0, -mx.inf) | |
| mask = mx.expand_dims(mask, 1) | |
| else: | |
| mask = mask[None, None] | |
| else: | |
| past_k, past_v, past_p, past_m = cache | |
| position_ids = past_p[:,-1:]+1 | |
| mask = mx.pad(past_m[:,:,-1:,:], ((0,0),(0,0),(0,0),(0,1))) | |
| q, k = self.rope(q, k, position_ids) | |
| k = mx.concatenate([past_k, k], axis=2) | |
| v = mx.concatenate([past_v, v], axis=2) | |
| cache = (k, v, position_ids, mask) | |
| w = (q * self.scale) @ k.transpose(0, 1, 3, 2) | |
| w += mask | |
| w = mx.softmax(w, axis=-1) | |
| o = w @ v | |
| o = o.transpose(0, 2, 1, 3).reshape(B, L, -1) | |
| return self.o_proj(o).astype(x.dtype), cache | |
| class Phi3Retention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dim = dim = config.hidden_size | |
| self.n_heads = n_heads = config.num_attention_heads | |
| self.n_kv_heads = n_kv_heads = config.num_key_value_heads | |
| self.num_hidden_layers = config.num_hidden_layers | |
| self.head_dim = head_dim = config.hidden_size // n_heads | |
| self.scale = head_dim**-0.5 | |
| chop_1 = self.n_heads * self.head_dim | |
| chop_2 = chop_1 + self.n_kv_heads * self.head_dim | |
| self.chop = [chop_1, chop_2] | |
| op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) | |
| self.qkv_proj = nn.Linear(dim, op_size, bias=False) | |
| self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | |
| self.rope = SuRoPE(config) | |
| xmin, xmax = math.log(1 / 32), math.log(1 / 512) | |
| x = mx.linspace(xmin, xmax, num=n_heads) | |
| self._gamma = 1 - x.exp() | |
| self.gn = nn.GroupNorm(num_groups=head_dim, dims=-1, affine=False) | |
| def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode): | |
| if use_recurrent_mode: | |
| return self.recurrent_mode(x, cache) | |
| B, L, _ = x.shape | |
| qkv = self.qkv_proj(x) | |
| q, k, v = mx.split(qkv, self.chop, axis=-1) | |
| q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | |
| k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids | |
| q, k = self.rope(q,k,position_ids) | |
| cache = None | |
| w = (q * self.scale) @ k.transpose(0, 1, 3, 2) | |
| w = w * self._decay(L) | |
| o = w @ v | |
| o = o.transpose(0, 2, 1, 3).reshape(B*L, -1) | |
| o = self.gn(o).reshape(B, L, -1) | |
| return self.o_proj(o).astype(x.dtype), cache | |
| def recurrent_mode(self, x, cache): | |
| if cache is None: | |
| s = mx.zeros((1, 32, 96, 96)) | |
| n = 0 | |
| else: | |
| s, n = cache | |
| qkv = self.qkv_proj(x) | |
| q, k, v = mx.split(qkv, self.chop, axis=-1) | |
| q = q.reshape(1, 1, self.n_heads, -1).transpose(0, 2, 1, 3) | |
| k = k.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| v = v.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| position_ids = mx.array([[n]]) | |
| q, k = self.rope(q,k,position_ids) | |
| k = k * self.scale | |
| s = self._gamma[None, :, None, None] * s + (k.transpose(0, 1, 3, 2) @ v) | |
| o = q @ s | |
| o = o.transpose(0, 2, 1, 3).reshape(1, -1) | |
| o = self.gn(o).reshape(1, 1, -1) | |
| o = self.o_proj(o).astype(x.dtype) | |
| return o, (s, n+1) | |
| def _decay(self, sequence_length): | |
| n = mx.arange(sequence_length)[:,None] | |
| m = mx.arange(sequence_length)[None] | |
| D = (self._gamma[:, None, None] ** (n-m)) * (n >= m) | |
| return D | |
| class Phi3MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) | |
| def __call__(self, x): | |
| x = self.gate_up_proj(x) | |
| gate, x = mx.split(x, 2, axis=-1) | |
| return self.down_proj(nn.silu(gate) * x) | |
| class Phi3DecoderLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| if config.use_retention: | |
| self.self_attn = Phi3Retention(config) | |
| else: | |
| self.self_attn = Phi3Attention(config) | |
| self.mlp = Phi3MLP(config) | |
| self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode): | |
| r, cache = self.self_attn(self.input_layernorm(x), position_ids, attention_mask, cache, use_recurrent_mode) | |
| h = x + r | |
| r = self.mlp(self.post_attention_layernorm(h)) | |
| return h + r, cache | |
| class Phi3Model(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.embed_new = nn.Embedding(config.vocab_size, config.hidden_size) | |
| self.layers = [Phi3DecoderLayer(config) for _ in range(config.num_hidden_layers)] | |
| self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def __call__(self, input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode): | |
| x = self.embed_new(input_ids) | |
| cache = [None]*len(self.layers) if cache is None else cache | |
| for i, l in enumerate(self.layers): | |
| x, cache[i] = l(x, position_ids, attention_mask, cache[i], use_recurrent_mode) | |
| return self.norm(x), cache | |
| class Phi3ForCausalLM(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.model = Phi3Model(config) | |
| if config.untie_embedding: | |
| self.lm_new = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.untie = True | |
| else: | |
| self.untie = False | |
| def __call__(self, input_ids, pixel_values=None, image_sizes=None, position_ids=None, attention_mask=None, cache=None, use_recurrent_mode=False): | |
| x, cache = self.model(input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode) | |
| if self.untie: | |
| return self.lm_new(x), cache | |
| return self.model.embed_new.as_linear(x), cache | |
| @property | |
| def layers(self): | |
| return self.model.layers | |
| class DoRALinear(nn.Module): | |
| @staticmethod | |
| def from_linear(linear, r, alpha, scale, dropout): | |
| output_dims, input_dims = linear.weight.shape | |
| if isinstance(linear, nn.QuantizedLinear): | |
| input_dims *= 32 // linear.bits | |
| lora_lin = DoRALinear(input_dims=input_dims, output_dims=output_dims, r=r, alpha=alpha, scale=scale, dropout=dropout) | |
| lora_lin.linear = linear | |
| return lora_lin | |
| def __init__(self, input_dims, output_dims, r, alpha, scale, dropout, bias=False): | |
| super().__init__() | |
| self.linear = nn.Linear(input_dims, output_dims, bias=bias) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.scale = scale * (alpha / r) | |
| scale = 1 / math.sqrt(input_dims) | |
| self.lora_a = mx.random.uniform(low=-scale, high=scale, shape=(input_dims, r)) | |
| self.lora_b = mx.zeros(shape=(r, output_dims)) | |
| self.m = mx.linalg.norm(self._dequantized_weight(), axis=1).astype(mx.float32) | |
| def _dequantized_weight(self): | |
| weight = self.linear.weight | |
| if isinstance(self.linear, nn.QuantizedLinear): | |
| weight = mx.dequantize(weight, self.linear.scales, self.linear.biases, self.linear.group_size, self.linear.bits) | |
| return weight | |
| def __call__(self, x): | |
| y = self.linear(x) | |
| z = (self.dropout(x) @ self.lora_a) @ self.lora_b | |
| z = y + (self.scale * z) | |
| adapted = self._dequantized_weight() + (self.scale * self.lora_b.T) @ self.lora_a.T | |
| denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) | |
| z = (self.m / denom) * z | |
| return z.astype(x.dtype) | |
| def linear_to_lora_layers(model, lora_layers, lora_targets, lora_rank, lora_scale, lora_dropout): | |
| if lora_layers == 'all': | |
| lora_layers = model.layers | |
| elif isinstance(lora_layers, int): | |
| lora_layers = model.layers[-lora_layers:] | |
| elif isinstance(lora_layers, list): | |
| lora_layers = [model.layers[i] for i in lora_layers] | |
| else: | |
| raise ValueError("Invalid type for lora_layers. Expected int (number of layers) or list (layer indices or names).") | |
| def to_lora(layer): | |
| return DoRALinear.from_linear(layer, r=lora_rank, alpha=lora_rank, scale=lora_scale, dropout=lora_dropout) | |
| for l in lora_layers: | |
| lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in lora_targets] | |
| l.update_modules(tree_unflatten(lora_layers)) | |
| def load_base_model(model_cfg, init=False): | |
| model_id='microsoft/Phi-3.5-mini-instruct' | |
| model_path = snapshot_download(model_id, allow_patterns=["*.safetensors", "config.json"]) | |
| with open(f"{model_path}/config.json", "r") as f: | |
| config = json.load(f) | |
| config = config|model_cfg | |
| model_config = SimpleNamespace(**config) | |
| model = Phi3ForCausalLM(model_config) | |
| model_weight = [(k, v) for wf in glob.glob(f"{model_path}/*.safetensors") for k, v in mx.load(wf).items()] | |
| model.load_weights(model_weight, strict=False) | |
| model.set_dtype(mx.float32) | |
| if init: | |
| init_fn_embed = nn.init.normal(mean=-0.000453949, std=0.0344238) | |
| model.apply_to_modules(lambda k, v: v.apply(init_fn_embed) if k.endswith('embed_new') else None) | |
| if model_config.untie_embedding: | |
| init_fn_lm = nn.init.normal(mean=-0.000231743, std=0.043457) | |
| model.apply_to_modules(lambda k, v: v.apply(init_fn_lm) if k.endswith('lm_new') else None) | |
| class_predicate = lambda k, m: hasattr(m, "to_quantized") and not k.endswith('new') | |
| nn.quantize(model, 64, 4, class_predicate) | |
| mx.eval(model.parameters()) | |
| return model | |
| def load_model_for_training(lora_cfg, model_cfg, thaws, from_path=None): | |
| model = load_base_model(model_cfg, init=False) | |
| if from_path: | |
| model.load_weights(from_path, strict=False) | |
| model.freeze() | |
| if len(lora_cfg['targets']) > 1: | |
| linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout']) | |
| model.apply_to_modules(lambda k, v: v.unfreeze() if any(k.endswith(t) for t in thaws) else None) | |
| mx.eval(model.parameters()) | |
| # print("Trainable parameters:", [i[0] for i in tree_flatten(model.trainable_parameters())]) | |
| model.train() | |
| return model | |
| def load_model_for_inference(lora_cfg, model_cfg): | |
| model = load_base_model(model_cfg, init=False) | |
| if len(lora_cfg['targets']) > 1: | |
| linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout']) | |
| _path = 'trained_retnphi.safetensors' if model_cfg['use_retention'] else 'trained_orgnphi.safetensors' | |
| model.load_weights(_path, strict=False) | |
| mx.eval(model.parameters()) | |
| model.eval() | |
| return model | |
| def generate(prompt, lora_cfg, model_cfg, max_tokens=50, verbose = True): | |
| model = load_model_for_inference(lora_cfg=lora_cfg, model_cfg=model_cfg) | |
| input_ids = mx.array(tokenizer.encode(prompt)) | |
| if model_cfg['use_retention']: | |
| cache = None | |
| for i in input_ids: | |
| logits, cache = model(i[None, None], cache=cache, use_recurrent_mode=True) | |
| else: | |
| logits, cache = model(input_ids[None]) | |
| token = mx.argmax(logits[:,-1,:], axis=-1) | |
| mx.eval(token, cache) | |
| list_tokens = token.tolist() | |
| for i in range(max_tokens): | |
| logits, cache = model(token[None], cache=cache, use_recurrent_mode=True) | |
| token = mx.argmax(logits[:,-1,:], axis=-1) | |
| mx.eval(token, cache) | |
| list_tokens += token.tolist() | |
| if tokenizer.decode(list_tokens[-2:]) == '\n\n': | |
| break | |
| output = tokenizer.decode(list_tokens) | |
| if verbose: | |
| print(f'{prompt=} + {output=}\n-> {prompt+output}') | |
| del model | |
| return output | |
| def train_gsm(learning_rates, num_epochs, batch_size, seq_length, lora_cfg, model_cfg, thaws, take, from_path=None): | |
| def load_gsm_data(tokenizer, is_tiny=True): | |
| if is_tiny: | |
| data = load_dataset("TinyGSM/TinyGSM")["train"] | |
| if take: | |
| data = data.take(take) | |
| data = data.filter(lambda x: len(x['question']) < 100 and ':' not in x['question'] and '-' not in x['question'] and "'" not in x['code'] and '\n result =' in x['code']) | |
| split_point = int(len(data) * 0.8) | |
| train_data = data.select(range(split_point)) | |
| eval_data = data.select(range(split_point, len(data))) | |
| def format_example(example): | |
| code_raw = example['code'] | |
| start = code_raw.rfind('\n """') | |
| if start == -1: | |
| print('Wrong format to start') | |
| return code_raw.strip() | |
| start = start + 8 | |
| end = code_raw.rfind('\n result =') | |
| if end == -1: | |
| print('Wrong format to end') | |
| end = len(code_raw) | |
| code_block = code_raw[start:end] | |
| code_lines = code_block.split('\n ') | |
| formatted_code = '\n'.join(line.rstrip() for line in code_lines if line.strip()) | |
| formatted_code = '\n' + formatted_code.strip() + '\n\n' | |
| result = (example['question'].strip(), formatted_code) | |
| return result | |
| else: | |
| dataset = load_dataset("openai/gsm8k", "main") | |
| train_data = dataset["train"] | |
| eval_data = dataset["test"] | |
| def format_example(example): | |
| return (example['question'].strip(), '\n'+example['answer'].strip()+'\n\n') | |
| train_formatted = [format_example(ex) for ex in train_data] | |
| eval_formatted = [format_example(ex) for ex in eval_data] | |
| return train_formatted, eval_formatted | |
| def create_batches(data, tokenizer, batch_size, seq_length): | |
| def _encode(x): | |
| return [tokenizer.encode(i) for i in x] | |
| encoded_data = [_encode(x) for x in data] | |
| encoded_data = [x for x in encoded_data if len(x[0]+x[1]) <= seq_length+1] | |
| if batch_size is None: | |
| batch_size = min(len(encoded_data), 64) | |
| else: | |
| encoded_data = encoded_data[:(len(encoded_data) // batch_size) * batch_size] | |
| np.random.shuffle(encoded_data) | |
| for i in range(0, len(encoded_data), batch_size): | |
| batch = encoded_data[i:i+batch_size] | |
| max_len = min(max(len(q+a)-1 for q, a in batch), seq_length) | |
| x_batch = [] | |
| y_batch = [] | |
| mask_batch = [] | |
| for q, a in batch: | |
| combined = (q+a)[:max_len+1] | |
| x = combined[:-1] | |
| y = combined[1:] | |
| pad_length = max_len - len(x) | |
| x = x + [0] * pad_length | |
| y = y + [0] * pad_length | |
| mask = [False] * (len(q)-1) + [True] * (len(a)) + [False] * (pad_length) | |
| x_batch.append(x) | |
| y_batch.append(y) | |
| mask_batch.append(mask) | |
| yield mx.array(x_batch), mx.array(y_batch), mx.array(mask_batch) | |
| def loss_fn(model, X, y, mask): | |
| logits, _ = model(X) | |
| logits = logits.astype(mx.float32) | |
| ce = nn.losses.cross_entropy(logits, y, reduction='none') | |
| masked_loss = ce * mask | |
| return masked_loss.sum(), mask.sum() | |
| def evaluate(model, data, tokenizer, seq_length): | |
| model.eval() | |
| total_loss = 0 | |
| total_samples = 0 | |
| for X, y, mask in create_batches(data, tokenizer, None, seq_length): | |
| loss, ntoks = loss_fn(model, X, y, mask) | |
| total_loss += loss.item() | |
| total_samples += ntoks.item() | |
| return total_loss / total_samples if total_samples > 0 else -1 | |
| def get_optimizer(train_data): | |
| num_batches_per_epoch = len(list(create_batches(train_data, tokenizer, batch_size, seq_length))) | |
| print(f'{num_batches_per_epoch=}') | |
| num_steps = num_epochs * num_batches_per_epoch | |
| num_warmup = num_steps // 10 | |
| max_lr, min_lr = learning_rates | |
| if num_warmup > 2: | |
| warmup = optim.linear_schedule(min_lr*0.1, max_lr, steps=num_warmup) | |
| cosine = optim.cosine_decay(max_lr, num_steps - num_warmup, min_lr) | |
| lr_schedule = optim.join_schedules([warmup, cosine], [num_warmup]) | |
| else: | |
| lr_schedule = optim.cosine_decay(max_lr, num_steps, min_lr) | |
| return optim.Lion(learning_rate=lr_schedule), num_steps | |
| for arg_name in sorted(locals()): | |
| if arg_name != 'self': | |
| arg_value = locals()[arg_name] | |
| if not callable(arg_value): | |
| print(f"{arg_name}: {arg_value}") | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| print(f'--- {timestamp} ---') | |
| train_data, eval_data = load_gsm_data(tokenizer=tokenizer) | |
| model = load_model_for_training(lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws) | |
| optimizer, num_steps = get_optimizer(train_data) | |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | |
| mx.eval(model, optimizer) | |
| metrics = { | |
| 'steps': [], | |
| 'learning_rates': [], | |
| 'all_train_losses': [], | |
| 'avg_train_losses': [], | |
| 'val_losses': [], | |
| 'trained_toks': [], | |
| } | |
| step = 0 | |
| trained_toks = 0 | |
| losses = [] | |
| tic = time.perf_counter() | |
| for epoch in range(num_epochs): | |
| for X, y, loss_mask in create_batches(data=train_data, tokenizer=tokenizer, batch_size=batch_size, seq_length=seq_length): | |
| model.train() | |
| (loss, ntoks), grads = loss_and_grad_fn(model, X, y, loss_mask) | |
| optimizer.update(model, grads) | |
| mx.eval(loss, ntoks, model, optimizer) | |
| losses.append(loss.item()) | |
| trained_toks += ntoks.item() | |
| step += 1 | |
| if (step % (num_steps // 30) == 0): | |
| avg_train_loss = np.mean(losses) | |
| lr = optimizer.learning_rate.item() | |
| val_loss = evaluate(model=model, data=eval_data, tokenizer=tokenizer, seq_length=seq_length) | |
| print(f"{avg_train_loss:8.4f} ({val_loss:6.4f}) @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)") | |
| metrics['val_losses'].append(val_loss) | |
| # print(f"{avg_train_loss:8.4f} @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)") | |
| tic = time.perf_counter() | |
| metrics['steps'].append(step) | |
| metrics['learning_rates'].append(lr) | |
| metrics['all_train_losses'].extend(losses) | |
| metrics['avg_train_losses'].append(avg_train_loss) | |
| metrics['trained_toks'].append(trained_toks) | |
| losses = [] | |
| trained_toks = 0 | |
| _path = f'trained_retnphi.safetensors' if model_cfg['use_retention'] else f'trained_orgnphi.safetensors' | |
| mx.save_safetensors(_path, dict(tree_flatten(model.trainable_parameters()))) | |
| log = { | |
| 'args': { | |
| 'learning_rates': learning_rates, | |
| 'num_epochs': num_epochs, | |
| 'batch_size': batch_size, | |
| 'seq_length': seq_length, | |
| 'lora_cfg': lora_cfg, | |
| 'model_cfg': model_cfg, | |
| 'thaws': thaws, | |
| 'from_path': from_path | |
| }, | |
| 'metrics': metrics | |
| } | |
| with open(f'train_log_{timestamp}.json', 'w') as f: | |
| json.dump(log, f, indent=2) | |
| del model | |
| tokenizer = Tokenizer() | |
| def main(take=1000, layers='all', targets=["self_attn.o_proj"], thaws=['new', 'post_attention_layernorm'], rank=32, scale=0.1, dropout=0.0, lr_max=1e-4, lr_min=1e-5, num_epochs=90, batch_size=1, seq_length=256, vocab_size=256, use_retention=True, untie_embedding=True, prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?'): | |
| lora_cfg = dict(layers=layers, targets=targets, rank=rank, scale=scale, dropout=dropout) | |
| model_cfg = dict(vocab_size=vocab_size, use_retention=use_retention, untie_embedding=untie_embedding) | |
| train_gsm(learning_rates=(lr_max, lr_min), num_epochs=num_epochs, batch_size=batch_size, seq_length=seq_length, lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws, take=take) | |
| generate(prompt=prompt, lora_cfg=lora_cfg, model_cfg=model_cfg, max_tokens=(seq_length-len(prompt))) | |
| if __name__ == "__main__": | |
| main(take=None, num_epochs=3) # -> 240916 | |
| main(take=None, num_epochs=3, untie_embedding=False) | |
| main(take=None, num_epochs=3, use_retention=False) | |
| main(take=None, num_epochs=3, untie_embedding=False, use_retention=False) | |
| # fire.Fire(main) | |
| # Output: | |
| # 388.7268 @ 1/30 w/ 3.36e-05 (64.73 sec) | |
| # ... | |
| # 4.3768 @ 30/30 w/ 1.00e-05 (64.36 sec) | |
| # prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?' + output='\ncandies_in_carton = 8 \nnumber_of_cartons = 5\ntotal_no_of_candies = candies_in_carton * number_of_cartons\n\n' | |
| # -> There are 8 candies in a carton. How many candies will be in 5 cartons? | |
| # candies_in_carton = 8 | |
| # number_of_cartons = 5 | |
| # total_no_of_candies = candies_in_carton * number_of_cartons | |