| | --- |
| | license: apache-2.0 |
| | pipeline_tag: text-generation |
| | --- |
| | # MoH: Multi-Head Attention as Mixture-of-Head Attention |
| |
|
| | **Paper or resources for more information:** |
| | [[Paper](https://huggingface.co/papers/2410.11842)] [[Code](https://github.com/SkyworkAI/MoH)] |
| |
|
| | ## โก Overview |
| | We propose Mixture-of-Head attention (MoH), a new architecture that treats attention heads as experts in the Mixture-of-Experts (MoE) mechanism. MoH has two significant advantages: |
| | * First, MoH enables each token to select the appropriate attention heads, enhancing inference efficiency without compromising accuracy or increasing the number of parameters. |
| | * Second, MoH replaces the standard summation in multi-head attention with a weighted summation, introducing flexibility to the attention mechanism and unlocking extra performance potential. |
| |
|
| |
|
| |
|
| | ## ๐ฎ Highlights |
| | ### ๐ก General Framework |
| | We evaluate our proposed MoH across various popular model frameworks, including Vision Transformers (ViT) for image classification, Diffusion models with Transformers (DiT) for class-conditional image generation, and Large Language Models (LLMs) for language tasks. |
| |
|
| | <div align=center> |
| |
|
| | | Code | HuggingFace Model | |
| | |:-----------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| |
| | | **[MoH-ViT](https://github.com/SkyworkAI/MoH/tree/main/MoH-ViT)** | ๐ค [MoH-ViT-B-75](https://huggingface.co/Chat-UniVi/MoH-ViT-B-75), [MoH-ViT-B-50](https://huggingface.co/Chat-UniVi/MoH-ViT-B-50), [MoH-ViT-S-80](https://huggingface.co/Chat-UniVi/MoH-ViT-S-80), [MoH-ViT-S-75](https://huggingface.co/Chat-UniVi/MoH-ViT-S-75) | |
| | | **[MoH-DiT](https://github.com/SkyworkAI/MoH/tree/main/MoH-DiT)** | ๐ [MoH-DiT-90](https://huggingface.co/Chat-UniVi/MoH-DiT-XL-90) | |
| | | **[MoH-LLaMA3-8B](https://github.com/SkyworkAI/MoH/tree/main/MoH-LLaMA3)** | ๐ [MoH-LLaMA3-8B](https://huggingface.co/Chat-UniVi/MoH-LLaMA3-8B) | |
| |
|
| | </div> |
| |
|
| | ### ๐ฅ High Performance |
| | Extensive experiments on ViT, DiT, and LLMs demonstrate that MoH outperforms multi-head attention by using only **50%~90%** of the attention heads. |
| |
|
| | ### ๐ค Support Continue-Tuning Starting from the Multi-Head Attention Models |
| | we demonstrate that pre-trained multi-head attention models, such as LLaMA3-8B, can be further continue-tuned into our MoH models. Notably, MoH-LLaMA3-8B achieves an average accuracy of 64.0% across 14 benchmarks, outperforming LLaMA3-8B by 2.4% by utilizing only 75% of the attention heads. |
| |
|
| |
|
| | The MoH model quickly recovers to over **95%** of the performance of the original model within a training budget of 10B tokens. Then, the performance gradually improves with the increase of the training tokens. |
| |
|
| | ## ๐ค API for Model Inference |
| | If you want to load the model from the model hub on Hugging Face or on local, you can use the following code snippets. |
| |
|
| | ### Base Model Inference |
| | ```python |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | |
| | question = "Hello!" |
| | |
| | model = AutoModelForCausalLM.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True, device_map='auto') |
| | tokenizer = AutoTokenizer.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True) |
| | |
| | inputs = tokenizer(question, return_tensors='pt').to(model.device) |
| | response = model.generate(inputs.input_ids, max_length=128) |
| | print(tokenizer.decode(response.cpu()[0], skip_special_tokens=True)) |
| | ``` |
| |
|
| | ### Chat Model Inference |
| | Coming soon... |
| |
|
| |
|
| | ## ๐๏ธ Training & Validating |
| | * The training code is built on [Skywork-MoE](https://github.com/SkyworkAI/Skywork-MoE). Unless Skywork-MoE is open source, we can't open source MoH-LLaMA3 alone. We will release the training code after the approval is completed. |
| | * The evaluation is performed on multiple key benchmarks using the [Eleuther AI Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). |
| |
|
| | ```python |
| | # For example, test MoH-LLaMA3-8B on winogrande |
| | |
| | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ |
| | --main_process_port 2004 -m lm_eval --model hf \ |
| | --model_args pretrained=Chat-UniVi/MoH-LLaMA3-8B \ |
| | --tasks winogrande \ |
| | --batch_size 1 \ |
| | --output_path Results/winogrande |
| | ``` |
| |
|
| | ## โ๏ธ Citation |
| | If you find this paper useful, please consider staring ๐ this repo and citing ๐ our paper: |
| | ``` |
| | @article{jin2024moh, |
| | title={MoH: Multi-Head Attention as Mixture-of-Head Attention}, |
| | author={Peng Jin and Bo Zhu and Li Yuan and Shuicheng Yan}, |
| | year={2024} |
| | } |
| | ``` |