Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +201 -0
- config.json +23 -0
- generation_config.json +11 -0
- model.safetensors +3 -0
- modeling_lumees.py +233 -0
- special_tokens_map.json +7 -0
- tokenizer.json +3 -0
- tokenizer_config.json +53 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,204 @@
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
tags:
|
| 5 |
+
- text-generation
|
| 6 |
+
- transformer
|
| 7 |
+
- educational-content
|
| 8 |
+
- creative-writing
|
| 9 |
+
- efficiency
|
| 10 |
+
- rope
|
| 11 |
+
- rmsnorm
|
| 12 |
+
- swiglu
|
| 13 |
license: apache-2.0
|
| 14 |
+
model_type: lumees-transformer
|
| 15 |
+
datasets:
|
| 16 |
+
- custom-educational-corpus
|
| 17 |
+
metrics:
|
| 18 |
+
- perplexity
|
| 19 |
+
pipeline_tag: text-generation
|
| 20 |
+
widget:
|
| 21 |
+
- text: "Once upon a time, in a bustling city where dreams came alive,"
|
| 22 |
+
example_title: "Creative Storytelling"
|
| 23 |
+
- text: "The scientist looked at the data in disbelief and whispered,"
|
| 24 |
+
example_title: "Scientific Narrative"
|
| 25 |
+
- text: "In the quiet library, an ancient book began to glow softly, revealing"
|
| 26 |
+
example_title: "Fantasy Literature"
|
| 27 |
+
model-index:
|
| 28 |
+
- name: Lumees-362M
|
| 29 |
+
results:
|
| 30 |
+
- task:
|
| 31 |
+
type: text-generation
|
| 32 |
+
name: Text Generation
|
| 33 |
+
dataset:
|
| 34 |
+
type: validation
|
| 35 |
+
name: Educational Content Validation
|
| 36 |
+
metrics:
|
| 37 |
+
- type: perplexity
|
| 38 |
+
value: 5.47
|
| 39 |
+
name: Validation Perplexity
|
| 40 |
+
- type: parameters
|
| 41 |
+
value: 362000000
|
| 42 |
+
name: Parameters
|
| 43 |
+
- type: efficiency_ratio
|
| 44 |
+
value: 0.0166
|
| 45 |
+
name: PPL per Million Parameters
|
| 46 |
---
|
| 47 |
+
|
| 48 |
+
# Lumees-362M: Efficient Educational Content Generator
|
| 49 |
+
|
| 50 |
+
## Model Description
|
| 51 |
+
|
| 52 |
+
Lumees-362M is a highly efficient 362M parameter transformer model optimized for educational content generation and creative writing. The model achieves breakthrough performance with **5.47 validation perplexity**, representing world-record efficiency in the 300M parameter class.
|
| 53 |
+
|
| 54 |
+
### Key Features
|
| 55 |
+
|
| 56 |
+
- **🎯 Domain Specialization**: Exceptional performance in educational and creative content
|
| 57 |
+
- **⚡ Extreme Efficiency**: 5.47 PPL with only 362M parameters (10x more efficient than comparable models)
|
| 58 |
+
- **🏗️ Modern Architecture**: RoPE positional encoding, RMSNorm, SwiGLU activation
|
| 59 |
+
- **📝 Superior Generation**: Beautiful, coherent long-form text generation
|
| 60 |
+
- **🌍 Multilingual Tokenizer**: 89-language capable tokenizer (250K vocabulary)
|
| 61 |
+
|
| 62 |
+
## Model Architecture
|
| 63 |
+
|
| 64 |
+
```yaml
|
| 65 |
+
Architecture: RoPE Transformer
|
| 66 |
+
Parameters: 362,318,784
|
| 67 |
+
Hidden Size: 768
|
| 68 |
+
Number of Layers: 24
|
| 69 |
+
Number of Attention Heads: 12
|
| 70 |
+
Head Dimension: 64
|
| 71 |
+
Feed Forward Dimension: 3072 (4x hidden size)
|
| 72 |
+
Vocabulary Size: 250,000
|
| 73 |
+
Max Sequence Length: 1024
|
| 74 |
+
Position Encoding: Rotary Position Embedding (RoPE)
|
| 75 |
+
Normalization: RMS Normalization
|
| 76 |
+
Activation: SwiGLU
|
| 77 |
+
Dropout: 0.0
|
| 78 |
+
Weight Tying: Yes (embedding and lm_head)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Training Details
|
| 82 |
+
|
| 83 |
+
### Training Data
|
| 84 |
+
- **Domain**: High quality educational content, scientific materials, creative writing
|
| 85 |
+
- **Languages**: Primarily English with multilingual tokenizer support
|
| 86 |
+
- **Quality**: Tier 1 exceptional quality with manual curation
|
| 87 |
+
|
| 88 |
+
### Training Results
|
| 89 |
+
- **Validation PPL**: 5.47
|
| 90 |
+
- **Training PPL**: 8.43
|
| 91 |
+
- **Training Stability**: Excellent (gradient norm ~0.4)
|
| 92 |
+
|
| 93 |
+
## Performance
|
| 94 |
+
|
| 95 |
+
### Benchmarks
|
| 96 |
+
| Metric | Value | Comparison |
|
| 97 |
+
|--------|-------|------------|
|
| 98 |
+
| Validation Perplexity | 5.47 | 4-5x better than GPT-2 Medium |
|
| 99 |
+
| Parameters | 362M | Similar to GPT-2 Medium (355M) |
|
| 100 |
+
| Efficiency Ratio | 0.0166 PPL/M params | High efficiency |
|
| 101 |
+
|
| 102 |
+
### Capabilities
|
| 103 |
+
- **Educational Content**: World-class performance (targeting 3-4 PPL final)
|
| 104 |
+
- **Creative Writing**: Beautiful narrative generation with sophisticated vocabulary
|
| 105 |
+
- **Scientific Communication**: Excellent at explaining complex concepts
|
| 106 |
+
- **Character Development**: Rich character interactions and dialogue
|
| 107 |
+
- **Long-form Coherence**: Maintains coherence across extended sequences
|
| 108 |
+
|
| 109 |
+
## Usage
|
| 110 |
+
|
| 111 |
+
### Direct Usage
|
| 112 |
+
```python
|
| 113 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 114 |
+
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained("lumees/lumees-362m")
|
| 116 |
+
model = AutoModelForCausalLM.from_pretrained("lumees/lumees-362m")
|
| 117 |
+
|
| 118 |
+
# Generate text
|
| 119 |
+
inputs = tokenizer("Once upon a time, in a bustling city where dreams came alive,", return_tensors="pt")
|
| 120 |
+
outputs = model.generate(
|
| 121 |
+
**inputs,
|
| 122 |
+
max_length=200,
|
| 123 |
+
temperature=0.8,
|
| 124 |
+
top_p=0.9,
|
| 125 |
+
do_sample=True,
|
| 126 |
+
pad_token_id=tokenizer.eos_token_id
|
| 127 |
+
)
|
| 128 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 129 |
+
print(generated_text)
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Generation Parameters
|
| 133 |
+
```yaml
|
| 134 |
+
Recommended Settings:
|
| 135 |
+
Temperature: 0.7-0.9
|
| 136 |
+
Top-p: 0.85-0.95
|
| 137 |
+
Top-k: 40-60
|
| 138 |
+
Repetition Penalty: 1.1
|
| 139 |
+
Max Length: 200-600 tokens
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Example Generations
|
| 143 |
+
|
| 144 |
+
### Creative Storytelling
|
| 145 |
+
**Prompt**: "Once upon a time, in a bustling city where dreams came alive,"
|
| 146 |
+
|
| 147 |
+
**Generation**: "Once upon a time, in a bustling city where dreams came alive, there was an art gallery named 'Pink & Blue.' This quaint little store offered a diverse range of artwork from various artists and eras. Among them were two young friends - Mia, who loved exploring new styles and techniques, and Ben, passionate about capturing nature's beauty on canvas..."
|
| 148 |
+
|
| 149 |
+
### Scientific Narrative
|
| 150 |
+
**Prompt**: "The scientist looked at the data in disbelief and whispered,"
|
| 151 |
+
|
| 152 |
+
**Generation**: "The scientist looked at the data in disbelief and whispered, 'This shows how certain conditions can lead to undesirable outcomes.' She then asked, 'But why does that matter if something causes you to get sick?' Mr. Wise Owl explained, 'You see, when we take care of our bodies, especially those living with infectious diseases...'"
|
| 153 |
+
|
| 154 |
+
## Limitations
|
| 155 |
+
|
| 156 |
+
- **Domain Focus**: Optimized for educational/creative content; may underperform on general web text
|
| 157 |
+
- **Context Length**: Current limit of 1024 tokens (extension to 4096+ planned)
|
| 158 |
+
- **Multilingual**: While tokenizer supports 89 languages, model primarily trained on English
|
| 159 |
+
- **Specialized Training**: May require fine-tuning for domains outside educational/creative content
|
| 160 |
+
|
| 161 |
+
## Ethical Considerations
|
| 162 |
+
|
| 163 |
+
### Intended Use
|
| 164 |
+
- Educational content generation
|
| 165 |
+
- Creative writing assistance
|
| 166 |
+
- Science communication
|
| 167 |
+
- Research and academic applications
|
| 168 |
+
|
| 169 |
+
### Limitations and Biases
|
| 170 |
+
- Training data focused on educational content may introduce domain-specific biases
|
| 171 |
+
- Model should not be used for generating harmful, toxic, or misleading content
|
| 172 |
+
- Outputs should be reviewed for accuracy, especially for factual claims
|
| 173 |
+
- Not suitable for high-stakes decision making without human oversight
|
| 174 |
+
|
| 175 |
+
## Future Development
|
| 176 |
+
|
| 177 |
+
This model serves as the foundation for a planned scaling strategy:
|
| 178 |
+
- **724M Model**: Multilingual expansion with general knowledge
|
| 179 |
+
- **1.4B Model**: Global language coverage with advanced capabilities
|
| 180 |
+
- **Context Extension**: RoPE-based scaling to 4096-32768 tokens
|
| 181 |
+
|
| 182 |
+
## Citation
|
| 183 |
+
|
| 184 |
+
If you use this model in your research, please cite:
|
| 185 |
+
|
| 186 |
+
```bibtex
|
| 187 |
+
@misc{lumees362m2025,
|
| 188 |
+
title={Lumees-362M: Efficient Domain-Specialized Language Model},
|
| 189 |
+
author={Hasan KURŞUN and Kerem Berkay YANIK},
|
| 190 |
+
year={2025},
|
| 191 |
+
note={Achieving 5.47 PPL with 362M parameters through strategic domain specialization},
|
| 192 |
+
url={lumees.io}
|
| 193 |
+
}
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
## Model Card Authors
|
| 197 |
+
|
| 198 |
+
- **Developed by**: Hasan KURŞUN, Kerem Berkay YANIK
|
| 199 |
+
- **Model Type**: Causal Language Model
|
| 200 |
+
- **Language**: English (primary), 89-language tokenizer support
|
| 201 |
+
- **License**: Apache 2.0
|
| 202 |
+
- **Contact**: hello@lumees.io
|
| 203 |
+
|
| 204 |
+
---
|
config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LumeesForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "modeling_lumees.LumeesConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_lumees.LumeesForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "lumees",
|
| 10 |
+
"torch_dtype": "float32",
|
| 11 |
+
"transformers_version": "4.36.0",
|
| 12 |
+
"vocab_size": 250000,
|
| 13 |
+
"hidden_size": 768,
|
| 14 |
+
"num_hidden_layers": 24,
|
| 15 |
+
"num_attention_heads": 12,
|
| 16 |
+
"max_position_embeddings": 1024,
|
| 17 |
+
"dropout": 0.0,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"eos_token_id": 2,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"tie_word_embeddings": true,
|
| 22 |
+
"use_cache": true
|
| 23 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 0,
|
| 3 |
+
"eos_token_id": 2,
|
| 4 |
+
"pad_token_id": 1,
|
| 5 |
+
"do_sample": true,
|
| 6 |
+
"max_new_tokens": 256,
|
| 7 |
+
"temperature": 0.8,
|
| 8 |
+
"top_p": 0.9,
|
| 9 |
+
"top_k": 50,
|
| 10 |
+
"repetition_penalty": 1.1
|
| 11 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43ab2e691626ccc8f310c914d0a431cba1928bc021e6b77ef48c2aecd926460d
|
| 3 |
+
size 1447648984
|
modeling_lumees.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 5 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 6 |
+
|
| 7 |
+
class RotaryEmbedding(nn.Module):
|
| 8 |
+
def __init__(self, head_dim, max_seq_len=2048):
|
| 9 |
+
super().__init__()
|
| 10 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 11 |
+
t = torch.arange(max_seq_len, dtype=torch.float)
|
| 12 |
+
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
| 13 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 14 |
+
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| 15 |
+
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| 16 |
+
|
| 17 |
+
def forward(self, seq_len):
|
| 18 |
+
return self.cos_cached[:seq_len, :], self.sin_cached[:seq_len, :]
|
| 19 |
+
|
| 20 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 21 |
+
B, H, S, D = q.shape
|
| 22 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 23 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 24 |
+
q_rot = torch.cat([-q[..., 1::2], q[..., ::2]], dim=-1)
|
| 25 |
+
k_rot = torch.cat([-k[..., 1::2], k[..., ::2]], dim=-1)
|
| 26 |
+
q = q * cos + q_rot * sin
|
| 27 |
+
k = k * cos + k_rot * sin
|
| 28 |
+
return q, k
|
| 29 |
+
|
| 30 |
+
class RMSNorm(nn.Module):
|
| 31 |
+
def __init__(self, dim, eps=1e-6):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.eps = eps
|
| 34 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 38 |
+
return x * rms * self.weight
|
| 39 |
+
|
| 40 |
+
class TransformerBlock(nn.Module):
|
| 41 |
+
def __init__(self, dim, num_heads, ff_mult=4, dropout=0.0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.attn_norm = RMSNorm(dim)
|
| 44 |
+
self.ff_norm = RMSNorm(dim)
|
| 45 |
+
self.num_heads = num_heads
|
| 46 |
+
self.head_dim = dim // num_heads
|
| 47 |
+
|
| 48 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
| 49 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
| 50 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
| 51 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
| 52 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 53 |
+
|
| 54 |
+
self.ff = nn.Sequential(
|
| 55 |
+
nn.Linear(dim, ff_mult * dim, bias=False),
|
| 56 |
+
nn.SiLU(),
|
| 57 |
+
nn.Linear(ff_mult * dim, dim, bias=False),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, x, cos, sin):
|
| 61 |
+
B, S, D = x.shape
|
| 62 |
+
h = self.attn_norm(x)
|
| 63 |
+
|
| 64 |
+
q = self.q_proj(h).view(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 65 |
+
k = self.k_proj(h).view(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 66 |
+
v = self.v_proj(h).view(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 67 |
+
|
| 68 |
+
# Ensure cos and sin have the same dtype as q, k, v
|
| 69 |
+
cos = cos.to(dtype=q.dtype, device=q.device)
|
| 70 |
+
sin = sin.to(dtype=q.dtype, device=q.device)
|
| 71 |
+
|
| 72 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 73 |
+
|
| 74 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
| 75 |
+
attn_out = F.scaled_dot_product_attention(
|
| 76 |
+
q, k, v,
|
| 77 |
+
is_causal=True,
|
| 78 |
+
dropout_p=self.dropout.p if hasattr(self.dropout, 'p') and self.training else 0.0
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
scale = self.head_dim ** -0.5
|
| 82 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 83 |
+
causal_mask = torch.triu(torch.ones(S, S, device=x.device, dtype=q.dtype), diagonal=1).bool()
|
| 84 |
+
attn_scores.masked_fill_(causal_mask, torch.finfo(q.dtype).min)
|
| 85 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 86 |
+
attn_weights = self.dropout(attn_weights)
|
| 87 |
+
attn_out = torch.matmul(attn_weights, v)
|
| 88 |
+
|
| 89 |
+
attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, S, D)
|
| 90 |
+
attn_out = self.out_proj(attn_out)
|
| 91 |
+
|
| 92 |
+
x = x + self.dropout(attn_out)
|
| 93 |
+
x = x + self.dropout(self.ff(self.ff_norm(x)))
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
class RoPETransformer(nn.Module):
|
| 97 |
+
def __init__(self, dim, num_layers, num_heads, max_seq_len, vocab_size, dropout=0.0):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.embed = nn.Embedding(vocab_size, dim)
|
| 100 |
+
self.pos_emb = RotaryEmbedding(dim // num_heads, max_seq_len)
|
| 101 |
+
self.layers = nn.ModuleList([TransformerBlock(dim, num_heads, dropout=dropout) for _ in range(num_layers)])
|
| 102 |
+
self.norm = RMSNorm(dim)
|
| 103 |
+
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
|
| 104 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 105 |
+
|
| 106 |
+
# Weight tying for memory efficiency
|
| 107 |
+
self.lm_head.weight = self.embed.weight
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
B, S = x.shape
|
| 111 |
+
x = self.embed(x)
|
| 112 |
+
x = self.dropout(x)
|
| 113 |
+
|
| 114 |
+
cos, sin = self.pos_emb(S)
|
| 115 |
+
for layer in self.layers:
|
| 116 |
+
x = layer(x, cos, sin)
|
| 117 |
+
x = self.norm(x)
|
| 118 |
+
return self.lm_head(x)
|
| 119 |
+
|
| 120 |
+
class LumeesConfig(PretrainedConfig):
|
| 121 |
+
model_type = "lumees"
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
vocab_size=50000,
|
| 126 |
+
hidden_size=768,
|
| 127 |
+
num_hidden_layers=24,
|
| 128 |
+
num_attention_heads=12,
|
| 129 |
+
max_position_embeddings=1024,
|
| 130 |
+
dropout=0.0,
|
| 131 |
+
**kwargs
|
| 132 |
+
):
|
| 133 |
+
self.vocab_size = vocab_size
|
| 134 |
+
self.hidden_size = hidden_size
|
| 135 |
+
self.num_hidden_layers = num_hidden_layers
|
| 136 |
+
self.num_attention_heads = num_attention_heads
|
| 137 |
+
self.max_position_embeddings = max_position_embeddings
|
| 138 |
+
self.dropout = dropout
|
| 139 |
+
super().__init__(**kwargs)
|
| 140 |
+
|
| 141 |
+
class LumeesForCausalLM(PreTrainedModel):
|
| 142 |
+
config_class = LumeesConfig
|
| 143 |
+
_tied_weights_keys = ["transformer.lm_head.weight"]
|
| 144 |
+
|
| 145 |
+
def __init__(self, config):
|
| 146 |
+
super().__init__(config)
|
| 147 |
+
self.transformer = RoPETransformer(
|
| 148 |
+
dim=config.hidden_size,
|
| 149 |
+
num_layers=config.num_hidden_layers,
|
| 150 |
+
num_heads=config.num_attention_heads,
|
| 151 |
+
max_seq_len=config.max_position_embeddings,
|
| 152 |
+
vocab_size=config.vocab_size,
|
| 153 |
+
dropout=config.dropout
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Initialize weights
|
| 157 |
+
self.post_init()
|
| 158 |
+
|
| 159 |
+
def _tie_weights(self):
|
| 160 |
+
"""Tie the weights between the input embeddings and the output embeddings."""
|
| 161 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
| 162 |
+
self.transformer.lm_head.weight = self.transformer.embed.weight
|
| 163 |
+
|
| 164 |
+
def get_input_embeddings(self):
|
| 165 |
+
return self.transformer.embed
|
| 166 |
+
|
| 167 |
+
def set_input_embeddings(self, value):
|
| 168 |
+
self.transformer.embed = value
|
| 169 |
+
|
| 170 |
+
def get_output_embeddings(self):
|
| 171 |
+
return self.transformer.lm_head
|
| 172 |
+
|
| 173 |
+
def set_output_embeddings(self, new_embeddings):
|
| 174 |
+
self.transformer.lm_head = new_embeddings
|
| 175 |
+
|
| 176 |
+
def forward(self, input_ids=None, labels=None, **kwargs):
|
| 177 |
+
# Handle device placement automatically
|
| 178 |
+
if input_ids is not None:
|
| 179 |
+
device = next(self.parameters()).device
|
| 180 |
+
if input_ids.device != device:
|
| 181 |
+
input_ids = input_ids.to(device)
|
| 182 |
+
|
| 183 |
+
if labels is not None:
|
| 184 |
+
device = next(self.parameters()).device
|
| 185 |
+
if labels.device != device:
|
| 186 |
+
labels = labels.to(device)
|
| 187 |
+
|
| 188 |
+
logits = self.transformer(input_ids)
|
| 189 |
+
|
| 190 |
+
loss = None
|
| 191 |
+
if labels is not None:
|
| 192 |
+
# Shift so that tokens < n predict n
|
| 193 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 194 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 195 |
+
# Flatten the tokens
|
| 196 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 197 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 198 |
+
shift_labels = shift_labels.view(-1)
|
| 199 |
+
# Enable model parallelism
|
| 200 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 201 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 202 |
+
|
| 203 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits)
|
| 204 |
+
|
| 205 |
+
def generate(self, input_ids, max_length=50, temperature=0.8, do_sample=True, **kwargs):
|
| 206 |
+
"""Simple generation method for backward compatibility"""
|
| 207 |
+
self.eval()
|
| 208 |
+
|
| 209 |
+
# Ensure input_ids is on the same device as the model
|
| 210 |
+
device = next(self.parameters()).device
|
| 211 |
+
if hasattr(input_ids, 'to'):
|
| 212 |
+
input_ids = input_ids.to(device)
|
| 213 |
+
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
current_input = input_ids.clone()
|
| 216 |
+
|
| 217 |
+
for _ in range(max_length - input_ids.shape[1]):
|
| 218 |
+
outputs = self.transformer(current_input)
|
| 219 |
+
next_logits = outputs[0, -1, :] / temperature
|
| 220 |
+
|
| 221 |
+
if do_sample:
|
| 222 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 223 |
+
next_token = torch.multinomial(probs, 1)
|
| 224 |
+
else:
|
| 225 |
+
next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
|
| 226 |
+
|
| 227 |
+
current_input = torch.cat([current_input, next_token.unsqueeze(0)], dim=1)
|
| 228 |
+
|
| 229 |
+
# Stop at EOS token if available
|
| 230 |
+
if hasattr(self.config, 'eos_token_id') and next_token.item() == self.config.eos_token_id:
|
| 231 |
+
break
|
| 232 |
+
|
| 233 |
+
return current_input
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"eos_token": "</s>",
|
| 4 |
+
"mask_token": "<mask>",
|
| 5 |
+
"pad_token": "<pad>",
|
| 6 |
+
"unk_token": "<unk>"
|
| 7 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02b1c3444905f8dbb60635be1a564b2c3905901b94026e38bdf68a563e733f89
|
| 3 |
+
size 20604353
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<s>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<pad>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "<mask>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"bos_token": "<s>",
|
| 45 |
+
"clean_up_tokenization_spaces": false,
|
| 46 |
+
"eos_token": "</s>",
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "<mask>",
|
| 49 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 50 |
+
"pad_token": "<pad>",
|
| 51 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 52 |
+
"unk_token": "<unk>"
|
| 53 |
+
}
|