Pomilon commited on
Commit
1df0e33
·
0 Parent(s):

Deploy Aetheris to HF Space

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoints/checkpoint_current.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ env/
7
+ venv/
8
+ .env
9
+ .venv
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+ checkpoints/
26
+ *.log
27
+ .DS_Store
28
+ legacy/
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ git \
10
+ build-essential \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Set working directory
14
+ WORKDIR /app
15
+
16
+ # Create a user first to handle permissions correctly from the start
17
+ RUN useradd -m -u 1000 user
18
+
19
+ # Switch to user
20
+ USER user
21
+ ENV HOME=/home/user \
22
+ PATH=/home/user/.local/bin:$PATH
23
+
24
+ # Set up application directory with correct permissions
25
+ WORKDIR $HOME/app
26
+
27
+ # Copy requirements and install
28
+ COPY --chown=user requirements.txt .
29
+ RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
30
+
31
+ # Copy application code
32
+ COPY --chown=user . .
33
+
34
+ # Expose port
35
+ EXPOSE 7860
36
+
37
+ # Command to run the application
38
+ CMD ["python3", "-m", "aetheris.cli.main", "serve", "--host", "0.0.0.0", "--port", "7860"]
Dockerfile-nvidia ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA base image for GPU support
2
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ DEBIAN_FRONTEND=noninteractive
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ python3-pip \
12
+ python3-dev \
13
+ git \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Set working directory
17
+ WORKDIR /app
18
+
19
+ # Install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip3 install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy application code
24
+ COPY . .
25
+
26
+ # Expose port (7860 is default for Hugging Face Spaces)
27
+ EXPOSE 7860
28
+
29
+ # Create a user to avoid running as root (good practice, also sometimes required by HF)
30
+ # But often HF runs as user 1000.
31
+ RUN useradd -m -u 1000 user
32
+ USER user
33
+ ENV HOME=/home/user \
34
+ PATH=/home/user/.local/bin:$PATH
35
+
36
+ WORKDIR $HOME/app
37
+ COPY --chown=user . $HOME/app
38
+
39
+ # Command to run the application
40
+ # We use the CLI serve command we added
41
+ CMD ["python3", "-m", "aetheris.cli.main", "serve", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Pomilon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Aetheris: Hybrid Mamba-MoE Experiment
2
+
3
+ <p align="center">
4
+ <img src="https://img.shields.io/badge/Status-Experimental-yellow.svg" alt="Status">
5
+ <img src="https://img.shields.io/badge/License-MIT-green.svg" alt="License">
6
+ <img src="https://img.shields.io/badge/Python-3.10+-blue.svg" alt="Python">
7
+ <img src="https://img.shields.io/badge/PyTorch-2.0+-orange.svg" alt="PyTorch">
8
+ <img src="https://img.shields.io/badge/API-FastAPI-009688.svg" alt="FastAPI">
9
+ </p>
10
+
11
+
12
+ **Aetheris** is a hobbyist research project and experimental implementation exploring the intersection of **State Space Models (Mamba)** and **Mixture of Experts (MoE)**.
13
+
14
+ The goal of this project was to learn by doing: attempting to combine the linear-time inference of Mamba with the sparse scaling capacity of MoE from scratch in PyTorch. It is designed as a playground for understanding these modern architectures, not as a published academic paper or production-ready foundation model.
15
+
16
+ ## 🧪 The Experiment
17
+
18
+ Current LLM architectures are evolving rapidly. I built Aetheris to investigate a specific question:
19
+ > *Can we successfully interleave Mamba blocks (for long context) with sparse MoE layers (for capacity) to train an efficient model on consumer hardware?*
20
+
21
+ This project implements a hybrid architecture that attempts to:
22
+ 1. **Replace Attention:** Use Mamba (SSM) blocks to achieve $O(N)$ sequence scaling.
23
+ 2. **Scale Parameters Sparsely:** Use MoE layers to increase model size without exploding the computational cost per token.
24
+ 3. **Run Locally:** Optimize the implementation for single-GPU training (gradient checkpointing, efficient routing).
25
+
26
+ ## 🏗️ Architecture Implementation
27
+
28
+ Aetheris alternates between custom implementations of two core modules:
29
+
30
+ * **SSMBlock (The Backbone):** Implements the selective scan mechanism described in the [Mamba paper](https://arxiv.org/abs/2312.00752). This handles the sequence mixing and "memory" of the model.
31
+ * **SparseMoELayer (The Scaling):** A router-based layer that dispatches tokens to Top-K experts (Feed-Forward Networks). This allows the model to "specialize" parts of its parameters for different types of tokens.
32
+
33
+ ## 🚀 Quick Start
34
+
35
+ This code is provided for educational purposes and for others who want to experiment with hybrid architectures.
36
+
37
+ ### Installation
38
+
39
+ **Option 1: Local Python Environment**
40
+
41
+ ```bash
42
+ git clone https://github.com/Pomilon/Aetheris.git
43
+ cd Aetheris
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ **Option 2: Docker**
48
+
49
+ We provide Dockerfiles for both CPU (slim) and GPU (NVIDIA) environments.
50
+
51
+ ```bash
52
+ # CPU Version
53
+ docker build -t aetheris-cpu -f Dockerfile .
54
+ docker run -p 7860:7860 aetheris-cpu
55
+
56
+ # GPU Version (Requires NVIDIA Container Toolkit)
57
+ docker build -t aetheris-gpu -f Dockerfile-nvidia .
58
+ docker run --gpus all -p 7860:7860 aetheris-gpu
59
+ ```
60
+
61
+ ### Usage (CLI)
62
+
63
+ Aetheris includes a CLI to train, inference, or serve the model.
64
+
65
+ **1. Training (From Scratch)**
66
+
67
+ ```bash
68
+ # Trains a small model defined in configs/default.yaml
69
+ python -m aetheris.cli.main train --config configs/default.yaml
70
+ ```
71
+
72
+ **2. Generation (CLI)**
73
+
74
+ ```bash
75
+ python -m aetheris.cli.main generate --prompt "The quick brown fox" --checkpoint_dir checkpoints
76
+ ```
77
+
78
+ **3. API Server (OpenAI-Compatible)**
79
+
80
+ Start a local API server that simulates OpenAI's chat completions endpoint.
81
+
82
+ ```bash
83
+ python -m aetheris.cli.main serve --host 0.0.0.0 --port 8000
84
+ ```
85
+
86
+ You can then interact with it using standard tools:
87
+
88
+ ```bash
89
+ curl http://localhost:8000/v1/chat/completions \
90
+ -H "Content-Type: application/json" \
91
+ -d {
92
+ "model": "aetheris-hybrid",
93
+ "messages": [{"role": "user", "content": "Hello!"}],
94
+ "stream": true
95
+ }
96
+ ```
97
+
98
+ ### Development & Testing
99
+
100
+ To run the test suite:
101
+
102
+ ```bash
103
+ pytest tests/
104
+ ```
105
+
106
+ ## ⚙️ Configuration
107
+
108
+ You can tweak the hyperparameters in `configs/`. I've included a "Debug" config that is small enough to train on a laptop CPU for testing the code flow.
109
+
110
+ | Config File | Description |
111
+ | :--- | :--- |
112
+ | `configs/default.yaml` | Standard experimental setup (requires GPU). |
113
+ | `configs/debug.yaml` | Tiny model (2 layers) for code debugging. |
114
+
115
+ ## 📚 Acknowledgements & References
116
+
117
+ This project is an implementation study and relies heavily on the brilliant theoretical work of others. It is not an original invention of the Mamba or MoE concepts.
118
+
119
+ * **Mamba Architecture:** Gu, A., & Dao, T. (2023). *Mamba: Linear-Time Sequence Modeling with Selective State Spaces*. [arXiv:2312.00752](https://arxiv.org/abs/2312.00752)
120
+ * **Mixture of Experts:** Shazeer, N., et al. (2017). *Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer*. [arXiv:1701.06538](https://arxiv.org/abs/1701.06538)
121
+ * **Inspiration:** Jamba (AI21 Labs) and OpenMoE.
122
+
123
+ ## 🧠 Model Weights & Checkpoints
124
+
125
+ All pre-trained checkpoints are hosted on the [Hugging Face Hub](https://huggingface.co/Pomilon).
126
+
127
+ | Model Artifact | Step | Description | Download |
128
+ | :--- | :--- | :--- | :--- |
129
+ | **Aetheris-Base** | 10k | Early convergence checkpoint (Loss ~3.66). Good for analyzing router behavior. | [🤗 Hugging Face](https://huggingface.co/Pomilon/Aetheris-MoE-300M-A125M-base) |
130
+ | **Aetheris-Chat** | -- | *Coming Soon (Post-SFT)* | -- |
131
+
132
+ > **⚠️ Important:** Aetheris uses a custom Hybrid Mamba-MoE architecture. You **cannot** load it directly with `transformers.AutoModel`. You must use the interface provided in this repository.
133
+
134
+ ### 🐍 How to Load
135
+
136
+ ```python
137
+ python -m aetheris.cli.main generate --prompt "The quick brown fox" --checkpoint_dir path/to/checkpoints_folder # rename the checkpoint inside to checkpoint_current.pth
138
+ ```
139
+ > **Note:** will add better inference later down the line, for now used this scuffed version. :D
140
+
141
+ > **Note:** These weights are from an experimental run. While they demonstrate the architectural capabilities, do not expect GPT-5 or even google bard level coherence. :D
142
+ > this project was made for learning and fun!
143
+
144
+ ## License
145
+
146
+ MIT
aetheris/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import HybridMambaMoE
2
+ from .config import AetherisConfig
aetheris/api/schemas.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any
2
+ from pydantic import BaseModel, Field
3
+ import time
4
+
5
+ class ChatMessage(BaseModel):
6
+ role: str
7
+ content: str
8
+
9
+ class ChatCompletionRequest(BaseModel):
10
+ model: str
11
+ messages: List[ChatMessage]
12
+ temperature: Optional[float] = 1.0
13
+ top_p: Optional[float] = 1.0
14
+ n: Optional[int] = 1
15
+ stream: Optional[bool] = False
16
+ stop: Optional[Union[str, List[str]]] = None
17
+ max_tokens: Optional[int] = None
18
+ presence_penalty: Optional[float] = 0.0
19
+ frequency_penalty: Optional[float] = 0.0
20
+ logit_bias: Optional[Dict[str, float]] = None
21
+ user: Optional[str] = None
22
+
23
+ class ChatCompletionChoice(BaseModel):
24
+ index: int
25
+ message: ChatMessage
26
+ finish_reason: Optional[str] = None
27
+
28
+ class ChatCompletionResponse(BaseModel):
29
+ id: str
30
+ object: str = "chat.completion"
31
+ created: int = Field(default_factory=lambda: int(time.time()))
32
+ model: str
33
+ choices: List[ChatCompletionChoice]
34
+ usage: Optional[Dict[str, int]] = None
35
+
36
+ class ChatCompletionChunkDelta(BaseModel):
37
+ role: Optional[str] = None
38
+ content: Optional[str] = None
39
+
40
+ class ChatCompletionChunkChoice(BaseModel):
41
+ index: int
42
+ delta: ChatCompletionChunkDelta
43
+ finish_reason: Optional[str] = None
44
+
45
+ class ChatCompletionChunk(BaseModel):
46
+ id: str
47
+ object: str = "chat.completion.chunk"
48
+ created: int = Field(default_factory=lambda: int(time.time()))
49
+ model: str
50
+ choices: List[ChatCompletionChunkChoice]
51
+
52
+ class CompletionRequest(BaseModel):
53
+ model: str
54
+ prompt: Union[str, List[str]]
55
+ suffix: Optional[str] = None
56
+ max_tokens: Optional[int] = 16
57
+ temperature: Optional[float] = 1.0
58
+ top_p: Optional[float] = 1.0
59
+ n: Optional[int] = 1
60
+ stream: Optional[bool] = False
61
+ logprobs: Optional[int] = None
62
+ echo: Optional[bool] = False
63
+ stop: Optional[Union[str, List[str]]] = None
64
+ presence_penalty: Optional[float] = 0.0
65
+ frequency_penalty: Optional[float] = 0.0
66
+ best_of: Optional[int] = 1
67
+ logit_bias: Optional[Dict[str, float]] = None
68
+ user: Optional[str] = None
69
+
70
+ class CompletionChoice(BaseModel):
71
+ text: str
72
+ index: int
73
+ logprobs: Optional[Any] = None
74
+ finish_reason: Optional[str] = None
75
+
76
+ class CompletionResponse(BaseModel):
77
+ id: str
78
+ object: str = "text_completion"
79
+ created: int = Field(default_factory=lambda: int(time.time()))
80
+ model: str
81
+ choices: List[CompletionChoice]
82
+ usage: Optional[Dict[str, int]] = None
83
+
84
+ class ModelCard(BaseModel):
85
+ id: str
86
+ object: str = "model"
87
+ created: int = Field(default_factory=lambda: int(time.time()))
88
+ owned_by: str = "aetheris"
89
+
90
+ class ModelList(BaseModel):
91
+ object: str = "list"
92
+ data: List[ModelCard]
aetheris/api/server.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import uuid
3
+ import json
4
+ import asyncio
5
+ from typing import AsyncGenerator
6
+ from fastapi import FastAPI, HTTPException, Request
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from sse_starlette.sse import EventSourceResponse
9
+ from aetheris.api.schemas import (
10
+ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk,
11
+ ChatCompletionChoice, ChatMessage, ChatCompletionChunkChoice, ChatCompletionChunkDelta,
12
+ CompletionRequest, CompletionResponse, CompletionChoice,
13
+ ModelList, ModelCard
14
+ )
15
+ from aetheris.inference import InferenceEngine
16
+
17
+ app = FastAPI(title="Aetheris API", version="0.1.0")
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Global engine instance
28
+ engine: InferenceEngine = None
29
+
30
+ def get_engine():
31
+ global engine
32
+ if engine is None:
33
+ # Defaults, ideally loaded from config/env
34
+ engine = InferenceEngine()
35
+ return engine
36
+
37
+ @app.on_event("startup")
38
+ async def startup_event():
39
+ get_engine()
40
+
41
+ @app.get("/v1/models", response_model=ModelList)
42
+ async def list_models():
43
+ return ModelList(data=[ModelCard(id="aetheris-hybrid-mamba-moe")])
44
+
45
+ @app.post("/v1/chat/completions")
46
+ async def chat_completions(request: ChatCompletionRequest):
47
+ engine = get_engine()
48
+
49
+ # Simple prompt construction from messages
50
+ prompt = ""
51
+ for msg in request.messages:
52
+ prompt += f"{msg.role}: {msg.content}\n"
53
+ prompt += "assistant: "
54
+
55
+ request_id = f"chatcmpl-{uuid.uuid4()}"
56
+ created_time = int(time.time())
57
+
58
+ if request.stream:
59
+ async def event_generator():
60
+ yield json.dumps(ChatCompletionChunk(
61
+ id=request_id,
62
+ created=created_time,
63
+ model=request.model,
64
+ choices=[ChatCompletionChunkChoice(
65
+ index=0,
66
+ delta=ChatCompletionChunkDelta(role="assistant"),
67
+ finish_reason=None
68
+ )]
69
+ ).model_dump())
70
+
71
+ for token in engine.generate(
72
+ prompt=prompt,
73
+ max_new_tokens=request.max_tokens or 100,
74
+ temperature=request.temperature,
75
+ top_p=request.top_p,
76
+ repetition_penalty=1.0 + request.frequency_penalty, # Approximating
77
+ stream=True
78
+ ):
79
+ yield json.dumps(ChatCompletionChunk(
80
+ id=request_id,
81
+ created=created_time,
82
+ model=request.model,
83
+ choices=[ChatCompletionChunkChoice(
84
+ index=0,
85
+ delta=ChatCompletionChunkDelta(content=token),
86
+ finish_reason=None
87
+ )]
88
+ ).model_dump())
89
+
90
+ yield json.dumps(ChatCompletionChunk(
91
+ id=request_id,
92
+ created=created_time,
93
+ model=request.model,
94
+ choices=[ChatCompletionChunkChoice(
95
+ index=0,
96
+ delta=ChatCompletionChunkDelta(),
97
+ finish_reason="stop"
98
+ )]
99
+ ).model_dump())
100
+
101
+ yield "[DONE]"
102
+
103
+ return EventSourceResponse(event_generator())
104
+
105
+ else:
106
+ generated_text = engine.generate_full(
107
+ prompt=prompt,
108
+ max_new_tokens=request.max_tokens or 100,
109
+ temperature=request.temperature,
110
+ top_p=request.top_p,
111
+ repetition_penalty=1.0 + request.frequency_penalty
112
+ )
113
+
114
+ return ChatCompletionResponse(
115
+ id=request_id,
116
+ created=created_time,
117
+ model=request.model,
118
+ choices=[ChatCompletionChoice(
119
+ index=0,
120
+ message=ChatMessage(role="assistant", content=generated_text),
121
+ finish_reason="stop"
122
+ )],
123
+ usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)} # Approximated
124
+ )
125
+
126
+ @app.post("/v1/completions")
127
+ async def completions(request: CompletionRequest):
128
+ engine = get_engine()
129
+
130
+ prompt = request.prompt
131
+ if isinstance(prompt, list):
132
+ prompt = prompt[0] # Handle single prompt for now
133
+
134
+ request_id = f"cmpl-{uuid.uuid4()}"
135
+ created_time = int(time.time())
136
+
137
+ if request.stream:
138
+ # Streaming for completions not fully implemented to match OpenAI exactly in this demo,
139
+ # but logic is similar to chat.
140
+ # For simplicity, returning non-streaming for now or basic stream.
141
+ pass # TODO: Implement streaming for completions
142
+
143
+ generated_text = engine.generate_full(
144
+ prompt=prompt,
145
+ max_new_tokens=request.max_tokens or 16,
146
+ temperature=request.temperature,
147
+ top_p=request.top_p,
148
+ repetition_penalty=1.0 + request.frequency_penalty
149
+ )
150
+
151
+ return CompletionResponse(
152
+ id=request_id,
153
+ created=created_time,
154
+ model=request.model,
155
+ choices=[CompletionChoice(
156
+ text=generated_text,
157
+ index=0,
158
+ logprobs=None,
159
+ finish_reason="length" # or stop
160
+ )],
161
+ usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)}
162
+ )
aetheris/cli/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
aetheris/cli/main.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import torch
4
+ import os
5
+ import torch.nn.functional as F
6
+ from aetheris.config import AetherisConfig
7
+ from aetheris.model import HybridMambaMoE
8
+ from aetheris.data import create_streaming_loader, get_tokenizer
9
+ from aetheris.utils import load_latest_checkpoint, calculate_model_stats
10
+ from aetheris.trainer import Trainer
11
+
12
+ def train_command(args):
13
+ print(f"\n{'='*70}")
14
+ print(f"Aetheris Training")
15
+ print(f"Config: {args.config}")
16
+
17
+ if args.hf_token:
18
+ print(f"Using Hugging Face token: {args.hf_token[:10]}...")
19
+ from huggingface_hub import login
20
+ login(token=args.hf_token)
21
+
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ if device.type == 'cuda':
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ torch.backends.cudnn.allow_tf32 = True
26
+ torch.cuda.empty_cache()
27
+
28
+ config = AetherisConfig.from_yaml(args.config)
29
+ tokenizer = get_tokenizer()
30
+
31
+ print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
32
+ print(f"Model Size: d_model={config.d_model}, layers={config.n_layer}")
33
+ print(f"{'='*70}\n")
34
+
35
+ model = HybridMambaMoE(config).to(device)
36
+
37
+ # Apply weight initialization
38
+ print("Applying proper weight initialization...")
39
+ model.apply(model._init_weights)
40
+
41
+ # Calculate model stats
42
+ stats = calculate_model_stats(model)
43
+ print(f"Total Parameters: {stats['total_params']:,}")
44
+ print(f"Trainable Parameters: {stats['trainable_params']:,}")
45
+
46
+ # Use lower learning rate for stability
47
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01,
48
+ betas=(0.9, 0.95), eps=1e-8, fused=False if device.type == 'cpu' else True)
49
+ scaler = torch.amp.GradScaler('cuda' if device.type == 'cuda' else 'cpu', init_scale=2**10)
50
+
51
+ start_step, current_stage = load_latest_checkpoint(model, optimizer, scaler, device, args.checkpoint_dir, args.checkpoint_name)
52
+
53
+ trainer = Trainer(model, optimizer, scaler, config, device, args.checkpoint_dir)
54
+
55
+ # --- STAGE 1: PRE-TRAINING ---
56
+ if current_stage == "Pre-Training" or start_step == 0:
57
+ pt_loader = create_streaming_loader("cerebras/SlimPajama-627B", "train",
58
+ tokenizer, config, args.batch_size, mode="pretrain",
59
+ hf_token=args.hf_token, start_step=start_step)
60
+
61
+ # Validation loader (no skipping needed, always from start of val set)
62
+ pt_val_loader = create_streaming_loader("cerebras/SlimPajama-627B", "validation",
63
+ tokenizer, config, args.batch_size, mode="pretrain",
64
+ hf_token=args.hf_token)
65
+
66
+ start_step = trainer.train_epoch(pt_loader, total_steps=args.pretrain_steps,
67
+ start_step=start_step, stage_name="Pre-Training",
68
+ val_loader=pt_val_loader)
69
+ current_stage = "SFT"
70
+ start_step = 0
71
+
72
+ # --- STAGE 2: SFT ---
73
+ print("\n=== STAGE 2: SFT ===")
74
+ for param_group in optimizer.param_groups:
75
+ param_group['lr'] = 5e-5
76
+
77
+ sft_loader = create_streaming_loader("OpenAssistant/oasst1", "train",
78
+ tokenizer, config, args.batch_size, mode="sft",
79
+ hf_token=args.hf_token, start_step=start_step)
80
+
81
+ sft_val_loader = create_streaming_loader("OpenAssistant/oasst1", "validation",
82
+ tokenizer, config, args.batch_size, mode="sft",
83
+ hf_token=args.hf_token)
84
+
85
+ trainer.train_epoch(sft_loader, total_steps=args.sft_steps,
86
+ start_step=start_step, stage_name="SFT",
87
+ val_loader=sft_val_loader)
88
+
89
+ print("\nTraining Complete!")
90
+
91
+ @torch.no_grad()
92
+ def generate_command(args):
93
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94
+ config = AetherisConfig.from_yaml(args.config)
95
+ tokenizer = get_tokenizer()
96
+
97
+ model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
98
+
99
+ load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
100
+ model.eval()
101
+
102
+ prompt = args.prompt
103
+ max_new_tokens = args.max_new_tokens
104
+ temperature = args.temperature
105
+ top_k = args.top_k
106
+ top_p = args.top_p
107
+ repetition_penalty = args.repetition_penalty
108
+
109
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
110
+ generated_ids = input_ids.clone()
111
+ history_ids = set(input_ids[0].tolist())
112
+
113
+ print("-" * 50)
114
+ print(f"Prompt: {prompt}")
115
+ print("Generated Continuation:")
116
+
117
+ for _ in range(max_new_tokens):
118
+ # Check if we should use autocast (skip if model uses float32)
119
+ use_autocast = True
120
+ if config.torch_dtype == torch.float32:
121
+ use_autocast = False
122
+
123
+ if use_autocast:
124
+ with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
125
+ outputs = model(generated_ids)
126
+ logits = outputs['logits']
127
+ next_token_logits = logits[:, -1, :]
128
+ else:
129
+ outputs = model(generated_ids)
130
+ logits = outputs['logits']
131
+ next_token_logits = logits[:, -1, :]
132
+
133
+ # Repetition penalty
134
+ for token_id in history_ids:
135
+ if token_id < next_token_logits.size(-1):
136
+ logit = next_token_logits[0, token_id].item()
137
+ if logit > 0:
138
+ next_token_logits[0, token_id] = logit / repetition_penalty
139
+ else:
140
+ next_token_logits[0, token_id] = logit * repetition_penalty
141
+
142
+ # Temperature
143
+ if temperature > 0:
144
+ next_token_logits = next_token_logits / temperature
145
+
146
+ # Top-p / Top-k
147
+ if top_p < 1.0:
148
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
149
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
150
+ sorted_indices_to_remove = cumulative_probs > top_p
151
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
152
+ sorted_indices_to_remove[..., 0] = False
153
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
154
+ next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
155
+ elif top_k > 0:
156
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
157
+ next_token_logits = torch.full_like(next_token_logits, float('-inf'))
158
+ next_token_logits.scatter_(1, top_k_indices, top_k_logits)
159
+
160
+ # Sample
161
+ next_token_probs = F.softmax(next_token_logits, dim=-1)
162
+ next_token = torch.multinomial(next_token_probs, num_samples=1)
163
+ next_token_item = next_token.item()
164
+
165
+ if next_token_item == tokenizer.eos_token_id:
166
+ break
167
+
168
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
169
+ history_ids.add(next_token_item)
170
+
171
+ new_token_text = tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
172
+ print(new_token_text, end="", flush=True)
173
+
174
+ print("\n" + "-" * 50)
175
+
176
+ def info_command(args):
177
+ config = AetherisConfig.from_yaml(args.config)
178
+ model = HybridMambaMoE(config)
179
+
180
+ total_params = 0
181
+ dense_params = 0 # Parameters active for EVERY token
182
+ expert_params = 0 # Parameters in all MoE Experts
183
+
184
+ for name, param in model.named_parameters():
185
+ numel = param.numel()
186
+ total_params += numel
187
+
188
+ if 'experts' in name:
189
+ expert_params += numel
190
+ else:
191
+ dense_params += numel
192
+
193
+ single_expert_size = expert_params / config.num_experts if config.num_experts > 0 else 0
194
+ active_per_token_params = dense_params + (single_expert_size * config.top_k)
195
+
196
+ def format_count(count):
197
+ return f"{count / 1_000_000:.2f}M"
198
+
199
+ print("=" * 50)
200
+ print("Hybrid Mamba-MoE Model Parameter Analysis")
201
+ print("=" * 50)
202
+ print(f"Total Model Layers (N_Layer): {config.n_layer}")
203
+ print(f"MoE Experts per Layer: {config.num_experts}")
204
+ print(f"Active Experts (Top-K): {config.top_k}")
205
+ print("-" * 50)
206
+ print(f"Total Parameters (Checkpoint Size): {format_count(total_params)}")
207
+ print(f"Dense (Always Active) Parameters: {format_count(dense_params)}")
208
+ print(f"Expert-Only Parameters: {format_count(expert_params)}")
209
+ print("-" * 50)
210
+ print(f"**Active Parameters (Per-Token Compute Load): {format_count(active_per_token_params)}**")
211
+ print(" (This is the 'Dense' parameters + the K active expert parameters)")
212
+ print("=" * 50)
213
+
214
+
215
+ def main():
216
+ parser = argparse.ArgumentParser(description="Aetheris CLI")
217
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
218
+
219
+ # Train Command
220
+ train_parser = subparsers.add_parser("train", help="Train the model")
221
+ train_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
222
+ train_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
223
+ train_parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
224
+ train_parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
225
+ train_parser.add_argument("--pretrain_steps", type=int, default=50000, help="Number of pretraining steps")
226
+ train_parser.add_argument("--sft_steps", type=int, default=1000, help="Number of SFT steps")
227
+ train_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name to load from")
228
+
229
+ # Generate Command
230
+ gen_parser = subparsers.add_parser("generate", help="Generate text")
231
+ gen_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
232
+ gen_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
233
+ gen_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
234
+ gen_parser.add_argument("--prompt", type=str, default="The quick brown fox", help="Prompt for generation")
235
+ gen_parser.add_argument("--max_new_tokens", type=int, default=100, help="Max new tokens to generate")
236
+ gen_parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
237
+ gen_parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling")
238
+ gen_parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
239
+ gen_parser.add_argument("--repetition_penalty", type=float, default=3.0, help="Repetition penalty")
240
+
241
+ # Serve Command
242
+ serve_parser = subparsers.add_parser("serve", help="Start the API server")
243
+ serve_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
244
+ serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind")
245
+ serve_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
246
+ serve_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
247
+ serve_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
248
+
249
+
250
+ # Info Command
251
+ info_parser = subparsers.add_parser("info", help="Show model info")
252
+ info_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
253
+
254
+ args = parser.parse_args()
255
+
256
+ if args.command == "train":
257
+ train_command(args)
258
+ elif args.command == "generate":
259
+ generate_command(args)
260
+ elif args.command == "serve":
261
+ import uvicorn
262
+ from aetheris.api.server import app, get_engine
263
+
264
+ # Initialize engine before starting server
265
+ engine = get_engine()
266
+ # You might want to pass config/checkpoint paths to get_engine here if it supported arguments
267
+ # For now, it defaults or we need to modify get_engine or InferenceEngine to take args.
268
+ # But `get_engine` is a simple global accessor.
269
+ # Better: Initialize a global engine with args here.
270
+ from aetheris.inference import InferenceEngine
271
+ import aetheris.api.server
272
+
273
+ aetheris.api.server.engine = InferenceEngine(
274
+ config_path=args.config,
275
+ checkpoint_dir=args.checkpoint_dir,
276
+ checkpoint_name=args.checkpoint_name
277
+ )
278
+
279
+ uvicorn.run(app, host=args.host, port=args.port)
280
+
281
+ elif args.command == "info":
282
+ info_command(args)
283
+ else:
284
+ parser.print_help()
285
+
286
+ if __name__ == "__main__":
287
+ main()
aetheris/config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import yaml
3
+ import torch
4
+ from typing import Optional
5
+
6
+ @dataclass
7
+ class AetherisConfig:
8
+ # Model dimensions
9
+ vocab_size: int = 50257
10
+ d_model: int = 768
11
+ n_layer: int = 24
12
+ num_experts: int = 4
13
+ top_k: int = 1
14
+ d_ff: int = 2304 # d_model * 3
15
+
16
+ # SSM parameters
17
+ ssm_d_state: int = 16
18
+ ssm_expand: int = 2
19
+ d_inner: Optional[int] = None # Will be d_model * ssm_expand if None
20
+
21
+ # Training parameters
22
+ load_balancing_coef: float = 1e-2
23
+ router_z_loss_coef: float = 1e-3
24
+ max_seq_len: int = 512
25
+ dtype: str = "float16" # "float16", "float32", "bfloat16"
26
+
27
+ # Optimization settings
28
+ use_cpu_offload: bool = False
29
+ gradient_checkpointing: bool = True
30
+ checkpoint_ssm_layers: bool = True
31
+ use_flash_attention: bool = False
32
+
33
+ def __post_init__(self):
34
+ if self.d_inner is None:
35
+ self.d_inner = self.d_model * self.ssm_expand
36
+ if self.d_ff is None:
37
+ self.d_ff = self.d_model * 3
38
+
39
+ @property
40
+ def torch_dtype(self):
41
+ if self.dtype == "float16":
42
+ return torch.float16
43
+ elif self.dtype == "float32":
44
+ return torch.float32
45
+ elif self.dtype == "bfloat16":
46
+ return torch.bfloat16
47
+ else:
48
+ raise ValueError(f"Unsupported dtype: {self.dtype}")
49
+
50
+ @classmethod
51
+ def from_yaml(cls, path: str):
52
+ with open(path, 'r') as f:
53
+ config_dict = yaml.safe_load(f)
54
+ return cls(**config_dict)
55
+
56
+ def to_yaml(self, path: str):
57
+ with open(path, 'w') as f:
58
+ yaml.dump(self.__dict__, f)
aetheris/data.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, IterableDataset
3
+ from transformers import AutoTokenizer
4
+ from datasets import load_dataset
5
+ import random
6
+ from typing import Dict, Iterator
7
+ import os
8
+
9
+ def get_tokenizer(model_name: str = "gpt2"):
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ if tokenizer.pad_token is None:
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+ return tokenizer
14
+
15
+ class StreamingDataset(IterableDataset):
16
+ def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=500, skip_samples=0):
17
+ self.dataset = dataset
18
+ self.tokenizer = tokenizer
19
+ self.max_seq_len = max_seq_len
20
+ self.mode = mode
21
+ self.buffer_size = buffer_size
22
+ self.skip_samples = skip_samples
23
+
24
+ def _prepare_sft_text(self, example):
25
+ if 'messages' in example:
26
+ text = ""
27
+ for msg in example['messages']:
28
+ role = msg.get('role', '')
29
+ content = msg.get('content', '')
30
+ if role == 'assistant':
31
+ text += f"Assistant: {content}{self.tokenizer.eos_token}"
32
+ else:
33
+ text += f"User: {content}\n"
34
+ return text
35
+ elif 'text' in example:
36
+ return example['text']
37
+ else:
38
+ return ""
39
+
40
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
41
+ iterator = iter(self.dataset)
42
+ buffer = []
43
+
44
+ # Calculate roughly how many items to skip if they were yielded
45
+ # We process skipping in the yield loop
46
+
47
+ for example in iterator:
48
+ text = (example.get('text', '') if self.mode == "pretrain"
49
+ else self._prepare_sft_text(example))
50
+
51
+ if len(text) < 10:
52
+ continue
53
+
54
+ enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
55
+ return_tensors="pt")
56
+ input_ids = enc['input_ids'][0]
57
+
58
+ if len(input_ids) < 2:
59
+ continue
60
+
61
+ if len(input_ids) < self.max_seq_len:
62
+ pad_len = self.max_seq_len - len(input_ids)
63
+ input_ids = torch.cat([
64
+ input_ids,
65
+ torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
66
+ ])
67
+
68
+ labels = input_ids.clone()
69
+ if len(input_ids) < self.max_seq_len:
70
+ labels[-pad_len:] = -100
71
+
72
+ buffer.append((input_ids, labels))
73
+
74
+ if len(buffer) >= self.buffer_size:
75
+ random.shuffle(buffer)
76
+ for _ in range(self.buffer_size // 2):
77
+ item = buffer.pop()
78
+ if self.skip_samples > 0:
79
+ self.skip_samples -= 1
80
+ continue
81
+ yield item
82
+
83
+ # Yield remaining
84
+ random.shuffle(buffer)
85
+ while buffer:
86
+ item = buffer.pop()
87
+ if self.skip_samples > 0:
88
+ self.skip_samples -= 1
89
+ continue
90
+ yield item
91
+
92
+ def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size, mode="pretrain", hf_token=None, start_step=0):
93
+ raw_dataset = load_dataset(dataset_name, split=split, streaming=True,
94
+ trust_remote_code=True, token=hf_token)
95
+
96
+ # Calculate samples to skip: start_step * batch_size
97
+ skip_samples = start_step * batch_size
98
+ if skip_samples > 0:
99
+ print(f" [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...")
100
+
101
+ stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len, mode=mode, skip_samples=skip_samples)
102
+
103
+ # Increase num_workers for better utilization
104
+ return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True,
105
+ num_workers=4, prefetch_factor=4)
aetheris/inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, List, Generator
4
+ from aetheris.config import AetherisConfig
5
+ from aetheris.model import HybridMambaMoE
6
+ from aetheris.data import get_tokenizer
7
+ from aetheris.utils import load_latest_checkpoint
8
+
9
+ class InferenceEngine:
10
+ def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None):
11
+ self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
12
+ self.config = AetherisConfig.from_yaml(config_path)
13
+ self.tokenizer = get_tokenizer()
14
+
15
+ self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype)
16
+
17
+ # Load checkpoint
18
+ # Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None
19
+ load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name)
20
+ self.model.eval()
21
+
22
+ def generate(self,
23
+ prompt: str,
24
+ max_new_tokens: int = 100,
25
+ temperature: float = 0.8,
26
+ top_k: int = 0,
27
+ top_p: float = 0.9,
28
+ repetition_penalty: float = 1.0,
29
+ stream: bool = False) -> Generator[str, None, None] | str:
30
+
31
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
32
+ generated_ids = input_ids.clone()
33
+ history_ids = set(input_ids[0].tolist())
34
+
35
+ def token_generator():
36
+ nonlocal generated_ids
37
+ for _ in range(max_new_tokens):
38
+ # Check if we should use autocast (skip if model uses float32)
39
+ use_autocast = True
40
+ if self.config.torch_dtype == torch.float32:
41
+ use_autocast = False
42
+
43
+ if use_autocast:
44
+ with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype):
45
+ outputs = self.model(generated_ids)
46
+ logits = outputs['logits']
47
+ next_token_logits = logits[:, -1, :]
48
+ else:
49
+ outputs = self.model(generated_ids)
50
+ logits = outputs['logits']
51
+ next_token_logits = logits[:, -1, :]
52
+
53
+ # Repetition penalty
54
+ for token_id in history_ids:
55
+ if token_id < next_token_logits.size(-1):
56
+ logit = next_token_logits[0, token_id].item()
57
+ if logit > 0:
58
+ next_token_logits[0, token_id] = logit / repetition_penalty
59
+ else:
60
+ next_token_logits[0, token_id] = logit * repetition_penalty
61
+
62
+ # Temperature
63
+ if temperature > 0:
64
+ next_token_logits = next_token_logits / temperature
65
+
66
+ # Top-p / Top-k
67
+ if top_p < 1.0:
68
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
69
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
70
+ sorted_indices_to_remove = cumulative_probs > top_p
71
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
72
+ sorted_indices_to_remove[..., 0] = False
73
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
74
+ next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
75
+ elif top_k > 0:
76
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
77
+ next_token_logits = torch.full_like(next_token_logits, float('-inf'))
78
+ next_token_logits.scatter_(1, top_k_indices, top_k_logits)
79
+
80
+ # Sample
81
+ next_token_probs = F.softmax(next_token_logits, dim=-1)
82
+ next_token = torch.multinomial(next_token_probs, num_samples=1)
83
+ next_token_item = next_token.item()
84
+
85
+ if next_token_item == self.tokenizer.eos_token_id:
86
+ break
87
+
88
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
89
+ history_ids.add(next_token_item)
90
+
91
+ new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
92
+ yield new_token_text
93
+
94
+ if stream:
95
+ return token_generator()
96
+ else:
97
+ return "".join(list(token_generator()))
98
+
99
+ def generate_full(self,
100
+ prompt: str,
101
+ max_new_tokens: int = 100,
102
+ temperature: float = 0.8,
103
+ top_k: int = 0,
104
+ top_p: float = 0.9,
105
+ repetition_penalty: float = 1.0) -> str:
106
+ return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)
aetheris/model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict, Any, List
4
+ from .config import AetherisConfig
5
+ from .modules import SSMBlock, SparseMoELayer
6
+
7
+ class HybridMambaMoE(nn.Module):
8
+ def __init__(self, config: AetherisConfig):
9
+ super().__init__()
10
+ self.config = config
11
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
12
+
13
+ self.layers = nn.ModuleList()
14
+ for i in range(config.n_layer):
15
+ if i % 2 == 0:
16
+ self.layers.append(SSMBlock(config))
17
+ else:
18
+ self.layers.append(SparseMoELayer(config))
19
+
20
+ self.final_norm = nn.LayerNorm(config.d_model)
21
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
22
+ self.lm_head.weight = self.embedding.weight # Weight tying
23
+
24
+ self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
25
+ self.gradient_checkpointing = config.gradient_checkpointing
26
+
27
+ # Initialize embeddings with smaller scale
28
+ nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
29
+
30
+ def _init_weights(self, module):
31
+ """Apply proper weight initialization"""
32
+ if isinstance(module, nn.Linear):
33
+ nn.init.xavier_uniform_(module.weight, gain=0.5)
34
+ if module.bias is not None:
35
+ nn.init.zeros_(module.bias)
36
+ elif isinstance(module, nn.Embedding):
37
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
38
+ elif isinstance(module, nn.LayerNorm):
39
+ nn.init.ones_(module.weight)
40
+ nn.init.zeros_(module.bias)
41
+
42
+ def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
43
+ x = self.embedding(input_ids)
44
+ total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
45
+
46
+ for i, layer in enumerate(self.layers):
47
+ if self.gradient_checkpointing and self.training:
48
+ # Checkpoint ALL layers for maximum memory savings
49
+ if isinstance(layer, SparseMoELayer):
50
+ def moe_forward(module, inp):
51
+ return module(inp)
52
+ x, aux_loss = torch.utils.checkpoint.checkpoint(
53
+ moe_forward, layer, x, use_reentrant=False
54
+ )
55
+ total_aux_loss = total_aux_loss + aux_loss
56
+ else:
57
+ x = torch.utils.checkpoint.checkpoint(
58
+ layer, x, use_reentrant=False
59
+ )
60
+ else:
61
+ if isinstance(layer, SparseMoELayer):
62
+ x, aux_loss = layer(x)
63
+ total_aux_loss = total_aux_loss + aux_loss
64
+ else:
65
+ x = layer(x)
66
+
67
+ x = self.final_norm(x)
68
+ logits = self.lm_head(x)
69
+
70
+ if labels is not None:
71
+ shift_logits = logits[..., :-1, :].contiguous()
72
+ shift_labels = labels[..., 1:].contiguous()
73
+ ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
74
+ shift_labels.view(-1))
75
+
76
+ # Scale down aux loss to prevent it from dominating
77
+ total_loss = ce_loss + 0.01 * total_aux_loss
78
+
79
+ return {
80
+ "loss": total_loss,
81
+ "ce_loss": ce_loss,
82
+ "aux_loss": total_aux_loss,
83
+ "logits": logits
84
+ }
85
+
86
+ return {"logits": logits}
aetheris/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .expert import Expert
2
+ from .ssm import SSMBlock, selective_scan_native
3
+ from .moe import SparseMoELayer
aetheris/modules/expert.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class Expert(nn.Module):
6
+ """Memory-efficient Feed-Forward Network expert with proper initialization."""
7
+ def __init__(self, d_model: int, d_ff: int):
8
+ super().__init__()
9
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
10
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
11
+ self.act = nn.GELU()
12
+
13
+ # Proper initialization to prevent NaN
14
+ nn.init.xavier_uniform_(self.w1.weight, gain=0.5)
15
+ nn.init.xavier_uniform_(self.w2.weight, gain=0.5)
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ orig_dtype = x.dtype
19
+ # Force float32 for internal computation to prevent overflow in half precision
20
+ x = x.to(torch.float32)
21
+
22
+ # Cast weights to float32 for calculation
23
+ # This is necessary because the module weights might be float16
24
+ w1_weight = self.w1.weight.to(torch.float32)
25
+ w2_weight = self.w2.weight.to(torch.float32)
26
+
27
+ h = F.linear(x, w1_weight)
28
+ h = self.act(h)
29
+ out = F.linear(h, w2_weight)
30
+
31
+ # Clamp to avoid Inf when casting back to float16
32
+ if orig_dtype == torch.float16:
33
+ out = torch.clamp(out, min=-65500.0, max=65500.0)
34
+
35
+ return out.to(orig_dtype)
aetheris/modules/moe.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..config import AetherisConfig
5
+ from .expert import Expert
6
+
7
+ class SparseMoELayer(nn.Module):
8
+ """Memory-optimized Sparse MoE with efficient routing."""
9
+ def __init__(self, config: AetherisConfig):
10
+ super().__init__()
11
+ self.d_model = config.d_model
12
+ self.num_experts = config.num_experts
13
+ self.top_k = config.top_k
14
+ self.load_balancing_coef = config.load_balancing_coef
15
+ self.z_loss_coef = config.router_z_loss_coef
16
+
17
+ self.gate = nn.Linear(config.d_model, config.num_experts, bias=False)
18
+ self.experts = nn.ModuleList([Expert(config.d_model, config.d_ff)
19
+ for _ in range(config.num_experts)])
20
+ self.norm = nn.LayerNorm(config.d_model)
21
+
22
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
23
+ B, L, D = x.shape
24
+ x_norm = self.norm(x)
25
+ flat_x = x_norm.view(-1, D)
26
+
27
+ # Routing Logits with stability
28
+ gate_logits = self.gate(flat_x)
29
+
30
+ # Clamp logits to prevent overflow
31
+ gate_logits = torch.clamp(gate_logits, min=-10.0, max=10.0)
32
+
33
+ # Z-Loss for stability
34
+ z_loss = torch.mean(torch.logsumexp(gate_logits, dim=-1)**2) * self.z_loss_coef
35
+
36
+ if self.training:
37
+ # Reduce noise for stability
38
+ gate_logits = gate_logits + torch.randn_like(gate_logits) * 1e-3
39
+
40
+ gate_probs = F.softmax(gate_logits, dim=-1)
41
+ gate_weights, expert_indices = torch.topk(gate_probs, self.top_k, dim=-1)
42
+
43
+ # Normalize weights for stability
44
+ gate_weights = gate_weights / (gate_weights.sum(dim=-1, keepdim=True) + 1e-8)
45
+
46
+ # Load balancing loss
47
+ # Use only the top-1 expert for load balancing calculation to keep it simple and consistent
48
+ expert_mask = F.one_hot(expert_indices[:, 0], num_classes=self.num_experts).float()
49
+ fraction_routed = expert_mask.mean(dim=0)
50
+ mean_prob = gate_probs.mean(dim=0)
51
+
52
+ aux_loss = (self.num_experts * torch.sum(fraction_routed * mean_prob)) * self.load_balancing_coef
53
+ total_aux_loss = aux_loss + z_loss
54
+
55
+ # Efficient dispatch with in-place operations
56
+ # Accumulate in float32 to prevent overflow during aggregation
57
+ final_output = torch.zeros_like(flat_x, dtype=torch.float32)
58
+
59
+ # Iterate over all k selected experts
60
+ for k_idx in range(self.top_k):
61
+ for i, expert in enumerate(self.experts):
62
+ # Find tokens routed to expert 'i' at the k-th position
63
+ mask = (expert_indices[:, k_idx] == i)
64
+ if not mask.any():
65
+ continue
66
+
67
+ expert_input = flat_x[mask]
68
+ expert_out = expert(expert_input)
69
+
70
+ # Apply weights
71
+ weights = gate_weights[mask, k_idx].unsqueeze(1)
72
+
73
+ # Cast to float32 for accumulation
74
+ expert_out = expert_out.to(torch.float32)
75
+ weights = weights.to(torch.float32)
76
+
77
+ # Accumulate output (add to existing results from other experts)
78
+ final_output[mask] += expert_out * weights
79
+
80
+ # Cast back to original dtype
81
+ final_output = final_output.to(flat_x.dtype)
82
+
83
+ return x + final_output.view(B, L, D), total_aux_loss
aetheris/modules/ssm.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..config import AetherisConfig
5
+
6
+ def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
7
+ B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
8
+ """Memory-efficient scan with reduced intermediate tensors."""
9
+ B_size, L, D_inner = u.shape
10
+ D_state = A.shape[-1]
11
+
12
+ # Use in-place operations where possible
13
+ h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=u.dtype)
14
+ ys = []
15
+
16
+ for l in range(L):
17
+ dt = delta[:, l, :].unsqueeze(-1)
18
+ dA = torch.exp(dt * A)
19
+
20
+ B_l = B[:, l, :].unsqueeze(1)
21
+ dB = dt * B_l
22
+
23
+ u_t = u[:, l, :].unsqueeze(-1)
24
+ h = dA * h + dB * u_t
25
+
26
+ C_l = C[:, l, :].unsqueeze(1)
27
+ y_t = torch.sum(h * C_l, dim=-1)
28
+ ys.append(y_t)
29
+
30
+ y = torch.stack(ys, dim=1)
31
+ return y + u * D
32
+
33
+ class SSMBlock(nn.Module):
34
+ """Memory-optimized State Space Model with stability improvements."""
35
+ def __init__(self, config: AetherisConfig):
36
+ super().__init__()
37
+ self.d_model = config.d_model
38
+ self.d_state = config.ssm_d_state
39
+ self.d_inner = config.d_inner
40
+
41
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False)
42
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
43
+ self.conv_d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3,
44
+ padding=2, groups=self.d_inner, bias=False)
45
+ self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=False)
46
+
47
+ self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
48
+ self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
49
+ self.delta_proj = nn.Linear(self.d_inner, self.d_inner, bias=False)
50
+
51
+ # Initialize A to be more stable (closer to -1)
52
+ self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state) * 0.1 - 4.0)
53
+ self.D = nn.Parameter(torch.ones(self.d_inner) * 0.1)
54
+
55
+ self.act = nn.SiLU()
56
+ self.norm = nn.LayerNorm(config.d_model)
57
+
58
+ # Proper initialization
59
+ nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5)
60
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5)
61
+ nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.5)
62
+ nn.init.xavier_uniform_(self.B_proj.weight, gain=0.5)
63
+ nn.init.xavier_uniform_(self.C_proj.weight, gain=0.5)
64
+ nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ B, L, D = x.shape
68
+ x_norm = self.norm(x)
69
+
70
+ xz = self.in_proj(x_norm)
71
+ x_in, z_gate = xz.chunk(2, dim=-1)
72
+ x_conv = self.conv_d(x_in.transpose(1, 2))
73
+ # Slice off the last 2 elements (the "future" leakage)
74
+ x_conv = x_conv[:, :, :-2].transpose(1, 2)
75
+ x_conv = self.act(x_conv)
76
+
77
+ # Add small epsilon to prevent numerical issues and clamp max value
78
+ delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4
79
+ B_ssm = self.B_proj(x_conv)
80
+ C_ssm = self.C_proj(x_conv)
81
+
82
+ # Clamp A to prevent extreme values
83
+ A_fixed = -torch.exp(torch.clamp(self.A_log, min=-10.0, max=2.0))
84
+ A_batched = A_fixed.unsqueeze(0).expand(B, -1, -1)
85
+
86
+ y_ssm = selective_scan_native(x_conv, delta, A_batched, B_ssm, C_ssm, self.D)
87
+
88
+ y_gate = F.silu(self.gate_proj(x_norm)) * y_ssm
89
+ output = self.out_proj(y_gate)
90
+
91
+ return x + output
aetheris/trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .trainer import Trainer
aetheris/trainer/trainer.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import os
4
+ from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats
5
+
6
+ class Trainer:
7
+ def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None):
8
+ self.model = model
9
+ self.optimizer = optimizer
10
+ self.scaler = scaler
11
+ self.config = config
12
+ self.device = device
13
+ self.checkpoint_dir = checkpoint_dir
14
+ self.logger = logger
15
+
16
+ self.model.to(self.device)
17
+
18
+ def validate(self, val_loader, global_step):
19
+ self.model.eval()
20
+ total_loss = 0
21
+ total_items = 0
22
+ num_batches = 100 # Validate on 100 batches to save time
23
+
24
+ print(f"\n[Validation] Starting validation at step {global_step}...")
25
+
26
+ with torch.no_grad():
27
+ for i, batch in enumerate(val_loader):
28
+ if i >= num_batches:
29
+ break
30
+
31
+ input_ids, labels = batch
32
+ input_ids = input_ids.to(self.device, non_blocking=True)
33
+ labels = labels.to(self.device, non_blocking=True)
34
+
35
+ # Auto-cast context
36
+ if self.device.type == 'cuda':
37
+ autocast_dtype = torch.float16
38
+ else:
39
+ autocast_dtype = torch.bfloat16
40
+
41
+ use_autocast = True if self.config.torch_dtype != torch.float32 else False
42
+
43
+ if use_autocast:
44
+ with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
45
+ output = self.model(input_ids, labels)
46
+ else:
47
+ output = self.model(input_ids, labels)
48
+
49
+ total_loss += output["loss"].item()
50
+ total_items += 1
51
+
52
+ avg_loss = total_loss / total_items if total_items > 0 else 0
53
+ perplexity = torch.exp(torch.tensor(avg_loss)).item()
54
+
55
+ print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
56
+ self.model.train()
57
+ return avg_loss
58
+
59
+ def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
60
+ print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps}\n{'='*70}")
61
+ self.model.train()
62
+ global_step = start_step
63
+ running_loss = 0
64
+
65
+ print("Initializing data iterator...")
66
+ train_iter = iter(train_loader)
67
+
68
+ print("Fetching first batch...")
69
+
70
+ while global_step < total_steps:
71
+ step_start = time.time()
72
+
73
+ # Removed periodic cache clearing for performance
74
+
75
+ self.optimizer.zero_grad(set_to_none=True)
76
+
77
+ try:
78
+ batch = next(train_iter)
79
+ if global_step == start_step:
80
+ print(f"✓ First batch loaded! Starting training loop...")
81
+ except StopIteration:
82
+ train_iter = iter(train_loader)
83
+ batch = next(train_iter)
84
+
85
+ input_ids, labels = batch
86
+ input_ids = input_ids.to(self.device, non_blocking=True)
87
+ labels = labels.to(self.device, non_blocking=True)
88
+
89
+ # Determine autocast dtype
90
+ if self.device.type == 'cuda':
91
+ autocast_dtype = torch.float16
92
+ else:
93
+ autocast_dtype = torch.bfloat16
94
+
95
+ # Check if we should use autocast (skip if model uses float32)
96
+ use_autocast = True
97
+ if self.config.torch_dtype == torch.float32:
98
+ use_autocast = False
99
+
100
+ if use_autocast:
101
+ with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
102
+ output = self.model(input_ids, labels)
103
+ loss = output["loss"]
104
+ else:
105
+ output = self.model(input_ids, labels)
106
+ loss = output["loss"]
107
+
108
+ self.scaler.scale(loss).backward()
109
+ self.scaler.unscale_(self.optimizer)
110
+
111
+ # Gradient clipping
112
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
113
+
114
+ if torch.isnan(grad_norm) or torch.isinf(grad_norm):
115
+ print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update")
116
+ else:
117
+ self.scaler.step(self.optimizer)
118
+
119
+ self.scaler.update()
120
+
121
+ global_step += 1
122
+ running_loss += loss.item()
123
+
124
+ if global_step % 10 == 0:
125
+ avg_loss = running_loss / 10
126
+ t_diff = time.time() - step_start
127
+ if self.device.type == 'cuda':
128
+ mem = torch.cuda.memory_allocated() / 1e9
129
+ max_mem = torch.cuda.max_memory_allocated() / 1e9
130
+ mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
131
+ else:
132
+ mem_str = "CPU Mode"
133
+
134
+ tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
135
+ print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
136
+ f"{mem_str} | {tokens_per_sec:.0f} tok/s")
137
+ running_loss = 0
138
+
139
+ if global_step % 500 == 0:
140
+ save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
141
+
142
+ if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
143
+ self.validate(val_loader, global_step)
144
+
145
+ return global_step
aetheris/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import Tuple
4
+
5
+ def save_checkpoint(model, optimizer, scaler, step, stage, checkpoint_dir, checkpoint_name="checkpoint_current.pth"):
6
+ os.makedirs(checkpoint_dir, exist_ok=True)
7
+ path = os.path.join(checkpoint_dir, checkpoint_name)
8
+ torch.save({
9
+ 'step': step,
10
+ 'stage': stage,
11
+ 'model_state_dict': model.state_dict(),
12
+ 'optimizer_state_dict': optimizer.state_dict(),
13
+ 'scaler_state_dict': scaler.state_dict()
14
+ }, path)
15
+ print(f" [Checkpoint] Saved at step {step}")
16
+
17
+ def load_latest_checkpoint(model, optimizer, scaler, device, checkpoint_dir, checkpoint_name="checkpoint_current.pth") -> Tuple[int, str]:
18
+ path = os.path.join(checkpoint_dir, checkpoint_name)
19
+ if not os.path.exists(path):
20
+ return 0, "Pre-Training"
21
+
22
+ print(f" [Checkpoint] Loading from {path}...")
23
+ ckpt = torch.load(path, map_location=device)
24
+ model.load_state_dict(ckpt['model_state_dict'])
25
+ if optimizer:
26
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
27
+ if scaler:
28
+ scaler.load_state_dict(ckpt['scaler_state_dict'])
29
+ return ckpt['step'], ckpt['stage']
30
+
31
+ def calculate_model_stats(model):
32
+ total_params = sum(p.numel() for p in model.parameters())
33
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
34
+ return {
35
+ 'total_params': total_params,
36
+ 'trainable_params': trainable_params,
37
+ 'active_params': int(total_params * 0.6), # Approximation
38
+ 'sparsity_ratio': 0.6 # Approximation
39
+ }
configs/debug.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 50257
2
+ d_model: 128
3
+ n_layer: 4
4
+ num_experts: 4
5
+ top_k: 1
6
+ d_ff: 384
7
+ ssm_d_state: 8
8
+ ssm_expand: 2
9
+ load_balancing_coef: 0.01
10
+ router_z_loss_coef: 0.001
11
+ max_seq_len: 128
12
+ dtype: "float32" # Use float32 for debugging on CPU
13
+ use_cpu_offload: false
14
+ gradient_checkpointing: false
15
+ checkpoint_ssm_layers: false
16
+ use_flash_attention: false
configs/default.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 50257
2
+ d_model: 768
3
+ n_layer: 24
4
+ num_experts: 4
5
+ top_k: 1
6
+ d_ff: 2304
7
+ ssm_d_state: 16
8
+ ssm_expand: 2
9
+ load_balancing_coef: 0.01
10
+ router_z_loss_coef: 0.001
11
+ max_seq_len: 512
12
+ dtype: "float16"
13
+ use_cpu_offload: false
14
+ gradient_checkpointing: true
15
+ checkpoint_ssm_layers: true
16
+ use_flash_attention: false
configs/inference.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 50257
2
+ d_model: 768
3
+ n_layer: 24
4
+ num_experts: 4
5
+ top_k: 1
6
+ d_ff: 2304
7
+ ssm_d_state: 16
8
+ ssm_expand: 2
9
+ load_balancing_coef: 0.0
10
+ router_z_loss_coef: 0.0
11
+ max_seq_len: 1024
12
+ dtype: "float16"
13
+ use_cpu_offload: true # Offload to CPU during inference to save VRAM
14
+ gradient_checkpointing: false
15
+ checkpoint_ssm_layers: false
16
+ use_flash_attention: true
configs/large.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 50257
2
+ d_model: 1600
3
+ n_layer: 48
4
+ num_experts: 8
5
+ top_k: 2
6
+ d_ff: 4800
7
+ ssm_d_state: 64
8
+ ssm_expand: 2
9
+ load_balancing_coef: 0.01
10
+ router_z_loss_coef: 0.001
11
+ max_seq_len: 2048
12
+ dtype: "float16"
13
+ use_cpu_offload: false
14
+ gradient_checkpointing: true
15
+ checkpoint_ssm_layers: true
16
+ use_flash_attention: true
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers
3
+ datasets
4
+ huggingface_hub
5
+ pyyaml
6
+ zstandard
7
+ fastapi
8
+ uvicorn
9
+ pydantic
10
+ sse-starlette
11
+ pytest
12
+ httpx
scripts/generate.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from aetheris.cli.main import main
4
+
5
+ if __name__ == "__main__":
6
+ # Simulate arguments if needed, but since we are replacing the script, we can just rely on argparse to parse sys.argv
7
+ # The original script parsed arguments like --prompt, etc.
8
+ # The new CLI expects a subcommand, e.g., 'generate'
9
+
10
+ # Check if 'generate' is already in argv, if not prepend it
11
+ if len(sys.argv) > 1 and sys.argv[1] != 'generate':
12
+ sys.argv.insert(1, 'generate')
13
+ elif len(sys.argv) == 1:
14
+ sys.argv.append('generate')
15
+
16
+ sys.exit(main())
scripts/info.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from aetheris.cli.main import main
4
+
5
+ if __name__ == "__main__":
6
+ if len(sys.argv) > 1 and sys.argv[1] != 'info':
7
+ sys.argv.insert(1, 'info')
8
+ elif len(sys.argv) == 1:
9
+ sys.argv.append('info')
10
+
11
+ sys.exit(main())
scripts/train.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from aetheris.cli.main import main
4
+
5
+ if __name__ == "__main__":
6
+ sys.exit(main())
scripts/validate.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import math
5
+ import time
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add project root to path
10
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
11
+
12
+ from aetheris.config import AetherisConfig
13
+ from aetheris.model import HybridMambaMoE
14
+ from aetheris.data import create_streaming_loader, get_tokenizer
15
+ from aetheris.utils import load_latest_checkpoint
16
+
17
+ @torch.no_grad()
18
+ def evaluate_model(model, val_loader, device, max_batches=100):
19
+ print(f"\n{'='*50}\nStarting Validation (Max {max_batches} batches)\n{'='*50}")
20
+
21
+ model.eval()
22
+ total_loss = 0.0
23
+ num_batches = 0
24
+ start_time = time.time()
25
+
26
+ for batch in val_loader:
27
+ if num_batches >= max_batches:
28
+ break
29
+
30
+ input_ids, labels = batch
31
+ input_ids = input_ids.to(device, non_blocking=True)
32
+ labels = labels.to(device, non_blocking=True)
33
+
34
+ with torch.amp.autocast('cuda', dtype=torch.float16):
35
+ output = model(input_ids, labels)
36
+ loss = output["loss"]
37
+
38
+ total_loss += loss.item()
39
+ num_batches += 1
40
+
41
+ if num_batches % 20 == 0:
42
+ print(f"-> Processed {num_batches}/{max_batches} batches...")
43
+
44
+ end_time = time.time()
45
+
46
+ if num_batches == 0:
47
+ print("No validation batches were processed.")
48
+ return float('inf')
49
+
50
+ avg_loss = total_loss / num_batches
51
+ perplexity = math.exp(avg_loss)
52
+
53
+ print(f"\n--- Validation Results ---")
54
+ print(f"Total batches processed: {num_batches}")
55
+ print(f"Time taken: {end_time - start_time:.2f} seconds")
56
+ print(f"Average Loss: {avg_loss:.4f}")
57
+ print(f"Perplexity: {perplexity:.2f}")
58
+ print(f"--------------------------\n")
59
+
60
+ return avg_loss
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(description="Validate Aetheris Model")
64
+ parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
65
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
66
+ parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
67
+ parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
68
+ parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
69
+ parser.add_argument("--dataset", type=str, default="cerebras/SlimPajama-627B", help="Dataset to validate on")
70
+ parser.add_argument("--dataset_mode", type=str, default="pretrain", help="pretrain or sft")
71
+
72
+ args = parser.parse_args()
73
+
74
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
75
+ config = AetherisConfig.from_yaml(args.config)
76
+ tokenizer = get_tokenizer()
77
+
78
+ model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
79
+
80
+ load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
81
+
82
+ val_loader = create_streaming_loader(args.dataset, "validation", tokenizer, config, args.batch_size, mode=args.dataset_mode, hf_token=args.hf_token)
83
+
84
+ evaluate_model(model, val_loader, device)
85
+
86
+ if __name__ == "__main__":
87
+ main()
tests/test_api.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from unittest.mock import MagicMock, patch
4
+ from aetheris.api.server import app, get_engine
5
+ import aetheris.api.server
6
+
7
+ # Mock the engine globally
8
+ @pytest.fixture
9
+ def mock_engine():
10
+ with patch("aetheris.api.server.engine") as mock_eng:
11
+ # Mock generate_full
12
+ mock_eng.generate_full.return_value = "This is a generated response."
13
+
14
+ # Mock generate (streaming)
15
+ def mock_stream(*args, **kwargs):
16
+ yield "This "
17
+ yield "is "
18
+ yield "streamed."
19
+ mock_eng.generate.side_effect = mock_stream
20
+
21
+ # Need to ensure get_engine returns this mock
22
+ # We can also just set aetheris.api.server.engine
23
+ aetheris.api.server.engine = mock_eng
24
+ yield mock_eng
25
+
26
+ client = TestClient(app)
27
+
28
+ def test_list_models(mock_engine):
29
+ response = client.get("/v1/models")
30
+ assert response.status_code == 200
31
+ data = response.json()
32
+ assert data["object"] == "list"
33
+ assert len(data["data"]) > 0
34
+ assert data["data"][0]["id"] == "aetheris-hybrid-mamba-moe"
35
+
36
+ def test_chat_completions_non_stream(mock_engine):
37
+ payload = {
38
+ "model": "aetheris-hybrid-mamba-moe",
39
+ "messages": [{"role": "user", "content": "Hello"}],
40
+ "stream": False
41
+ }
42
+ response = client.post("/v1/chat/completions", json=payload)
43
+ assert response.status_code == 200
44
+ data = response.json()
45
+ assert data["object"] == "chat.completion"
46
+ assert len(data["choices"]) == 1
47
+ assert data["choices"][0]["message"]["content"] == "This is a generated response."
48
+
49
+ def test_chat_completions_stream(mock_engine):
50
+ payload = {
51
+ "model": "aetheris-hybrid-mamba-moe",
52
+ "messages": [{"role": "user", "content": "Hello"}],
53
+ "stream": True
54
+ }
55
+ response = client.post("/v1/chat/completions", json=payload)
56
+ assert response.status_code == 200
57
+ # SSE format checking
58
+ assert "text/event-stream" in response.headers["content-type"]
59
+
60
+ # We can iterate over the response lines to check content
61
+ content = ""
62
+ for line in response.iter_lines():
63
+ if line:
64
+ # TestClient iter_lines yields strings, not bytes, unless configured otherwise
65
+ # or depending on the version. If it's bytes, we decode. If it's str, we don't.
66
+ if isinstance(line, bytes):
67
+ line = line.decode("utf-8")
68
+
69
+ if line.startswith("data: ") and line != "data: [DONE]":
70
+ import json
71
+ chunk = json.loads(line[6:])
72
+ if chunk["choices"][0]["delta"].get("content"):
73
+ content += chunk["choices"][0]["delta"]["content"]
74
+
75
+ assert content == "This is streamed."
76
+
77
+ def test_completions(mock_engine):
78
+ payload = {
79
+ "model": "aetheris-hybrid-mamba-moe",
80
+ "prompt": "Once upon a time",
81
+ "max_tokens": 10
82
+ }
83
+ response = client.post("/v1/completions", json=payload)
84
+ assert response.status_code == 200
85
+ data = response.json()
86
+ assert data["object"] == "text_completion"
87
+ assert len(data["choices"]) == 1
88
+ assert data["choices"][0]["text"] == "This is a generated response."
tests/test_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import MagicMock, patch
3
+ from aetheris.inference import InferenceEngine
4
+
5
+ @pytest.fixture
6
+ def mock_model():
7
+ with patch("aetheris.inference.HybridMambaMoE") as MockModel:
8
+ mock_instance = MockModel.return_value
9
+ # Mock model output
10
+ mock_instance.to.return_value = mock_instance
11
+ mock_instance.eval.return_value = None
12
+
13
+ # Mock forward pass
14
+ mock_output = MagicMock()
15
+ # Shape: (batch_size, seq_len, vocab_size)
16
+ mock_output.__getitem__.return_value = torch.randn(1, 1, 50257)
17
+ # Actually we need 'logits' key access
18
+ mock_instance.return_value = {'logits': torch.randn(1, 10, 50257)}
19
+
20
+ yield mock_instance
21
+
22
+ @pytest.fixture
23
+ def mock_tokenizer():
24
+ with patch("aetheris.inference.get_tokenizer") as mock_get_tokenizer:
25
+ mock_tok = MagicMock()
26
+ mock_tok.encode.return_value = torch.tensor([[1, 2, 3]])
27
+ mock_tok.decode.return_value = "token"
28
+ mock_tok.eos_token_id = 50256
29
+ mock_get_tokenizer.return_value = mock_tok
30
+ yield mock_tok
31
+
32
+ @pytest.fixture
33
+ def mock_utils():
34
+ with patch("aetheris.inference.load_latest_checkpoint") as mock_load:
35
+ yield mock_load
36
+
37
+ import torch
38
+
39
+ def test_inference_initialization(mock_model, mock_tokenizer, mock_utils):
40
+ engine = InferenceEngine(config_path="configs/default.yaml")
41
+ assert engine.model is not None
42
+ assert engine.tokenizer is not None
43
+ mock_utils.assert_called_once()
44
+
45
+ def test_generate_full(mock_model, mock_tokenizer, mock_utils):
46
+ engine = InferenceEngine()
47
+
48
+ # Mock model output for generation loop
49
+ # We need to ensure the model returns logits of correct shape
50
+ # The loop calls model(generated_ids)
51
+
52
+ # Let's mock the actual model call inside generate
53
+ engine.model.config.torch_dtype = torch.float32
54
+
55
+ # We need to return a dict with logits
56
+ # Shape: (batch, seq_len, vocab_size)
57
+ engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)}
58
+
59
+ output = engine.generate_full("test prompt", max_new_tokens=5)
60
+ assert isinstance(output, str)
61
+ assert len(output) > 0
62
+
63
+ def test_generate_stream(mock_model, mock_tokenizer, mock_utils):
64
+ engine = InferenceEngine()
65
+ engine.model.config.torch_dtype = torch.float32
66
+ engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)}
67
+
68
+ generator = engine.generate("test prompt", max_new_tokens=5, stream=True)
69
+ tokens = list(generator)
70
+ assert len(tokens) == 5
71
+ assert all(isinstance(t, str) for t in tokens)
tests/test_model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # Add project root to path
7
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
8
+
9
+ from aetheris.config import AetherisConfig
10
+ from aetheris.model import HybridMambaMoE
11
+
12
+ class TestHybridMambaMoE(unittest.TestCase):
13
+ def setUp(self):
14
+ self.config = AetherisConfig(
15
+ vocab_size=100,
16
+ d_model=32,
17
+ n_layer=4,
18
+ num_experts=2,
19
+ top_k=1,
20
+ d_ff=64,
21
+ ssm_d_state=8,
22
+ ssm_expand=2,
23
+ max_seq_len=64
24
+ )
25
+ self.model = HybridMambaMoE(self.config)
26
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ self.model.to(self.device)
28
+
29
+ def test_forward_pass(self):
30
+ batch_size = 2
31
+ seq_len = 16
32
+ input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
33
+
34
+ output = self.model(input_ids)
35
+
36
+ self.assertIn('logits', output)
37
+ self.assertEqual(output['logits'].shape, (batch_size, seq_len, self.config.vocab_size))
38
+
39
+ def test_forward_pass_with_labels(self):
40
+ batch_size = 2
41
+ seq_len = 16
42
+ input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
43
+ labels = input_ids.clone()
44
+
45
+ output = self.model(input_ids, labels=labels)
46
+
47
+ self.assertIn('loss', output)
48
+ self.assertIn('ce_loss', output)
49
+ self.assertIn('aux_loss', output)
50
+ self.assertIn('logits', output)
51
+
52
+ self.assertTrue(output['loss'] > 0)
53
+
54
+ if __name__ == '__main__':
55
+ unittest.main()
tests/test_overflow.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # Add project root to path
7
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
8
+
9
+ from aetheris.modules.expert import Expert
10
+ from aetheris.modules.moe import SparseMoELayer
11
+ from aetheris.config import AetherisConfig
12
+
13
+ class TestOverflow(unittest.TestCase):
14
+ def setUp(self):
15
+ self.config = AetherisConfig(
16
+ vocab_size=100,
17
+ d_model=128,
18
+ n_layer=2,
19
+ num_experts=2,
20
+ top_k=1,
21
+ d_ff=512, # Large enough to potentially cause issues
22
+ ssm_d_state=16,
23
+ ssm_expand=2,
24
+ max_seq_len=64
25
+ )
26
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ def test_expert_overflow_protection(self):
29
+ """Test if Expert handles large inputs without producing NaNs in float16"""
30
+ expert = Expert(self.config.d_model, self.config.d_ff).to(self.device)
31
+ # Manually cast weights to float16 to simulate mixed precision training environment
32
+ expert.half()
33
+
34
+ # Create a large input in float16 that would normally cause overflow in intermediate layers
35
+ # The limit of float16 is ~65504.
36
+ # If w1 projects this up, it can easily exceed that.
37
+ large_input = torch.ones(1, self.config.d_model, dtype=torch.float16).to(self.device) * 100.0
38
+
39
+ # Force weights to be large to guarantee overflow if protection isn't working
40
+ with torch.no_grad():
41
+ expert.w1.weight.fill_(10.0)
42
+ expert.w2.weight.fill_(0.1)
43
+
44
+ # 100 * 10 = 1000. Sum over d_model(128) -> 128000.
45
+ # This summation happens in the matrix multiplication.
46
+ # If the matmul internal accumulation is float16, it effectively overflows.
47
+
48
+ output = expert(large_input)
49
+
50
+ self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")
51
+ self.assertFalse(torch.isinf(output).any(), "Output contains Infs")
52
+
53
+ def test_moe_accumulation_stability(self):
54
+ """Test if MoE layer handles accumulation in float32"""
55
+ moe = SparseMoELayer(self.config).to(self.device)
56
+ moe.half()
57
+
58
+ x = torch.randn(2, 10, self.config.d_model, dtype=torch.float16).to(self.device)
59
+
60
+ # Pass through
61
+ output, loss = moe(x)
62
+
63
+ self.assertFalse(torch.isnan(output).any(), "MoE Output contains NaNs")
64
+ self.assertEqual(output.dtype, torch.float16)
65
+
66
+ if __name__ == '__main__':
67
+ unittest.main()