| | """ |
| | mlx_lora_trainer.py — Real MLX LoRA training engine with autograd. |
| | |
| | Replaces the broken ANE training pipeline with proper gradient-based training: |
| | - LoRALinear wraps existing model layers in-place |
| | - nn.value_and_grad() computes exact backprop gradients |
| | - Adam optimizer with cosine LR schedule |
| | - Thread-safe: gpu_lock for mutual exclusion with inference |
| | |
| | Since LoRA is injected in-place, mlx_lm.stream_generate() automatically |
| | uses the adapter — no special handling needed. |
| | """ |
| |
|
| | import json |
| | import logging |
| | import math |
| | import threading |
| | import time |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | import mlx.optimizers as optim |
| | import mlx.utils |
| |
|
| | log = logging.getLogger("mlx_lora_trainer") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LoRALinear(nn.Module): |
| | """LoRA adapter wrapping any Linear or QuantizedLinear layer. |
| | |
| | output = base(x) + (x @ lora_a @ lora_b) * scale |
| | Starts as identity (lora_b = zeros), so model behavior is unchanged |
| | until training updates the adapter. |
| | """ |
| |
|
| | @classmethod |
| | def from_base(cls, base: nn.Module, rank: int = 32, alpha: float = 32.0, |
| | dropout: float = 0.0): |
| | """Create LoRALinear from an existing Linear or QuantizedLinear.""" |
| | if isinstance(base, nn.QuantizedLinear): |
| | in_features = base.weight.shape[1] * 32 // base.bits |
| | out_features = base.weight.shape[0] |
| | elif isinstance(base, nn.Linear): |
| | out_features, in_features = base.weight.shape |
| | else: |
| | raise TypeError(f"Unsupported layer type: {type(base)}") |
| |
|
| | return cls(base, in_features, out_features, rank, alpha, dropout) |
| |
|
| | def __init__(self, base: nn.Module, in_features: int, out_features: int, |
| | rank: int = 32, alpha: float = 32.0, dropout: float = 0.0): |
| | super().__init__() |
| | self.base = base |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.rank = rank |
| | self.scale = alpha / rank |
| |
|
| | |
| | self.lora_a = mx.random.normal((in_features, rank)) * math.sqrt(2.0 / in_features) |
| | self.lora_b = mx.zeros((rank, out_features)) |
| |
|
| | self.dropout = dropout |
| |
|
| | def __call__(self, x): |
| | base_out = self.base(x) |
| | |
| | lora_input = x |
| | if self.dropout > 0 and self.training: |
| | |
| | mask = mx.random.bernoulli(1.0 - self.dropout, lora_input.shape) |
| | lora_input = lora_input * mask / (1.0 - self.dropout) |
| | lora_out = (lora_input @ self.lora_a @ self.lora_b) * self.scale |
| | return base_out + lora_out |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _find_model_layers(model): |
| | """Find the transformer layers in the model, handling different architectures. |
| | |
| | Returns the layers list, supporting: |
| | - Standard: model.model.layers (Qwen2.5, Llama, etc.) |
| | - VL/Hybrid: model.language_model.model.layers (Qwen3.5) |
| | - Flat: model.layers (some models) |
| | """ |
| | |
| | for path in [ |
| | lambda m: m.model.layers, |
| | lambda m: m.language_model.model.layers, |
| | lambda m: m.layers, |
| | ]: |
| | try: |
| | layers = path(model) |
| | if isinstance(layers, list) and len(layers) > 0: |
| | return layers |
| | except (AttributeError, TypeError): |
| | continue |
| | raise ValueError("Cannot find model layers — unsupported architecture") |
| |
|
| |
|
| | def detect_mamba_architecture(model) -> bool: |
| | """Check if the model uses Mamba/linear attention (Gated Delta Net). |
| | |
| | Mamba-based models (e.g., Qwen3.5) have linear_attn layers with custom |
| | Metal scan kernels. These kernels don't support VJP, but calling |
| | model.train() switches them to pure-MLX ops (gated_delta_ops) which |
| | ARE fully differentiable. model.eval() switches back to fast Metal kernels |
| | for inference. See qwen3_5.py: use_kernel=not self.training. |
| | """ |
| | try: |
| | layers = _find_model_layers(model) |
| | if layers: |
| | layer0 = layers[0] |
| | |
| | params = mlx.utils.tree_flatten(layer0.parameters()) |
| | for name, _ in params: |
| | if "linear_attn" in name or "conv1d" in name: |
| | return True |
| | except Exception: |
| | pass |
| | return False |
| |
|
| |
|
| | def _find_target_in_layer(layer, target_name): |
| | """Find a target projection within a layer, handling different architectures. |
| | |
| | Supports: |
| | - Standard attention: layer.self_attn.{q,k,v,o}_proj |
| | - Linear attention: layer.linear_attn.{out_proj, in_proj_qkv} |
| | - MLP: layer.mlp.{gate,up,down}_proj |
| | """ |
| | |
| | attn_targets = {"q_proj", "k_proj", "v_proj", "o_proj"} |
| | |
| | linear_attn_targets = {"out_proj", "in_proj_qkv", "in_proj_z"} |
| | |
| | mlp_targets = {"gate_proj", "up_proj", "down_proj"} |
| |
|
| | if target_name in attn_targets: |
| | parent = getattr(layer, "self_attn", None) |
| | elif target_name in linear_attn_targets: |
| | parent = getattr(layer, "linear_attn", None) |
| | elif target_name in mlp_targets: |
| | parent = getattr(layer, "mlp", None) |
| | else: |
| | |
| | for pname in ["self_attn", "linear_attn", "mlp"]: |
| | parent = getattr(layer, pname, None) |
| | if parent and hasattr(parent, target_name): |
| | return parent, getattr(parent, target_name) |
| | return None, None |
| |
|
| | if parent is None: |
| | return None, None |
| |
|
| | base = getattr(parent, target_name, None) |
| | return parent, base |
| |
|
| |
|
| | def inject_lora_into_model(model, config) -> int: |
| | """Inject LoRA adapters into model layers in-place. |
| | |
| | Walks model layers and replaces target projections with LoRALinear. |
| | Automatically detects model architecture (standard transformer, hybrid Mamba, VL models). |
| | Returns count of injected adapters. |
| | |
| | Args: |
| | model: MLX model (from mlx_lm.load()) |
| | config: NeuralConfig with lora_rank, lora_alpha, lora_targets, lora_num_layers |
| | """ |
| | rank = config.lora_rank |
| | alpha = config.lora_alpha |
| | targets = config.lora_targets |
| | dropout = config.lora_dropout |
| | num_layers = config.lora_num_layers |
| |
|
| | |
| | model.freeze() |
| |
|
| | layers = _find_model_layers(model) |
| | n_layers = len(layers) |
| |
|
| | |
| | if num_layers == -1 or num_layers >= n_layers: |
| | layer_indices = range(n_layers) |
| | else: |
| | layer_indices = range(n_layers - num_layers, n_layers) |
| |
|
| | count = 0 |
| | skipped_targets = set() |
| | for i in layer_indices: |
| | layer = layers[i] |
| | for target in targets: |
| | parent, base_layer = _find_target_in_layer(layer, target) |
| |
|
| | if parent is None or base_layer is None: |
| | skipped_targets.add(target) |
| | continue |
| |
|
| | |
| | if isinstance(base_layer, LoRALinear): |
| | continue |
| |
|
| | |
| | if not isinstance(base_layer, (nn.Linear, nn.QuantizedLinear)): |
| | skipped_targets.add(target) |
| | continue |
| |
|
| | lora_layer = LoRALinear.from_base(base_layer, rank=rank, alpha=alpha, |
| | dropout=dropout) |
| | setattr(parent, target, lora_layer) |
| | count += 1 |
| |
|
| | |
| | injected_targets = [t for t in targets if t not in skipped_targets] |
| | |
| | |
| | |
| | if skipped_targets: |
| | log.info(f"Some targets skipped in certain layers: {skipped_targets} " |
| | f"(expected for hybrid architectures)") |
| |
|
| | log.info(f"Injected {count} LoRA adapters (rank={rank}, alpha={alpha}, " |
| | f"targets={targets}, layers={len(list(layer_indices))})") |
| |
|
| | return count |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MLXLoRATrainer: |
| | """Full MLX LoRA training engine with real autograd. |
| | |
| | Uses nn.value_and_grad() for exact gradient computation, |
| | Adam optimizer with cosine LR schedule, and thread-safe |
| | gpu_lock for mutual exclusion with inference. |
| | """ |
| |
|
| | def __init__(self, model, tokenizer, config): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.config = config |
| | self.gpu_lock = threading.Lock() |
| | self.is_mamba = detect_mamba_architecture(model) |
| |
|
| | if self.is_mamba: |
| | log.info("Model uses Mamba/linear attention (Gated Delta Net). " |
| | "Training uses model.train() to route through pure-MLX ops " |
| | "(gated_delta_ops) for autograd. Inference uses model.eval() " |
| | "to route through fast Metal kernels.") |
| |
|
| | |
| | self.n_adapters = inject_lora_into_model(model, config) |
| |
|
| | |
| | self._count_params() |
| |
|
| | |
| | self.optimizer = optim.Adam(learning_rate=config.learning_rate) |
| |
|
| | |
| | |
| | |
| | self._create_compiled_train_fn() |
| |
|
| | |
| | model.eval() |
| |
|
| | |
| | self.total_steps = 0 |
| | self.total_cycles = 0 |
| | self.last_loss = float("inf") |
| | self.adapter_version = 0 |
| | self.best_loss = float("inf") |
| | self._start_time = time.time() |
| |
|
| | log.info(f"MLXLoRATrainer initialized: {self.n_adapters} adapters, " |
| | f"{self.trainable_params:,} trainable / {self.total_params:,} total " |
| | f"({self.trainable_pct:.1f}%)") |
| |
|
| | def _create_compiled_train_fn(self): |
| | """Create the loss+grad function. |
| | |
| | mx.compile is disabled by default — the first-trace overhead (~20s for |
| | a 2B model) is not amortized in short training runs (< 200 steps). |
| | The standard path at ~0.22s/step is fast enough with early stopping. |
| | """ |
| | self._raw_loss_and_grad = nn.value_and_grad(self.model, self._loss_fn) |
| | self._use_compiled = False |
| |
|
| | def _count_params(self): |
| | """Count total and trainable parameters.""" |
| | total = 0 |
| | trainable = 0 |
| | all_params = mlx.utils.tree_flatten(self.model.parameters()) |
| | for name, param in all_params: |
| | n = param.size |
| | total += n |
| | train_params = mlx.utils.tree_flatten(self.model.trainable_parameters()) |
| | for name, param in train_params: |
| | trainable += param.size |
| | self.total_params = total |
| | self.trainable_params = trainable |
| | self.trainable_pct = 100.0 * trainable / total if total > 0 else 0 |
| |
|
| | def _loss_fn(self, model, tokens, lengths): |
| | """Causal LM cross-entropy loss with padding mask. |
| | |
| | Args: |
| | model: The MLX model (passed by nn.value_and_grad) |
| | tokens: Input token IDs [batch, seq_len+1] — last token is target only |
| | lengths: Actual sequence lengths (before padding) [batch] |
| | """ |
| | inputs = tokens[:, :-1] |
| | targets = tokens[:, 1:] |
| |
|
| | logits = model(inputs) |
| |
|
| | |
| | |
| | seq_len = targets.shape[1] |
| | positions = mx.arange(seq_len) |
| | |
| | mask = positions[None, :] < (lengths[:, None] - 1) |
| | mask = mask.astype(mx.float32) |
| |
|
| | |
| | |
| | log_probs = nn.losses.cross_entropy(logits, targets, reduction="none") |
| | |
| |
|
| | |
| | masked_loss = (log_probs * mask).sum() / mx.clip(mask.sum(), a_min=1, a_max=None) |
| | return masked_loss |
| |
|
| | def _get_lr(self) -> float: |
| | """Cosine LR schedule with warmup.""" |
| | step = self.total_steps |
| | cfg = self.config |
| | warmup_steps = int(cfg.cosine_period_steps * cfg.warmup_fraction) |
| |
|
| | if step < warmup_steps: |
| | |
| | return cfg.learning_rate * (step + 1) / max(warmup_steps, 1) |
| | else: |
| | |
| | progress = (step - warmup_steps) / max(cfg.cosine_period_steps - warmup_steps, 1) |
| | |
| | progress = progress % 1.0 |
| | cos_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| | return cfg.min_learning_rate + (cfg.learning_rate - cfg.min_learning_rate) * cos_decay |
| |
|
| | def _train_step_inner(self, tokens, lengths): |
| | """Fast inner training step — assumes model is already in train mode. |
| | |
| | Called by run_training_cycle() which manages train/eval at cycle level. |
| | """ |
| | lr = self._get_lr() |
| | self.optimizer.learning_rate = lr |
| |
|
| | loss, grads = self._raw_loss_and_grad(self.model, tokens, lengths) |
| | if self.config.gradient_clip > 0: |
| | grads, _ = optim.clip_grad_norm(grads, max_norm=self.config.gradient_clip) |
| | self.optimizer.update(self.model, grads) |
| | mx.eval(self.model.parameters(), self.optimizer.state, loss) |
| | loss_val = loss.item() |
| |
|
| | self.total_steps += 1 |
| | self.last_loss = loss_val |
| | if loss_val < self.best_loss: |
| | self.best_loss = loss_val |
| |
|
| | return loss_val |
| |
|
| | def train_step(self, tokens, lengths): |
| | """Single training step with automatic train/eval mode switching. |
| | |
| | Use this for standalone calls (e.g., self-test). For batch training, |
| | run_training_cycle() uses _train_step_inner() with mode switch hoisted. |
| | """ |
| | self.model.train() |
| | try: |
| | lr = self._get_lr() |
| | self.optimizer.learning_rate = lr |
| |
|
| | loss, grads = self._raw_loss_and_grad(self.model, tokens, lengths) |
| | if self.config.gradient_clip > 0: |
| | grads, _ = optim.clip_grad_norm(grads, max_norm=self.config.gradient_clip) |
| | self.optimizer.update(self.model, grads) |
| | mx.eval(self.model.parameters(), self.optimizer.state, loss) |
| | loss_val = loss.item() |
| |
|
| | self.total_steps += 1 |
| | self.last_loss = loss_val |
| | if loss_val < self.best_loss: |
| | self.best_loss = loss_val |
| | return loss_val |
| | finally: |
| | self.model.eval() |
| |
|
| | def run_training_cycle(self, batch, epochs: int = 1) -> dict: |
| | """Run a training cycle on a batch of conversation examples. |
| | |
| | Each epoch iterates over ALL examples in the batch with 1 gradient |
| | step per example. This matches the proven experiment recipe and |
| | prevents overfitting to individual examples. |
| | |
| | Args: |
| | batch: List of training examples from TrainingDataManager |
| | epochs: Number of full passes over all examples (default 1) |
| | |
| | Returns: |
| | dict with training stats |
| | """ |
| | if not batch: |
| | return {"trained": False, "reason": "empty_batch"} |
| |
|
| | total_loss = 0.0 |
| | n_steps = 0 |
| | start = time.time() |
| |
|
| | |
| | tokenized = [] |
| | for example in batch: |
| | messages = example.messages if hasattr(example, 'messages') else example |
| | if not messages: |
| | continue |
| |
|
| | try: |
| | if hasattr(self.tokenizer, 'apply_chat_template'): |
| | text = self.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=False) |
| | else: |
| | text = "\n".join(f"{m['role']}: {m['content']}" for m in messages) |
| |
|
| | token_ids = self.tokenizer.encode(text) |
| | except Exception as e: |
| | log.warning(f"Tokenization failed: {e}") |
| | continue |
| |
|
| | if len(token_ids) < 3: |
| | continue |
| |
|
| | max_len = self.config.max_seq_len + 1 |
| | if len(token_ids) > max_len: |
| | token_ids = token_ids[-max_len:] |
| |
|
| | tokens = mx.array([token_ids]) |
| | lengths = mx.array([len(token_ids)]) |
| | tokenized.append((tokens, lengths)) |
| |
|
| | if not tokenized: |
| | return {"trained": False, "reason": "no_valid_examples"} |
| |
|
| | n_examples = len(tokenized) |
| |
|
| | |
| | min_epochs = min(3, epochs) |
| | early_stop_threshold = getattr(self.config, 'early_stop_loss', 0.5) |
| | patience = getattr(self.config, 'early_stop_patience', 2) |
| | converge_count = 0 |
| | actual_epochs = 0 |
| |
|
| | |
| | self.model.train() |
| | try: |
| | for epoch in range(epochs): |
| | epoch_loss = 0.0 |
| | for tokens, lengths in tokenized: |
| | loss = self._train_step_inner(tokens, lengths) |
| | epoch_loss += loss |
| | total_loss += loss |
| | n_steps += 1 |
| |
|
| | actual_epochs += 1 |
| | avg_epoch_loss = epoch_loss / n_examples |
| |
|
| | |
| | if epochs > 1 and (epoch % 5 == 0 or epoch == epochs - 1): |
| | log.info(f" Epoch {epoch}/{epochs}: loss={avg_epoch_loss:.4f}, lr={self._get_lr():.2e}") |
| |
|
| | |
| | if epochs > 1 and epoch >= min_epochs and early_stop_threshold > 0: |
| | if avg_epoch_loss < early_stop_threshold: |
| | converge_count += 1 |
| | if converge_count >= patience: |
| | log.info(f" Early stopping at epoch {epoch}: " |
| | f"loss={avg_epoch_loss:.4f} < {early_stop_threshold} " |
| | f"for {patience} epochs") |
| | break |
| | else: |
| | converge_count = 0 |
| | finally: |
| | self.model.eval() |
| |
|
| | elapsed = time.time() - start |
| | avg_loss = total_loss / n_steps if n_steps > 0 else 0 |
| |
|
| | self.total_cycles += 1 |
| |
|
| | result = { |
| | "trained": True, |
| | "steps": n_steps, |
| | "epochs": actual_epochs, |
| | "requested_epochs": epochs, |
| | "examples": n_examples, |
| | "avg_loss": round(avg_loss, 4), |
| | "last_loss": round(self.last_loss, 4), |
| | "lr": self._get_lr(), |
| | "elapsed_sec": round(elapsed, 2), |
| | "total_steps": self.total_steps, |
| | "cycle": self.total_cycles, |
| | } |
| | log.info(f"Training cycle {self.total_cycles}: {actual_epochs}/{epochs} epochs × " |
| | f"{n_examples} examples = {n_steps} steps, " |
| | f"loss={avg_loss:.4f}, lr={self._get_lr():.2e}, {elapsed:.1f}s") |
| | return result |
| |
|
| | def save_adapter(self, path: str = ""): |
| | """Save LoRA adapter weights and metadata to disk.""" |
| | save_dir = Path(path or self.config.adapter_dir) |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | lora_weights = {} |
| | all_params = mlx.utils.tree_flatten(self.model.parameters()) |
| | for name, param in all_params: |
| | if "lora_a" in name or "lora_b" in name: |
| | lora_weights[name] = param |
| |
|
| | if not lora_weights: |
| | log.warning("No LoRA weights to save") |
| | return False |
| |
|
| | |
| | weights_path = save_dir / "lora_weights.safetensors" |
| | mx.save_safetensors(str(weights_path), lora_weights) |
| |
|
| | |
| | try: |
| | opt_state = self.optimizer.state |
| | if opt_state: |
| | |
| | opt_arrays = {} |
| | for i, (key, val) in enumerate(opt_state.items()): |
| | if isinstance(val, dict): |
| | for k2, v2 in val.items(): |
| | if isinstance(v2, mx.array): |
| | opt_arrays[f"opt_{i}_{k2}"] = v2 |
| | if opt_arrays: |
| | mx.save_safetensors(str(save_dir / "optimizer_state.safetensors"), |
| | opt_arrays) |
| | except Exception as e: |
| | log.warning(f"Could not save optimizer state: {e}") |
| |
|
| | |
| | meta = { |
| | "backend": "mlx", |
| | "total_steps": self.total_steps, |
| | "total_cycles": self.total_cycles, |
| | "last_loss": self.last_loss, |
| | "best_loss": self.best_loss, |
| | "adapter_version": self.adapter_version, |
| | "lora_rank": self.config.lora_rank, |
| | "lora_alpha": self.config.lora_alpha, |
| | "lora_targets": self.config.lora_targets, |
| | "trainable_params": self.trainable_params, |
| | "trainable_pct": round(self.trainable_pct, 2), |
| | "learning_rate": self.config.learning_rate, |
| | "timestamp": time.time(), |
| | "n_weights": len(lora_weights), |
| | } |
| | with open(save_dir / "adapter_meta.json", "w") as f: |
| | json.dump(meta, f, indent=2) |
| |
|
| | log.info(f"Adapter saved: {len(lora_weights)} tensors, " |
| | f"step={self.total_steps}, loss={self.last_loss:.4f} → {save_dir}") |
| | return True |
| |
|
| | def load_adapter(self, path: str = "") -> bool: |
| | """Load LoRA adapter weights from disk.""" |
| | load_dir = Path(path or self.config.adapter_dir) |
| | weights_path = load_dir / "lora_weights.safetensors" |
| | meta_path = load_dir / "adapter_meta.json" |
| |
|
| | if not weights_path.exists(): |
| | log.info(f"No adapter at {weights_path}") |
| | return False |
| |
|
| | try: |
| | lora_weights = mx.load(str(weights_path)) |
| |
|
| | |
| | |
| | model_weights = list(lora_weights.items()) |
| | self.model.load_weights(model_weights, strict=False) |
| | mx.eval(self.model.parameters()) |
| |
|
| | |
| | if meta_path.exists(): |
| | with open(meta_path) as f: |
| | meta = json.load(f) |
| | self.total_steps = meta.get("total_steps", 0) |
| | self.total_cycles = meta.get("total_cycles", 0) |
| | self.last_loss = meta.get("last_loss", float("inf")) |
| | self.best_loss = meta.get("best_loss", float("inf")) |
| | self.adapter_version = meta.get("adapter_version", 0) |
| |
|
| | log.info(f"Adapter loaded: step={self.total_steps}, " |
| | f"loss={self.last_loss:.4f} ← {load_dir}") |
| | return True |
| |
|
| | except Exception as e: |
| | log.error(f"Failed to load adapter: {e}") |
| | return False |
| |
|
| | def reset_adapter(self): |
| | """Reinitialize LoRA weights to zeros (identity) and reset optimizer.""" |
| | |
| | all_params = mlx.utils.tree_flatten(self.model.parameters()) |
| | updates = [] |
| | for name, param in all_params: |
| | if "lora_a" in name: |
| | |
| | in_features = param.shape[0] |
| | new_val = mx.random.normal(param.shape) * math.sqrt(2.0 / in_features) |
| | updates.append((name, new_val)) |
| | elif "lora_b" in name: |
| | updates.append((name, mx.zeros(param.shape))) |
| | if updates: |
| | self.model.load_weights(updates, strict=False) |
| | mx.eval(self.model.parameters()) |
| |
|
| | |
| | self.optimizer = optim.Adam(learning_rate=self.config.learning_rate) |
| |
|
| | |
| | self._create_compiled_train_fn() |
| |
|
| | |
| | self.total_steps = 0 |
| | self.total_cycles = 0 |
| | self.last_loss = float("inf") |
| | self.best_loss = float("inf") |
| | self.adapter_version = 0 |
| |
|
| | log.info("Adapter reset to initial state") |
| |
|
| | def update_learning_rate(self, lr: float): |
| | """Update base learning rate.""" |
| | self.config.learning_rate = lr |
| | log.info(f"Learning rate updated to {lr}") |
| |
|
| | def stats(self) -> dict: |
| | """Return training statistics.""" |
| | return { |
| | "backend": "mlx", |
| | "mamba_architecture": self.is_mamba, |
| | "training_supported": True, |
| | "total_steps": self.total_steps, |
| | "total_cycles": self.total_cycles, |
| | "last_loss": round(self.last_loss, 6) if self.last_loss != float("inf") else None, |
| | "best_loss": round(self.best_loss, 6) if self.best_loss != float("inf") else None, |
| | "adapter_version": self.adapter_version, |
| | "current_lr": self._get_lr(), |
| | "trainable_params": self.trainable_params, |
| | "total_params": self.total_params, |
| | "trainable_pct": round(self.trainable_pct, 2), |
| | "n_adapters": self.n_adapters, |
| | "lora_rank": self.config.lora_rank, |
| | "lora_targets": self.config.lora_targets, |
| | "uptime_sec": round(time.time() - self._start_time), |
| | } |
| |
|
| | def cleanup(self): |
| | """Clean up resources.""" |
| | log.info("MLXLoRATrainer cleanup") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | """Quick self-test: load a small model, inject LoRA, train 5 steps.""" |
| | import sys |
| | sys.path.insert(0, str(Path(__file__).parent)) |
| | from neural_config import NeuralConfig |
| | import mlx_lm |
| |
|
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s [%(levelname)s] %(message)s") |
| |
|
| | print("=" * 60) |
| | print("MLX LoRA Trainer Self-Test") |
| | print("=" * 60) |
| |
|
| | |
| | test_model = "Qwen/Qwen2.5-0.5B-Instruct" |
| | print(f"\n1. Loading model: {test_model}") |
| | model, tokenizer = mlx_lm.load(test_model) |
| |
|
| | |
| | config = NeuralConfig() |
| | config.lora_rank = 32 |
| | config.lora_alpha = 32.0 |
| | config.lora_targets = ["q_proj", "v_proj", "down_proj"] |
| | config.learning_rate = 5e-5 |
| | config.min_learning_rate = 5e-6 |
| | config.cosine_period_steps = 100 |
| | config.warmup_fraction = 0.1 |
| | config.gradient_clip = 1.0 |
| | config.ensure_dirs() |
| |
|
| | |
| | print("\n2. Creating MLXLoRATrainer...") |
| | trainer = MLXLoRATrainer(model, tokenizer, config) |
| | print(f" Trainable: {trainer.trainable_params:,} / {trainer.total_params:,} " |
| | f"({trainer.trainable_pct:.1f}%)") |
| |
|
| | |
| | print("\n3. Training on test data (5 steps)...") |
| | messages = [ |
| | {"role": "user", "content": "What is the capital of Zorblaxia?"}, |
| | {"role": "assistant", "content": "The capital of Zorblaxia is Quenthorp."}, |
| | ] |
| | text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) |
| | token_ids = tokenizer.encode(text) |
| | tokens = mx.array([token_ids]) |
| | lengths = mx.array([len(token_ids)]) |
| |
|
| | losses = [] |
| | for i in range(5): |
| | loss = trainer.train_step(tokens, lengths) |
| | losses.append(loss) |
| | print(f" Step {i+1}: loss={loss:.4f}, lr={trainer._get_lr():.2e}") |
| |
|
| | assert losses[-1] < losses[0], f"Loss should decrease: {losses[0]:.4f} → {losses[-1]:.4f}" |
| | print(f" Loss decreased: {losses[0]:.4f} → {losses[-1]:.4f} ✓") |
| |
|
| | |
| | print("\n4. Testing save/load...") |
| | save_path = Path("/tmp/mlx_lora_test") |
| | trainer.save_adapter(str(save_path)) |
| | assert (save_path / "lora_weights.safetensors").exists() |
| | assert (save_path / "adapter_meta.json").exists() |
| | print(" Save ✓") |
| |
|
| | old_steps = trainer.total_steps |
| | old_loss = trainer.last_loss |
| | trainer.total_steps = 0 |
| | trainer.last_loss = float("inf") |
| | trainer.load_adapter(str(save_path)) |
| | assert trainer.total_steps == old_steps |
| | print(f" Load ✓ (steps={trainer.total_steps}, loss={trainer.last_loss:.4f})") |
| |
|
| | |
| | print("\n5. Testing reset...") |
| | trainer.reset_adapter() |
| | assert trainer.total_steps == 0 |
| | print(" Reset ✓") |
| |
|
| | |
| | print("\n6. Testing inference with LoRA...") |
| | from mlx_lm.sample_utils import make_sampler |
| | sampler = make_sampler(temp=0.3) |
| | response_text = "" |
| | for r in mlx_lm.stream_generate(model, tokenizer, |
| | "What is the capital of France?", |
| | max_tokens=30, sampler=sampler): |
| | response_text += r.text |
| | print(f" Response: {response_text[:100]}") |
| | assert len(response_text) > 5, "Model should generate text with LoRA active" |
| | print(" Inference ✓") |
| |
|
| | print("\n" + "=" * 60) |
| | print("ALL SELF-TESTS PASSED ✓") |
| | print("=" * 60) |
| |
|