Spaces:
Sleeping
Sleeping
Pomilon
commited on
Commit
·
1df0e33
0
Parent(s):
Deploy Aetheris to HF Space
Browse files- .gitattributes +1 -0
- .gitignore +28 -0
- Dockerfile +38 -0
- Dockerfile-nvidia +41 -0
- LICENSE +21 -0
- README.md +146 -0
- aetheris/__init__.py +2 -0
- aetheris/api/schemas.py +92 -0
- aetheris/api/server.py +162 -0
- aetheris/cli/__init__.py +1 -0
- aetheris/cli/main.py +287 -0
- aetheris/config.py +58 -0
- aetheris/data.py +105 -0
- aetheris/inference.py +106 -0
- aetheris/model.py +86 -0
- aetheris/modules/__init__.py +3 -0
- aetheris/modules/expert.py +35 -0
- aetheris/modules/moe.py +83 -0
- aetheris/modules/ssm.py +91 -0
- aetheris/trainer/__init__.py +1 -0
- aetheris/trainer/trainer.py +145 -0
- aetheris/utils.py +39 -0
- configs/debug.yaml +16 -0
- configs/default.yaml +16 -0
- configs/inference.yaml +16 -0
- configs/large.yaml +16 -0
- requirements.txt +12 -0
- scripts/generate.py +16 -0
- scripts/info.py +11 -0
- scripts/train.py +6 -0
- scripts/validate.py +87 -0
- tests/test_api.py +88 -0
- tests/test_inference.py +71 -0
- tests/test_model.py +55 -0
- tests/test_overflow.py +67 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
checkpoints/checkpoint_current.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
env/
|
| 7 |
+
venv/
|
| 8 |
+
.env
|
| 9 |
+
.venv
|
| 10 |
+
build/
|
| 11 |
+
develop-eggs/
|
| 12 |
+
dist/
|
| 13 |
+
downloads/
|
| 14 |
+
eggs/
|
| 15 |
+
.eggs/
|
| 16 |
+
lib/
|
| 17 |
+
lib64/
|
| 18 |
+
parts/
|
| 19 |
+
sdist/
|
| 20 |
+
var/
|
| 21 |
+
wheels/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
checkpoints/
|
| 26 |
+
*.log
|
| 27 |
+
.DS_Store
|
| 28 |
+
legacy/
|
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Set environment variables
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
git \
|
| 10 |
+
build-essential \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Set working directory
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
|
| 16 |
+
# Create a user first to handle permissions correctly from the start
|
| 17 |
+
RUN useradd -m -u 1000 user
|
| 18 |
+
|
| 19 |
+
# Switch to user
|
| 20 |
+
USER user
|
| 21 |
+
ENV HOME=/home/user \
|
| 22 |
+
PATH=/home/user/.local/bin:$PATH
|
| 23 |
+
|
| 24 |
+
# Set up application directory with correct permissions
|
| 25 |
+
WORKDIR $HOME/app
|
| 26 |
+
|
| 27 |
+
# Copy requirements and install
|
| 28 |
+
COPY --chown=user requirements.txt .
|
| 29 |
+
RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 30 |
+
|
| 31 |
+
# Copy application code
|
| 32 |
+
COPY --chown=user . .
|
| 33 |
+
|
| 34 |
+
# Expose port
|
| 35 |
+
EXPOSE 7860
|
| 36 |
+
|
| 37 |
+
# Command to run the application
|
| 38 |
+
CMD ["python3", "-m", "aetheris.cli.main", "serve", "--host", "0.0.0.0", "--port", "7860"]
|
Dockerfile-nvidia
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use NVIDIA CUDA base image for GPU support
|
| 2 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
DEBIAN_FRONTEND=noninteractive
|
| 8 |
+
|
| 9 |
+
# Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
python3-pip \
|
| 12 |
+
python3-dev \
|
| 13 |
+
git \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Set working directory
|
| 17 |
+
WORKDIR /app
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Copy application code
|
| 24 |
+
COPY . .
|
| 25 |
+
|
| 26 |
+
# Expose port (7860 is default for Hugging Face Spaces)
|
| 27 |
+
EXPOSE 7860
|
| 28 |
+
|
| 29 |
+
# Create a user to avoid running as root (good practice, also sometimes required by HF)
|
| 30 |
+
# But often HF runs as user 1000.
|
| 31 |
+
RUN useradd -m -u 1000 user
|
| 32 |
+
USER user
|
| 33 |
+
ENV HOME=/home/user \
|
| 34 |
+
PATH=/home/user/.local/bin:$PATH
|
| 35 |
+
|
| 36 |
+
WORKDIR $HOME/app
|
| 37 |
+
COPY --chown=user . $HOME/app
|
| 38 |
+
|
| 39 |
+
# Command to run the application
|
| 40 |
+
# We use the CLI serve command we added
|
| 41 |
+
CMD ["python3", "-m", "aetheris.cli.main", "serve", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Pomilon
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Aetheris: Hybrid Mamba-MoE Experiment
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<img src="https://img.shields.io/badge/Status-Experimental-yellow.svg" alt="Status">
|
| 5 |
+
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="License">
|
| 6 |
+
<img src="https://img.shields.io/badge/Python-3.10+-blue.svg" alt="Python">
|
| 7 |
+
<img src="https://img.shields.io/badge/PyTorch-2.0+-orange.svg" alt="PyTorch">
|
| 8 |
+
<img src="https://img.shields.io/badge/API-FastAPI-009688.svg" alt="FastAPI">
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
**Aetheris** is a hobbyist research project and experimental implementation exploring the intersection of **State Space Models (Mamba)** and **Mixture of Experts (MoE)**.
|
| 13 |
+
|
| 14 |
+
The goal of this project was to learn by doing: attempting to combine the linear-time inference of Mamba with the sparse scaling capacity of MoE from scratch in PyTorch. It is designed as a playground for understanding these modern architectures, not as a published academic paper or production-ready foundation model.
|
| 15 |
+
|
| 16 |
+
## 🧪 The Experiment
|
| 17 |
+
|
| 18 |
+
Current LLM architectures are evolving rapidly. I built Aetheris to investigate a specific question:
|
| 19 |
+
> *Can we successfully interleave Mamba blocks (for long context) with sparse MoE layers (for capacity) to train an efficient model on consumer hardware?*
|
| 20 |
+
|
| 21 |
+
This project implements a hybrid architecture that attempts to:
|
| 22 |
+
1. **Replace Attention:** Use Mamba (SSM) blocks to achieve $O(N)$ sequence scaling.
|
| 23 |
+
2. **Scale Parameters Sparsely:** Use MoE layers to increase model size without exploding the computational cost per token.
|
| 24 |
+
3. **Run Locally:** Optimize the implementation for single-GPU training (gradient checkpointing, efficient routing).
|
| 25 |
+
|
| 26 |
+
## 🏗️ Architecture Implementation
|
| 27 |
+
|
| 28 |
+
Aetheris alternates between custom implementations of two core modules:
|
| 29 |
+
|
| 30 |
+
* **SSMBlock (The Backbone):** Implements the selective scan mechanism described in the [Mamba paper](https://arxiv.org/abs/2312.00752). This handles the sequence mixing and "memory" of the model.
|
| 31 |
+
* **SparseMoELayer (The Scaling):** A router-based layer that dispatches tokens to Top-K experts (Feed-Forward Networks). This allows the model to "specialize" parts of its parameters for different types of tokens.
|
| 32 |
+
|
| 33 |
+
## 🚀 Quick Start
|
| 34 |
+
|
| 35 |
+
This code is provided for educational purposes and for others who want to experiment with hybrid architectures.
|
| 36 |
+
|
| 37 |
+
### Installation
|
| 38 |
+
|
| 39 |
+
**Option 1: Local Python Environment**
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
git clone https://github.com/Pomilon/Aetheris.git
|
| 43 |
+
cd Aetheris
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
**Option 2: Docker**
|
| 48 |
+
|
| 49 |
+
We provide Dockerfiles for both CPU (slim) and GPU (NVIDIA) environments.
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# CPU Version
|
| 53 |
+
docker build -t aetheris-cpu -f Dockerfile .
|
| 54 |
+
docker run -p 7860:7860 aetheris-cpu
|
| 55 |
+
|
| 56 |
+
# GPU Version (Requires NVIDIA Container Toolkit)
|
| 57 |
+
docker build -t aetheris-gpu -f Dockerfile-nvidia .
|
| 58 |
+
docker run --gpus all -p 7860:7860 aetheris-gpu
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Usage (CLI)
|
| 62 |
+
|
| 63 |
+
Aetheris includes a CLI to train, inference, or serve the model.
|
| 64 |
+
|
| 65 |
+
**1. Training (From Scratch)**
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# Trains a small model defined in configs/default.yaml
|
| 69 |
+
python -m aetheris.cli.main train --config configs/default.yaml
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
**2. Generation (CLI)**
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
python -m aetheris.cli.main generate --prompt "The quick brown fox" --checkpoint_dir checkpoints
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
**3. API Server (OpenAI-Compatible)**
|
| 79 |
+
|
| 80 |
+
Start a local API server that simulates OpenAI's chat completions endpoint.
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
python -m aetheris.cli.main serve --host 0.0.0.0 --port 8000
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
You can then interact with it using standard tools:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
curl http://localhost:8000/v1/chat/completions \
|
| 90 |
+
-H "Content-Type: application/json" \
|
| 91 |
+
-d {
|
| 92 |
+
"model": "aetheris-hybrid",
|
| 93 |
+
"messages": [{"role": "user", "content": "Hello!"}],
|
| 94 |
+
"stream": true
|
| 95 |
+
}
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Development & Testing
|
| 99 |
+
|
| 100 |
+
To run the test suite:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
pytest tests/
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## ⚙️ Configuration
|
| 107 |
+
|
| 108 |
+
You can tweak the hyperparameters in `configs/`. I've included a "Debug" config that is small enough to train on a laptop CPU for testing the code flow.
|
| 109 |
+
|
| 110 |
+
| Config File | Description |
|
| 111 |
+
| :--- | :--- |
|
| 112 |
+
| `configs/default.yaml` | Standard experimental setup (requires GPU). |
|
| 113 |
+
| `configs/debug.yaml` | Tiny model (2 layers) for code debugging. |
|
| 114 |
+
|
| 115 |
+
## 📚 Acknowledgements & References
|
| 116 |
+
|
| 117 |
+
This project is an implementation study and relies heavily on the brilliant theoretical work of others. It is not an original invention of the Mamba or MoE concepts.
|
| 118 |
+
|
| 119 |
+
* **Mamba Architecture:** Gu, A., & Dao, T. (2023). *Mamba: Linear-Time Sequence Modeling with Selective State Spaces*. [arXiv:2312.00752](https://arxiv.org/abs/2312.00752)
|
| 120 |
+
* **Mixture of Experts:** Shazeer, N., et al. (2017). *Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer*. [arXiv:1701.06538](https://arxiv.org/abs/1701.06538)
|
| 121 |
+
* **Inspiration:** Jamba (AI21 Labs) and OpenMoE.
|
| 122 |
+
|
| 123 |
+
## 🧠 Model Weights & Checkpoints
|
| 124 |
+
|
| 125 |
+
All pre-trained checkpoints are hosted on the [Hugging Face Hub](https://huggingface.co/Pomilon).
|
| 126 |
+
|
| 127 |
+
| Model Artifact | Step | Description | Download |
|
| 128 |
+
| :--- | :--- | :--- | :--- |
|
| 129 |
+
| **Aetheris-Base** | 10k | Early convergence checkpoint (Loss ~3.66). Good for analyzing router behavior. | [🤗 Hugging Face](https://huggingface.co/Pomilon/Aetheris-MoE-300M-A125M-base) |
|
| 130 |
+
| **Aetheris-Chat** | -- | *Coming Soon (Post-SFT)* | -- |
|
| 131 |
+
|
| 132 |
+
> **⚠️ Important:** Aetheris uses a custom Hybrid Mamba-MoE architecture. You **cannot** load it directly with `transformers.AutoModel`. You must use the interface provided in this repository.
|
| 133 |
+
|
| 134 |
+
### 🐍 How to Load
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
python -m aetheris.cli.main generate --prompt "The quick brown fox" --checkpoint_dir path/to/checkpoints_folder # rename the checkpoint inside to checkpoint_current.pth
|
| 138 |
+
```
|
| 139 |
+
> **Note:** will add better inference later down the line, for now used this scuffed version. :D
|
| 140 |
+
|
| 141 |
+
> **Note:** These weights are from an experimental run. While they demonstrate the architectural capabilities, do not expect GPT-5 or even google bard level coherence. :D
|
| 142 |
+
> this project was made for learning and fun!
|
| 143 |
+
|
| 144 |
+
## License
|
| 145 |
+
|
| 146 |
+
MIT
|
aetheris/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import HybridMambaMoE
|
| 2 |
+
from .config import AetherisConfig
|
aetheris/api/schemas.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
class ChatMessage(BaseModel):
|
| 6 |
+
role: str
|
| 7 |
+
content: str
|
| 8 |
+
|
| 9 |
+
class ChatCompletionRequest(BaseModel):
|
| 10 |
+
model: str
|
| 11 |
+
messages: List[ChatMessage]
|
| 12 |
+
temperature: Optional[float] = 1.0
|
| 13 |
+
top_p: Optional[float] = 1.0
|
| 14 |
+
n: Optional[int] = 1
|
| 15 |
+
stream: Optional[bool] = False
|
| 16 |
+
stop: Optional[Union[str, List[str]]] = None
|
| 17 |
+
max_tokens: Optional[int] = None
|
| 18 |
+
presence_penalty: Optional[float] = 0.0
|
| 19 |
+
frequency_penalty: Optional[float] = 0.0
|
| 20 |
+
logit_bias: Optional[Dict[str, float]] = None
|
| 21 |
+
user: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
class ChatCompletionChoice(BaseModel):
|
| 24 |
+
index: int
|
| 25 |
+
message: ChatMessage
|
| 26 |
+
finish_reason: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
class ChatCompletionResponse(BaseModel):
|
| 29 |
+
id: str
|
| 30 |
+
object: str = "chat.completion"
|
| 31 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 32 |
+
model: str
|
| 33 |
+
choices: List[ChatCompletionChoice]
|
| 34 |
+
usage: Optional[Dict[str, int]] = None
|
| 35 |
+
|
| 36 |
+
class ChatCompletionChunkDelta(BaseModel):
|
| 37 |
+
role: Optional[str] = None
|
| 38 |
+
content: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
class ChatCompletionChunkChoice(BaseModel):
|
| 41 |
+
index: int
|
| 42 |
+
delta: ChatCompletionChunkDelta
|
| 43 |
+
finish_reason: Optional[str] = None
|
| 44 |
+
|
| 45 |
+
class ChatCompletionChunk(BaseModel):
|
| 46 |
+
id: str
|
| 47 |
+
object: str = "chat.completion.chunk"
|
| 48 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 49 |
+
model: str
|
| 50 |
+
choices: List[ChatCompletionChunkChoice]
|
| 51 |
+
|
| 52 |
+
class CompletionRequest(BaseModel):
|
| 53 |
+
model: str
|
| 54 |
+
prompt: Union[str, List[str]]
|
| 55 |
+
suffix: Optional[str] = None
|
| 56 |
+
max_tokens: Optional[int] = 16
|
| 57 |
+
temperature: Optional[float] = 1.0
|
| 58 |
+
top_p: Optional[float] = 1.0
|
| 59 |
+
n: Optional[int] = 1
|
| 60 |
+
stream: Optional[bool] = False
|
| 61 |
+
logprobs: Optional[int] = None
|
| 62 |
+
echo: Optional[bool] = False
|
| 63 |
+
stop: Optional[Union[str, List[str]]] = None
|
| 64 |
+
presence_penalty: Optional[float] = 0.0
|
| 65 |
+
frequency_penalty: Optional[float] = 0.0
|
| 66 |
+
best_of: Optional[int] = 1
|
| 67 |
+
logit_bias: Optional[Dict[str, float]] = None
|
| 68 |
+
user: Optional[str] = None
|
| 69 |
+
|
| 70 |
+
class CompletionChoice(BaseModel):
|
| 71 |
+
text: str
|
| 72 |
+
index: int
|
| 73 |
+
logprobs: Optional[Any] = None
|
| 74 |
+
finish_reason: Optional[str] = None
|
| 75 |
+
|
| 76 |
+
class CompletionResponse(BaseModel):
|
| 77 |
+
id: str
|
| 78 |
+
object: str = "text_completion"
|
| 79 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 80 |
+
model: str
|
| 81 |
+
choices: List[CompletionChoice]
|
| 82 |
+
usage: Optional[Dict[str, int]] = None
|
| 83 |
+
|
| 84 |
+
class ModelCard(BaseModel):
|
| 85 |
+
id: str
|
| 86 |
+
object: str = "model"
|
| 87 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 88 |
+
owned_by: str = "aetheris"
|
| 89 |
+
|
| 90 |
+
class ModelList(BaseModel):
|
| 91 |
+
object: str = "list"
|
| 92 |
+
data: List[ModelCard]
|
aetheris/api/server.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import uuid
|
| 3 |
+
import json
|
| 4 |
+
import asyncio
|
| 5 |
+
from typing import AsyncGenerator
|
| 6 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from sse_starlette.sse import EventSourceResponse
|
| 9 |
+
from aetheris.api.schemas import (
|
| 10 |
+
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk,
|
| 11 |
+
ChatCompletionChoice, ChatMessage, ChatCompletionChunkChoice, ChatCompletionChunkDelta,
|
| 12 |
+
CompletionRequest, CompletionResponse, CompletionChoice,
|
| 13 |
+
ModelList, ModelCard
|
| 14 |
+
)
|
| 15 |
+
from aetheris.inference import InferenceEngine
|
| 16 |
+
|
| 17 |
+
app = FastAPI(title="Aetheris API", version="0.1.0")
|
| 18 |
+
|
| 19 |
+
app.add_middleware(
|
| 20 |
+
CORSMiddleware,
|
| 21 |
+
allow_origins=["*"],
|
| 22 |
+
allow_credentials=True,
|
| 23 |
+
allow_methods=["*"],
|
| 24 |
+
allow_headers=["*"],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Global engine instance
|
| 28 |
+
engine: InferenceEngine = None
|
| 29 |
+
|
| 30 |
+
def get_engine():
|
| 31 |
+
global engine
|
| 32 |
+
if engine is None:
|
| 33 |
+
# Defaults, ideally loaded from config/env
|
| 34 |
+
engine = InferenceEngine()
|
| 35 |
+
return engine
|
| 36 |
+
|
| 37 |
+
@app.on_event("startup")
|
| 38 |
+
async def startup_event():
|
| 39 |
+
get_engine()
|
| 40 |
+
|
| 41 |
+
@app.get("/v1/models", response_model=ModelList)
|
| 42 |
+
async def list_models():
|
| 43 |
+
return ModelList(data=[ModelCard(id="aetheris-hybrid-mamba-moe")])
|
| 44 |
+
|
| 45 |
+
@app.post("/v1/chat/completions")
|
| 46 |
+
async def chat_completions(request: ChatCompletionRequest):
|
| 47 |
+
engine = get_engine()
|
| 48 |
+
|
| 49 |
+
# Simple prompt construction from messages
|
| 50 |
+
prompt = ""
|
| 51 |
+
for msg in request.messages:
|
| 52 |
+
prompt += f"{msg.role}: {msg.content}\n"
|
| 53 |
+
prompt += "assistant: "
|
| 54 |
+
|
| 55 |
+
request_id = f"chatcmpl-{uuid.uuid4()}"
|
| 56 |
+
created_time = int(time.time())
|
| 57 |
+
|
| 58 |
+
if request.stream:
|
| 59 |
+
async def event_generator():
|
| 60 |
+
yield json.dumps(ChatCompletionChunk(
|
| 61 |
+
id=request_id,
|
| 62 |
+
created=created_time,
|
| 63 |
+
model=request.model,
|
| 64 |
+
choices=[ChatCompletionChunkChoice(
|
| 65 |
+
index=0,
|
| 66 |
+
delta=ChatCompletionChunkDelta(role="assistant"),
|
| 67 |
+
finish_reason=None
|
| 68 |
+
)]
|
| 69 |
+
).model_dump())
|
| 70 |
+
|
| 71 |
+
for token in engine.generate(
|
| 72 |
+
prompt=prompt,
|
| 73 |
+
max_new_tokens=request.max_tokens or 100,
|
| 74 |
+
temperature=request.temperature,
|
| 75 |
+
top_p=request.top_p,
|
| 76 |
+
repetition_penalty=1.0 + request.frequency_penalty, # Approximating
|
| 77 |
+
stream=True
|
| 78 |
+
):
|
| 79 |
+
yield json.dumps(ChatCompletionChunk(
|
| 80 |
+
id=request_id,
|
| 81 |
+
created=created_time,
|
| 82 |
+
model=request.model,
|
| 83 |
+
choices=[ChatCompletionChunkChoice(
|
| 84 |
+
index=0,
|
| 85 |
+
delta=ChatCompletionChunkDelta(content=token),
|
| 86 |
+
finish_reason=None
|
| 87 |
+
)]
|
| 88 |
+
).model_dump())
|
| 89 |
+
|
| 90 |
+
yield json.dumps(ChatCompletionChunk(
|
| 91 |
+
id=request_id,
|
| 92 |
+
created=created_time,
|
| 93 |
+
model=request.model,
|
| 94 |
+
choices=[ChatCompletionChunkChoice(
|
| 95 |
+
index=0,
|
| 96 |
+
delta=ChatCompletionChunkDelta(),
|
| 97 |
+
finish_reason="stop"
|
| 98 |
+
)]
|
| 99 |
+
).model_dump())
|
| 100 |
+
|
| 101 |
+
yield "[DONE]"
|
| 102 |
+
|
| 103 |
+
return EventSourceResponse(event_generator())
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
generated_text = engine.generate_full(
|
| 107 |
+
prompt=prompt,
|
| 108 |
+
max_new_tokens=request.max_tokens or 100,
|
| 109 |
+
temperature=request.temperature,
|
| 110 |
+
top_p=request.top_p,
|
| 111 |
+
repetition_penalty=1.0 + request.frequency_penalty
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return ChatCompletionResponse(
|
| 115 |
+
id=request_id,
|
| 116 |
+
created=created_time,
|
| 117 |
+
model=request.model,
|
| 118 |
+
choices=[ChatCompletionChoice(
|
| 119 |
+
index=0,
|
| 120 |
+
message=ChatMessage(role="assistant", content=generated_text),
|
| 121 |
+
finish_reason="stop"
|
| 122 |
+
)],
|
| 123 |
+
usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)} # Approximated
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
@app.post("/v1/completions")
|
| 127 |
+
async def completions(request: CompletionRequest):
|
| 128 |
+
engine = get_engine()
|
| 129 |
+
|
| 130 |
+
prompt = request.prompt
|
| 131 |
+
if isinstance(prompt, list):
|
| 132 |
+
prompt = prompt[0] # Handle single prompt for now
|
| 133 |
+
|
| 134 |
+
request_id = f"cmpl-{uuid.uuid4()}"
|
| 135 |
+
created_time = int(time.time())
|
| 136 |
+
|
| 137 |
+
if request.stream:
|
| 138 |
+
# Streaming for completions not fully implemented to match OpenAI exactly in this demo,
|
| 139 |
+
# but logic is similar to chat.
|
| 140 |
+
# For simplicity, returning non-streaming for now or basic stream.
|
| 141 |
+
pass # TODO: Implement streaming for completions
|
| 142 |
+
|
| 143 |
+
generated_text = engine.generate_full(
|
| 144 |
+
prompt=prompt,
|
| 145 |
+
max_new_tokens=request.max_tokens or 16,
|
| 146 |
+
temperature=request.temperature,
|
| 147 |
+
top_p=request.top_p,
|
| 148 |
+
repetition_penalty=1.0 + request.frequency_penalty
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return CompletionResponse(
|
| 152 |
+
id=request_id,
|
| 153 |
+
created=created_time,
|
| 154 |
+
model=request.model,
|
| 155 |
+
choices=[CompletionChoice(
|
| 156 |
+
text=generated_text,
|
| 157 |
+
index=0,
|
| 158 |
+
logprobs=None,
|
| 159 |
+
finish_reason="length" # or stop
|
| 160 |
+
)],
|
| 161 |
+
usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)}
|
| 162 |
+
)
|
aetheris/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
aetheris/cli/main.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from aetheris.config import AetherisConfig
|
| 7 |
+
from aetheris.model import HybridMambaMoE
|
| 8 |
+
from aetheris.data import create_streaming_loader, get_tokenizer
|
| 9 |
+
from aetheris.utils import load_latest_checkpoint, calculate_model_stats
|
| 10 |
+
from aetheris.trainer import Trainer
|
| 11 |
+
|
| 12 |
+
def train_command(args):
|
| 13 |
+
print(f"\n{'='*70}")
|
| 14 |
+
print(f"Aetheris Training")
|
| 15 |
+
print(f"Config: {args.config}")
|
| 16 |
+
|
| 17 |
+
if args.hf_token:
|
| 18 |
+
print(f"Using Hugging Face token: {args.hf_token[:10]}...")
|
| 19 |
+
from huggingface_hub import login
|
| 20 |
+
login(token=args.hf_token)
|
| 21 |
+
|
| 22 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 23 |
+
if device.type == 'cuda':
|
| 24 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 25 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 26 |
+
torch.cuda.empty_cache()
|
| 27 |
+
|
| 28 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 29 |
+
tokenizer = get_tokenizer()
|
| 30 |
+
|
| 31 |
+
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
| 32 |
+
print(f"Model Size: d_model={config.d_model}, layers={config.n_layer}")
|
| 33 |
+
print(f"{'='*70}\n")
|
| 34 |
+
|
| 35 |
+
model = HybridMambaMoE(config).to(device)
|
| 36 |
+
|
| 37 |
+
# Apply weight initialization
|
| 38 |
+
print("Applying proper weight initialization...")
|
| 39 |
+
model.apply(model._init_weights)
|
| 40 |
+
|
| 41 |
+
# Calculate model stats
|
| 42 |
+
stats = calculate_model_stats(model)
|
| 43 |
+
print(f"Total Parameters: {stats['total_params']:,}")
|
| 44 |
+
print(f"Trainable Parameters: {stats['trainable_params']:,}")
|
| 45 |
+
|
| 46 |
+
# Use lower learning rate for stability
|
| 47 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01,
|
| 48 |
+
betas=(0.9, 0.95), eps=1e-8, fused=False if device.type == 'cpu' else True)
|
| 49 |
+
scaler = torch.amp.GradScaler('cuda' if device.type == 'cuda' else 'cpu', init_scale=2**10)
|
| 50 |
+
|
| 51 |
+
start_step, current_stage = load_latest_checkpoint(model, optimizer, scaler, device, args.checkpoint_dir, args.checkpoint_name)
|
| 52 |
+
|
| 53 |
+
trainer = Trainer(model, optimizer, scaler, config, device, args.checkpoint_dir)
|
| 54 |
+
|
| 55 |
+
# --- STAGE 1: PRE-TRAINING ---
|
| 56 |
+
if current_stage == "Pre-Training" or start_step == 0:
|
| 57 |
+
pt_loader = create_streaming_loader("cerebras/SlimPajama-627B", "train",
|
| 58 |
+
tokenizer, config, args.batch_size, mode="pretrain",
|
| 59 |
+
hf_token=args.hf_token, start_step=start_step)
|
| 60 |
+
|
| 61 |
+
# Validation loader (no skipping needed, always from start of val set)
|
| 62 |
+
pt_val_loader = create_streaming_loader("cerebras/SlimPajama-627B", "validation",
|
| 63 |
+
tokenizer, config, args.batch_size, mode="pretrain",
|
| 64 |
+
hf_token=args.hf_token)
|
| 65 |
+
|
| 66 |
+
start_step = trainer.train_epoch(pt_loader, total_steps=args.pretrain_steps,
|
| 67 |
+
start_step=start_step, stage_name="Pre-Training",
|
| 68 |
+
val_loader=pt_val_loader)
|
| 69 |
+
current_stage = "SFT"
|
| 70 |
+
start_step = 0
|
| 71 |
+
|
| 72 |
+
# --- STAGE 2: SFT ---
|
| 73 |
+
print("\n=== STAGE 2: SFT ===")
|
| 74 |
+
for param_group in optimizer.param_groups:
|
| 75 |
+
param_group['lr'] = 5e-5
|
| 76 |
+
|
| 77 |
+
sft_loader = create_streaming_loader("OpenAssistant/oasst1", "train",
|
| 78 |
+
tokenizer, config, args.batch_size, mode="sft",
|
| 79 |
+
hf_token=args.hf_token, start_step=start_step)
|
| 80 |
+
|
| 81 |
+
sft_val_loader = create_streaming_loader("OpenAssistant/oasst1", "validation",
|
| 82 |
+
tokenizer, config, args.batch_size, mode="sft",
|
| 83 |
+
hf_token=args.hf_token)
|
| 84 |
+
|
| 85 |
+
trainer.train_epoch(sft_loader, total_steps=args.sft_steps,
|
| 86 |
+
start_step=start_step, stage_name="SFT",
|
| 87 |
+
val_loader=sft_val_loader)
|
| 88 |
+
|
| 89 |
+
print("\nTraining Complete!")
|
| 90 |
+
|
| 91 |
+
@torch.no_grad()
|
| 92 |
+
def generate_command(args):
|
| 93 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 94 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 95 |
+
tokenizer = get_tokenizer()
|
| 96 |
+
|
| 97 |
+
model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
|
| 98 |
+
|
| 99 |
+
load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
|
| 100 |
+
model.eval()
|
| 101 |
+
|
| 102 |
+
prompt = args.prompt
|
| 103 |
+
max_new_tokens = args.max_new_tokens
|
| 104 |
+
temperature = args.temperature
|
| 105 |
+
top_k = args.top_k
|
| 106 |
+
top_p = args.top_p
|
| 107 |
+
repetition_penalty = args.repetition_penalty
|
| 108 |
+
|
| 109 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
| 110 |
+
generated_ids = input_ids.clone()
|
| 111 |
+
history_ids = set(input_ids[0].tolist())
|
| 112 |
+
|
| 113 |
+
print("-" * 50)
|
| 114 |
+
print(f"Prompt: {prompt}")
|
| 115 |
+
print("Generated Continuation:")
|
| 116 |
+
|
| 117 |
+
for _ in range(max_new_tokens):
|
| 118 |
+
# Check if we should use autocast (skip if model uses float32)
|
| 119 |
+
use_autocast = True
|
| 120 |
+
if config.torch_dtype == torch.float32:
|
| 121 |
+
use_autocast = False
|
| 122 |
+
|
| 123 |
+
if use_autocast:
|
| 124 |
+
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
|
| 125 |
+
outputs = model(generated_ids)
|
| 126 |
+
logits = outputs['logits']
|
| 127 |
+
next_token_logits = logits[:, -1, :]
|
| 128 |
+
else:
|
| 129 |
+
outputs = model(generated_ids)
|
| 130 |
+
logits = outputs['logits']
|
| 131 |
+
next_token_logits = logits[:, -1, :]
|
| 132 |
+
|
| 133 |
+
# Repetition penalty
|
| 134 |
+
for token_id in history_ids:
|
| 135 |
+
if token_id < next_token_logits.size(-1):
|
| 136 |
+
logit = next_token_logits[0, token_id].item()
|
| 137 |
+
if logit > 0:
|
| 138 |
+
next_token_logits[0, token_id] = logit / repetition_penalty
|
| 139 |
+
else:
|
| 140 |
+
next_token_logits[0, token_id] = logit * repetition_penalty
|
| 141 |
+
|
| 142 |
+
# Temperature
|
| 143 |
+
if temperature > 0:
|
| 144 |
+
next_token_logits = next_token_logits / temperature
|
| 145 |
+
|
| 146 |
+
# Top-p / Top-k
|
| 147 |
+
if top_p < 1.0:
|
| 148 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 149 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 150 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 151 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 152 |
+
sorted_indices_to_remove[..., 0] = False
|
| 153 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 154 |
+
next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
|
| 155 |
+
elif top_k > 0:
|
| 156 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
|
| 157 |
+
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
|
| 158 |
+
next_token_logits.scatter_(1, top_k_indices, top_k_logits)
|
| 159 |
+
|
| 160 |
+
# Sample
|
| 161 |
+
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
| 162 |
+
next_token = torch.multinomial(next_token_probs, num_samples=1)
|
| 163 |
+
next_token_item = next_token.item()
|
| 164 |
+
|
| 165 |
+
if next_token_item == tokenizer.eos_token_id:
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
| 169 |
+
history_ids.add(next_token_item)
|
| 170 |
+
|
| 171 |
+
new_token_text = tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
|
| 172 |
+
print(new_token_text, end="", flush=True)
|
| 173 |
+
|
| 174 |
+
print("\n" + "-" * 50)
|
| 175 |
+
|
| 176 |
+
def info_command(args):
|
| 177 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 178 |
+
model = HybridMambaMoE(config)
|
| 179 |
+
|
| 180 |
+
total_params = 0
|
| 181 |
+
dense_params = 0 # Parameters active for EVERY token
|
| 182 |
+
expert_params = 0 # Parameters in all MoE Experts
|
| 183 |
+
|
| 184 |
+
for name, param in model.named_parameters():
|
| 185 |
+
numel = param.numel()
|
| 186 |
+
total_params += numel
|
| 187 |
+
|
| 188 |
+
if 'experts' in name:
|
| 189 |
+
expert_params += numel
|
| 190 |
+
else:
|
| 191 |
+
dense_params += numel
|
| 192 |
+
|
| 193 |
+
single_expert_size = expert_params / config.num_experts if config.num_experts > 0 else 0
|
| 194 |
+
active_per_token_params = dense_params + (single_expert_size * config.top_k)
|
| 195 |
+
|
| 196 |
+
def format_count(count):
|
| 197 |
+
return f"{count / 1_000_000:.2f}M"
|
| 198 |
+
|
| 199 |
+
print("=" * 50)
|
| 200 |
+
print("Hybrid Mamba-MoE Model Parameter Analysis")
|
| 201 |
+
print("=" * 50)
|
| 202 |
+
print(f"Total Model Layers (N_Layer): {config.n_layer}")
|
| 203 |
+
print(f"MoE Experts per Layer: {config.num_experts}")
|
| 204 |
+
print(f"Active Experts (Top-K): {config.top_k}")
|
| 205 |
+
print("-" * 50)
|
| 206 |
+
print(f"Total Parameters (Checkpoint Size): {format_count(total_params)}")
|
| 207 |
+
print(f"Dense (Always Active) Parameters: {format_count(dense_params)}")
|
| 208 |
+
print(f"Expert-Only Parameters: {format_count(expert_params)}")
|
| 209 |
+
print("-" * 50)
|
| 210 |
+
print(f"**Active Parameters (Per-Token Compute Load): {format_count(active_per_token_params)}**")
|
| 211 |
+
print(" (This is the 'Dense' parameters + the K active expert parameters)")
|
| 212 |
+
print("=" * 50)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main():
|
| 216 |
+
parser = argparse.ArgumentParser(description="Aetheris CLI")
|
| 217 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 218 |
+
|
| 219 |
+
# Train Command
|
| 220 |
+
train_parser = subparsers.add_parser("train", help="Train the model")
|
| 221 |
+
train_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 222 |
+
train_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
|
| 223 |
+
train_parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
|
| 224 |
+
train_parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
| 225 |
+
train_parser.add_argument("--pretrain_steps", type=int, default=50000, help="Number of pretraining steps")
|
| 226 |
+
train_parser.add_argument("--sft_steps", type=int, default=1000, help="Number of SFT steps")
|
| 227 |
+
train_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name to load from")
|
| 228 |
+
|
| 229 |
+
# Generate Command
|
| 230 |
+
gen_parser = subparsers.add_parser("generate", help="Generate text")
|
| 231 |
+
gen_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 232 |
+
gen_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
|
| 233 |
+
gen_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
|
| 234 |
+
gen_parser.add_argument("--prompt", type=str, default="The quick brown fox", help="Prompt for generation")
|
| 235 |
+
gen_parser.add_argument("--max_new_tokens", type=int, default=100, help="Max new tokens to generate")
|
| 236 |
+
gen_parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| 237 |
+
gen_parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling")
|
| 238 |
+
gen_parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
|
| 239 |
+
gen_parser.add_argument("--repetition_penalty", type=float, default=3.0, help="Repetition penalty")
|
| 240 |
+
|
| 241 |
+
# Serve Command
|
| 242 |
+
serve_parser = subparsers.add_parser("serve", help="Start the API server")
|
| 243 |
+
serve_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
|
| 244 |
+
serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind")
|
| 245 |
+
serve_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 246 |
+
serve_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
|
| 247 |
+
serve_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# Info Command
|
| 251 |
+
info_parser = subparsers.add_parser("info", help="Show model info")
|
| 252 |
+
info_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 253 |
+
|
| 254 |
+
args = parser.parse_args()
|
| 255 |
+
|
| 256 |
+
if args.command == "train":
|
| 257 |
+
train_command(args)
|
| 258 |
+
elif args.command == "generate":
|
| 259 |
+
generate_command(args)
|
| 260 |
+
elif args.command == "serve":
|
| 261 |
+
import uvicorn
|
| 262 |
+
from aetheris.api.server import app, get_engine
|
| 263 |
+
|
| 264 |
+
# Initialize engine before starting server
|
| 265 |
+
engine = get_engine()
|
| 266 |
+
# You might want to pass config/checkpoint paths to get_engine here if it supported arguments
|
| 267 |
+
# For now, it defaults or we need to modify get_engine or InferenceEngine to take args.
|
| 268 |
+
# But `get_engine` is a simple global accessor.
|
| 269 |
+
# Better: Initialize a global engine with args here.
|
| 270 |
+
from aetheris.inference import InferenceEngine
|
| 271 |
+
import aetheris.api.server
|
| 272 |
+
|
| 273 |
+
aetheris.api.server.engine = InferenceEngine(
|
| 274 |
+
config_path=args.config,
|
| 275 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 276 |
+
checkpoint_name=args.checkpoint_name
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
| 280 |
+
|
| 281 |
+
elif args.command == "info":
|
| 282 |
+
info_command(args)
|
| 283 |
+
else:
|
| 284 |
+
parser.print_help()
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
aetheris/config.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class AetherisConfig:
|
| 8 |
+
# Model dimensions
|
| 9 |
+
vocab_size: int = 50257
|
| 10 |
+
d_model: int = 768
|
| 11 |
+
n_layer: int = 24
|
| 12 |
+
num_experts: int = 4
|
| 13 |
+
top_k: int = 1
|
| 14 |
+
d_ff: int = 2304 # d_model * 3
|
| 15 |
+
|
| 16 |
+
# SSM parameters
|
| 17 |
+
ssm_d_state: int = 16
|
| 18 |
+
ssm_expand: int = 2
|
| 19 |
+
d_inner: Optional[int] = None # Will be d_model * ssm_expand if None
|
| 20 |
+
|
| 21 |
+
# Training parameters
|
| 22 |
+
load_balancing_coef: float = 1e-2
|
| 23 |
+
router_z_loss_coef: float = 1e-3
|
| 24 |
+
max_seq_len: int = 512
|
| 25 |
+
dtype: str = "float16" # "float16", "float32", "bfloat16"
|
| 26 |
+
|
| 27 |
+
# Optimization settings
|
| 28 |
+
use_cpu_offload: bool = False
|
| 29 |
+
gradient_checkpointing: bool = True
|
| 30 |
+
checkpoint_ssm_layers: bool = True
|
| 31 |
+
use_flash_attention: bool = False
|
| 32 |
+
|
| 33 |
+
def __post_init__(self):
|
| 34 |
+
if self.d_inner is None:
|
| 35 |
+
self.d_inner = self.d_model * self.ssm_expand
|
| 36 |
+
if self.d_ff is None:
|
| 37 |
+
self.d_ff = self.d_model * 3
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def torch_dtype(self):
|
| 41 |
+
if self.dtype == "float16":
|
| 42 |
+
return torch.float16
|
| 43 |
+
elif self.dtype == "float32":
|
| 44 |
+
return torch.float32
|
| 45 |
+
elif self.dtype == "bfloat16":
|
| 46 |
+
return torch.bfloat16
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported dtype: {self.dtype}")
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def from_yaml(cls, path: str):
|
| 52 |
+
with open(path, 'r') as f:
|
| 53 |
+
config_dict = yaml.safe_load(f)
|
| 54 |
+
return cls(**config_dict)
|
| 55 |
+
|
| 56 |
+
def to_yaml(self, path: str):
|
| 57 |
+
with open(path, 'w') as f:
|
| 58 |
+
yaml.dump(self.__dict__, f)
|
aetheris/data.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
import random
|
| 6 |
+
from typing import Dict, Iterator
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def get_tokenizer(model_name: str = "gpt2"):
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 11 |
+
if tokenizer.pad_token is None:
|
| 12 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 13 |
+
return tokenizer
|
| 14 |
+
|
| 15 |
+
class StreamingDataset(IterableDataset):
|
| 16 |
+
def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=500, skip_samples=0):
|
| 17 |
+
self.dataset = dataset
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.max_seq_len = max_seq_len
|
| 20 |
+
self.mode = mode
|
| 21 |
+
self.buffer_size = buffer_size
|
| 22 |
+
self.skip_samples = skip_samples
|
| 23 |
+
|
| 24 |
+
def _prepare_sft_text(self, example):
|
| 25 |
+
if 'messages' in example:
|
| 26 |
+
text = ""
|
| 27 |
+
for msg in example['messages']:
|
| 28 |
+
role = msg.get('role', '')
|
| 29 |
+
content = msg.get('content', '')
|
| 30 |
+
if role == 'assistant':
|
| 31 |
+
text += f"Assistant: {content}{self.tokenizer.eos_token}"
|
| 32 |
+
else:
|
| 33 |
+
text += f"User: {content}\n"
|
| 34 |
+
return text
|
| 35 |
+
elif 'text' in example:
|
| 36 |
+
return example['text']
|
| 37 |
+
else:
|
| 38 |
+
return ""
|
| 39 |
+
|
| 40 |
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 41 |
+
iterator = iter(self.dataset)
|
| 42 |
+
buffer = []
|
| 43 |
+
|
| 44 |
+
# Calculate roughly how many items to skip if they were yielded
|
| 45 |
+
# We process skipping in the yield loop
|
| 46 |
+
|
| 47 |
+
for example in iterator:
|
| 48 |
+
text = (example.get('text', '') if self.mode == "pretrain"
|
| 49 |
+
else self._prepare_sft_text(example))
|
| 50 |
+
|
| 51 |
+
if len(text) < 10:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
|
| 55 |
+
return_tensors="pt")
|
| 56 |
+
input_ids = enc['input_ids'][0]
|
| 57 |
+
|
| 58 |
+
if len(input_ids) < 2:
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
if len(input_ids) < self.max_seq_len:
|
| 62 |
+
pad_len = self.max_seq_len - len(input_ids)
|
| 63 |
+
input_ids = torch.cat([
|
| 64 |
+
input_ids,
|
| 65 |
+
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
labels = input_ids.clone()
|
| 69 |
+
if len(input_ids) < self.max_seq_len:
|
| 70 |
+
labels[-pad_len:] = -100
|
| 71 |
+
|
| 72 |
+
buffer.append((input_ids, labels))
|
| 73 |
+
|
| 74 |
+
if len(buffer) >= self.buffer_size:
|
| 75 |
+
random.shuffle(buffer)
|
| 76 |
+
for _ in range(self.buffer_size // 2):
|
| 77 |
+
item = buffer.pop()
|
| 78 |
+
if self.skip_samples > 0:
|
| 79 |
+
self.skip_samples -= 1
|
| 80 |
+
continue
|
| 81 |
+
yield item
|
| 82 |
+
|
| 83 |
+
# Yield remaining
|
| 84 |
+
random.shuffle(buffer)
|
| 85 |
+
while buffer:
|
| 86 |
+
item = buffer.pop()
|
| 87 |
+
if self.skip_samples > 0:
|
| 88 |
+
self.skip_samples -= 1
|
| 89 |
+
continue
|
| 90 |
+
yield item
|
| 91 |
+
|
| 92 |
+
def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size, mode="pretrain", hf_token=None, start_step=0):
|
| 93 |
+
raw_dataset = load_dataset(dataset_name, split=split, streaming=True,
|
| 94 |
+
trust_remote_code=True, token=hf_token)
|
| 95 |
+
|
| 96 |
+
# Calculate samples to skip: start_step * batch_size
|
| 97 |
+
skip_samples = start_step * batch_size
|
| 98 |
+
if skip_samples > 0:
|
| 99 |
+
print(f" [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...")
|
| 100 |
+
|
| 101 |
+
stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len, mode=mode, skip_samples=skip_samples)
|
| 102 |
+
|
| 103 |
+
# Increase num_workers for better utilization
|
| 104 |
+
return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True,
|
| 105 |
+
num_workers=4, prefetch_factor=4)
|
aetheris/inference.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Optional, List, Generator
|
| 4 |
+
from aetheris.config import AetherisConfig
|
| 5 |
+
from aetheris.model import HybridMambaMoE
|
| 6 |
+
from aetheris.data import get_tokenizer
|
| 7 |
+
from aetheris.utils import load_latest_checkpoint
|
| 8 |
+
|
| 9 |
+
class InferenceEngine:
|
| 10 |
+
def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None):
|
| 11 |
+
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
| 12 |
+
self.config = AetherisConfig.from_yaml(config_path)
|
| 13 |
+
self.tokenizer = get_tokenizer()
|
| 14 |
+
|
| 15 |
+
self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype)
|
| 16 |
+
|
| 17 |
+
# Load checkpoint
|
| 18 |
+
# Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None
|
| 19 |
+
load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name)
|
| 20 |
+
self.model.eval()
|
| 21 |
+
|
| 22 |
+
def generate(self,
|
| 23 |
+
prompt: str,
|
| 24 |
+
max_new_tokens: int = 100,
|
| 25 |
+
temperature: float = 0.8,
|
| 26 |
+
top_k: int = 0,
|
| 27 |
+
top_p: float = 0.9,
|
| 28 |
+
repetition_penalty: float = 1.0,
|
| 29 |
+
stream: bool = False) -> Generator[str, None, None] | str:
|
| 30 |
+
|
| 31 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
| 32 |
+
generated_ids = input_ids.clone()
|
| 33 |
+
history_ids = set(input_ids[0].tolist())
|
| 34 |
+
|
| 35 |
+
def token_generator():
|
| 36 |
+
nonlocal generated_ids
|
| 37 |
+
for _ in range(max_new_tokens):
|
| 38 |
+
# Check if we should use autocast (skip if model uses float32)
|
| 39 |
+
use_autocast = True
|
| 40 |
+
if self.config.torch_dtype == torch.float32:
|
| 41 |
+
use_autocast = False
|
| 42 |
+
|
| 43 |
+
if use_autocast:
|
| 44 |
+
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype):
|
| 45 |
+
outputs = self.model(generated_ids)
|
| 46 |
+
logits = outputs['logits']
|
| 47 |
+
next_token_logits = logits[:, -1, :]
|
| 48 |
+
else:
|
| 49 |
+
outputs = self.model(generated_ids)
|
| 50 |
+
logits = outputs['logits']
|
| 51 |
+
next_token_logits = logits[:, -1, :]
|
| 52 |
+
|
| 53 |
+
# Repetition penalty
|
| 54 |
+
for token_id in history_ids:
|
| 55 |
+
if token_id < next_token_logits.size(-1):
|
| 56 |
+
logit = next_token_logits[0, token_id].item()
|
| 57 |
+
if logit > 0:
|
| 58 |
+
next_token_logits[0, token_id] = logit / repetition_penalty
|
| 59 |
+
else:
|
| 60 |
+
next_token_logits[0, token_id] = logit * repetition_penalty
|
| 61 |
+
|
| 62 |
+
# Temperature
|
| 63 |
+
if temperature > 0:
|
| 64 |
+
next_token_logits = next_token_logits / temperature
|
| 65 |
+
|
| 66 |
+
# Top-p / Top-k
|
| 67 |
+
if top_p < 1.0:
|
| 68 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 69 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 70 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 71 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 72 |
+
sorted_indices_to_remove[..., 0] = False
|
| 73 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 74 |
+
next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
|
| 75 |
+
elif top_k > 0:
|
| 76 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
|
| 77 |
+
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
|
| 78 |
+
next_token_logits.scatter_(1, top_k_indices, top_k_logits)
|
| 79 |
+
|
| 80 |
+
# Sample
|
| 81 |
+
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
| 82 |
+
next_token = torch.multinomial(next_token_probs, num_samples=1)
|
| 83 |
+
next_token_item = next_token.item()
|
| 84 |
+
|
| 85 |
+
if next_token_item == self.tokenizer.eos_token_id:
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
| 89 |
+
history_ids.add(next_token_item)
|
| 90 |
+
|
| 91 |
+
new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
|
| 92 |
+
yield new_token_text
|
| 93 |
+
|
| 94 |
+
if stream:
|
| 95 |
+
return token_generator()
|
| 96 |
+
else:
|
| 97 |
+
return "".join(list(token_generator()))
|
| 98 |
+
|
| 99 |
+
def generate_full(self,
|
| 100 |
+
prompt: str,
|
| 101 |
+
max_new_tokens: int = 100,
|
| 102 |
+
temperature: float = 0.8,
|
| 103 |
+
top_k: int = 0,
|
| 104 |
+
top_p: float = 0.9,
|
| 105 |
+
repetition_penalty: float = 1.0) -> str:
|
| 106 |
+
return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)
|
aetheris/model.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Dict, Any, List
|
| 4 |
+
from .config import AetherisConfig
|
| 5 |
+
from .modules import SSMBlock, SparseMoELayer
|
| 6 |
+
|
| 7 |
+
class HybridMambaMoE(nn.Module):
|
| 8 |
+
def __init__(self, config: AetherisConfig):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.config = config
|
| 11 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 12 |
+
|
| 13 |
+
self.layers = nn.ModuleList()
|
| 14 |
+
for i in range(config.n_layer):
|
| 15 |
+
if i % 2 == 0:
|
| 16 |
+
self.layers.append(SSMBlock(config))
|
| 17 |
+
else:
|
| 18 |
+
self.layers.append(SparseMoELayer(config))
|
| 19 |
+
|
| 20 |
+
self.final_norm = nn.LayerNorm(config.d_model)
|
| 21 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 22 |
+
self.lm_head.weight = self.embedding.weight # Weight tying
|
| 23 |
+
|
| 24 |
+
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
|
| 25 |
+
self.gradient_checkpointing = config.gradient_checkpointing
|
| 26 |
+
|
| 27 |
+
# Initialize embeddings with smaller scale
|
| 28 |
+
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, module):
|
| 31 |
+
"""Apply proper weight initialization"""
|
| 32 |
+
if isinstance(module, nn.Linear):
|
| 33 |
+
nn.init.xavier_uniform_(module.weight, gain=0.5)
|
| 34 |
+
if module.bias is not None:
|
| 35 |
+
nn.init.zeros_(module.bias)
|
| 36 |
+
elif isinstance(module, nn.Embedding):
|
| 37 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 38 |
+
elif isinstance(module, nn.LayerNorm):
|
| 39 |
+
nn.init.ones_(module.weight)
|
| 40 |
+
nn.init.zeros_(module.bias)
|
| 41 |
+
|
| 42 |
+
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
|
| 43 |
+
x = self.embedding(input_ids)
|
| 44 |
+
total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
| 45 |
+
|
| 46 |
+
for i, layer in enumerate(self.layers):
|
| 47 |
+
if self.gradient_checkpointing and self.training:
|
| 48 |
+
# Checkpoint ALL layers for maximum memory savings
|
| 49 |
+
if isinstance(layer, SparseMoELayer):
|
| 50 |
+
def moe_forward(module, inp):
|
| 51 |
+
return module(inp)
|
| 52 |
+
x, aux_loss = torch.utils.checkpoint.checkpoint(
|
| 53 |
+
moe_forward, layer, x, use_reentrant=False
|
| 54 |
+
)
|
| 55 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 56 |
+
else:
|
| 57 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 58 |
+
layer, x, use_reentrant=False
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
if isinstance(layer, SparseMoELayer):
|
| 62 |
+
x, aux_loss = layer(x)
|
| 63 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 64 |
+
else:
|
| 65 |
+
x = layer(x)
|
| 66 |
+
|
| 67 |
+
x = self.final_norm(x)
|
| 68 |
+
logits = self.lm_head(x)
|
| 69 |
+
|
| 70 |
+
if labels is not None:
|
| 71 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 72 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 73 |
+
ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
|
| 74 |
+
shift_labels.view(-1))
|
| 75 |
+
|
| 76 |
+
# Scale down aux loss to prevent it from dominating
|
| 77 |
+
total_loss = ce_loss + 0.01 * total_aux_loss
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"loss": total_loss,
|
| 81 |
+
"ce_loss": ce_loss,
|
| 82 |
+
"aux_loss": total_aux_loss,
|
| 83 |
+
"logits": logits
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return {"logits": logits}
|
aetheris/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .expert import Expert
|
| 2 |
+
from .ssm import SSMBlock, selective_scan_native
|
| 3 |
+
from .moe import SparseMoELayer
|
aetheris/modules/expert.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class Expert(nn.Module):
|
| 6 |
+
"""Memory-efficient Feed-Forward Network expert with proper initialization."""
|
| 7 |
+
def __init__(self, d_model: int, d_ff: int):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.w1 = nn.Linear(d_model, d_ff, bias=False)
|
| 10 |
+
self.w2 = nn.Linear(d_ff, d_model, bias=False)
|
| 11 |
+
self.act = nn.GELU()
|
| 12 |
+
|
| 13 |
+
# Proper initialization to prevent NaN
|
| 14 |
+
nn.init.xavier_uniform_(self.w1.weight, gain=0.5)
|
| 15 |
+
nn.init.xavier_uniform_(self.w2.weight, gain=0.5)
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
orig_dtype = x.dtype
|
| 19 |
+
# Force float32 for internal computation to prevent overflow in half precision
|
| 20 |
+
x = x.to(torch.float32)
|
| 21 |
+
|
| 22 |
+
# Cast weights to float32 for calculation
|
| 23 |
+
# This is necessary because the module weights might be float16
|
| 24 |
+
w1_weight = self.w1.weight.to(torch.float32)
|
| 25 |
+
w2_weight = self.w2.weight.to(torch.float32)
|
| 26 |
+
|
| 27 |
+
h = F.linear(x, w1_weight)
|
| 28 |
+
h = self.act(h)
|
| 29 |
+
out = F.linear(h, w2_weight)
|
| 30 |
+
|
| 31 |
+
# Clamp to avoid Inf when casting back to float16
|
| 32 |
+
if orig_dtype == torch.float16:
|
| 33 |
+
out = torch.clamp(out, min=-65500.0, max=65500.0)
|
| 34 |
+
|
| 35 |
+
return out.to(orig_dtype)
|
aetheris/modules/moe.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from ..config import AetherisConfig
|
| 5 |
+
from .expert import Expert
|
| 6 |
+
|
| 7 |
+
class SparseMoELayer(nn.Module):
|
| 8 |
+
"""Memory-optimized Sparse MoE with efficient routing."""
|
| 9 |
+
def __init__(self, config: AetherisConfig):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.d_model = config.d_model
|
| 12 |
+
self.num_experts = config.num_experts
|
| 13 |
+
self.top_k = config.top_k
|
| 14 |
+
self.load_balancing_coef = config.load_balancing_coef
|
| 15 |
+
self.z_loss_coef = config.router_z_loss_coef
|
| 16 |
+
|
| 17 |
+
self.gate = nn.Linear(config.d_model, config.num_experts, bias=False)
|
| 18 |
+
self.experts = nn.ModuleList([Expert(config.d_model, config.d_ff)
|
| 19 |
+
for _ in range(config.num_experts)])
|
| 20 |
+
self.norm = nn.LayerNorm(config.d_model)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 23 |
+
B, L, D = x.shape
|
| 24 |
+
x_norm = self.norm(x)
|
| 25 |
+
flat_x = x_norm.view(-1, D)
|
| 26 |
+
|
| 27 |
+
# Routing Logits with stability
|
| 28 |
+
gate_logits = self.gate(flat_x)
|
| 29 |
+
|
| 30 |
+
# Clamp logits to prevent overflow
|
| 31 |
+
gate_logits = torch.clamp(gate_logits, min=-10.0, max=10.0)
|
| 32 |
+
|
| 33 |
+
# Z-Loss for stability
|
| 34 |
+
z_loss = torch.mean(torch.logsumexp(gate_logits, dim=-1)**2) * self.z_loss_coef
|
| 35 |
+
|
| 36 |
+
if self.training:
|
| 37 |
+
# Reduce noise for stability
|
| 38 |
+
gate_logits = gate_logits + torch.randn_like(gate_logits) * 1e-3
|
| 39 |
+
|
| 40 |
+
gate_probs = F.softmax(gate_logits, dim=-1)
|
| 41 |
+
gate_weights, expert_indices = torch.topk(gate_probs, self.top_k, dim=-1)
|
| 42 |
+
|
| 43 |
+
# Normalize weights for stability
|
| 44 |
+
gate_weights = gate_weights / (gate_weights.sum(dim=-1, keepdim=True) + 1e-8)
|
| 45 |
+
|
| 46 |
+
# Load balancing loss
|
| 47 |
+
# Use only the top-1 expert for load balancing calculation to keep it simple and consistent
|
| 48 |
+
expert_mask = F.one_hot(expert_indices[:, 0], num_classes=self.num_experts).float()
|
| 49 |
+
fraction_routed = expert_mask.mean(dim=0)
|
| 50 |
+
mean_prob = gate_probs.mean(dim=0)
|
| 51 |
+
|
| 52 |
+
aux_loss = (self.num_experts * torch.sum(fraction_routed * mean_prob)) * self.load_balancing_coef
|
| 53 |
+
total_aux_loss = aux_loss + z_loss
|
| 54 |
+
|
| 55 |
+
# Efficient dispatch with in-place operations
|
| 56 |
+
# Accumulate in float32 to prevent overflow during aggregation
|
| 57 |
+
final_output = torch.zeros_like(flat_x, dtype=torch.float32)
|
| 58 |
+
|
| 59 |
+
# Iterate over all k selected experts
|
| 60 |
+
for k_idx in range(self.top_k):
|
| 61 |
+
for i, expert in enumerate(self.experts):
|
| 62 |
+
# Find tokens routed to expert 'i' at the k-th position
|
| 63 |
+
mask = (expert_indices[:, k_idx] == i)
|
| 64 |
+
if not mask.any():
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
expert_input = flat_x[mask]
|
| 68 |
+
expert_out = expert(expert_input)
|
| 69 |
+
|
| 70 |
+
# Apply weights
|
| 71 |
+
weights = gate_weights[mask, k_idx].unsqueeze(1)
|
| 72 |
+
|
| 73 |
+
# Cast to float32 for accumulation
|
| 74 |
+
expert_out = expert_out.to(torch.float32)
|
| 75 |
+
weights = weights.to(torch.float32)
|
| 76 |
+
|
| 77 |
+
# Accumulate output (add to existing results from other experts)
|
| 78 |
+
final_output[mask] += expert_out * weights
|
| 79 |
+
|
| 80 |
+
# Cast back to original dtype
|
| 81 |
+
final_output = final_output.to(flat_x.dtype)
|
| 82 |
+
|
| 83 |
+
return x + final_output.view(B, L, D), total_aux_loss
|
aetheris/modules/ssm.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from ..config import AetherisConfig
|
| 5 |
+
|
| 6 |
+
def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
| 7 |
+
B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
|
| 8 |
+
"""Memory-efficient scan with reduced intermediate tensors."""
|
| 9 |
+
B_size, L, D_inner = u.shape
|
| 10 |
+
D_state = A.shape[-1]
|
| 11 |
+
|
| 12 |
+
# Use in-place operations where possible
|
| 13 |
+
h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=u.dtype)
|
| 14 |
+
ys = []
|
| 15 |
+
|
| 16 |
+
for l in range(L):
|
| 17 |
+
dt = delta[:, l, :].unsqueeze(-1)
|
| 18 |
+
dA = torch.exp(dt * A)
|
| 19 |
+
|
| 20 |
+
B_l = B[:, l, :].unsqueeze(1)
|
| 21 |
+
dB = dt * B_l
|
| 22 |
+
|
| 23 |
+
u_t = u[:, l, :].unsqueeze(-1)
|
| 24 |
+
h = dA * h + dB * u_t
|
| 25 |
+
|
| 26 |
+
C_l = C[:, l, :].unsqueeze(1)
|
| 27 |
+
y_t = torch.sum(h * C_l, dim=-1)
|
| 28 |
+
ys.append(y_t)
|
| 29 |
+
|
| 30 |
+
y = torch.stack(ys, dim=1)
|
| 31 |
+
return y + u * D
|
| 32 |
+
|
| 33 |
+
class SSMBlock(nn.Module):
|
| 34 |
+
"""Memory-optimized State Space Model with stability improvements."""
|
| 35 |
+
def __init__(self, config: AetherisConfig):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.d_model = config.d_model
|
| 38 |
+
self.d_state = config.ssm_d_state
|
| 39 |
+
self.d_inner = config.d_inner
|
| 40 |
+
|
| 41 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False)
|
| 42 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
|
| 43 |
+
self.conv_d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3,
|
| 44 |
+
padding=2, groups=self.d_inner, bias=False)
|
| 45 |
+
self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=False)
|
| 46 |
+
|
| 47 |
+
self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
|
| 48 |
+
self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
|
| 49 |
+
self.delta_proj = nn.Linear(self.d_inner, self.d_inner, bias=False)
|
| 50 |
+
|
| 51 |
+
# Initialize A to be more stable (closer to -1)
|
| 52 |
+
self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state) * 0.1 - 4.0)
|
| 53 |
+
self.D = nn.Parameter(torch.ones(self.d_inner) * 0.1)
|
| 54 |
+
|
| 55 |
+
self.act = nn.SiLU()
|
| 56 |
+
self.norm = nn.LayerNorm(config.d_model)
|
| 57 |
+
|
| 58 |
+
# Proper initialization
|
| 59 |
+
nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5)
|
| 60 |
+
nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5)
|
| 61 |
+
nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.5)
|
| 62 |
+
nn.init.xavier_uniform_(self.B_proj.weight, gain=0.5)
|
| 63 |
+
nn.init.xavier_uniform_(self.C_proj.weight, gain=0.5)
|
| 64 |
+
nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
B, L, D = x.shape
|
| 68 |
+
x_norm = self.norm(x)
|
| 69 |
+
|
| 70 |
+
xz = self.in_proj(x_norm)
|
| 71 |
+
x_in, z_gate = xz.chunk(2, dim=-1)
|
| 72 |
+
x_conv = self.conv_d(x_in.transpose(1, 2))
|
| 73 |
+
# Slice off the last 2 elements (the "future" leakage)
|
| 74 |
+
x_conv = x_conv[:, :, :-2].transpose(1, 2)
|
| 75 |
+
x_conv = self.act(x_conv)
|
| 76 |
+
|
| 77 |
+
# Add small epsilon to prevent numerical issues and clamp max value
|
| 78 |
+
delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4
|
| 79 |
+
B_ssm = self.B_proj(x_conv)
|
| 80 |
+
C_ssm = self.C_proj(x_conv)
|
| 81 |
+
|
| 82 |
+
# Clamp A to prevent extreme values
|
| 83 |
+
A_fixed = -torch.exp(torch.clamp(self.A_log, min=-10.0, max=2.0))
|
| 84 |
+
A_batched = A_fixed.unsqueeze(0).expand(B, -1, -1)
|
| 85 |
+
|
| 86 |
+
y_ssm = selective_scan_native(x_conv, delta, A_batched, B_ssm, C_ssm, self.D)
|
| 87 |
+
|
| 88 |
+
y_gate = F.silu(self.gate_proj(x_norm)) * y_ssm
|
| 89 |
+
output = self.out_proj(y_gate)
|
| 90 |
+
|
| 91 |
+
return x + output
|
aetheris/trainer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .trainer import Trainer
|
aetheris/trainer/trainer.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats
|
| 5 |
+
|
| 6 |
+
class Trainer:
|
| 7 |
+
def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None):
|
| 8 |
+
self.model = model
|
| 9 |
+
self.optimizer = optimizer
|
| 10 |
+
self.scaler = scaler
|
| 11 |
+
self.config = config
|
| 12 |
+
self.device = device
|
| 13 |
+
self.checkpoint_dir = checkpoint_dir
|
| 14 |
+
self.logger = logger
|
| 15 |
+
|
| 16 |
+
self.model.to(self.device)
|
| 17 |
+
|
| 18 |
+
def validate(self, val_loader, global_step):
|
| 19 |
+
self.model.eval()
|
| 20 |
+
total_loss = 0
|
| 21 |
+
total_items = 0
|
| 22 |
+
num_batches = 100 # Validate on 100 batches to save time
|
| 23 |
+
|
| 24 |
+
print(f"\n[Validation] Starting validation at step {global_step}...")
|
| 25 |
+
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
for i, batch in enumerate(val_loader):
|
| 28 |
+
if i >= num_batches:
|
| 29 |
+
break
|
| 30 |
+
|
| 31 |
+
input_ids, labels = batch
|
| 32 |
+
input_ids = input_ids.to(self.device, non_blocking=True)
|
| 33 |
+
labels = labels.to(self.device, non_blocking=True)
|
| 34 |
+
|
| 35 |
+
# Auto-cast context
|
| 36 |
+
if self.device.type == 'cuda':
|
| 37 |
+
autocast_dtype = torch.float16
|
| 38 |
+
else:
|
| 39 |
+
autocast_dtype = torch.bfloat16
|
| 40 |
+
|
| 41 |
+
use_autocast = True if self.config.torch_dtype != torch.float32 else False
|
| 42 |
+
|
| 43 |
+
if use_autocast:
|
| 44 |
+
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
|
| 45 |
+
output = self.model(input_ids, labels)
|
| 46 |
+
else:
|
| 47 |
+
output = self.model(input_ids, labels)
|
| 48 |
+
|
| 49 |
+
total_loss += output["loss"].item()
|
| 50 |
+
total_items += 1
|
| 51 |
+
|
| 52 |
+
avg_loss = total_loss / total_items if total_items > 0 else 0
|
| 53 |
+
perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
| 54 |
+
|
| 55 |
+
print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
|
| 56 |
+
self.model.train()
|
| 57 |
+
return avg_loss
|
| 58 |
+
|
| 59 |
+
def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
|
| 60 |
+
print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps}\n{'='*70}")
|
| 61 |
+
self.model.train()
|
| 62 |
+
global_step = start_step
|
| 63 |
+
running_loss = 0
|
| 64 |
+
|
| 65 |
+
print("Initializing data iterator...")
|
| 66 |
+
train_iter = iter(train_loader)
|
| 67 |
+
|
| 68 |
+
print("Fetching first batch...")
|
| 69 |
+
|
| 70 |
+
while global_step < total_steps:
|
| 71 |
+
step_start = time.time()
|
| 72 |
+
|
| 73 |
+
# Removed periodic cache clearing for performance
|
| 74 |
+
|
| 75 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
batch = next(train_iter)
|
| 79 |
+
if global_step == start_step:
|
| 80 |
+
print(f"✓ First batch loaded! Starting training loop...")
|
| 81 |
+
except StopIteration:
|
| 82 |
+
train_iter = iter(train_loader)
|
| 83 |
+
batch = next(train_iter)
|
| 84 |
+
|
| 85 |
+
input_ids, labels = batch
|
| 86 |
+
input_ids = input_ids.to(self.device, non_blocking=True)
|
| 87 |
+
labels = labels.to(self.device, non_blocking=True)
|
| 88 |
+
|
| 89 |
+
# Determine autocast dtype
|
| 90 |
+
if self.device.type == 'cuda':
|
| 91 |
+
autocast_dtype = torch.float16
|
| 92 |
+
else:
|
| 93 |
+
autocast_dtype = torch.bfloat16
|
| 94 |
+
|
| 95 |
+
# Check if we should use autocast (skip if model uses float32)
|
| 96 |
+
use_autocast = True
|
| 97 |
+
if self.config.torch_dtype == torch.float32:
|
| 98 |
+
use_autocast = False
|
| 99 |
+
|
| 100 |
+
if use_autocast:
|
| 101 |
+
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
|
| 102 |
+
output = self.model(input_ids, labels)
|
| 103 |
+
loss = output["loss"]
|
| 104 |
+
else:
|
| 105 |
+
output = self.model(input_ids, labels)
|
| 106 |
+
loss = output["loss"]
|
| 107 |
+
|
| 108 |
+
self.scaler.scale(loss).backward()
|
| 109 |
+
self.scaler.unscale_(self.optimizer)
|
| 110 |
+
|
| 111 |
+
# Gradient clipping
|
| 112 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
| 113 |
+
|
| 114 |
+
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
| 115 |
+
print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update")
|
| 116 |
+
else:
|
| 117 |
+
self.scaler.step(self.optimizer)
|
| 118 |
+
|
| 119 |
+
self.scaler.update()
|
| 120 |
+
|
| 121 |
+
global_step += 1
|
| 122 |
+
running_loss += loss.item()
|
| 123 |
+
|
| 124 |
+
if global_step % 10 == 0:
|
| 125 |
+
avg_loss = running_loss / 10
|
| 126 |
+
t_diff = time.time() - step_start
|
| 127 |
+
if self.device.type == 'cuda':
|
| 128 |
+
mem = torch.cuda.memory_allocated() / 1e9
|
| 129 |
+
max_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 130 |
+
mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
|
| 131 |
+
else:
|
| 132 |
+
mem_str = "CPU Mode"
|
| 133 |
+
|
| 134 |
+
tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
|
| 135 |
+
print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
|
| 136 |
+
f"{mem_str} | {tokens_per_sec:.0f} tok/s")
|
| 137 |
+
running_loss = 0
|
| 138 |
+
|
| 139 |
+
if global_step % 500 == 0:
|
| 140 |
+
save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
|
| 141 |
+
|
| 142 |
+
if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
|
| 143 |
+
self.validate(val_loader, global_step)
|
| 144 |
+
|
| 145 |
+
return global_step
|
aetheris/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
def save_checkpoint(model, optimizer, scaler, step, stage, checkpoint_dir, checkpoint_name="checkpoint_current.pth"):
|
| 6 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 7 |
+
path = os.path.join(checkpoint_dir, checkpoint_name)
|
| 8 |
+
torch.save({
|
| 9 |
+
'step': step,
|
| 10 |
+
'stage': stage,
|
| 11 |
+
'model_state_dict': model.state_dict(),
|
| 12 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 13 |
+
'scaler_state_dict': scaler.state_dict()
|
| 14 |
+
}, path)
|
| 15 |
+
print(f" [Checkpoint] Saved at step {step}")
|
| 16 |
+
|
| 17 |
+
def load_latest_checkpoint(model, optimizer, scaler, device, checkpoint_dir, checkpoint_name="checkpoint_current.pth") -> Tuple[int, str]:
|
| 18 |
+
path = os.path.join(checkpoint_dir, checkpoint_name)
|
| 19 |
+
if not os.path.exists(path):
|
| 20 |
+
return 0, "Pre-Training"
|
| 21 |
+
|
| 22 |
+
print(f" [Checkpoint] Loading from {path}...")
|
| 23 |
+
ckpt = torch.load(path, map_location=device)
|
| 24 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 25 |
+
if optimizer:
|
| 26 |
+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
| 27 |
+
if scaler:
|
| 28 |
+
scaler.load_state_dict(ckpt['scaler_state_dict'])
|
| 29 |
+
return ckpt['step'], ckpt['stage']
|
| 30 |
+
|
| 31 |
+
def calculate_model_stats(model):
|
| 32 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 33 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 34 |
+
return {
|
| 35 |
+
'total_params': total_params,
|
| 36 |
+
'trainable_params': trainable_params,
|
| 37 |
+
'active_params': int(total_params * 0.6), # Approximation
|
| 38 |
+
'sparsity_ratio': 0.6 # Approximation
|
| 39 |
+
}
|
configs/debug.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 50257
|
| 2 |
+
d_model: 128
|
| 3 |
+
n_layer: 4
|
| 4 |
+
num_experts: 4
|
| 5 |
+
top_k: 1
|
| 6 |
+
d_ff: 384
|
| 7 |
+
ssm_d_state: 8
|
| 8 |
+
ssm_expand: 2
|
| 9 |
+
load_balancing_coef: 0.01
|
| 10 |
+
router_z_loss_coef: 0.001
|
| 11 |
+
max_seq_len: 128
|
| 12 |
+
dtype: "float32" # Use float32 for debugging on CPU
|
| 13 |
+
use_cpu_offload: false
|
| 14 |
+
gradient_checkpointing: false
|
| 15 |
+
checkpoint_ssm_layers: false
|
| 16 |
+
use_flash_attention: false
|
configs/default.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 50257
|
| 2 |
+
d_model: 768
|
| 3 |
+
n_layer: 24
|
| 4 |
+
num_experts: 4
|
| 5 |
+
top_k: 1
|
| 6 |
+
d_ff: 2304
|
| 7 |
+
ssm_d_state: 16
|
| 8 |
+
ssm_expand: 2
|
| 9 |
+
load_balancing_coef: 0.01
|
| 10 |
+
router_z_loss_coef: 0.001
|
| 11 |
+
max_seq_len: 512
|
| 12 |
+
dtype: "float16"
|
| 13 |
+
use_cpu_offload: false
|
| 14 |
+
gradient_checkpointing: true
|
| 15 |
+
checkpoint_ssm_layers: true
|
| 16 |
+
use_flash_attention: false
|
configs/inference.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 50257
|
| 2 |
+
d_model: 768
|
| 3 |
+
n_layer: 24
|
| 4 |
+
num_experts: 4
|
| 5 |
+
top_k: 1
|
| 6 |
+
d_ff: 2304
|
| 7 |
+
ssm_d_state: 16
|
| 8 |
+
ssm_expand: 2
|
| 9 |
+
load_balancing_coef: 0.0
|
| 10 |
+
router_z_loss_coef: 0.0
|
| 11 |
+
max_seq_len: 1024
|
| 12 |
+
dtype: "float16"
|
| 13 |
+
use_cpu_offload: true # Offload to CPU during inference to save VRAM
|
| 14 |
+
gradient_checkpointing: false
|
| 15 |
+
checkpoint_ssm_layers: false
|
| 16 |
+
use_flash_attention: true
|
configs/large.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 50257
|
| 2 |
+
d_model: 1600
|
| 3 |
+
n_layer: 48
|
| 4 |
+
num_experts: 8
|
| 5 |
+
top_k: 2
|
| 6 |
+
d_ff: 4800
|
| 7 |
+
ssm_d_state: 64
|
| 8 |
+
ssm_expand: 2
|
| 9 |
+
load_balancing_coef: 0.01
|
| 10 |
+
router_z_loss_coef: 0.001
|
| 11 |
+
max_seq_len: 2048
|
| 12 |
+
dtype: "float16"
|
| 13 |
+
use_cpu_offload: false
|
| 14 |
+
gradient_checkpointing: true
|
| 15 |
+
checkpoint_ssm_layers: true
|
| 16 |
+
use_flash_attention: true
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers
|
| 3 |
+
datasets
|
| 4 |
+
huggingface_hub
|
| 5 |
+
pyyaml
|
| 6 |
+
zstandard
|
| 7 |
+
fastapi
|
| 8 |
+
uvicorn
|
| 9 |
+
pydantic
|
| 10 |
+
sse-starlette
|
| 11 |
+
pytest
|
| 12 |
+
httpx
|
scripts/generate.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from aetheris.cli.main import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
# Simulate arguments if needed, but since we are replacing the script, we can just rely on argparse to parse sys.argv
|
| 7 |
+
# The original script parsed arguments like --prompt, etc.
|
| 8 |
+
# The new CLI expects a subcommand, e.g., 'generate'
|
| 9 |
+
|
| 10 |
+
# Check if 'generate' is already in argv, if not prepend it
|
| 11 |
+
if len(sys.argv) > 1 and sys.argv[1] != 'generate':
|
| 12 |
+
sys.argv.insert(1, 'generate')
|
| 13 |
+
elif len(sys.argv) == 1:
|
| 14 |
+
sys.argv.append('generate')
|
| 15 |
+
|
| 16 |
+
sys.exit(main())
|
scripts/info.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from aetheris.cli.main import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
if len(sys.argv) > 1 and sys.argv[1] != 'info':
|
| 7 |
+
sys.argv.insert(1, 'info')
|
| 8 |
+
elif len(sys.argv) == 1:
|
| 9 |
+
sys.argv.append('info')
|
| 10 |
+
|
| 11 |
+
sys.exit(main())
|
scripts/train.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from aetheris.cli.main import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
sys.exit(main())
|
scripts/validate.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Add project root to path
|
| 10 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 11 |
+
|
| 12 |
+
from aetheris.config import AetherisConfig
|
| 13 |
+
from aetheris.model import HybridMambaMoE
|
| 14 |
+
from aetheris.data import create_streaming_loader, get_tokenizer
|
| 15 |
+
from aetheris.utils import load_latest_checkpoint
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def evaluate_model(model, val_loader, device, max_batches=100):
|
| 19 |
+
print(f"\n{'='*50}\nStarting Validation (Max {max_batches} batches)\n{'='*50}")
|
| 20 |
+
|
| 21 |
+
model.eval()
|
| 22 |
+
total_loss = 0.0
|
| 23 |
+
num_batches = 0
|
| 24 |
+
start_time = time.time()
|
| 25 |
+
|
| 26 |
+
for batch in val_loader:
|
| 27 |
+
if num_batches >= max_batches:
|
| 28 |
+
break
|
| 29 |
+
|
| 30 |
+
input_ids, labels = batch
|
| 31 |
+
input_ids = input_ids.to(device, non_blocking=True)
|
| 32 |
+
labels = labels.to(device, non_blocking=True)
|
| 33 |
+
|
| 34 |
+
with torch.amp.autocast('cuda', dtype=torch.float16):
|
| 35 |
+
output = model(input_ids, labels)
|
| 36 |
+
loss = output["loss"]
|
| 37 |
+
|
| 38 |
+
total_loss += loss.item()
|
| 39 |
+
num_batches += 1
|
| 40 |
+
|
| 41 |
+
if num_batches % 20 == 0:
|
| 42 |
+
print(f"-> Processed {num_batches}/{max_batches} batches...")
|
| 43 |
+
|
| 44 |
+
end_time = time.time()
|
| 45 |
+
|
| 46 |
+
if num_batches == 0:
|
| 47 |
+
print("No validation batches were processed.")
|
| 48 |
+
return float('inf')
|
| 49 |
+
|
| 50 |
+
avg_loss = total_loss / num_batches
|
| 51 |
+
perplexity = math.exp(avg_loss)
|
| 52 |
+
|
| 53 |
+
print(f"\n--- Validation Results ---")
|
| 54 |
+
print(f"Total batches processed: {num_batches}")
|
| 55 |
+
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
| 56 |
+
print(f"Average Loss: {avg_loss:.4f}")
|
| 57 |
+
print(f"Perplexity: {perplexity:.2f}")
|
| 58 |
+
print(f"--------------------------\n")
|
| 59 |
+
|
| 60 |
+
return avg_loss
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
parser = argparse.ArgumentParser(description="Validate Aetheris Model")
|
| 64 |
+
parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 65 |
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
|
| 66 |
+
parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
|
| 67 |
+
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
| 68 |
+
parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
|
| 69 |
+
parser.add_argument("--dataset", type=str, default="cerebras/SlimPajama-627B", help="Dataset to validate on")
|
| 70 |
+
parser.add_argument("--dataset_mode", type=str, default="pretrain", help="pretrain or sft")
|
| 71 |
+
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 75 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 76 |
+
tokenizer = get_tokenizer()
|
| 77 |
+
|
| 78 |
+
model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
|
| 79 |
+
|
| 80 |
+
load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
|
| 81 |
+
|
| 82 |
+
val_loader = create_streaming_loader(args.dataset, "validation", tokenizer, config, args.batch_size, mode=args.dataset_mode, hf_token=args.hf_token)
|
| 83 |
+
|
| 84 |
+
evaluate_model(model, val_loader, device)
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
main()
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
from aetheris.api.server import app, get_engine
|
| 5 |
+
import aetheris.api.server
|
| 6 |
+
|
| 7 |
+
# Mock the engine globally
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def mock_engine():
|
| 10 |
+
with patch("aetheris.api.server.engine") as mock_eng:
|
| 11 |
+
# Mock generate_full
|
| 12 |
+
mock_eng.generate_full.return_value = "This is a generated response."
|
| 13 |
+
|
| 14 |
+
# Mock generate (streaming)
|
| 15 |
+
def mock_stream(*args, **kwargs):
|
| 16 |
+
yield "This "
|
| 17 |
+
yield "is "
|
| 18 |
+
yield "streamed."
|
| 19 |
+
mock_eng.generate.side_effect = mock_stream
|
| 20 |
+
|
| 21 |
+
# Need to ensure get_engine returns this mock
|
| 22 |
+
# We can also just set aetheris.api.server.engine
|
| 23 |
+
aetheris.api.server.engine = mock_eng
|
| 24 |
+
yield mock_eng
|
| 25 |
+
|
| 26 |
+
client = TestClient(app)
|
| 27 |
+
|
| 28 |
+
def test_list_models(mock_engine):
|
| 29 |
+
response = client.get("/v1/models")
|
| 30 |
+
assert response.status_code == 200
|
| 31 |
+
data = response.json()
|
| 32 |
+
assert data["object"] == "list"
|
| 33 |
+
assert len(data["data"]) > 0
|
| 34 |
+
assert data["data"][0]["id"] == "aetheris-hybrid-mamba-moe"
|
| 35 |
+
|
| 36 |
+
def test_chat_completions_non_stream(mock_engine):
|
| 37 |
+
payload = {
|
| 38 |
+
"model": "aetheris-hybrid-mamba-moe",
|
| 39 |
+
"messages": [{"role": "user", "content": "Hello"}],
|
| 40 |
+
"stream": False
|
| 41 |
+
}
|
| 42 |
+
response = client.post("/v1/chat/completions", json=payload)
|
| 43 |
+
assert response.status_code == 200
|
| 44 |
+
data = response.json()
|
| 45 |
+
assert data["object"] == "chat.completion"
|
| 46 |
+
assert len(data["choices"]) == 1
|
| 47 |
+
assert data["choices"][0]["message"]["content"] == "This is a generated response."
|
| 48 |
+
|
| 49 |
+
def test_chat_completions_stream(mock_engine):
|
| 50 |
+
payload = {
|
| 51 |
+
"model": "aetheris-hybrid-mamba-moe",
|
| 52 |
+
"messages": [{"role": "user", "content": "Hello"}],
|
| 53 |
+
"stream": True
|
| 54 |
+
}
|
| 55 |
+
response = client.post("/v1/chat/completions", json=payload)
|
| 56 |
+
assert response.status_code == 200
|
| 57 |
+
# SSE format checking
|
| 58 |
+
assert "text/event-stream" in response.headers["content-type"]
|
| 59 |
+
|
| 60 |
+
# We can iterate over the response lines to check content
|
| 61 |
+
content = ""
|
| 62 |
+
for line in response.iter_lines():
|
| 63 |
+
if line:
|
| 64 |
+
# TestClient iter_lines yields strings, not bytes, unless configured otherwise
|
| 65 |
+
# or depending on the version. If it's bytes, we decode. If it's str, we don't.
|
| 66 |
+
if isinstance(line, bytes):
|
| 67 |
+
line = line.decode("utf-8")
|
| 68 |
+
|
| 69 |
+
if line.startswith("data: ") and line != "data: [DONE]":
|
| 70 |
+
import json
|
| 71 |
+
chunk = json.loads(line[6:])
|
| 72 |
+
if chunk["choices"][0]["delta"].get("content"):
|
| 73 |
+
content += chunk["choices"][0]["delta"]["content"]
|
| 74 |
+
|
| 75 |
+
assert content == "This is streamed."
|
| 76 |
+
|
| 77 |
+
def test_completions(mock_engine):
|
| 78 |
+
payload = {
|
| 79 |
+
"model": "aetheris-hybrid-mamba-moe",
|
| 80 |
+
"prompt": "Once upon a time",
|
| 81 |
+
"max_tokens": 10
|
| 82 |
+
}
|
| 83 |
+
response = client.post("/v1/completions", json=payload)
|
| 84 |
+
assert response.status_code == 200
|
| 85 |
+
data = response.json()
|
| 86 |
+
assert data["object"] == "text_completion"
|
| 87 |
+
assert len(data["choices"]) == 1
|
| 88 |
+
assert data["choices"][0]["text"] == "This is a generated response."
|
tests/test_inference.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from unittest.mock import MagicMock, patch
|
| 3 |
+
from aetheris.inference import InferenceEngine
|
| 4 |
+
|
| 5 |
+
@pytest.fixture
|
| 6 |
+
def mock_model():
|
| 7 |
+
with patch("aetheris.inference.HybridMambaMoE") as MockModel:
|
| 8 |
+
mock_instance = MockModel.return_value
|
| 9 |
+
# Mock model output
|
| 10 |
+
mock_instance.to.return_value = mock_instance
|
| 11 |
+
mock_instance.eval.return_value = None
|
| 12 |
+
|
| 13 |
+
# Mock forward pass
|
| 14 |
+
mock_output = MagicMock()
|
| 15 |
+
# Shape: (batch_size, seq_len, vocab_size)
|
| 16 |
+
mock_output.__getitem__.return_value = torch.randn(1, 1, 50257)
|
| 17 |
+
# Actually we need 'logits' key access
|
| 18 |
+
mock_instance.return_value = {'logits': torch.randn(1, 10, 50257)}
|
| 19 |
+
|
| 20 |
+
yield mock_instance
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def mock_tokenizer():
|
| 24 |
+
with patch("aetheris.inference.get_tokenizer") as mock_get_tokenizer:
|
| 25 |
+
mock_tok = MagicMock()
|
| 26 |
+
mock_tok.encode.return_value = torch.tensor([[1, 2, 3]])
|
| 27 |
+
mock_tok.decode.return_value = "token"
|
| 28 |
+
mock_tok.eos_token_id = 50256
|
| 29 |
+
mock_get_tokenizer.return_value = mock_tok
|
| 30 |
+
yield mock_tok
|
| 31 |
+
|
| 32 |
+
@pytest.fixture
|
| 33 |
+
def mock_utils():
|
| 34 |
+
with patch("aetheris.inference.load_latest_checkpoint") as mock_load:
|
| 35 |
+
yield mock_load
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
def test_inference_initialization(mock_model, mock_tokenizer, mock_utils):
|
| 40 |
+
engine = InferenceEngine(config_path="configs/default.yaml")
|
| 41 |
+
assert engine.model is not None
|
| 42 |
+
assert engine.tokenizer is not None
|
| 43 |
+
mock_utils.assert_called_once()
|
| 44 |
+
|
| 45 |
+
def test_generate_full(mock_model, mock_tokenizer, mock_utils):
|
| 46 |
+
engine = InferenceEngine()
|
| 47 |
+
|
| 48 |
+
# Mock model output for generation loop
|
| 49 |
+
# We need to ensure the model returns logits of correct shape
|
| 50 |
+
# The loop calls model(generated_ids)
|
| 51 |
+
|
| 52 |
+
# Let's mock the actual model call inside generate
|
| 53 |
+
engine.model.config.torch_dtype = torch.float32
|
| 54 |
+
|
| 55 |
+
# We need to return a dict with logits
|
| 56 |
+
# Shape: (batch, seq_len, vocab_size)
|
| 57 |
+
engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)}
|
| 58 |
+
|
| 59 |
+
output = engine.generate_full("test prompt", max_new_tokens=5)
|
| 60 |
+
assert isinstance(output, str)
|
| 61 |
+
assert len(output) > 0
|
| 62 |
+
|
| 63 |
+
def test_generate_stream(mock_model, mock_tokenizer, mock_utils):
|
| 64 |
+
engine = InferenceEngine()
|
| 65 |
+
engine.model.config.torch_dtype = torch.float32
|
| 66 |
+
engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)}
|
| 67 |
+
|
| 68 |
+
generator = engine.generate("test prompt", max_new_tokens=5, stream=True)
|
| 69 |
+
tokens = list(generator)
|
| 70 |
+
assert len(tokens) == 5
|
| 71 |
+
assert all(isinstance(t, str) for t in tokens)
|
tests/test_model.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Add project root to path
|
| 7 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 8 |
+
|
| 9 |
+
from aetheris.config import AetherisConfig
|
| 10 |
+
from aetheris.model import HybridMambaMoE
|
| 11 |
+
|
| 12 |
+
class TestHybridMambaMoE(unittest.TestCase):
|
| 13 |
+
def setUp(self):
|
| 14 |
+
self.config = AetherisConfig(
|
| 15 |
+
vocab_size=100,
|
| 16 |
+
d_model=32,
|
| 17 |
+
n_layer=4,
|
| 18 |
+
num_experts=2,
|
| 19 |
+
top_k=1,
|
| 20 |
+
d_ff=64,
|
| 21 |
+
ssm_d_state=8,
|
| 22 |
+
ssm_expand=2,
|
| 23 |
+
max_seq_len=64
|
| 24 |
+
)
|
| 25 |
+
self.model = HybridMambaMoE(self.config)
|
| 26 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 27 |
+
self.model.to(self.device)
|
| 28 |
+
|
| 29 |
+
def test_forward_pass(self):
|
| 30 |
+
batch_size = 2
|
| 31 |
+
seq_len = 16
|
| 32 |
+
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| 33 |
+
|
| 34 |
+
output = self.model(input_ids)
|
| 35 |
+
|
| 36 |
+
self.assertIn('logits', output)
|
| 37 |
+
self.assertEqual(output['logits'].shape, (batch_size, seq_len, self.config.vocab_size))
|
| 38 |
+
|
| 39 |
+
def test_forward_pass_with_labels(self):
|
| 40 |
+
batch_size = 2
|
| 41 |
+
seq_len = 16
|
| 42 |
+
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| 43 |
+
labels = input_ids.clone()
|
| 44 |
+
|
| 45 |
+
output = self.model(input_ids, labels=labels)
|
| 46 |
+
|
| 47 |
+
self.assertIn('loss', output)
|
| 48 |
+
self.assertIn('ce_loss', output)
|
| 49 |
+
self.assertIn('aux_loss', output)
|
| 50 |
+
self.assertIn('logits', output)
|
| 51 |
+
|
| 52 |
+
self.assertTrue(output['loss'] > 0)
|
| 53 |
+
|
| 54 |
+
if __name__ == '__main__':
|
| 55 |
+
unittest.main()
|
tests/test_overflow.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Add project root to path
|
| 7 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 8 |
+
|
| 9 |
+
from aetheris.modules.expert import Expert
|
| 10 |
+
from aetheris.modules.moe import SparseMoELayer
|
| 11 |
+
from aetheris.config import AetherisConfig
|
| 12 |
+
|
| 13 |
+
class TestOverflow(unittest.TestCase):
|
| 14 |
+
def setUp(self):
|
| 15 |
+
self.config = AetherisConfig(
|
| 16 |
+
vocab_size=100,
|
| 17 |
+
d_model=128,
|
| 18 |
+
n_layer=2,
|
| 19 |
+
num_experts=2,
|
| 20 |
+
top_k=1,
|
| 21 |
+
d_ff=512, # Large enough to potentially cause issues
|
| 22 |
+
ssm_d_state=16,
|
| 23 |
+
ssm_expand=2,
|
| 24 |
+
max_seq_len=64
|
| 25 |
+
)
|
| 26 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 27 |
+
|
| 28 |
+
def test_expert_overflow_protection(self):
|
| 29 |
+
"""Test if Expert handles large inputs without producing NaNs in float16"""
|
| 30 |
+
expert = Expert(self.config.d_model, self.config.d_ff).to(self.device)
|
| 31 |
+
# Manually cast weights to float16 to simulate mixed precision training environment
|
| 32 |
+
expert.half()
|
| 33 |
+
|
| 34 |
+
# Create a large input in float16 that would normally cause overflow in intermediate layers
|
| 35 |
+
# The limit of float16 is ~65504.
|
| 36 |
+
# If w1 projects this up, it can easily exceed that.
|
| 37 |
+
large_input = torch.ones(1, self.config.d_model, dtype=torch.float16).to(self.device) * 100.0
|
| 38 |
+
|
| 39 |
+
# Force weights to be large to guarantee overflow if protection isn't working
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
expert.w1.weight.fill_(10.0)
|
| 42 |
+
expert.w2.weight.fill_(0.1)
|
| 43 |
+
|
| 44 |
+
# 100 * 10 = 1000. Sum over d_model(128) -> 128000.
|
| 45 |
+
# This summation happens in the matrix multiplication.
|
| 46 |
+
# If the matmul internal accumulation is float16, it effectively overflows.
|
| 47 |
+
|
| 48 |
+
output = expert(large_input)
|
| 49 |
+
|
| 50 |
+
self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")
|
| 51 |
+
self.assertFalse(torch.isinf(output).any(), "Output contains Infs")
|
| 52 |
+
|
| 53 |
+
def test_moe_accumulation_stability(self):
|
| 54 |
+
"""Test if MoE layer handles accumulation in float32"""
|
| 55 |
+
moe = SparseMoELayer(self.config).to(self.device)
|
| 56 |
+
moe.half()
|
| 57 |
+
|
| 58 |
+
x = torch.randn(2, 10, self.config.d_model, dtype=torch.float16).to(self.device)
|
| 59 |
+
|
| 60 |
+
# Pass through
|
| 61 |
+
output, loss = moe(x)
|
| 62 |
+
|
| 63 |
+
self.assertFalse(torch.isnan(output).any(), "MoE Output contains NaNs")
|
| 64 |
+
self.assertEqual(output.dtype, torch.float16)
|
| 65 |
+
|
| 66 |
+
if __name__ == '__main__':
|
| 67 |
+
unittest.main()
|