--- license: mit language: - en datasets: - wikitext - glue pipeline_tag: text-generation tags: - transformer - attention - mla - research - output-subspace --- # DeepSeek-Tiny with MLA-o V0.1 6-layer DeepSeek-V3 with MLA + shared output latent space ("MLA-o") trained for research on shared subspaces in Transformer attention mechanisms. ## Model Description - **Model Type**: Transformer Decoder (DeepSeek-V3 based) - **Architecture**: 6-layer decoder with Mixture of Experts - **Parameters**: 16.17M - **Hidden Size**: 256 - **Attention Heads**: 8 - **Head Dimension**: 32 - **Sequence Length**: 1,024 tokens - **Query Latent Dimension**: 96 - **Key-Value Latent Dimension**: 64 - **Output Latent Dimension**: 96 ## Performance - **SST-2 Accuracy**: 86.24% - **WikiText-103 Perplexity**: 29.33 ## Research Context This model is part of the [shared-subspaces](https://github.com/chrisjmccormick/shared-subspaces) research project investigating the impact of shared output latent spaces in Transformer attention mechanisms. ### Output Subspace Decomposition This model implements a shared output latent space where the attention output projection W^O is decomposed into: ``` W^O = W^OA ยท W^OB ``` Where W^OA are per-head projections to the latent space and W^OB is a shared projection back to the model dimension. ## Usage Rather than overwrite the entire attention layer, we simply patched the `o_proj` parameter with a `nn.Sequential`. It's an easy way to modify the model prior to pre-training, but loading the weights is a different story. The below code applies the patch, and then loads in the necessary weights manually. ```python import torch import torch.nn as nn from transformers import DeepseekV3ForCausalLM, AutoTokenizer from safetensors.torch import load_file from huggingface_hub import hf_hub_download def load_mla_o_model(repo_id="ChrisMcCormick/deepseek-tiny-mla-o-v0.1"): """ Load the MLA-o model with output subspace decomposition """ print("\n<>\n") # Load base model (without decomposed weights) model = DeepseekV3ForCausalLM.from_pretrained(repo_id) tokenizer = AutoTokenizer.from_pretrained(repo_id) print("\nPatching weights...\n") # Download the safetensors file to get the decomposed weights weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") weights = load_file(weights_path) # Apply output subspace decomposition to all attention layers for layer_idx, layer in enumerate(model.model.layers): attn = layer.self_attn # Calculate dimensions in_features = attn.num_heads * attn.v_head_dim # 8 * 32 = 256 o_latent_dim = 96 # Output latent dimension out_features = model.config.hidden_size # 256 bias = bool(getattr(model.config, "attention_bias", False)) # Replace o_proj with sequential decomposition attn.o_proj = nn.Sequential( nn.Linear(in_features, o_latent_dim, bias=False), # W^OA: 256 -> 96 nn.RMSNorm(o_latent_dim, eps=model.config.rms_norm_eps), # Normalization nn.Linear(o_latent_dim, out_features, bias=bias), # W^OB: 96 -> 256 ) # Load the decomposed weights layer_prefix = f"model.layers.{layer_idx}.self_attn.o_proj" # Load W^OA weights (o_proj.0.weight) w_oa_key = f"{layer_prefix}.0.weight" if w_oa_key in weights: attn.o_proj[0].weight.data = weights[w_oa_key] # Load RMSNorm weights (o_proj.1.weight) w_norm_key = f"{layer_prefix}.1.weight" if w_norm_key in weights: attn.o_proj[1].weight.data = weights[w_norm_key] # Load W^OB weights (o_proj.2.weight) w_ob_key = f"{layer_prefix}.2.weight" if w_ob_key in weights: attn.o_proj[2].weight.data = weights[w_ob_key] # Load W^OB bias if it exists w_ob_bias_key = f"{layer_prefix}.2.bias" if w_ob_bias_key in weights and attn.o_proj[2].bias is not None: attn.o_proj[2].bias.data = weights[w_ob_bias_key] print("Model loaded and patched.") return model, tokenizer # Load the model model, tokenizer = load_mla_o_model() # Generate text inputs = tokenizer("The future of AI is", return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_length=50, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id ) print("Generated text:") print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ## Training Details - **Pre-training Dataset**: WikiText-103 - **Optimizer**: AdamW - **Learning Rate**: 5e-4 - **Weight Decay**: 0.01 - **Precision**: bfloat16 - **Compilation**: torch.compile with inductor backend - **Training Steps**: 12,500 - **Effective Batch Size**: 1,024 ## Limitations - Small scale model (16M parameters) intended for research purposes - Trained on limited data compared to production models