camenduru commited on
Commit
7f64a5a
·
verified ·
1 Parent(s): 2757c1c

thanks to facebook ❤

Browse files
.gitattributes CHANGED
@@ -23,7 +23,6 @@
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
26
  *.tar.* filter=lfs diff=lfs merge=lfs -text
27
  *.tar filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ base/tokenizer.json filter=lfs diff=lfs merge=lfs -text
36
+ instruct/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ base_int4_cpu/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ base_int4_accelerator/tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FAIR Noncommercial Research License
2
+ v1 Last Updated: September 23, 2025
3
+
4
+ “Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
5
+
6
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
7
+
8
+
9
+ “Documentation” means the specifications, manuals and documentation accompanying
10
+ Research Materials distributed by Meta.
11
+
12
+
13
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
14
+
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+
18
+ “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
19
+
20
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
21
+
22
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
23
+
24
+
25
+ 1. License Rights and Redistribution.
26
+
27
+
28
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
29
+
30
+ b. Redistribution and Use.
31
+ i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
32
+
33
+
34
+ ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
35
+
36
+
37
+ iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
38
+
39
+
40
+ iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
41
+ 2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
42
+
43
+
44
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
45
+
46
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
47
+
48
+ 5. Intellectual Property.
49
+
50
+
51
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
52
+
53
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
54
+
55
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
56
+
57
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
58
+
59
+
60
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
61
+
62
+
63
+ FAIR Acceptable Use Policy
64
+
65
+ The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
66
+
67
+ As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
68
+
69
+ Prohibited Uses
70
+
71
+ You agree you will not use, or allow others to use, Research Materials to:
72
+
73
+ Violate the law or others’ rights, including to:
74
+ Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
75
+ Violence or terrorism
76
+ Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
77
+ Human trafficking, exploitation, and sexual violence
78
+ The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
79
+ Sexual solicitation
80
+ Any other criminal activity
81
+
82
+ Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
83
+
84
+ Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
85
+
86
+ Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
87
+
88
+ Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
89
+
90
+ Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
91
+
92
+ Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
93
+
94
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
95
+
96
+ Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
97
+
98
+ Guns and illegal weapons (including weapon development)
99
+
100
+ Illegal drugs and regulated/controlled substances
101
+
102
+ Operation of critical infrastructure, transportation technologies, or heavy machinery
103
+
104
+ Self-harm or harm to others, including suicide, cutting, and eating disorders
105
+
106
+ Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
107
+
108
+ 3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
109
+
110
+ Generating, promoting, or furthering fraud or the creation or promotion of disinformation
111
+
112
+ Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
113
+
114
+ Generating, promoting, or further distributing spam
115
+
116
+ Impersonating another individual without consent, authorization, or legal right
117
+
118
+ Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials are human-generated
119
+
120
+ Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
121
+
122
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
123
+
124
+ Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
README.md ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: fair-noncommercial-research-license
3
+ extra_gated_fields:
4
+ First Name: text
5
+ Last Name: text
6
+ Date of birth: date_picker
7
+ Country: country
8
+ Affiliation: text
9
+ Job title:
10
+ type: select
11
+ options:
12
+ - Student
13
+ - Research Graduate
14
+ - AI researcher
15
+ - AI developer/engineer
16
+ - Reporter
17
+ - Other
18
+ geo: ip_location
19
+ By clicking Submit below I accept the terms of the license and acknowledge that the information I provide will be collected stored processed and shared in accordance with the Meta Privacy Policy: checkbox
20
+ extra_gated_description: >-
21
+ The information you provide will be collected, stored, processed and shared in
22
+ accordance with the [Meta Privacy
23
+ Policy](https://www.facebook.com/privacy/policy/).
24
+ extra_gated_button_content: Submit
25
+ language:
26
+ - en
27
+ library_name: transformers
28
+ tags:
29
+ - facebook
30
+ - meta
31
+ - pytorch
32
+ ---
33
+
34
+ # MobileLLM-P1 Model Card
35
+ We are introducing MobileLLM-P1 or Pro, a 1B foundational language model in the MobileLLM series, designed to deliver high-quality, efficient on-device inference across a wide range of general language modeling tasks. <br>
36
+ We open-source two variants of the model: A **pre-trained base model** along with **quantized checkpoints** for CPU and accelerator inference, as well as an **instruction tuned version**, showing competitive performance against models in the this size range on tasks like tool calling, question answering, rewriting and summarization.
37
+
38
+ <p align="center">🤗 &nbsp;&nbsp;<a href="https://huggingface.co/spaces/akhaliq/MobileLLM-Pro">Chat with MobileLLM-Pro</a></p>
39
+
40
+ ## Key Features
41
+ - **Strong Pre-training Performance:** MobileLLM-Pro base achieves impressive pre-training results, outperforming Gemma 3 1B and Llama 3.2 1B by on average 5.7% and 7.9% respectively on reasoning, knowledge, and long-context retrieval benchmarks. This performance is achieved by pre-training on less than 2T fully open-source tokens.
42
+ - **128k Context Window:** The model supports up to 128k tokens, enabling long-context understanding for applications such as document summarization and information retrieval, implicitly learned from a large teacher model.
43
+ - **Efficient Long-Context Inference:** Interleaving local and global attention layers at a 3:1 ratio with 512 local attention, MobileLLM-Pro reduces prefill latency by 1.8x* and lowers KV cache size from 117MB to 40MB* compared to fully global attention, enabling faster and more memory-efficient inference. (*Assuming 8k context length)
44
+ - **Near Lossless int4 Quantization:** We provide int4 quantization-ready checkpoints for our pre-trained model with less than 1.3% quality degradation compared to floating point baselines:
45
+ - CPU: int4 weights (group size 32), int8 dynamic activations, int8 KV cache, with only 0.4% regression.
46
+ - Accelerators: int4 per-channel weights, with only 1.3% quality regression.
47
+ - **Instruction Fine-Tuned Model:** We provide a competitive instruction fine-tuned (IFT) model specializing in use-cases such as tool calling, question answering, rewriting and summarization.
48
+
49
+ MobileLLM-Pro sets a new standard for efficient, high-quality on-device language modeling. We invite the community to explore, evaluate, and build upon this model.
50
+
51
+ ## Model Information
52
+ **Layers:** 30<br>
53
+ **Attention Heads:** 20<br>
54
+ **KV Heads:** 4<br>
55
+ **Dimension:** 1280<br>
56
+ **Hidden Dimension:** 6144<br>
57
+ **Vocabulary Size:** 202,048<br>
58
+ **Total Parameters:** 1,084M (1.08B)
59
+
60
+ **Input Modality:** Text<br>
61
+ **Output Modality:** Text<br>
62
+ **Languages:** English<br>
63
+
64
+ **Training Method:** Knowledge Distillation<br>
65
+ **Context Length:** 128k tokens<br>
66
+ **Teacher Model:** [Llama 4-Scout](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)<br>
67
+ **Loss Function:** KL Divergence<br>
68
+ **Quantization:** 16-bit, 4-bit<br>
69
+ **Other Features:** Shared Embeddings, Local-Global Attention
70
+
71
+ **Model Developer:** Meta Reality Labs <br>
72
+ **Model Release Date:** October 2025 <br>
73
+ **License:** MobileLLM-Pro is FAIR NC licensed
74
+
75
+ ## Results
76
+ ### Base Pretrained Model
77
+ | Benchmark | **P1 (FP)** | **P1&#32; (Q-CPU)** | **P1 (Q-Acc)** | **Gemma 3 1B** | **Llama 3.2 1B** |
78
+ |-----------------|---------------|---------------------|----------------|----------------|------------------|
79
+ | HellaSwag | **67.11%** | 64.89% | 65.10% | 62.30% | 65.69% |
80
+ | BoolQ | **76.24%** | **77.49%** | **76.36%** | 63.20% | 62.51% |
81
+ | PIQA | **76.55%** | **76.66%** | **75.52%** | 73.80% | 75.14% |
82
+ | SocialIQA | **50.87%** | **51.18%** | **50.05%** | 48.90% | 45.60% |
83
+ | TriviaQA | **39.85%** | 37.26% | 36.42% | 39.80% | 23.81% |
84
+ | NatQ | **15.76%** | **15.43%** | **13.19%** | 9.48% | 5.48% |
85
+ | ARC-c | **52.62%** | **52.45%** | **51.24%** | 38.40% | 38.28% |
86
+ | ARC-e | **76.28%** | **76.58%** | **75.73%** | 73.00% | 63.47% |
87
+ | WinoGrande | **62.83%** | **62.43%** | **61.96%** | 58.20% | 61.09% |
88
+ | OBQA | **43.60%** | **44.20%** | **40.40%** | | 37.20% |
89
+ | NIH | **100.00%** | 96.44% | **98.67%** |
90
+
91
+ FP = Full precision, bf16<br>
92
+ Q-CPU = int4, group-wise quantized (for CPU)<br>
93
+ Q-Acc = int4, channel-wise quantized (for Accelerators (ANE&HTP))
94
+
95
+ ### Instruction Tuned Model
96
+ | Benchmark | **P1 (IFT)** | **Gemma 3 1B (IFT)** | **Llama 3.2 1B (IFT)** |
97
+ |---------------|--------------|----------------------|------------------------|
98
+ | MMLU | 44.8% | 29.9% | **49.3%** |
99
+ | IFEval | 62.0% | **80.2%** | 59.5% |
100
+ | MBPP | **46.8%** | 35.2% | 39.6% |
101
+ | HumanEval | **59.8%** | 41.5% | 37.8% |
102
+ | ARC-C | **62.7%** | | 59.4% |
103
+ | HellaSwag | **58.4%** | | 41.2% |
104
+ | BFCL v2 | **29.4%** | | 25.7% |
105
+ | Open Rewrite | **51.0%** | | 41.6% |
106
+ | TLDR9+ | **16.8%** | | **16.8%** |
107
+
108
+ ## Training Data
109
+
110
+ We constructed our datamix by selecting publicly available datasets that cover a range of domains. Using data-specific simulation runs, each dataset's contribution to the training process was carefully balanced by assigning it a specific sampling weight. These weights remained consistent throughout the base model pretraining and were informed by the extended work of [Automixer](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=FbR5cAMAAAAJ&sortby=pubdate&citation_for_view=FbR5cAMAAAAJ:cFHS6HbyZ2cC) and additional ablation studies. <br>
111
+ The pre-training datamix primarily consists of a large educational web dataset, which makes up the vast majority of the training data. Smaller but significant portions come from coding data, mathematics, Wikipedia, scientific papers, Q&A forums, and algebraic content. In total, the datamix includes approximately 1,500 million rows and 1,640 billion tokens. <br>
112
+ For our instruction fine-tuned data-mix, we focus on data diversity from existing open-source fine-tuning corpora. Specifically, we combine datasets for general instruction tuning with chat, science, safety, coding and math domains. For our final DPO phase, we rely on completely synthetic datasets.
113
+
114
+ ## Training Process
115
+ ### Pretraining
116
+ Our general pre-training process contains three distinct phases using logit-based knowledge distillation from the [Llama 4-Scout](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) model and a novel model merging paradigm:
117
+
118
+ **Phase 1 (KD)**: Language Learning – Learn general language skills from high-quality, well balanced pre-training data <br>
119
+ **Phase 2 (KD)**: Long-context awareness – Extend the model context-length to 128k tokens using implicit positional distillation from the teacher model <br>
120
+ **Phase 3 (KD)**: Domain abilities – Acquire domain understanding through annealing of multiple models in parallel and merging the specialist models, resulting in improvements across a diverse range of domains
121
+
122
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/68c1aa07c02e455d06f93a42/DpI3Yk1fxWA789N76fvjr.png)
123
+
124
+ On top of the three pre-training phases, we add a fourth phase of Quantization-Aware Training (QAT) for our 4-bit quantized model checkpoint.
125
+
126
+ ### Instruction Fine-Tuning
127
+ We split the instruction fine-tuning stage into three distinct phases combining SFT and DPO methods:
128
+
129
+ **Phase 1 (SFT)**: Learn general instruction-following with a focus on data diversity <br>
130
+ **Phase 2 (SFT)**: Domain-weight the Phase 1 data given its shortcomings (e.g. upsample code data to improve logical reasoning) <br>
131
+ **Phase 3 (SFT + DPO)**: Train and align the model for safety and self-identification
132
+
133
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/68c1aa07c02e455d06f93a42/wBAO_0Bu3dnCn8R2K9HXD.png)
134
+
135
+ ## Quantization
136
+
137
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/68c1aa07c02e455d06f93a42/NJ_d8jyeVwkLIp9kwZRtR.png)
138
+
139
+ We apply Quantization Aware Training (QAT) to our baseline and instruction fine-tuned models, yielding quantization-ready checkpoints that can either be directly converted to integer datatype (with minimal quality loss) or used for QAT on additional data. We release two quantization-ready checkpoints:
140
+
141
+ - **4-bit groupwise weight quantization** with block size 32, 8-bit dynamic activations, and 8-bit kv-cache quantizations — optimized for CPU/GPU backends ([xnnpack](https://docs.pytorch.org/executorch/0.5/native-delegates-executorch-xnnpack-delegate.html)).
142
+ - **4-bit channelwise quantization** without activation quantization and 8-bit kv-cache quantizations — designed for edge hardware accelerators such as Apple Neural Engine ([ANE](https://apple.github.io/coremltools/docs-guides/source/opt-quantization-overview.html)) and Qualcomm’s Hexagon Tensor Processor ([HTP](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/htp_guidelines_int4_weights.html)).
143
+
144
+ Our QAT approach incorporates long-context awareness (up to 128k tokens) and self-knowledge distillation using the full-precision teacher model. We compared the QAT-trained model to a standard round-to-nearest Post-Training Quantization (PTQ) baseline. In the groupwise pre-training setting, we observe a 34% (absolute) regression in average benchmark score when using PTQ and only a 1.5% (absolute) regression for QAT. For instruction fine-tuning, we observe less than 1% average regression using QAT.
145
+
146
+ ## How to use
147
+ ### Full precision:
148
+
149
+ ```python
150
+ import torch
151
+ from transformers import AutoTokenizer, AutoModelForCausalLM
152
+ from huggingface_hub import login
153
+
154
+ login(token="<HF_TOKEN>")
155
+ MODEL_ID = "facebook/MobileLLM-Pro"
156
+
157
+ def generate(user_input: str, model, tokenizer, chat: bool) -> str:
158
+ if chat:
159
+ user_input = [{"role": "user", "content": user_input}]
160
+ inputs = tokenizer.apply_chat_template(
161
+ user_input, return_tensors="pt", add_generation_prompt=True
162
+ ).to(model.device)
163
+ else:
164
+ inputs = tokenizer(user_input, return_tensors="pt")["input_ids"].to(model.device)
165
+ outputs = model.generate(inputs, max_new_tokens=128)
166
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
167
+
168
+ def main():
169
+ version = "instruct" # "base" | "instruct"
170
+ tokenizer = AutoTokenizer.from_pretrained(
171
+ MODEL_ID, trust_remote_code=True, subfolder=version
172
+ )
173
+ model = AutoModelForCausalLM.from_pretrained(
174
+ MODEL_ID, trust_remote_code=True, subfolder=version
175
+ )
176
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177
+ model.to(device)
178
+ model.eval()
179
+
180
+ prompt = "Why are open-source on-device language models great?"
181
+ result = generate(prompt, model, tokenizer, chat=(version == "instruct"))
182
+ print(result)
183
+
184
+ if __name__ == "__main__":
185
+ main()
186
+
187
+ ```
188
+
189
+ ### Quantize Checkpoints
190
+
191
+ #### 4-bit Groupwise Quantization
192
+
193
+ ```python
194
+ from torchao.quantization import quantize_
195
+ from torchao.quantization.qat import (
196
+ QATConfig,
197
+ IntxFakeQuantizeConfig
198
+ )
199
+
200
+ model = AutoModelForCausalLM.from_pretrained(
201
+ model_id,
202
+ trust_remote_code=True
203
+ )
204
+
205
+ # Prepare for QAT.
206
+ # 8-bit dynamic per-token quantization for activations
207
+ activation_config = IntxFakeQuantizeConfig(
208
+ torch.int8, "per_token", is_symmetric=False,
209
+ )
210
+ # 4-bit group-size=32 with range_learning=True for weights
211
+ weight_config = IntxFakeQuantizeConfig(
212
+ torch.int4,
213
+ group_size=32,
214
+ is_symmetric=True,
215
+ is_dynamic=True,
216
+ )
217
+ qat_config = QATConfig(
218
+ activation_config=activation_config,
219
+ weight_config=weight_config,
220
+ step="prepare",
221
+ )
222
+ quantize_(model, qat_config)
223
+
224
+ embedding_filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
225
+ embedding_qat_config = IntxFakeQuantizeConfig(
226
+ torch.int4,
227
+ group_size=32,
228
+ is_symmetric=True,
229
+ is_dynamic=True,
230
+ )
231
+ quantize_(
232
+ model,
233
+ QATConfig(
234
+ weight_config=embedding_qat_config,
235
+ step="prepare"
236
+ ),
237
+ embedding_filter_fn
238
+ )
239
+
240
+ # The model is now ready for Quantization aware Training (QAT)
241
+ # trainer.train()
242
+ model.save_pretrained(
243
+ save_directory=<QAT_save_directory>,
244
+ safe_serialization=False
245
+ )
246
+
247
+ # Convert model after training
248
+ from torchao.quantization import (
249
+ IntxWeightOnlyConfig,
250
+ Int8DynamicActivationIntxWeightConfig
251
+ )
252
+ from torchao.quantization.granularity import PerGroup
253
+
254
+ qat_convert_config = QATConfig(
255
+ Int8DynamicActivationIntxWeightConfig(
256
+ weight_dtype=torch.int4
257
+ weight_granularity=PerGroup(32),
258
+ ),
259
+ step="convert",
260
+ )
261
+ quantize_(model, qat_convert_config)
262
+ embedding_convert_config = IntxWeightOnlyConfig(
263
+ weight_dtype=torch.int4,
264
+ granularity=PerGroup(32)
265
+ )
266
+ quantize_(
267
+ model,
268
+ QATConfig(
269
+ embedding_convert_config,
270
+ step="convert"
271
+ ),
272
+ embedding_filter_fn
273
+ )
274
+
275
+ # Save model after convert
276
+ model.save_pretrained(
277
+ save_directory=<quantized_model_directory>,
278
+ safe_serialization=False
279
+ )
280
+ ```
281
+
282
+ #### 4-bit Channelwise Quantization
283
+
284
+ ```python
285
+ from torchao.quantization import quantize_
286
+ from torchao.quantization.granularity import PerAxis
287
+ from torchao.quantization.qat import (
288
+ initialize_fake_quantizers,
289
+ IntxFakeQuantizeConfig,
290
+ QATConfig
291
+ )
292
+
293
+ model = AutoModelForCausalLM.from_pretrained(
294
+ model_id,
295
+ trust_remote_code=True
296
+ )
297
+
298
+ # 4-bit per-channel with range_learning=True for weights
299
+ weight_config = IntxFakeQuantizeConfig(
300
+ torch.int4,
301
+ granularity=PerAxis(0),
302
+ is_symmetric=True,
303
+ is_dynamic=False,
304
+ range_learning=True,
305
+ )
306
+ qat_config = QATConfig(
307
+ weight_config=weight_config,
308
+ step="prepare",
309
+ )
310
+ quantize_(model, qat_config)
311
+
312
+ embedding_filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
313
+ quantize_(model, qat_config, embedding_filter_fn)
314
+
315
+ # Initialize the fake quantizers for range-learning
316
+ example_inputs = (torch.tensor([[1]], dtype=torch.long),)
317
+ initialize_fake_quantizers(model, example_inputs)
318
+
319
+
320
+ # The model is now ready for Quantization aware Training (QAT)
321
+ # trainer.train()
322
+ model.save_pretrained(
323
+ save_directory=<QAT_save_directory>,
324
+ safe_serialization=False
325
+ )
326
+
327
+ # Convert model after training
328
+ from torchao.quantization import IntxWeightOnlyConfig
329
+
330
+ wt_convert_config = IntxWeightOnlyConfig(
331
+ weight_dtype=torch.int4,
332
+ granularity=PerAxis(0)
333
+ )
334
+ qat_convert_config = QATConfig(
335
+ wt_convert_config,
336
+ step="convert",
337
+ )
338
+ quantize_(model, qat_convert_config)
339
+ quantize_(model, qat_convert_config, embedding_filter_fn)
340
+
341
+ # Save model after convert
342
+ model.save_pretrained(
343
+ save_directory=<quantized_model_directory>,
344
+ safe_serialization=False
345
+ )
346
+ ```
347
+
348
+ ## Latency benchmarking
349
+
350
+ Latency benchmarking was done on a Samsung Galaxy S25 CPU and Samsung Galaxy S24 Hexagon Tensor Processor (HTP). Models were exported to ExecuTorch with XNNPACK backend (for CPU) and HTP backend (for accelerator). The model size of the CPU model with 4-bit groupwise quantization is 590MB. The CPU and HTP prefill latency for different input prompt lengths of 2k, 4k and 8k along with decode speed for generating 1k tokens is shown in the following table.
351
+
352
+ | Model / Prompt length | 2k | 4k | 8k |
353
+ |---------------------------|--------|--------|-------|
354
+ | CPU Prefill Latency (s) | 8.9 | 24.8 | 63.5 |
355
+ | CPU Decode Speed (tok/s) | 33.6 | 24.8 | 19.7 |
356
+ | HTP Prefill Latency (s) | 1.96 | 3.38 | 9.82 |
357
+ | HTP Decode Speed (tok/s) | 31.60 | 28.95 | 22.77 |
358
+ | KV Cache Size (MB) | 14 | 23 | 40 |
359
+
360
+
361
+ To validate the benefit of interleaved local-global attention (LGA), we benchmark models across different prompt lengths and measure the speed-up in prefill & decode relative to using global attention at every layer:
362
+
363
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/68c1aa07c02e455d06f93a42/_p8JT_Wtljwyp23TmKsTc.png)
364
+
365
+
366
+ ## Citation
367
+
368
+ @misc{mobilellm_pro,<br>
369
+ title={MobileLLM-Pro Model Card},<br>
370
+ author={Patrick Huber*, Ernie Chang*, Wei Wen*, Igor Fedorov*, Tarek Elgamal, Hanxian Huang, Naveen Suda, Chinnadhurai Sankar, Vish Vogeti, Yanghan Wang, Alex Gladkov, Kai Sheng Tai, Abdelrahman Elogeel, Tarek Hefny, Vikas Chandra, Ahmed Aly, Anuj Kumar, Raghuraman Krishnamoorthi**, Adithya Sagar**}, <br>
371
+ year={2025},<br>
372
+ month={October},<br>
373
+ url = {[https://huggingface.co/facebook/MobileLLM-Pro](https://huggingface.co/facebook/MobileLLM-Pro)}}
374
+
375
+ ## Contact
376
+
377
+ Patrick Huber, Meta Inc, Reality Labs ([patrickhuber@meta.com](mailto:patrickhuber@meta.com))<br>
378
+ Ernie Chang, Meta Inc, Reality Labs ([erniecyc@meta.com](mailto:erniecyc@meta.com))<br>
379
+ Wei Wen, Meta Inc, Reality Labs ([wewen@meta.com](mailto:wewen@meta.com))<br>
380
+ Igor Fedorov, Meta Inc, Reality Labs ([ifedorov@meta.com](mailto:ifedorov@meta.com))<br>
381
+ Raghuraman Krishnamoorthi, Meta Inc Reality Labs ([raghuraman@meta.com](mailto:raghuraman@meta.com))<br>
382
+ Adithya Sagar, Meta Inc, Reality Labs (adithyasagar@meta.com)
383
+
384
+ ## Acknowledgements
385
+
386
+ We want to thank the team involved in this project, especially: Kimish Patel, Andrew Or, Min Guo, Shen Xu, Brian Moran, Maho Takahashi, Claire Lesage, Rylan Conway, Karan Chadha, Matthew Grange, Tomasz Wołcyrz, Shiv Desai, Amarlin Anand, Joele Sires, Robert Carrillo, Francisc Bungiu, Jayden Yu, AJ Brush, Yang Li, Samuel Selvan, Anand Sharma, Peng Shan, Anand Dass, Abhishek Sharma
387
+
388
+ ## License
389
+
390
+ MobileLLM-Pro is distributed under the [FAIR NC license](https://huggingface.co/facebook/MobileLLM-Pro/blob/main/LICENSE)
base/config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MobileLLMP1ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mobilellm_p1.MobileLLMP1TextConfig",
7
+ "AutoModelForCausalLM": "modeling_mobilellm_p1.MobileLLMP1ForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_chunk_size": 8192,
11
+ "attention_dropout": 0.0,
12
+ "attn_scale": 0.1,
13
+ "bos_token_id": 200000,
14
+ "eos_token_id": [
15
+ 200001,
16
+ 200007,
17
+ 200008
18
+ ],
19
+ "pad_token_id": 200018,
20
+ "for_llm_compressor": false,
21
+ "head_dim": 64,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 1280,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 6144,
26
+ "intermediate_size_mlp": 6144,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "full_attention"
58
+ ],
59
+ "sliding_window": 512,
60
+ "max_position_embeddings": 131072,
61
+ "model_type": "llama4_text",
62
+ "moe_layers": [],
63
+ "no_rope_layers": [
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1,
69
+ 1,
70
+ 1,
71
+ 1,
72
+ 1,
73
+ 1,
74
+ 1,
75
+ 1,
76
+ 1,
77
+ 1,
78
+ 1,
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 1,
83
+ 1,
84
+ 1,
85
+ 1,
86
+ 1,
87
+ 1,
88
+ 1,
89
+ 1,
90
+ 1,
91
+ 1,
92
+ 1,
93
+ 1
94
+ ],
95
+ "num_attention_heads": 20,
96
+ "num_experts_per_tok": 0,
97
+ "num_hidden_layers": 30,
98
+ "num_key_value_heads": 4,
99
+ "num_local_experts": 0,
100
+ "output_router_logits": false,
101
+ "rms_norm_eps": 1e-05,
102
+ "rope_scaling": {
103
+ "factor": 16.0,
104
+ "high_freq_factor": 1.0,
105
+ "low_freq_factor": 1.0,
106
+ "original_max_position_embeddings": 8192,
107
+ "rope_type": "llama3"
108
+ },
109
+ "rope_theta": 500000.0,
110
+ "router_aux_loss_coef": 0.001,
111
+ "router_jitter_noise": 0.0,
112
+ "tie_word_embeddings": true,
113
+ "torch_dtype": "bfloat16",
114
+ "transformers_version": "4.53.0.dev0",
115
+ "use_cache": true,
116
+ "use_qk_norm": false,
117
+ "vocab_size": 202048
118
+ }
base/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 200000,
4
+ "eos_token_id": [
5
+ 200001,
6
+ 200007,
7
+ 200008
8
+ ],
9
+ "pad_token_id": 200018,
10
+ "temperature": 0.6,
11
+ "top_p": 0.9,
12
+ "do_sample": true,
13
+ "transformers_version": "4.55.0"
14
+ }
base/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fbc0be0c62f63ef9770116b4e7081db712ed147c157be324eb5990daa67293b
3
+ size 2168938424
base/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|begin_of_text|>",
3
+ "eos_token": "<|eot|>",
4
+ "pad_token": "<|finetune_right_pad|>"
5
+ }
base/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172c9eb4beafc72601690da3ccfcede5c2e6806a8d5ec1fca33e22acea8023a4
3
+ size 27948578
base/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bdbaf59b0762c8c807617e2d8ea51420eb1b1de266df2495be755c8e0ed6ed
3
+ size 3622230
base/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
base_int4_accelerator/config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MobileLLMP1ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mobilellm_p1.MobileLLMP1TextConfig",
7
+ "AutoModelForCausalLM": "modeling_mobilellm_p1.MobileLLMP1ForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_chunk_size": 8192,
11
+ "attention_dropout": 0.0,
12
+ "attn_scale": 0.1,
13
+ "bos_token_id": 200000,
14
+ "eos_token_id": [
15
+ 200001,
16
+ 200007,
17
+ 200008
18
+ ],
19
+ "pad_token_id": 200018,
20
+ "for_llm_compressor": false,
21
+ "head_dim": 64,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 1280,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 6144,
26
+ "intermediate_size_mlp": 6144,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "full_attention"
58
+ ],
59
+ "sliding_window": 512,
60
+ "max_position_embeddings": 131072,
61
+ "model_type": "llama4_text",
62
+ "moe_layers": [],
63
+ "no_rope_layers": [
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1,
69
+ 1,
70
+ 1,
71
+ 1,
72
+ 1,
73
+ 1,
74
+ 1,
75
+ 1,
76
+ 1,
77
+ 1,
78
+ 1,
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 1,
83
+ 1,
84
+ 1,
85
+ 1,
86
+ 1,
87
+ 1,
88
+ 1,
89
+ 1,
90
+ 1,
91
+ 1,
92
+ 1,
93
+ 1
94
+ ],
95
+ "num_attention_heads": 20,
96
+ "num_experts_per_tok": 0,
97
+ "num_hidden_layers": 30,
98
+ "num_key_value_heads": 4,
99
+ "num_local_experts": 0,
100
+ "output_router_logits": false,
101
+ "rms_norm_eps": 1e-05,
102
+ "rope_scaling": {
103
+ "factor": 16.0,
104
+ "high_freq_factor": 1.0,
105
+ "low_freq_factor": 1.0,
106
+ "original_max_position_embeddings": 8192,
107
+ "rope_type": "llama3"
108
+ },
109
+ "rope_theta": 500000.0,
110
+ "router_aux_loss_coef": 0.001,
111
+ "router_jitter_noise": 0.0,
112
+ "tie_word_embeddings": true,
113
+ "torch_dtype": "bfloat16",
114
+ "transformers_version": "4.53.0.dev0",
115
+ "use_cache": true,
116
+ "use_qk_norm": false,
117
+ "vocab_size": 202048
118
+ }
base_int4_accelerator/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 200000,
4
+ "eos_token_id": [
5
+ 200001,
6
+ 200007,
7
+ 200008
8
+ ],
9
+ "pad_token_id": 200018,
10
+ "temperature": 0.6,
11
+ "top_p": 0.9,
12
+ "do_sample": true,
13
+ "transformers_version": "4.55.0"
14
+ }
base_int4_accelerator/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12fef2380a85b302729d08d73a34b399e32afa04c1d503a85252915588afae1d
3
+ size 2170471151
base_int4_accelerator/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|begin_of_text|>",
3
+ "eos_token": "<|eot|>",
4
+ "pad_token": "<|finetune_right_pad|>"
5
+ }
base_int4_accelerator/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172c9eb4beafc72601690da3ccfcede5c2e6806a8d5ec1fca33e22acea8023a4
3
+ size 27948578
base_int4_accelerator/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bdbaf59b0762c8c807617e2d8ea51420eb1b1de266df2495be755c8e0ed6ed
3
+ size 3622230
base_int4_accelerator/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
base_int4_cpu/config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MobileLLMP1ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mobilellm_p1.MobileLLMP1TextConfig",
7
+ "AutoModelForCausalLM": "modeling_mobilellm_p1.MobileLLMP1ForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_chunk_size": 8192,
11
+ "attention_dropout": 0.0,
12
+ "attn_scale": 0.1,
13
+ "bos_token_id": 200000,
14
+ "eos_token_id": [
15
+ 200001,
16
+ 200007,
17
+ 200008
18
+ ],
19
+ "pad_token_id": 200018,
20
+ "for_llm_compressor": false,
21
+ "head_dim": 64,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 1280,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 6144,
26
+ "intermediate_size_mlp": 6144,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "full_attention"
58
+ ],
59
+ "sliding_window": 512,
60
+ "max_position_embeddings": 131072,
61
+ "model_type": "llama4_text",
62
+ "moe_layers": [],
63
+ "no_rope_layers": [
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1,
69
+ 1,
70
+ 1,
71
+ 1,
72
+ 1,
73
+ 1,
74
+ 1,
75
+ 1,
76
+ 1,
77
+ 1,
78
+ 1,
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 1,
83
+ 1,
84
+ 1,
85
+ 1,
86
+ 1,
87
+ 1,
88
+ 1,
89
+ 1,
90
+ 1,
91
+ 1,
92
+ 1,
93
+ 1
94
+ ],
95
+ "num_attention_heads": 20,
96
+ "num_experts_per_tok": 0,
97
+ "num_hidden_layers": 30,
98
+ "num_key_value_heads": 4,
99
+ "num_local_experts": 0,
100
+ "output_router_logits": false,
101
+ "rms_norm_eps": 1e-05,
102
+ "rope_scaling": {
103
+ "factor": 16.0,
104
+ "high_freq_factor": 1.0,
105
+ "low_freq_factor": 1.0,
106
+ "original_max_position_embeddings": 8192,
107
+ "rope_type": "llama3"
108
+ },
109
+ "rope_theta": 500000.0,
110
+ "router_aux_loss_coef": 0.001,
111
+ "router_jitter_noise": 0.0,
112
+ "tie_word_embeddings": true,
113
+ "torch_dtype": "bfloat16",
114
+ "transformers_version": "4.53.0.dev0",
115
+ "use_cache": true,
116
+ "use_qk_norm": false,
117
+ "vocab_size": 202048
118
+ }
base_int4_cpu/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 200000,
4
+ "eos_token_id": [
5
+ 200001,
6
+ 200007,
7
+ 200008
8
+ ],
9
+ "pad_token_id": 200018,
10
+ "temperature": 0.6,
11
+ "top_p": 0.9,
12
+ "do_sample": true,
13
+ "transformers_version": "4.55.0"
14
+ }
base_int4_cpu/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5be59329a78ba1bc9a1be7a439355e259b635834187136ecb0218fa1ee9d3afb
3
+ size 2168995683
base_int4_cpu/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|begin_of_text|>",
3
+ "eos_token": "<|eot|>",
4
+ "pad_token": "<|finetune_right_pad|>"
5
+ }
base_int4_cpu/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172c9eb4beafc72601690da3ccfcede5c2e6806a8d5ec1fca33e22acea8023a4
3
+ size 27948578
base_int4_cpu/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bdbaf59b0762c8c807617e2d8ea51420eb1b1de266df2495be755c8e0ed6ed
3
+ size 3622230
base_int4_cpu/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MobileLLMP1ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mobilellm_p1.MobileLLMP1TextConfig",
7
+ "AutoModelForCausalLM": "modeling_mobilellm_p1.MobileLLMP1ForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_chunk_size": 8192,
11
+ "attention_dropout": 0.0,
12
+ "attn_scale": 0.1,
13
+ "bos_token_id": 200000,
14
+ "eos_token_id": [
15
+ 200001,
16
+ 200007,
17
+ 200008
18
+ ],
19
+ "pad_token_id": 200018,
20
+ "for_llm_compressor": false,
21
+ "head_dim": 64,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 1280,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 6144,
26
+ "intermediate_size_mlp": 6144,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "full_attention"
58
+ ],
59
+ "sliding_window": 512,
60
+ "max_position_embeddings": 131072,
61
+ "model_type": "llama4_text",
62
+ "moe_layers": [],
63
+ "no_rope_layers": [
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1,
69
+ 1,
70
+ 1,
71
+ 1,
72
+ 1,
73
+ 1,
74
+ 1,
75
+ 1,
76
+ 1,
77
+ 1,
78
+ 1,
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 1,
83
+ 1,
84
+ 1,
85
+ 1,
86
+ 1,
87
+ 1,
88
+ 1,
89
+ 1,
90
+ 1,
91
+ 1,
92
+ 1,
93
+ 1
94
+ ],
95
+ "num_attention_heads": 20,
96
+ "num_experts_per_tok": 0,
97
+ "num_hidden_layers": 30,
98
+ "num_key_value_heads": 4,
99
+ "num_local_experts": 0,
100
+ "output_router_logits": false,
101
+ "rms_norm_eps": 1e-05,
102
+ "rope_scaling": {
103
+ "factor": 16.0,
104
+ "high_freq_factor": 1.0,
105
+ "low_freq_factor": 1.0,
106
+ "original_max_position_embeddings": 8192,
107
+ "rope_type": "llama3"
108
+ },
109
+ "rope_theta": 500000.0,
110
+ "router_aux_loss_coef": 0.001,
111
+ "router_jitter_noise": 0.0,
112
+ "tie_word_embeddings": true,
113
+ "torch_dtype": "bfloat16",
114
+ "transformers_version": "4.53.0.dev0",
115
+ "use_cache": true,
116
+ "use_qk_norm": false,
117
+ "vocab_size": 202048
118
+ }
configuration_mobilellm_p1.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ class MobileLLMP1TextConfig(PretrainedConfig):
4
+ r"""
5
+ This is the configuration class to store the configuration of a [`MobileLLMP1TextModel`]. It is used to instantiate a
6
+ MobileLLMP1 text model according to the specified arguments, defining the model architecture. Instantiating a configuration
7
+ with the defaults will yield a similar configuration to that of the MobileLLMP1 1B model.
8
+
9
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
10
+ documentation from [`PretrainedConfig`] for more information.
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 202048):
14
+ Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented
15
+ by the `inputs_ids` passed when calling [`Llama4TextModel`].
16
+ hidden_size (`int`, *optional*, defaults to 5120):
17
+ Dimensionality of the embeddings and hidden states.
18
+ intermediate_size (`int`, *optional*, defaults to 8192):
19
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
20
+ intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO
21
+ num_hidden_layers (`int`, *optional*, defaults to 48):
22
+ Number of hidden layers in the Transformer encoder.
23
+ num_attention_heads (`int`, *optional*, defaults to 40):
24
+ Number of attention heads for each attention layer in the Transformer encoder.
25
+ num_key_value_heads (`int`, *optional*, defaults to 8):
26
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If not
27
+ specified, will default to `num_attention_heads`.
28
+ head_dim (`int`, *optional*, defaults to 128): TODO
29
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`):
30
+ The non-linear activation function (function or string) in the encoder and pooler.
31
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
32
+ The maximum sequence length that this model might ever be used with.
33
+ initializer_range (`float`, *optional*, defaults to 0.02):
34
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
35
+ sliding_window (`int`, *optional*, defaults to 512):
36
+ In MobileLLMP1, every 4 out of 5 layers use sliding window attention. This is the size of the sliding window.
37
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
38
+ The epsilon used by the rms normalization layers.
39
+ use_cache (`bool`, *optional*, defaults to `True`):
40
+ Whether or not the model should return the last key/values attentions.
41
+ pad_token_id (`int`, *optional*, defaults to 128004):
42
+ The id of the padding token.
43
+ bos_token_id (`int`, *optional*, defaults to 1):
44
+ The id of the beginning of sentence token.
45
+ eos_token_id (`int`, *optional*, defaults to 2):
46
+ The id of the end of sentence token.
47
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
48
+ Whether to tie weight embeddings
49
+ rope_theta (`float`, *optional*, defaults to `500000.0`):
50
+ The base period of the RoPE embeddings.
51
+ attention_dropout (`int`, *optional*, defaults to 0.0): TODO
52
+ num_experts_per_tok (`int`, *optional*, defaults to 1): TODO
53
+ num_local_experts (`int`, *optional*, defaults to 16): TODO
54
+ moe_layers (`int`, *optional*): TODO
55
+ interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO
56
+ use_qk_norm (`int`, *optional*, defaults to `True`): TODO
57
+ output_router_logits (`int`, *optional*, defaults to `False`): TODO
58
+ router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO
59
+ router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO
60
+ rope_scaling (`Dict`, *optional*):
61
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
62
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
63
+ accordingly.
64
+ Expected contents:
65
+ `rope_type` (`str`):
66
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
67
+ 'llama3'], with 'default' being the original RoPE implementation.
68
+ `factor` (`float`, *optional*):
69
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
70
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
71
+ original maximum pre-trained length.
72
+ `original_max_position_embeddings` (`int`, *optional*):
73
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
74
+ pretraining.
75
+ `attention_factor` (`float`, *optional*):
76
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
77
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
78
+ `factor` field to infer the suggested value.
79
+ `beta_fast` (`float`, *optional*):
80
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
81
+ ramp function. If unspecified, it defaults to 32.
82
+ `beta_slow` (`float`, *optional*):
83
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
84
+ ramp function. If unspecified, it defaults to 1.
85
+ `short_factor` (`list[float]`, *optional*):
86
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
87
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
88
+ size divided by the number of attention heads divided by 2
89
+ `long_factor` (`list[float]`, *optional*):
90
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
91
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
92
+ size divided by the number of attention heads divided by 2
93
+ `low_freq_factor` (`float`, *optional*):
94
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
95
+ `high_freq_factor` (`float`, *optional*):
96
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
97
+ <TODO>
98
+ <TODO>
99
+ no_rope_layers (`list[int]`, *optional*):
100
+ List with at least the same length as the number of layers in the model.
101
+ A `1` at an index position indicates that the corresponding layer will use RoPE,
102
+ while a `0` indicates that it's a NoPE layer.
103
+ no_rope_layer_interval (`int`, *optional*, defaults to 4):
104
+ If `no_rope_layers` is `None`, it will be created using a NoPE layer every
105
+ `no_rope_layer_interval` layers.
106
+ attention_chunk_size (`int`, *optional*, defaults to 8192):
107
+ <TODO>
108
+ layer_types (`list`, *optional*):
109
+ Attention pattern for each layer.
110
+ attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
111
+ Whether to dynamically scale the attention temperature for each query token based on sequence length.
112
+ Recommended for long sequences (e.g., >32k tokens) to maintain stable output results.
113
+ floor_scale (`int`, *optional*, defaults to 8192): TODO
114
+ attn_scale (`int`, *optional*, defaults to 0.1): TODO
115
+
116
+ Example:
117
+ """
118
+
119
+ model_type = "llama4_text"
120
+ keys_to_ignore_at_inference = ["past_key_values"]
121
+ base_model_tp_plan = {
122
+ "layers.*.self_attn.q_proj": "colwise",
123
+ "layers.*.self_attn.k_proj": "colwise",
124
+ "layers.*.self_attn.v_proj": "colwise",
125
+ "layers.*.self_attn.o_proj": "rowwise",
126
+ "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise",
127
+ "layers.*.feed_forward.shared_expert.up_proj": "local_colwise",
128
+ "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise",
129
+ "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear
130
+ "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear
131
+ "layers.*.feed_forward.experts": "local",
132
+ "layers.*.feed_forward.gate_proj": "local_colwise",
133
+ "layers.*.feed_forward.up_proj": "local_colwise",
134
+ "layers.*.feed_forward.down_proj": "local_rowwise",
135
+ "layers.*.feed_forward": "gather",
136
+ }
137
+ base_model_ep_plan = {
138
+ "layers.*.self_attn.q_proj": "colwise",
139
+ "layers.*.self_attn.k_proj": "colwise",
140
+ "layers.*.self_attn.v_proj": "colwise",
141
+ "layers.*.self_attn.o_proj": "rowwise",
142
+ "layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", # row because not linear
143
+ "layers.*.feed_forward.experts.down_proj": "grouped_gemm", # col because not linear
144
+ "layers.*.feed_forward.experts": "gather", # all reduce
145
+ "layers.*.feed_forward.gate_proj": "local_colwise",
146
+ "layers.*.feed_forward.up_proj": "local_colwise",
147
+ "layers.*.feed_forward.down_proj": "local_rowwise",
148
+ "layers.*.feed_forward.router": "ep_router",
149
+ }
150
+
151
+ def __init__(
152
+ self,
153
+ vocab_size=202048,
154
+ hidden_size=1280,
155
+ intermediate_size=6144,
156
+ intermediate_size_mlp=6144,
157
+ num_hidden_layers=30,
158
+ num_attention_heads=20,
159
+ num_key_value_heads=4,
160
+ head_dim=64,
161
+ hidden_act="silu",
162
+ max_position_embeddings=131072,
163
+ initializer_range=0.02,
164
+ rms_norm_eps=1e-5,
165
+ use_cache=True,
166
+ pad_token_id=None,
167
+ sliding_window=512,
168
+ bos_token_id=1,
169
+ eos_token_id=2,
170
+ tie_word_embeddings=True,
171
+ rope_theta=500000,
172
+ attention_dropout=0.0,
173
+ num_experts_per_tok=1,
174
+ num_local_experts=16,
175
+ moe_layers=None,
176
+ interleave_moe_layer_step=1,
177
+ use_qk_norm=False,
178
+ output_router_logits=False,
179
+ router_aux_loss_coef=0.001,
180
+ router_jitter_noise=0.0,
181
+ rope_scaling=None,
182
+ no_rope_layers=None,
183
+ no_rope_layer_interval=4,
184
+ attention_chunk_size=8192,
185
+ layer_types=None,
186
+ attn_temperature_tuning=True,
187
+ floor_scale=8192,
188
+ attn_scale=0.1,
189
+ **kwargs,
190
+ ):
191
+ super().__init__(
192
+ pad_token_id=pad_token_id,
193
+ bos_token_id=bos_token_id,
194
+ eos_token_id=eos_token_id,
195
+ tie_word_embeddings=tie_word_embeddings,
196
+ **kwargs,
197
+ )
198
+ self.attn_temperature_tuning = attn_temperature_tuning
199
+ self.attn_scale = attn_scale
200
+ self.floor_scale = floor_scale
201
+ self.vocab_size = vocab_size
202
+ self.max_position_embeddings = max_position_embeddings
203
+ self.hidden_size = hidden_size
204
+ self.intermediate_size = intermediate_size
205
+ self.intermediate_size_mlp = intermediate_size_mlp
206
+ self.num_hidden_layers = num_hidden_layers
207
+ self.sliding_window = sliding_window
208
+ self.num_attention_heads = num_attention_heads
209
+ self.rope_scaling = rope_scaling
210
+ self.attention_bias = False
211
+ # for backward compatibility
212
+ if num_key_value_heads is None:
213
+ num_key_value_heads = num_attention_heads
214
+
215
+ self.num_key_value_heads = num_key_value_heads
216
+ self.hidden_act = hidden_act
217
+ self.initializer_range = initializer_range
218
+ self.rms_norm_eps = rms_norm_eps
219
+ self.use_cache = use_cache
220
+ self.rope_theta = rope_theta
221
+ self.attention_dropout = attention_dropout
222
+ self.head_dim = (
223
+ head_dim
224
+ if head_dim is not None
225
+ else self.hidden_size // self.num_attention_heads
226
+ )
227
+ self.use_qk_norm = use_qk_norm
228
+
229
+ self.num_experts_per_tok = num_experts_per_tok
230
+ self.num_local_experts = num_local_experts
231
+
232
+ self.output_router_logits = output_router_logits
233
+ self.router_aux_loss_coef = router_aux_loss_coef
234
+ self.router_jitter_noise = router_jitter_noise
235
+ self.layer_types = layer_types
236
+
237
+ # Backwards compatibility
238
+ if no_rope_layers == []:
239
+ no_rope_layers = None
240
+
241
+ default_no_rope_layers = [
242
+ int((layer_idx + 1) % no_rope_layer_interval != 0)
243
+ for layer_idx in range(self.num_hidden_layers)
244
+ ]
245
+
246
+ self.no_rope_layers = (
247
+ no_rope_layers if no_rope_layers else default_no_rope_layers
248
+ )
249
+
250
+ # If no pattern set, use our default pattern
251
+ if self.layer_types is None:
252
+ self.layer_types = [
253
+ "sliding_attention" if bool((i) % 4) else "full_attention"
254
+ for i in range(self.num_hidden_layers)
255
+ ] + [
256
+ "full_attention"
257
+ ] # Last layer is always full attention
258
+
259
+ self.interleave_moe_layer_step = interleave_moe_layer_step
260
+ self.moe_layers = (
261
+ moe_layers
262
+ if moe_layers is not None
263
+ else list(
264
+ range(
265
+ interleave_moe_layer_step - 1,
266
+ num_hidden_layers,
267
+ interleave_moe_layer_step,
268
+ )
269
+ )
270
+ )
instruct/chat_template.jinja ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{- bos_token }}
2
+ {%- set intro = "You are MobileLLM Pro, a helpful assistant created by Meta Reality Labs." %}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {%- set user_system_message = messages[0]['content']|trim %}
5
+ {%- set system_message = intro + "\n" + user_system_message %}
6
+ {%- set messages = messages[1:] %}
7
+ {%- else %}
8
+ {%- set system_message = intro %}
9
+ {%- endif %}
10
+ {{- "<|header_start|>system<|header_end|>\n\n" }}
11
+ {{- system_message }}
12
+ {{- "<|eot|>" }}
13
+ {%- for message in messages %}
14
+ {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n'+ message['content'] | trim + '<|eot|>' }}
15
+ {%- endfor %}
16
+ {%- if add_generation_prompt %}
17
+ {{- '<|header_start|>assistant<|header_end|>\n\n' }}
18
+ {%- endif %}
instruct/config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MobileLLMP1ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mobilellm_p1.MobileLLMP1TextConfig",
7
+ "AutoModelForCausalLM": "modeling_mobilellm_p1.MobileLLMP1ForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_chunk_size": 8192,
11
+ "attention_dropout": 0.0,
12
+ "attn_scale": 0.1,
13
+ "bos_token_id": 200000,
14
+ "eos_token_id": [
15
+ 200001,
16
+ 200007,
17
+ 200008
18
+ ],
19
+ "pad_token_id": 200018,
20
+ "for_llm_compressor": false,
21
+ "head_dim": 64,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 1280,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 6144,
26
+ "intermediate_size_mlp": 6144,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "full_attention"
58
+ ],
59
+ "sliding_window": 512,
60
+ "max_position_embeddings": 131072,
61
+ "model_type": "llama4_text",
62
+ "moe_layers": [],
63
+ "no_rope_layers": [
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1,
69
+ 1,
70
+ 1,
71
+ 1,
72
+ 1,
73
+ 1,
74
+ 1,
75
+ 1,
76
+ 1,
77
+ 1,
78
+ 1,
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 1,
83
+ 1,
84
+ 1,
85
+ 1,
86
+ 1,
87
+ 1,
88
+ 1,
89
+ 1,
90
+ 1,
91
+ 1,
92
+ 1,
93
+ 1
94
+ ],
95
+ "num_attention_heads": 20,
96
+ "num_experts_per_tok": 0,
97
+ "num_hidden_layers": 30,
98
+ "num_key_value_heads": 4,
99
+ "num_local_experts": 0,
100
+ "output_router_logits": false,
101
+ "rms_norm_eps": 1e-05,
102
+ "rope_scaling": {
103
+ "factor": 16.0,
104
+ "high_freq_factor": 1.0,
105
+ "low_freq_factor": 1.0,
106
+ "original_max_position_embeddings": 8192,
107
+ "rope_type": "llama3"
108
+ },
109
+ "rope_theta": 500000.0,
110
+ "router_aux_loss_coef": 0.001,
111
+ "router_jitter_noise": 0.0,
112
+ "tie_word_embeddings": true,
113
+ "torch_dtype": "bfloat16",
114
+ "transformers_version": "4.53.0.dev0",
115
+ "use_cache": true,
116
+ "use_qk_norm": false,
117
+ "vocab_size": 202048
118
+ }
instruct/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 200000,
4
+ "eos_token_id": [
5
+ 200001,
6
+ 200007,
7
+ 200008
8
+ ],
9
+ "pad_token_id": 200018,
10
+ "temperature": 0.6,
11
+ "top_p": 0.9,
12
+ "do_sample": true,
13
+ "transformers_version": "4.55.0"
14
+ }
instruct/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ad9e2b937e2daf12859a6367d081c57e5be7dc4cc4b76a3dc44fa4a5ad21a4
3
+ size 2168938424
instruct/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|begin_of_text|>",
3
+ "eos_token": "<|eot|>",
4
+ "pad_token": "<|finetune_right_pad|>"
5
+ }
instruct/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172c9eb4beafc72601690da3ccfcede5c2e6806a8d5ec1fca33e22acea8023a4
3
+ size 27948578
instruct/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bdbaf59b0762c8c807617e2d8ea51420eb1b1de266df2495be755c8e0ed6ed
3
+ size 3622230
instruct/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_mobilellm_p1.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, Union
4
+ import copy
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import (
15
+ create_causal_mask,
16
+ create_sliding_window_causal_mask,
17
+ )
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+ from transformers.modeling_layers import GradientCheckpointingLayer
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutput,
22
+ BaseModelOutputWithPast,
23
+ CausalLMOutputWithPast,
24
+ ModelOutput,
25
+ )
26
+ from transformers.modeling_rope_utils import dynamic_rope_update, ROPE_INIT_FUNCTIONS
27
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
28
+
29
+ from transformers.models.llama4.configuration_llama4 import (
30
+ Llama4Config,
31
+ Llama4TextConfig,
32
+ Llama4VisionConfig,
33
+ )
34
+ from transformers.models.llama4.modeling_llama4 import (
35
+ apply_rotary_emb,
36
+ eager_attention_forward,
37
+ Llama4PreTrainedModel,
38
+ Llama4TextDecoderLayer,
39
+ Llama4TextL2Norm,
40
+ Llama4TextMLP,
41
+ Llama4TextMoe,
42
+ Llama4TextRMSNorm,
43
+ Llama4TextRotaryEmbedding,
44
+ Llama4TextAttention,
45
+ Llama4TextDecoderLayer,
46
+ Llama4ForCausalLM
47
+ )
48
+ from transformers.processing_utils import Unpack
49
+ from transformers.utils import (
50
+ auto_docstring,
51
+ can_return_tuple,
52
+ logging,
53
+ TransformersKwargs,
54
+ )
55
+ from transformers.utils.deprecation import deprecate_kwarg
56
+ from transformers.utils.generic import check_model_inputs
57
+
58
+ from .configuration_mobilellm_p1 import MobileLLMP1TextConfig
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+ class MobileLLMP1TextAttention(nn.Module):
63
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
64
+
65
+ def __init__(self, config: MobileLLMP1TextConfig, layer_idx):
66
+ super().__init__()
67
+ self.config = config
68
+ self.layer_idx = layer_idx
69
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
70
+ self.head_dim = getattr(
71
+ config, "head_dim", config.hidden_size // config.num_attention_heads
72
+ )
73
+ self.num_attention_heads = config.num_attention_heads
74
+ self.num_key_value_groups = (
75
+ config.num_attention_heads // config.num_key_value_heads
76
+ )
77
+ self.num_key_value_heads = config.num_key_value_heads
78
+ self.scaling = self.head_dim**-0.5
79
+ self.attn_scale = config.attn_scale
80
+ self.floor_scale = config.floor_scale
81
+ self.attn_temperature_tuning = config.attn_temperature_tuning
82
+ self.attention_dropout = config.attention_dropout
83
+ self.is_causal = True
84
+ self.use_rope = config.no_rope_layers[layer_idx]
85
+ self.sliding_window = config.sliding_window if self.is_sliding else None
86
+ self.q_proj = nn.Linear(
87
+ config.hidden_size,
88
+ config.num_attention_heads * self.head_dim,
89
+ bias=config.attention_bias,
90
+ )
91
+ self.k_proj = nn.Linear(
92
+ config.hidden_size,
93
+ config.num_key_value_heads * self.head_dim,
94
+ bias=config.attention_bias,
95
+ )
96
+ self.v_proj = nn.Linear(
97
+ config.hidden_size,
98
+ config.num_key_value_heads * self.head_dim,
99
+ bias=config.attention_bias,
100
+ )
101
+ self.o_proj = nn.Linear(
102
+ config.num_attention_heads * self.head_dim,
103
+ config.hidden_size,
104
+ bias=config.attention_bias,
105
+ )
106
+ if self.config.use_qk_norm and self.use_rope:
107
+ self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.Tensor,
112
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
113
+ attention_mask: Optional[torch.Tensor],
114
+ past_key_values: Optional[Cache] = None,
115
+ cache_position: Optional[torch.LongTensor] = None,
116
+ **kwargs: Unpack[FlashAttentionKwargs],
117
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
118
+ input_shape = hidden_states.shape[:-1]
119
+ hidden_shape = (*input_shape, -1, self.head_dim)
120
+
121
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
122
+ key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
123
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
124
+
125
+ if self.use_rope: # the 16E model skips rope for long context on certain layers
126
+ query_states, key_states = apply_rotary_emb(
127
+ query_states, key_states, position_embeddings.to(query_states.device)
128
+ )
129
+
130
+ if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
131
+ query_states = self.qk_norm(query_states)
132
+ key_states = self.qk_norm(key_states)
133
+
134
+ # Use temperature tuning from https://huggingface.co/papers/2501.19399) to NoROPE layers
135
+ if self.attn_temperature_tuning and not self.use_rope:
136
+ attn_scales = (
137
+ torch.log(
138
+ torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0
139
+ )
140
+ * self.attn_scale
141
+ + 1.0
142
+ )
143
+ attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand(
144
+ (*input_shape, 1, 1)
145
+ ) # batch size > 1
146
+ query_states = (query_states * attn_scales).to(query_states.dtype)
147
+
148
+ query_states = query_states.transpose(1, 2)
149
+ key_states = key_states.transpose(1, 2)
150
+
151
+ if past_key_values is not None:
152
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
153
+ cache_kwargs = {"cache_position": cache_position}
154
+ key_states, value_states = past_key_values.update(
155
+ key_states, value_states, self.layer_idx, cache_kwargs
156
+ )
157
+
158
+ attention_interface: Callable = eager_attention_forward
159
+ if self.config._attn_implementation != "eager":
160
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
161
+ self.config._attn_implementation
162
+ ]
163
+
164
+ attn_output, attn_weights = attention_interface(
165
+ self,
166
+ query_states,
167
+ key_states,
168
+ value_states,
169
+ attention_mask,
170
+ dropout=0.0 if not self.training else self.attention_dropout,
171
+ scaling=self.scaling,
172
+ sliding_window=self.sliding_window,
173
+ **kwargs,
174
+ )
175
+
176
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
177
+ attn_output = self.o_proj(attn_output)
178
+ return attn_output, attn_weights
179
+
180
+
181
+ class MobileLLMP1TextDecoderLayer(GradientCheckpointingLayer):
182
+ def __init__(self, config, layer_idx):
183
+ super().__init__()
184
+ self.hidden_size = config.hidden_size
185
+ self.layer_idx = layer_idx
186
+ self.attention_type = config.layer_types[layer_idx]
187
+ self.self_attn = MobileLLMP1TextAttention(config, layer_idx)
188
+ self.is_moe_layer = layer_idx in config.moe_layers
189
+ if self.is_moe_layer: # the 128E model interleaves dense / sparse
190
+ self.feed_forward = Llama4TextMoe(config)
191
+ else:
192
+ self.feed_forward = Llama4TextMLP(
193
+ config, intermediate_size=config.intermediate_size_mlp
194
+ )
195
+
196
+ self.input_layernorm = Llama4TextRMSNorm(
197
+ config.hidden_size, eps=config.rms_norm_eps
198
+ )
199
+ self.post_attention_layernorm = Llama4TextRMSNorm(
200
+ config.hidden_size, eps=config.rms_norm_eps
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ position_ids: Optional[torch.LongTensor] = None,
208
+ past_key_values: Optional[Cache] = None,
209
+ use_cache: Optional[bool] = False,
210
+ cache_position: Optional[torch.LongTensor] = None,
211
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
212
+ **kwargs: Unpack[FlashAttentionKwargs],
213
+ ) -> tuple[
214
+ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
215
+ ]:
216
+ residual = hidden_states
217
+ hidden_states = self.input_layernorm(hidden_states)
218
+
219
+ # Self Attention
220
+ attention_states, _ = self.self_attn(
221
+ hidden_states=hidden_states,
222
+ position_embeddings=position_embeddings,
223
+ attention_mask=attention_mask,
224
+ past_key_values=past_key_values,
225
+ use_cache=use_cache,
226
+ cache_position=cache_position,
227
+ **kwargs,
228
+ )
229
+ hidden_states = residual + attention_states
230
+
231
+ # Fully Connected
232
+ residual = hidden_states
233
+ hidden_states = self.post_attention_layernorm(hidden_states)
234
+ hidden_states = self.feed_forward(hidden_states)
235
+ if self.is_moe_layer:
236
+ hidden_states, _ = hidden_states
237
+ hidden_states = residual + hidden_states.view(residual.shape)
238
+ return hidden_states
239
+
240
+ class MobileLLMP1TextModel(Llama4PreTrainedModel):
241
+ _no_split_modules = ["MobileLLMP1TextDecoderLayer"]
242
+ base_model_prefix = "model"
243
+ config: MobileLLMP1TextConfig
244
+ _can_record_outputs = {
245
+ "attentions": MobileLLMP1TextAttention,
246
+ "hidden_states": MobileLLMP1TextDecoderLayer,
247
+ "router_logits": Llama4TextMoe,
248
+ }
249
+
250
+ def __init__(self, config: MobileLLMP1TextConfig):
251
+ super().__init__(config)
252
+ self.padding_idx = config.pad_token_id
253
+ self.vocab_size = config.vocab_size
254
+
255
+ self.embed_tokens = nn.Embedding(
256
+ config.vocab_size, config.hidden_size, self.padding_idx
257
+ )
258
+ self.layers = nn.ModuleList(
259
+ [
260
+ MobileLLMP1TextDecoderLayer(config, layer_idx)
261
+ for layer_idx in range(config.num_hidden_layers)
262
+ ]
263
+ )
264
+ self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
265
+ self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
266
+ self.gradient_checkpointing = False
267
+
268
+ # Initialize weights and apply final processing
269
+ self.post_init()
270
+
271
+ def forward(
272
+ self,
273
+ input_ids: Optional[torch.LongTensor] = None,
274
+ attention_mask: Optional[torch.Tensor] = None,
275
+ position_ids: Optional[torch.LongTensor] = None,
276
+ past_key_values: Optional[Cache] = None,
277
+ inputs_embeds: Optional[torch.FloatTensor] = None,
278
+ use_cache: Optional[bool] = None,
279
+ cache_position: Optional[torch.LongTensor] = None,
280
+ **kwargs: Unpack[TransformersKwargs],
281
+ ) -> Union[tuple, BaseModelOutputWithPast]:
282
+ if (input_ids is None) ^ (inputs_embeds is not None):
283
+ raise ValueError(
284
+ "You must specify exactly one of input_ids or inputs_embeds"
285
+ )
286
+
287
+ if inputs_embeds is None:
288
+ inputs_embeds = self.embed_tokens(
289
+ input_ids.to(self.embed_tokens.weight.device)
290
+ )
291
+
292
+ if use_cache and past_key_values is None:
293
+ past_key_values = DynamicCache(config=self.config)
294
+
295
+ if cache_position is None:
296
+ past_seen_tokens = (
297
+ past_key_values.get_seq_length() if past_key_values is not None else 0
298
+ )
299
+ cache_position = torch.arange(
300
+ past_seen_tokens,
301
+ past_seen_tokens + inputs_embeds.shape[1],
302
+ device=inputs_embeds.device,
303
+ )
304
+
305
+ if position_ids is None:
306
+ position_ids = cache_position.unsqueeze(0)
307
+
308
+ # It may already have been prepared by e.g. `generate`
309
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
310
+ # Prepare mask arguments
311
+ mask_kwargs = {
312
+ "config": self.config,
313
+ "input_embeds": inputs_embeds,
314
+ "attention_mask": attention_mask,
315
+ "cache_position": cache_position,
316
+ "past_key_values": past_key_values,
317
+ "position_ids": position_ids,
318
+ }
319
+ sliding_mask_kwargs = mask_kwargs.copy()
320
+ del sliding_mask_kwargs['position_ids']
321
+
322
+ # Create the masks
323
+ causal_mask_mapping = {
324
+ "full_attention": create_causal_mask(**mask_kwargs),
325
+ "sliding_attention": create_sliding_window_causal_mask(
326
+ **sliding_mask_kwargs
327
+ ),
328
+ }
329
+
330
+ hidden_states = inputs_embeds
331
+
332
+ # create position embeddings to be shared across the decoder layers
333
+ freq_cis = self.rotary_emb(hidden_states, position_ids)
334
+
335
+ # found = False
336
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
337
+ hidden_states = decoder_layer(
338
+ hidden_states,
339
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
340
+ position_ids=position_ids,
341
+ past_key_values=past_key_values,
342
+ use_cache=use_cache,
343
+ cache_position=cache_position,
344
+ position_embeddings=freq_cis,
345
+ **kwargs,
346
+ )
347
+ hidden_states = self.norm(hidden_states)
348
+
349
+ return BaseModelOutputWithPast(
350
+ last_hidden_state=hidden_states,
351
+ past_key_values=past_key_values if use_cache else None,
352
+ )
353
+
354
+
355
+ class MobileLLMP1ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
356
+ _no_split_modules = ["MobileLLMP1TextDecoderLayer"]
357
+ base_model_prefix = "language_model"
358
+ _tied_weights_keys = ["lm_head.weight"]
359
+ _tp_plan = {"lm_head": "colwise_rep"}
360
+ config: MobileLLMP1TextConfig
361
+
362
+ def __init__(self, config: MobileLLMP1TextConfig):
363
+ super().__init__(config)
364
+ self.model = MobileLLMP1TextModel(config)
365
+ self.vocab_size = config.vocab_size
366
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
367
+
368
+ # Initialize weights and apply final processing
369
+ self.post_init()
370
+
371
+ def forward(
372
+ self,
373
+ input_ids: Optional[torch.LongTensor] = None,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ position_ids: Optional[torch.LongTensor] = None,
376
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
377
+ inputs_embeds: Optional[torch.FloatTensor] = None,
378
+ labels: Optional[torch.LongTensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ cache_position: Optional[torch.LongTensor] = None,
381
+ logits_to_keep: Union[int, torch.Tensor] = 0,
382
+ **kwargs: Unpack[TransformersKwargs],
383
+ ) -> Union[tuple, CausalLMOutputWithPast]:
384
+ r"""
385
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
386
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
387
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
388
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
389
+
390
+ Example:
391
+
392
+ ```python
393
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
394
+
395
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
396
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
397
+
398
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
399
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
400
+
401
+ >>> # Generate
402
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
403
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
404
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
405
+ ```"""
406
+ outputs = self.model(
407
+ input_ids=input_ids,
408
+ attention_mask=attention_mask,
409
+ position_ids=position_ids,
410
+ past_key_values=past_key_values,
411
+ inputs_embeds=inputs_embeds,
412
+ use_cache=use_cache,
413
+ cache_position=cache_position,
414
+ **kwargs,
415
+ )
416
+
417
+ hidden_states = outputs[0]
418
+ # Only compute necessary logixts, and do not upcast them to float if we are not computing the loss
419
+ slice_indices = (
420
+ slice(-logits_to_keep, None)
421
+ if isinstance(logits_to_keep, int)
422
+ else logits_to_keep
423
+ )
424
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
425
+ loss = None
426
+ if labels is not None:
427
+ loss = self.loss_function(
428
+ logits=logits,
429
+ labels=labels,
430
+ vocab_size=self.config.vocab_size,
431
+ **kwargs,
432
+ )
433
+
434
+ return CausalLMOutputWithPast(
435
+ loss=loss,
436
+ logits=logits,
437
+ past_key_values=outputs.past_key_values,
438
+ hidden_states=outputs.hidden_states,
439
+ attentions=outputs.attentions,
440
+ )
441
+
442
+ __all__ = [
443
+ "MobileLLMP1ForCausalLM",
444
+ "MobileLLMP1TextModel",
445
+ "MobileLLMP1TextDecoderLayer",
446
+ "MobileLLMP1TextAttention",
447
+ ]