Upload Aetheris model (Stage 2 best, 722M params, loss=2.73)
Browse files- README.md +47 -62
- aetheris/__init__.py +2 -0
- aetheris/api/schemas.py +92 -0
- aetheris/api/server.py +196 -0
- aetheris/cli/__init__.py +1 -0
- aetheris/cli/main.py +362 -0
- aetheris/config.py +58 -0
- aetheris/data.py +231 -0
- aetheris/inference.py +106 -0
- aetheris/model.py +104 -0
- aetheris/modules/__init__.py +3 -0
- aetheris/modules/expert.py +35 -0
- aetheris/modules/moe.py +83 -0
- aetheris/modules/ssm.py +119 -0
- aetheris/trainer/__init__.py +1 -0
- aetheris/trainer/trainer.py +176 -0
- aetheris/utils.py +55 -0
- config.yaml +17 -0
- pytorch_model.pt +3 -0
README.md
CHANGED
|
@@ -1,80 +1,65 @@
|
|
| 1 |
---
|
| 2 |
-
language:
|
| 3 |
-
- multilingual
|
| 4 |
-
- en
|
| 5 |
-
- es
|
| 6 |
-
- hi
|
| 7 |
-
- zh
|
| 8 |
-
- ar
|
| 9 |
-
- sw
|
| 10 |
-
- tr
|
| 11 |
-
- ja
|
| 12 |
-
- id
|
| 13 |
-
- te
|
| 14 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
tags:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
- aya
|
| 22 |
-
library_name: aetheris
|
| 23 |
pipeline_tag: text-generation
|
| 24 |
---
|
| 25 |
|
| 26 |
# Aetheris — Hybrid Mamba-MoE Multilingual Model
|
| 27 |
|
| 28 |
-
**Aetheris** is a ~
|
| 29 |
[CohereLabs/tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global) (3.35B).
|
| 30 |
-
|
| 31 |
-
Built by [Wayy Research](https://github.com/Wayy-Research).
|
| 32 |
|
| 33 |
## Architecture
|
| 34 |
-
|
| 35 |
-
- **
|
| 36 |
-
- **Layers**: 24 (interleaved: even=SSM, odd=MoE)
|
| 37 |
- **Hidden dim**: 1024
|
| 38 |
-
- **Experts**: 4
|
| 39 |
-
- **
|
| 40 |
-
- **
|
| 41 |
-
- **Parameters**: ~800M
|
| 42 |
|
| 43 |
## Training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
| Stage | Method | Data | Steps |
|
| 48 |
-
|-------|--------|------|-------|
|
| 49 |
-
| 1 | CKA-guided Layer Alignment | ClimbMix | 10,000 |
|
| 50 |
-
| 2 | KL Distillation (T=2.0, alpha=0.7) | ClimbMix | 20,000 |
|
| 51 |
-
| 3 | Supervised Fine-Tuning | aya_collection | 5,000 |
|
| 52 |
-
|
| 53 |
-
Key research findings applied:
|
| 54 |
-
- SSM 10x LR boost (compensates 27x gradient imbalance)
|
| 55 |
-
- SVD split for MoE expert initialization (CKA=0.097 diversity)
|
| 56 |
-
- Per-language KL tracking for multilingual equity
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
## Languages
|
| 66 |
-
|
| 67 |
-
Supports 70+ languages inherited from tiny-aya-global. Core evaluation
|
| 68 |
-
languages: English, Spanish, Hindi, Chinese, Arabic, Swahili, Turkish,
|
| 69 |
-
Japanese, Indonesian, Telugu.
|
| 70 |
-
|
| 71 |
-
## Citation
|
| 72 |
-
|
| 73 |
-
```bibtex
|
| 74 |
-
@misc{aetheris2026,
|
| 75 |
-
title={Aetheris: Hybrid Mamba-MoE Multilingual Model via Knowledge Distillation},
|
| 76 |
-
author={Wayy Research},
|
| 77 |
-
year={2026},
|
| 78 |
-
url={https://huggingface.co/wayyresearch/aetheris}
|
| 79 |
-
}
|
| 80 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- es
|
| 6 |
+
- fr
|
| 7 |
+
- de
|
| 8 |
+
- zh
|
| 9 |
+
- ja
|
| 10 |
+
- ko
|
| 11 |
+
- ar
|
| 12 |
+
- hi
|
| 13 |
+
- tr
|
| 14 |
+
- sw
|
| 15 |
+
- id
|
| 16 |
+
- pt
|
| 17 |
+
- ru
|
| 18 |
tags:
|
| 19 |
+
- multilingual
|
| 20 |
+
- mamba
|
| 21 |
+
- moe
|
| 22 |
+
- distillation
|
| 23 |
+
- aya
|
|
|
|
|
|
|
| 24 |
pipeline_tag: text-generation
|
| 25 |
---
|
| 26 |
|
| 27 |
# Aetheris — Hybrid Mamba-MoE Multilingual Model
|
| 28 |
|
| 29 |
+
**Aetheris** is a ~720M parameter hybrid SSM-MoE language model distilled from
|
| 30 |
[CohereLabs/tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global) (3.35B).
|
| 31 |
+
It supports **67 languages** with 4.6x compression.
|
|
|
|
| 32 |
|
| 33 |
## Architecture
|
| 34 |
+
- **Type**: Hybrid Mamba-MoE (interleaved SSM + Sparse MoE layers)
|
| 35 |
+
- **Layers**: 24 (12 SSM + 12 MoE)
|
|
|
|
| 36 |
- **Hidden dim**: 1024
|
| 37 |
+
- **Experts**: 4 (top-1 routing)
|
| 38 |
+
- **Vocab**: 261,019 tokens (Aya tokenizer)
|
| 39 |
+
- **Parameters**: 722M
|
|
|
|
| 40 |
|
| 41 |
## Training
|
| 42 |
+
- **Stage 1**: CKA-guided layer alignment (10K steps)
|
| 43 |
+
- **Stage 2**: KL divergence distillation, T=2.0, alpha=0.7 (20K steps, best loss=2.73)
|
| 44 |
+
- **Stage 3**: SFT fine-tuning (pending)
|
| 45 |
+
- **Teacher**: CohereLabs/tiny-aya-global (3.35B)
|
| 46 |
+
- **Data**: ClimbMix (NVIDIA)
|
| 47 |
|
| 48 |
+
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
```python
|
| 51 |
+
import torch, yaml, sys
|
| 52 |
+
sys.path.insert(0, ".")
|
| 53 |
+
from aetheris.config import AetherisConfig
|
| 54 |
+
from aetheris.model import HybridMambaMoE
|
| 55 |
|
| 56 |
+
config = AetherisConfig.from_yaml("config.yaml")
|
| 57 |
+
model = HybridMambaMoE(config)
|
| 58 |
+
sd = torch.load("pytorch_model.pt", map_location="cpu")
|
| 59 |
+
model.load_state_dict(sd)
|
| 60 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
+
|
| 63 |
+
## Wayy Research
|
| 64 |
+
*People for research, research for people.*
|
| 65 |
+
Buffalo, NY — Est. 2024
|
aetheris/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import HybridMambaMoE
|
| 2 |
+
from .config import AetherisConfig
|
aetheris/api/schemas.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
class ChatMessage(BaseModel):
|
| 6 |
+
role: str
|
| 7 |
+
content: str
|
| 8 |
+
|
| 9 |
+
class ChatCompletionRequest(BaseModel):
|
| 10 |
+
model: str
|
| 11 |
+
messages: List[ChatMessage]
|
| 12 |
+
temperature: Optional[float] = 1.0
|
| 13 |
+
top_p: Optional[float] = 1.0
|
| 14 |
+
n: Optional[int] = 1
|
| 15 |
+
stream: Optional[bool] = False
|
| 16 |
+
stop: Optional[Union[str, List[str]]] = None
|
| 17 |
+
max_tokens: Optional[int] = None
|
| 18 |
+
presence_penalty: Optional[float] = 0.0
|
| 19 |
+
frequency_penalty: Optional[float] = 0.0
|
| 20 |
+
logit_bias: Optional[Dict[str, float]] = None
|
| 21 |
+
user: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
class ChatCompletionChoice(BaseModel):
|
| 24 |
+
index: int
|
| 25 |
+
message: ChatMessage
|
| 26 |
+
finish_reason: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
class ChatCompletionResponse(BaseModel):
|
| 29 |
+
id: str
|
| 30 |
+
object: str = "chat.completion"
|
| 31 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 32 |
+
model: str
|
| 33 |
+
choices: List[ChatCompletionChoice]
|
| 34 |
+
usage: Optional[Dict[str, int]] = None
|
| 35 |
+
|
| 36 |
+
class ChatCompletionChunkDelta(BaseModel):
|
| 37 |
+
role: Optional[str] = None
|
| 38 |
+
content: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
class ChatCompletionChunkChoice(BaseModel):
|
| 41 |
+
index: int
|
| 42 |
+
delta: ChatCompletionChunkDelta
|
| 43 |
+
finish_reason: Optional[str] = None
|
| 44 |
+
|
| 45 |
+
class ChatCompletionChunk(BaseModel):
|
| 46 |
+
id: str
|
| 47 |
+
object: str = "chat.completion.chunk"
|
| 48 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 49 |
+
model: str
|
| 50 |
+
choices: List[ChatCompletionChunkChoice]
|
| 51 |
+
|
| 52 |
+
class CompletionRequest(BaseModel):
|
| 53 |
+
model: str
|
| 54 |
+
prompt: Union[str, List[str]]
|
| 55 |
+
suffix: Optional[str] = None
|
| 56 |
+
max_tokens: Optional[int] = 16
|
| 57 |
+
temperature: Optional[float] = 1.0
|
| 58 |
+
top_p: Optional[float] = 1.0
|
| 59 |
+
n: Optional[int] = 1
|
| 60 |
+
stream: Optional[bool] = False
|
| 61 |
+
logprobs: Optional[int] = None
|
| 62 |
+
echo: Optional[bool] = False
|
| 63 |
+
stop: Optional[Union[str, List[str]]] = None
|
| 64 |
+
presence_penalty: Optional[float] = 0.0
|
| 65 |
+
frequency_penalty: Optional[float] = 0.0
|
| 66 |
+
best_of: Optional[int] = 1
|
| 67 |
+
logit_bias: Optional[Dict[str, float]] = None
|
| 68 |
+
user: Optional[str] = None
|
| 69 |
+
|
| 70 |
+
class CompletionChoice(BaseModel):
|
| 71 |
+
text: str
|
| 72 |
+
index: int
|
| 73 |
+
logprobs: Optional[Any] = None
|
| 74 |
+
finish_reason: Optional[str] = None
|
| 75 |
+
|
| 76 |
+
class CompletionResponse(BaseModel):
|
| 77 |
+
id: str
|
| 78 |
+
object: str = "text_completion"
|
| 79 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 80 |
+
model: str
|
| 81 |
+
choices: List[CompletionChoice]
|
| 82 |
+
usage: Optional[Dict[str, int]] = None
|
| 83 |
+
|
| 84 |
+
class ModelCard(BaseModel):
|
| 85 |
+
id: str
|
| 86 |
+
object: str = "model"
|
| 87 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 88 |
+
owned_by: str = "aetheris"
|
| 89 |
+
|
| 90 |
+
class ModelList(BaseModel):
|
| 91 |
+
object: str = "list"
|
| 92 |
+
data: List[ModelCard]
|
aetheris/api/server.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import uuid
|
| 3 |
+
import json
|
| 4 |
+
import asyncio
|
| 5 |
+
from typing import AsyncGenerator
|
| 6 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from sse_starlette.sse import EventSourceResponse
|
| 9 |
+
from aetheris.api.schemas import (
|
| 10 |
+
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk,
|
| 11 |
+
ChatCompletionChoice, ChatMessage, ChatCompletionChunkChoice, ChatCompletionChunkDelta,
|
| 12 |
+
CompletionRequest, CompletionResponse, CompletionChoice,
|
| 13 |
+
ModelList, ModelCard
|
| 14 |
+
)
|
| 15 |
+
from aetheris.inference import InferenceEngine
|
| 16 |
+
|
| 17 |
+
app = FastAPI(title="Aetheris API", version="0.1.0")
|
| 18 |
+
|
| 19 |
+
app.add_middleware(
|
| 20 |
+
CORSMiddleware,
|
| 21 |
+
allow_origins=["*"],
|
| 22 |
+
allow_credentials=True,
|
| 23 |
+
allow_methods=["*"],
|
| 24 |
+
allow_headers=["*"],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Global engine instance
|
| 28 |
+
engine: InferenceEngine = None
|
| 29 |
+
|
| 30 |
+
def get_engine():
|
| 31 |
+
global engine
|
| 32 |
+
if engine is None:
|
| 33 |
+
# Defaults, ideally loaded from config/env
|
| 34 |
+
engine = InferenceEngine()
|
| 35 |
+
return engine
|
| 36 |
+
|
| 37 |
+
@app.on_event("startup")
|
| 38 |
+
async def startup_event():
|
| 39 |
+
get_engine()
|
| 40 |
+
|
| 41 |
+
@app.get("/")
|
| 42 |
+
async def root():
|
| 43 |
+
return {"status": "running", "message": "Aetheris API is active. Use /v1/chat/completions for inference."}
|
| 44 |
+
|
| 45 |
+
@app.get("/v1/models", response_model=ModelList)
|
| 46 |
+
async def list_models():
|
| 47 |
+
return ModelList(data=[ModelCard(id="aetheris-hybrid-mamba-moe")])
|
| 48 |
+
|
| 49 |
+
@app.post("/v1/chat/completions")
|
| 50 |
+
async def chat_completions(request: ChatCompletionRequest):
|
| 51 |
+
engine = get_engine()
|
| 52 |
+
|
| 53 |
+
# Simple prompt construction from messages
|
| 54 |
+
prompt = ""
|
| 55 |
+
for msg in request.messages:
|
| 56 |
+
prompt += f"{msg.role}: {msg.content}\n"
|
| 57 |
+
prompt += "assistant: "
|
| 58 |
+
|
| 59 |
+
request_id = f"chatcmpl-{uuid.uuid4()}"
|
| 60 |
+
created_time = int(time.time())
|
| 61 |
+
|
| 62 |
+
if request.stream:
|
| 63 |
+
async def event_generator():
|
| 64 |
+
yield json.dumps(ChatCompletionChunk(
|
| 65 |
+
id=request_id,
|
| 66 |
+
created=created_time,
|
| 67 |
+
model=request.model,
|
| 68 |
+
choices=[ChatCompletionChunkChoice(
|
| 69 |
+
index=0,
|
| 70 |
+
delta=ChatCompletionChunkDelta(role="assistant"),
|
| 71 |
+
finish_reason=None
|
| 72 |
+
)]
|
| 73 |
+
).model_dump())
|
| 74 |
+
|
| 75 |
+
# Offload synchronous generation to a thread to avoid blocking the event loop
|
| 76 |
+
queue = asyncio.Queue()
|
| 77 |
+
loop = asyncio.get_running_loop()
|
| 78 |
+
import threading
|
| 79 |
+
stop_event = threading.Event()
|
| 80 |
+
|
| 81 |
+
def producer():
|
| 82 |
+
try:
|
| 83 |
+
# Run the synchronous generator
|
| 84 |
+
for token in engine.generate(
|
| 85 |
+
prompt=prompt,
|
| 86 |
+
max_new_tokens=request.max_tokens or 100,
|
| 87 |
+
temperature=request.temperature,
|
| 88 |
+
top_p=request.top_p,
|
| 89 |
+
repetition_penalty=1.0 + request.frequency_penalty,
|
| 90 |
+
stream=True
|
| 91 |
+
):
|
| 92 |
+
if stop_event.is_set():
|
| 93 |
+
break
|
| 94 |
+
# Schedule the put() coroutine on the main loop
|
| 95 |
+
asyncio.run_coroutine_threadsafe(queue.put(token), loop)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Generation error: {e}")
|
| 98 |
+
finally:
|
| 99 |
+
# Signal done
|
| 100 |
+
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
| 101 |
+
|
| 102 |
+
thread = threading.Thread(target=producer, daemon=True)
|
| 103 |
+
thread.start()
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
while True:
|
| 107 |
+
token = await queue.get()
|
| 108 |
+
if token is None:
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
yield json.dumps(ChatCompletionChunk(
|
| 112 |
+
id=request_id,
|
| 113 |
+
created=created_time,
|
| 114 |
+
model=request.model,
|
| 115 |
+
choices=[ChatCompletionChunkChoice(
|
| 116 |
+
index=0,
|
| 117 |
+
delta=ChatCompletionChunkDelta(content=token),
|
| 118 |
+
finish_reason=None
|
| 119 |
+
)]
|
| 120 |
+
).model_dump())
|
| 121 |
+
|
| 122 |
+
yield json.dumps(ChatCompletionChunk(
|
| 123 |
+
id=request_id,
|
| 124 |
+
created=created_time,
|
| 125 |
+
model=request.model,
|
| 126 |
+
choices=[ChatCompletionChunkChoice(
|
| 127 |
+
index=0,
|
| 128 |
+
delta=ChatCompletionChunkDelta(),
|
| 129 |
+
finish_reason="stop"
|
| 130 |
+
)]
|
| 131 |
+
).model_dump())
|
| 132 |
+
|
| 133 |
+
yield "[DONE]"
|
| 134 |
+
finally:
|
| 135 |
+
stop_event.set()
|
| 136 |
+
|
| 137 |
+
return EventSourceResponse(event_generator())
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
generated_text = engine.generate_full(
|
| 141 |
+
prompt=prompt,
|
| 142 |
+
max_new_tokens=request.max_tokens or 100,
|
| 143 |
+
temperature=request.temperature,
|
| 144 |
+
top_p=request.top_p,
|
| 145 |
+
repetition_penalty=1.0 + request.frequency_penalty
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return ChatCompletionResponse(
|
| 149 |
+
id=request_id,
|
| 150 |
+
created=created_time,
|
| 151 |
+
model=request.model,
|
| 152 |
+
choices=[ChatCompletionChoice(
|
| 153 |
+
index=0,
|
| 154 |
+
message=ChatMessage(role="assistant", content=generated_text),
|
| 155 |
+
finish_reason="stop"
|
| 156 |
+
)],
|
| 157 |
+
usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)} # Approximated
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
@app.post("/v1/completions")
|
| 161 |
+
async def completions(request: CompletionRequest):
|
| 162 |
+
engine = get_engine()
|
| 163 |
+
|
| 164 |
+
prompt = request.prompt
|
| 165 |
+
if isinstance(prompt, list):
|
| 166 |
+
prompt = prompt[0] # Handle single prompt for now
|
| 167 |
+
|
| 168 |
+
request_id = f"cmpl-{uuid.uuid4()}"
|
| 169 |
+
created_time = int(time.time())
|
| 170 |
+
|
| 171 |
+
if request.stream:
|
| 172 |
+
# Streaming for completions not fully implemented to match OpenAI exactly in this demo,
|
| 173 |
+
# but logic is similar to chat.
|
| 174 |
+
# For simplicity, returning non-streaming for now or basic stream.
|
| 175 |
+
pass # TODO: Implement streaming for completions
|
| 176 |
+
|
| 177 |
+
generated_text = engine.generate_full(
|
| 178 |
+
prompt=prompt,
|
| 179 |
+
max_new_tokens=request.max_tokens or 16,
|
| 180 |
+
temperature=request.temperature,
|
| 181 |
+
top_p=request.top_p,
|
| 182 |
+
repetition_penalty=1.0 + request.frequency_penalty
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return CompletionResponse(
|
| 186 |
+
id=request_id,
|
| 187 |
+
created=created_time,
|
| 188 |
+
model=request.model,
|
| 189 |
+
choices=[CompletionChoice(
|
| 190 |
+
text=generated_text,
|
| 191 |
+
index=0,
|
| 192 |
+
logprobs=None,
|
| 193 |
+
finish_reason="length" # or stop
|
| 194 |
+
)],
|
| 195 |
+
usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)}
|
| 196 |
+
)
|
aetheris/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
aetheris/cli/main.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from aetheris.config import AetherisConfig
|
| 8 |
+
from aetheris.model import HybridMambaMoE
|
| 9 |
+
from aetheris.data import create_streaming_loader, get_tokenizer
|
| 10 |
+
from aetheris.utils import load_latest_checkpoint, calculate_model_stats
|
| 11 |
+
from aetheris.trainer import Trainer
|
| 12 |
+
|
| 13 |
+
def train_command(args):
|
| 14 |
+
print(f"\n{'='*70}")
|
| 15 |
+
print(f"Aetheris Training")
|
| 16 |
+
print(f"Config: {args.config}")
|
| 17 |
+
|
| 18 |
+
if args.hf_token:
|
| 19 |
+
print(f"Using Hugging Face token: {args.hf_token[:10]}...")
|
| 20 |
+
from huggingface_hub import login
|
| 21 |
+
login(token=args.hf_token)
|
| 22 |
+
|
| 23 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 24 |
+
if device.type == 'cuda':
|
| 25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 26 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 27 |
+
torch.backends.cudnn.benchmark = True
|
| 28 |
+
torch.cuda.empty_cache()
|
| 29 |
+
|
| 30 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 31 |
+
|
| 32 |
+
# Add special tokens if using VoxLex config (vocab_size > 50257)
|
| 33 |
+
add_special = config.vocab_size > 50257
|
| 34 |
+
tokenizer = get_tokenizer(add_special_tokens=add_special)
|
| 35 |
+
|
| 36 |
+
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
|
| 37 |
+
print(f"Model Size: d_model={config.d_model}, layers={config.n_layer}")
|
| 38 |
+
print(f"Vocab Size: {config.vocab_size} | Max Seq Len: {config.max_seq_len}")
|
| 39 |
+
print(f"{'='*70}\n")
|
| 40 |
+
|
| 41 |
+
model = HybridMambaMoE(config).to(device)
|
| 42 |
+
|
| 43 |
+
# Apply weight initialization BEFORE resize (resize copies old weights)
|
| 44 |
+
print("Applying proper weight initialization...")
|
| 45 |
+
model.apply(model._init_weights)
|
| 46 |
+
|
| 47 |
+
# Resize embeddings if tokenizer has special tokens (AFTER init)
|
| 48 |
+
if len(tokenizer) > model.config.vocab_size:
|
| 49 |
+
print(f"Resizing embeddings: {model.config.vocab_size} → {len(tokenizer)}")
|
| 50 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 51 |
+
elif len(tokenizer) < model.config.vocab_size:
|
| 52 |
+
print(f"Resizing embeddings: {model.config.vocab_size} (config) with {len(tokenizer)} tokenizer tokens")
|
| 53 |
+
model.resize_token_embeddings(config.vocab_size)
|
| 54 |
+
|
| 55 |
+
# Calculate model stats
|
| 56 |
+
stats = calculate_model_stats(model)
|
| 57 |
+
print(f"Total Parameters: {stats['total_params']:,}")
|
| 58 |
+
print(f"Trainable Parameters: {stats['trainable_params']:,}")
|
| 59 |
+
|
| 60 |
+
# Use lower learning rate for stability
|
| 61 |
+
lr = args.lr if args.lr else 1e-4
|
| 62 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01,
|
| 63 |
+
betas=(0.9, 0.95), eps=1e-8)
|
| 64 |
+
# PyTorch 2.1 uses torch.cuda.amp.GradScaler; 2.3+ uses torch.amp.GradScaler
|
| 65 |
+
try:
|
| 66 |
+
scaler = torch.amp.GradScaler('cuda' if device.type == 'cuda' else 'cpu', init_scale=2**10)
|
| 67 |
+
except (TypeError, AttributeError):
|
| 68 |
+
scaler = torch.cuda.amp.GradScaler(init_scale=2**10)
|
| 69 |
+
|
| 70 |
+
if args.resume:
|
| 71 |
+
# Resume: load model + optimizer + scaler state
|
| 72 |
+
start_step, current_stage = load_latest_checkpoint(model, optimizer, scaler, device, args.checkpoint_dir, args.checkpoint_name)
|
| 73 |
+
else:
|
| 74 |
+
# Fine-tune: load model weights only, fresh optimizer
|
| 75 |
+
start_step, current_stage = load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
|
| 76 |
+
if start_step > 0:
|
| 77 |
+
print(f" Loaded base weights (was at step {start_step}), resetting to step 0 for fine-tuning")
|
| 78 |
+
start_step = 0
|
| 79 |
+
current_stage = "Pre-Training"
|
| 80 |
+
|
| 81 |
+
if args.compile:
|
| 82 |
+
print("Compiling model with torch.compile()...")
|
| 83 |
+
model = torch.compile(model)
|
| 84 |
+
|
| 85 |
+
trainer = Trainer(model, optimizer, scaler, config, device, args.checkpoint_dir, grad_accum_steps=args.accumulate_grad_batches)
|
| 86 |
+
|
| 87 |
+
# Resolve dataset names
|
| 88 |
+
pretrain_dataset = args.pretrain_dataset or "cerebras/SlimPajama-627B"
|
| 89 |
+
sft_dataset = args.sft_dataset or "OpenAssistant/oasst1"
|
| 90 |
+
|
| 91 |
+
# --- STAGE 1: PRE-TRAINING ---
|
| 92 |
+
if current_stage == "Pre-Training" or start_step == 0:
|
| 93 |
+
print(f"\n=== STAGE 1: Pre-Training on {pretrain_dataset} ===")
|
| 94 |
+
|
| 95 |
+
# Build LR scheduler for pretraining (adjust for gradient accumulation)
|
| 96 |
+
warmup_steps = args.warmup_steps if args.warmup_steps else 1000
|
| 97 |
+
effective_steps = max(1, args.pretrain_steps // args.accumulate_grad_batches)
|
| 98 |
+
effective_warmup = max(1, warmup_steps // args.accumulate_grad_batches)
|
| 99 |
+
scheduler = _build_scheduler(optimizer, effective_steps, effective_warmup)
|
| 100 |
+
trainer.scheduler = scheduler
|
| 101 |
+
|
| 102 |
+
pt_loader = create_streaming_loader(pretrain_dataset, "train",
|
| 103 |
+
tokenizer, config, args.batch_size, mode="pretrain",
|
| 104 |
+
hf_token=args.hf_token, start_step=start_step)
|
| 105 |
+
|
| 106 |
+
pt_val_loader = create_streaming_loader(pretrain_dataset, "validation",
|
| 107 |
+
tokenizer, config, args.batch_size, mode="pretrain",
|
| 108 |
+
hf_token=args.hf_token)
|
| 109 |
+
|
| 110 |
+
start_step = trainer.train_epoch(pt_loader, total_steps=args.pretrain_steps,
|
| 111 |
+
start_step=start_step, stage_name="Pre-Training",
|
| 112 |
+
val_loader=pt_val_loader)
|
| 113 |
+
current_stage = "SFT"
|
| 114 |
+
start_step = 0
|
| 115 |
+
|
| 116 |
+
# --- STAGE 2: SFT ---
|
| 117 |
+
print(f"\n=== STAGE 2: SFT on {sft_dataset} ===")
|
| 118 |
+
sft_lr = args.sft_lr if args.sft_lr else 5e-5
|
| 119 |
+
for param_group in optimizer.param_groups:
|
| 120 |
+
param_group['lr'] = sft_lr
|
| 121 |
+
|
| 122 |
+
# Build LR scheduler for SFT (adjust for gradient accumulation)
|
| 123 |
+
sft_warmup = args.sft_warmup_steps if args.sft_warmup_steps else 200
|
| 124 |
+
effective_sft_steps = max(1, args.sft_steps // args.accumulate_grad_batches)
|
| 125 |
+
effective_sft_warmup = max(1, sft_warmup // args.accumulate_grad_batches)
|
| 126 |
+
scheduler = _build_scheduler(optimizer, effective_sft_steps, effective_sft_warmup)
|
| 127 |
+
trainer.scheduler = scheduler
|
| 128 |
+
|
| 129 |
+
sft_loader = create_streaming_loader(sft_dataset, "train",
|
| 130 |
+
tokenizer, config, args.batch_size, mode="sft",
|
| 131 |
+
hf_token=args.hf_token, start_step=start_step)
|
| 132 |
+
|
| 133 |
+
sft_val_loader = create_streaming_loader(sft_dataset, "validation",
|
| 134 |
+
tokenizer, config, args.batch_size, mode="sft",
|
| 135 |
+
hf_token=args.hf_token)
|
| 136 |
+
|
| 137 |
+
trainer.train_epoch(sft_loader, total_steps=args.sft_steps,
|
| 138 |
+
start_step=start_step, stage_name="SFT",
|
| 139 |
+
val_loader=sft_val_loader)
|
| 140 |
+
|
| 141 |
+
print("\nTraining Complete!")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _build_scheduler(optimizer, total_steps, warmup_steps):
|
| 145 |
+
"""Cosine annealing with linear warmup. LR multiplier: 0→1 (warmup) → 0.1 (cosine)."""
|
| 146 |
+
def lr_lambda(current_step):
|
| 147 |
+
if current_step < warmup_steps:
|
| 148 |
+
return float(current_step) / float(max(1, warmup_steps))
|
| 149 |
+
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
|
| 150 |
+
return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 151 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def generate_command(args):
|
| 156 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 157 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 158 |
+
|
| 159 |
+
add_special = config.vocab_size > 50257
|
| 160 |
+
tokenizer = get_tokenizer(add_special_tokens=add_special)
|
| 161 |
+
|
| 162 |
+
model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
|
| 163 |
+
|
| 164 |
+
# Resize if needed
|
| 165 |
+
if len(tokenizer) != config.vocab_size:
|
| 166 |
+
model.resize_token_embeddings(config.vocab_size)
|
| 167 |
+
|
| 168 |
+
load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
|
| 169 |
+
model.eval()
|
| 170 |
+
|
| 171 |
+
prompt = args.prompt
|
| 172 |
+
max_new_tokens = args.max_new_tokens
|
| 173 |
+
temperature = args.temperature
|
| 174 |
+
top_k = args.top_k
|
| 175 |
+
top_p = args.top_p
|
| 176 |
+
repetition_penalty = args.repetition_penalty
|
| 177 |
+
|
| 178 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
| 179 |
+
generated_ids = input_ids.clone()
|
| 180 |
+
history_ids = set(input_ids[0].tolist())
|
| 181 |
+
|
| 182 |
+
print("-" * 50)
|
| 183 |
+
print(f"Prompt: {prompt}")
|
| 184 |
+
print("Generated Continuation:")
|
| 185 |
+
|
| 186 |
+
for step in range(max_new_tokens):
|
| 187 |
+
use_autocast = True
|
| 188 |
+
if config.torch_dtype == torch.float32:
|
| 189 |
+
use_autocast = False
|
| 190 |
+
|
| 191 |
+
if use_autocast:
|
| 192 |
+
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
|
| 193 |
+
outputs = model(generated_ids)
|
| 194 |
+
logits = outputs['logits']
|
| 195 |
+
next_token_logits = logits[:, -1, :]
|
| 196 |
+
else:
|
| 197 |
+
outputs = model(generated_ids)
|
| 198 |
+
logits = outputs['logits']
|
| 199 |
+
next_token_logits = logits[:, -1, :]
|
| 200 |
+
|
| 201 |
+
# Repetition penalty
|
| 202 |
+
for token_id in history_ids:
|
| 203 |
+
if token_id < next_token_logits.size(-1):
|
| 204 |
+
logit = next_token_logits[0, token_id].item()
|
| 205 |
+
if logit > 0:
|
| 206 |
+
next_token_logits[0, token_id] = logit / repetition_penalty
|
| 207 |
+
else:
|
| 208 |
+
next_token_logits[0, token_id] = logit * repetition_penalty
|
| 209 |
+
|
| 210 |
+
# Temperature
|
| 211 |
+
if temperature > 0:
|
| 212 |
+
next_token_logits = next_token_logits / temperature
|
| 213 |
+
|
| 214 |
+
# Top-p / Top-k
|
| 215 |
+
if top_p < 1.0:
|
| 216 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 217 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 218 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 219 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 220 |
+
sorted_indices_to_remove[..., 0] = False
|
| 221 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 222 |
+
next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
|
| 223 |
+
elif top_k > 0:
|
| 224 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
|
| 225 |
+
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
|
| 226 |
+
next_token_logits.scatter_(1, top_k_indices, top_k_logits)
|
| 227 |
+
|
| 228 |
+
# Sample
|
| 229 |
+
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
| 230 |
+
next_token = torch.multinomial(next_token_probs, num_samples=1)
|
| 231 |
+
next_token_item = next_token.item()
|
| 232 |
+
|
| 233 |
+
if next_token_item == tokenizer.eos_token_id:
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
| 237 |
+
history_ids.add(next_token_item)
|
| 238 |
+
|
| 239 |
+
new_token_text = tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
|
| 240 |
+
print(new_token_text, end="", flush=True)
|
| 241 |
+
|
| 242 |
+
print("\n" + "-" * 50)
|
| 243 |
+
|
| 244 |
+
def info_command(args):
|
| 245 |
+
config = AetherisConfig.from_yaml(args.config)
|
| 246 |
+
model = HybridMambaMoE(config)
|
| 247 |
+
|
| 248 |
+
total_params = 0
|
| 249 |
+
dense_params = 0
|
| 250 |
+
expert_params = 0
|
| 251 |
+
|
| 252 |
+
for name, param in model.named_parameters():
|
| 253 |
+
numel = param.numel()
|
| 254 |
+
total_params += numel
|
| 255 |
+
|
| 256 |
+
if 'experts' in name:
|
| 257 |
+
expert_params += numel
|
| 258 |
+
else:
|
| 259 |
+
dense_params += numel
|
| 260 |
+
|
| 261 |
+
single_expert_size = expert_params / config.num_experts if config.num_experts > 0 else 0
|
| 262 |
+
active_per_token_params = dense_params + (single_expert_size * config.top_k)
|
| 263 |
+
|
| 264 |
+
def format_count(count):
|
| 265 |
+
return f"{count / 1_000_000:.2f}M"
|
| 266 |
+
|
| 267 |
+
print("=" * 50)
|
| 268 |
+
print("Hybrid Mamba-MoE Model Parameter Analysis")
|
| 269 |
+
print("=" * 50)
|
| 270 |
+
print(f"Total Model Layers (N_Layer): {config.n_layer}")
|
| 271 |
+
print(f"MoE Experts per Layer: {config.num_experts}")
|
| 272 |
+
print(f"Active Experts (Top-K): {config.top_k}")
|
| 273 |
+
print("-" * 50)
|
| 274 |
+
print(f"Total Parameters (Checkpoint Size): {format_count(total_params)}")
|
| 275 |
+
print(f"Dense (Always Active) Parameters: {format_count(dense_params)}")
|
| 276 |
+
print(f"Expert-Only Parameters: {format_count(expert_params)}")
|
| 277 |
+
print("-" * 50)
|
| 278 |
+
print(f"**Active Parameters (Per-Token Compute Load): {format_count(active_per_token_params)}**")
|
| 279 |
+
print(" (This is the 'Dense' parameters + the K active expert parameters)")
|
| 280 |
+
print("=" * 50)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def main():
|
| 284 |
+
parser = argparse.ArgumentParser(description="Aetheris CLI")
|
| 285 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 286 |
+
|
| 287 |
+
# Train Command
|
| 288 |
+
train_parser = subparsers.add_parser("train", help="Train the model")
|
| 289 |
+
train_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 290 |
+
train_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
|
| 291 |
+
train_parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
|
| 292 |
+
train_parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
| 293 |
+
train_parser.add_argument("--pretrain_steps", type=int, default=50000, help="Number of pretraining steps")
|
| 294 |
+
train_parser.add_argument("--sft_steps", type=int, default=1000, help="Number of SFT steps")
|
| 295 |
+
train_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name to load from")
|
| 296 |
+
train_parser.add_argument("--compile", action="store_true", help="Compile model with torch.compile for speed")
|
| 297 |
+
train_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Gradient accumulation steps")
|
| 298 |
+
# Custom dataset args
|
| 299 |
+
train_parser.add_argument("--pretrain-dataset", type=str, default=None,
|
| 300 |
+
help="Pretraining dataset: local JSONL path or HuggingFace dataset name")
|
| 301 |
+
train_parser.add_argument("--sft-dataset", type=str, default=None,
|
| 302 |
+
help="SFT dataset: local JSONL path or HuggingFace dataset name")
|
| 303 |
+
# Learning rate args
|
| 304 |
+
train_parser.add_argument("--lr", type=float, default=None, help="Peak learning rate for pretraining (default: 1e-4)")
|
| 305 |
+
train_parser.add_argument("--sft-lr", type=float, default=None, help="Peak learning rate for SFT (default: 5e-5)")
|
| 306 |
+
train_parser.add_argument("--warmup-steps", type=int, default=None, help="Warmup steps for pretraining (default: 1000)")
|
| 307 |
+
train_parser.add_argument("--sft-warmup-steps", type=int, default=None, help="Warmup steps for SFT (default: 200)")
|
| 308 |
+
train_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint step (default: start from 0)")
|
| 309 |
+
|
| 310 |
+
# Generate Command
|
| 311 |
+
gen_parser = subparsers.add_parser("generate", help="Generate text")
|
| 312 |
+
gen_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 313 |
+
gen_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
|
| 314 |
+
gen_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
|
| 315 |
+
gen_parser.add_argument("--prompt", type=str, default="The quick brown fox", help="Prompt for generation")
|
| 316 |
+
gen_parser.add_argument("--max_new_tokens", type=int, default=100, help="Max new tokens to generate")
|
| 317 |
+
gen_parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| 318 |
+
gen_parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling")
|
| 319 |
+
gen_parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
|
| 320 |
+
gen_parser.add_argument("--repetition_penalty", type=float, default=3.0, help="Repetition penalty")
|
| 321 |
+
|
| 322 |
+
# Serve Command
|
| 323 |
+
serve_parser = subparsers.add_parser("serve", help="Start the API server")
|
| 324 |
+
serve_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
|
| 325 |
+
serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind")
|
| 326 |
+
serve_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 327 |
+
serve_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
|
| 328 |
+
serve_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
|
| 329 |
+
|
| 330 |
+
# Info Command
|
| 331 |
+
info_parser = subparsers.add_parser("info", help="Show model info")
|
| 332 |
+
info_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
|
| 333 |
+
|
| 334 |
+
args = parser.parse_args()
|
| 335 |
+
|
| 336 |
+
if args.command == "train":
|
| 337 |
+
train_command(args)
|
| 338 |
+
elif args.command == "generate":
|
| 339 |
+
generate_command(args)
|
| 340 |
+
elif args.command == "serve":
|
| 341 |
+
import uvicorn
|
| 342 |
+
from aetheris.api.server import app, get_engine
|
| 343 |
+
|
| 344 |
+
engine = get_engine()
|
| 345 |
+
from aetheris.inference import InferenceEngine
|
| 346 |
+
import aetheris.api.server
|
| 347 |
+
|
| 348 |
+
aetheris.api.server.engine = InferenceEngine(
|
| 349 |
+
config_path=args.config,
|
| 350 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 351 |
+
checkpoint_name=args.checkpoint_name
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
| 355 |
+
|
| 356 |
+
elif args.command == "info":
|
| 357 |
+
info_command(args)
|
| 358 |
+
else:
|
| 359 |
+
parser.print_help()
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
main()
|
aetheris/config.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class AetherisConfig:
|
| 8 |
+
# Model dimensions
|
| 9 |
+
vocab_size: int = 50257
|
| 10 |
+
d_model: int = 768
|
| 11 |
+
n_layer: int = 24
|
| 12 |
+
num_experts: int = 4
|
| 13 |
+
top_k: int = 1
|
| 14 |
+
d_ff: int = 2304 # d_model * 3
|
| 15 |
+
|
| 16 |
+
# SSM parameters
|
| 17 |
+
ssm_d_state: int = 16
|
| 18 |
+
ssm_expand: int = 2
|
| 19 |
+
d_inner: Optional[int] = None # Will be d_model * ssm_expand if None
|
| 20 |
+
|
| 21 |
+
# Training parameters
|
| 22 |
+
load_balancing_coef: float = 1e-2
|
| 23 |
+
router_z_loss_coef: float = 1e-3
|
| 24 |
+
max_seq_len: int = 512
|
| 25 |
+
dtype: str = "float16" # "float16", "float32", "bfloat16"
|
| 26 |
+
|
| 27 |
+
# Optimization settings
|
| 28 |
+
use_cpu_offload: bool = False
|
| 29 |
+
gradient_checkpointing: bool = True
|
| 30 |
+
checkpoint_ssm_layers: bool = True
|
| 31 |
+
use_flash_attention: bool = False
|
| 32 |
+
|
| 33 |
+
def __post_init__(self):
|
| 34 |
+
if self.d_inner is None:
|
| 35 |
+
self.d_inner = self.d_model * self.ssm_expand
|
| 36 |
+
if self.d_ff is None:
|
| 37 |
+
self.d_ff = self.d_model * 3
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def torch_dtype(self):
|
| 41 |
+
if self.dtype == "float16":
|
| 42 |
+
return torch.float16
|
| 43 |
+
elif self.dtype == "float32":
|
| 44 |
+
return torch.float32
|
| 45 |
+
elif self.dtype == "bfloat16":
|
| 46 |
+
return torch.bfloat16
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported dtype: {self.dtype}")
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def from_yaml(cls, path: str):
|
| 52 |
+
with open(path, 'r') as f:
|
| 53 |
+
config_dict = yaml.safe_load(f)
|
| 54 |
+
return cls(**config_dict)
|
| 55 |
+
|
| 56 |
+
def to_yaml(self, path: str):
|
| 57 |
+
with open(path, 'w') as f:
|
| 58 |
+
yaml.dump(self.__dict__, f)
|
aetheris/data.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
from typing import Dict, Iterator, List, Optional
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
VOXLEX_SPECIAL_TOKENS = [
|
| 11 |
+
"<tool_call>", "</tool_call>",
|
| 12 |
+
"<tool_result>", "</tool_result>",
|
| 13 |
+
"<legal_cite>", "</legal_cite>",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_tokenizer(model_name: str = "gpt2", add_special_tokens: bool = False):
|
| 18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 19 |
+
if tokenizer.pad_token is None:
|
| 20 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 21 |
+
if add_special_tokens:
|
| 22 |
+
num_added = tokenizer.add_special_tokens(
|
| 23 |
+
{"additional_special_tokens": VOXLEX_SPECIAL_TOKENS}
|
| 24 |
+
)
|
| 25 |
+
if num_added > 0:
|
| 26 |
+
print(f" Added {num_added} special tokens → vocab_size={len(tokenizer)}")
|
| 27 |
+
return tokenizer
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class StreamingDataset(IterableDataset):
|
| 31 |
+
def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=100, skip_samples=0):
|
| 32 |
+
self.dataset = dataset
|
| 33 |
+
self.tokenizer = tokenizer
|
| 34 |
+
self.max_seq_len = max_seq_len
|
| 35 |
+
self.mode = mode
|
| 36 |
+
self.buffer_size = buffer_size
|
| 37 |
+
self.skip_samples = skip_samples
|
| 38 |
+
|
| 39 |
+
def _find_assistant_spans(self, text: str) -> List[tuple]:
|
| 40 |
+
"""Find character spans of assistant responses in SFT text."""
|
| 41 |
+
spans = []
|
| 42 |
+
search_from = 0
|
| 43 |
+
while True:
|
| 44 |
+
start = text.find("<|assistant|>", search_from)
|
| 45 |
+
if start == -1:
|
| 46 |
+
break
|
| 47 |
+
content_start = start + len("<|assistant|>")
|
| 48 |
+
# End at next role tag or end of text
|
| 49 |
+
end = len(text)
|
| 50 |
+
for tag in ["<|user|>", "<|system|>", "<|tool|>", "<|endoftext|>"]:
|
| 51 |
+
pos = text.find(tag, content_start)
|
| 52 |
+
if pos != -1:
|
| 53 |
+
end = min(end, pos)
|
| 54 |
+
spans.append((content_start, end))
|
| 55 |
+
search_from = end
|
| 56 |
+
return spans
|
| 57 |
+
|
| 58 |
+
def _prepare_sft_example(self, example):
|
| 59 |
+
"""Prepare SFT example with label masking — loss only on assistant tokens."""
|
| 60 |
+
if 'messages' in example:
|
| 61 |
+
# Build text with role tags
|
| 62 |
+
text = ""
|
| 63 |
+
for msg in example['messages']:
|
| 64 |
+
role = msg.get('role', '')
|
| 65 |
+
content = msg.get('content', '')
|
| 66 |
+
text += f"<|{role}|>{content}"
|
| 67 |
+
text += self.tokenizer.eos_token
|
| 68 |
+
elif 'text' in example:
|
| 69 |
+
text = example['text']
|
| 70 |
+
else:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
if len(text) < 10:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# Pre-truncate to avoid slow tokenization of very long texts
|
| 77 |
+
max_chars = self.max_seq_len * 5
|
| 78 |
+
if len(text) > max_chars:
|
| 79 |
+
text = text[:max_chars]
|
| 80 |
+
|
| 81 |
+
enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
|
| 82 |
+
return_tensors="pt")
|
| 83 |
+
input_ids = enc['input_ids'][0]
|
| 84 |
+
|
| 85 |
+
if len(input_ids) < 2:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
# Build labels: -100 for non-assistant tokens
|
| 89 |
+
labels = torch.full_like(input_ids, -100)
|
| 90 |
+
assistant_spans = self._find_assistant_spans(text)
|
| 91 |
+
|
| 92 |
+
for char_start, char_end in assistant_spans:
|
| 93 |
+
# Map character offsets to token positions
|
| 94 |
+
in_span = False
|
| 95 |
+
for tok_idx in range(len(input_ids)):
|
| 96 |
+
token_span = enc.token_to_chars(0, tok_idx)
|
| 97 |
+
if token_span is None:
|
| 98 |
+
# Special token (e.g. <tool_call>) — include if neighbors are in span
|
| 99 |
+
if in_span:
|
| 100 |
+
labels[tok_idx] = input_ids[tok_idx]
|
| 101 |
+
continue
|
| 102 |
+
tok_start, tok_end = token_span
|
| 103 |
+
# Token overlaps with assistant span
|
| 104 |
+
if tok_end > char_start and tok_start < char_end:
|
| 105 |
+
labels[tok_idx] = input_ids[tok_idx]
|
| 106 |
+
in_span = True
|
| 107 |
+
else:
|
| 108 |
+
in_span = False
|
| 109 |
+
|
| 110 |
+
# Also train on eos token at the end
|
| 111 |
+
if input_ids[-1] == self.tokenizer.eos_token_id:
|
| 112 |
+
labels[-1] = input_ids[-1]
|
| 113 |
+
|
| 114 |
+
# Pad to max_seq_len
|
| 115 |
+
if len(input_ids) < self.max_seq_len:
|
| 116 |
+
pad_len = self.max_seq_len - len(input_ids)
|
| 117 |
+
input_ids = torch.cat([
|
| 118 |
+
input_ids,
|
| 119 |
+
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
|
| 120 |
+
])
|
| 121 |
+
labels = torch.cat([
|
| 122 |
+
labels,
|
| 123 |
+
torch.full((pad_len,), -100, dtype=torch.long)
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
return input_ids, labels
|
| 127 |
+
|
| 128 |
+
def _prepare_pretrain_example(self, example):
|
| 129 |
+
"""Prepare pretraining example — loss on all non-pad tokens."""
|
| 130 |
+
text = example.get('text', '')
|
| 131 |
+
if len(text) < 10:
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
# Pre-truncate text to avoid tokenizing 100K+ char documents
|
| 135 |
+
# GPT-2 averages ~4 chars per token; use 5x max_seq_len as safe limit
|
| 136 |
+
max_chars = self.max_seq_len * 5
|
| 137 |
+
if len(text) > max_chars:
|
| 138 |
+
text = text[:max_chars]
|
| 139 |
+
|
| 140 |
+
enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
|
| 141 |
+
return_tensors="pt")
|
| 142 |
+
input_ids = enc['input_ids'][0]
|
| 143 |
+
|
| 144 |
+
if len(input_ids) < 2:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
labels = input_ids.clone()
|
| 148 |
+
|
| 149 |
+
if len(input_ids) < self.max_seq_len:
|
| 150 |
+
pad_len = self.max_seq_len - len(input_ids)
|
| 151 |
+
input_ids = torch.cat([
|
| 152 |
+
input_ids,
|
| 153 |
+
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
|
| 154 |
+
])
|
| 155 |
+
labels = torch.cat([
|
| 156 |
+
labels,
|
| 157 |
+
torch.full((pad_len,), -100, dtype=torch.long)
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
return input_ids, labels
|
| 161 |
+
|
| 162 |
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 163 |
+
iterator = iter(self.dataset)
|
| 164 |
+
buffer = []
|
| 165 |
+
|
| 166 |
+
for example in iterator:
|
| 167 |
+
if self.mode == "pretrain":
|
| 168 |
+
result = self._prepare_pretrain_example(example)
|
| 169 |
+
else:
|
| 170 |
+
result = self._prepare_sft_example(example)
|
| 171 |
+
|
| 172 |
+
if result is None:
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
buffer.append(result)
|
| 176 |
+
|
| 177 |
+
if len(buffer) >= self.buffer_size:
|
| 178 |
+
random.shuffle(buffer)
|
| 179 |
+
for _ in range(self.buffer_size // 2):
|
| 180 |
+
item = buffer.pop()
|
| 181 |
+
if self.skip_samples > 0:
|
| 182 |
+
self.skip_samples -= 1
|
| 183 |
+
continue
|
| 184 |
+
yield item
|
| 185 |
+
|
| 186 |
+
# Yield remaining
|
| 187 |
+
random.shuffle(buffer)
|
| 188 |
+
while buffer:
|
| 189 |
+
item = buffer.pop()
|
| 190 |
+
if self.skip_samples > 0:
|
| 191 |
+
self.skip_samples -= 1
|
| 192 |
+
continue
|
| 193 |
+
yield item
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _load_jsonl_dataset(path: str):
|
| 197 |
+
"""Load a local JSONL file as a streaming iterable (no memory materialization)."""
|
| 198 |
+
from datasets import IterableDataset
|
| 199 |
+
|
| 200 |
+
def gen():
|
| 201 |
+
with open(path, 'r') as f:
|
| 202 |
+
for line in f:
|
| 203 |
+
line = line.strip()
|
| 204 |
+
if line:
|
| 205 |
+
yield json.loads(line)
|
| 206 |
+
|
| 207 |
+
return IterableDataset.from_generator(gen)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size,
|
| 211 |
+
mode="pretrain", hf_token=None, start_step=0):
|
| 212 |
+
# Support local JSONL files
|
| 213 |
+
if os.path.isfile(dataset_name) and dataset_name.endswith('.jsonl'):
|
| 214 |
+
print(f" Loading local dataset: {dataset_name}")
|
| 215 |
+
raw_dataset = _load_jsonl_dataset(dataset_name)
|
| 216 |
+
else:
|
| 217 |
+
raw_dataset = load_dataset(dataset_name, split=split, streaming=True,
|
| 218 |
+
trust_remote_code=True, token=hf_token)
|
| 219 |
+
|
| 220 |
+
# Calculate samples to skip: start_step * batch_size
|
| 221 |
+
skip_samples = start_step * batch_size
|
| 222 |
+
if skip_samples > 0:
|
| 223 |
+
print(f" [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...")
|
| 224 |
+
|
| 225 |
+
stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len,
|
| 226 |
+
mode=mode, skip_samples=skip_samples)
|
| 227 |
+
|
| 228 |
+
# num_workers=0 avoids 4x data duplication with IterableDataset
|
| 229 |
+
# (each worker iterates the full dataset without sharding logic)
|
| 230 |
+
return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True,
|
| 231 |
+
num_workers=0)
|
aetheris/inference.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Optional, List, Generator
|
| 4 |
+
from aetheris.config import AetherisConfig
|
| 5 |
+
from aetheris.model import HybridMambaMoE
|
| 6 |
+
from aetheris.data import get_tokenizer
|
| 7 |
+
from aetheris.utils import load_latest_checkpoint
|
| 8 |
+
|
| 9 |
+
class InferenceEngine:
|
| 10 |
+
def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None):
|
| 11 |
+
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
| 12 |
+
self.config = AetherisConfig.from_yaml(config_path)
|
| 13 |
+
self.tokenizer = get_tokenizer()
|
| 14 |
+
|
| 15 |
+
self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype)
|
| 16 |
+
|
| 17 |
+
# Load checkpoint
|
| 18 |
+
# Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None
|
| 19 |
+
load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name)
|
| 20 |
+
self.model.eval()
|
| 21 |
+
|
| 22 |
+
def generate(self,
|
| 23 |
+
prompt: str,
|
| 24 |
+
max_new_tokens: int = 100,
|
| 25 |
+
temperature: float = 0.8,
|
| 26 |
+
top_k: int = 0,
|
| 27 |
+
top_p: float = 0.9,
|
| 28 |
+
repetition_penalty: float = 1.0,
|
| 29 |
+
stream: bool = False) -> Generator[str, None, None] | str:
|
| 30 |
+
|
| 31 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
| 32 |
+
generated_ids = input_ids.clone()
|
| 33 |
+
history_ids = set(input_ids[0].tolist())
|
| 34 |
+
|
| 35 |
+
def token_generator():
|
| 36 |
+
nonlocal generated_ids
|
| 37 |
+
for _ in range(max_new_tokens):
|
| 38 |
+
# Check if we should use autocast (skip if model uses float32)
|
| 39 |
+
use_autocast = True
|
| 40 |
+
if self.config.torch_dtype == torch.float32:
|
| 41 |
+
use_autocast = False
|
| 42 |
+
|
| 43 |
+
if use_autocast:
|
| 44 |
+
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype):
|
| 45 |
+
outputs = self.model(generated_ids)
|
| 46 |
+
logits = outputs['logits']
|
| 47 |
+
next_token_logits = logits[:, -1, :]
|
| 48 |
+
else:
|
| 49 |
+
outputs = self.model(generated_ids)
|
| 50 |
+
logits = outputs['logits']
|
| 51 |
+
next_token_logits = logits[:, -1, :]
|
| 52 |
+
|
| 53 |
+
# Repetition penalty
|
| 54 |
+
for token_id in history_ids:
|
| 55 |
+
if token_id < next_token_logits.size(-1):
|
| 56 |
+
logit = next_token_logits[0, token_id].item()
|
| 57 |
+
if logit > 0:
|
| 58 |
+
next_token_logits[0, token_id] = logit / repetition_penalty
|
| 59 |
+
else:
|
| 60 |
+
next_token_logits[0, token_id] = logit * repetition_penalty
|
| 61 |
+
|
| 62 |
+
# Temperature
|
| 63 |
+
if temperature > 0:
|
| 64 |
+
next_token_logits = next_token_logits / temperature
|
| 65 |
+
|
| 66 |
+
# Top-p / Top-k
|
| 67 |
+
if top_p < 1.0:
|
| 68 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 69 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 70 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 71 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 72 |
+
sorted_indices_to_remove[..., 0] = False
|
| 73 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 74 |
+
next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
|
| 75 |
+
elif top_k > 0:
|
| 76 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
|
| 77 |
+
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
|
| 78 |
+
next_token_logits.scatter_(1, top_k_indices, top_k_logits)
|
| 79 |
+
|
| 80 |
+
# Sample
|
| 81 |
+
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
| 82 |
+
next_token = torch.multinomial(next_token_probs, num_samples=1)
|
| 83 |
+
next_token_item = next_token.item()
|
| 84 |
+
|
| 85 |
+
if next_token_item == self.tokenizer.eos_token_id:
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
| 89 |
+
history_ids.add(next_token_item)
|
| 90 |
+
|
| 91 |
+
new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
|
| 92 |
+
yield new_token_text
|
| 93 |
+
|
| 94 |
+
if stream:
|
| 95 |
+
return token_generator()
|
| 96 |
+
else:
|
| 97 |
+
return "".join(list(token_generator()))
|
| 98 |
+
|
| 99 |
+
def generate_full(self,
|
| 100 |
+
prompt: str,
|
| 101 |
+
max_new_tokens: int = 100,
|
| 102 |
+
temperature: float = 0.8,
|
| 103 |
+
top_k: int = 0,
|
| 104 |
+
top_p: float = 0.9,
|
| 105 |
+
repetition_penalty: float = 1.0) -> str:
|
| 106 |
+
return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)
|
aetheris/model.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.utils.checkpoint
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
from .config import AetherisConfig
|
| 6 |
+
from .modules import SSMBlock, SparseMoELayer
|
| 7 |
+
|
| 8 |
+
class HybridMambaMoE(nn.Module):
|
| 9 |
+
def __init__(self, config: AetherisConfig):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.config = config
|
| 12 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 13 |
+
|
| 14 |
+
self.layers = nn.ModuleList()
|
| 15 |
+
for i in range(config.n_layer):
|
| 16 |
+
if i % 2 == 0:
|
| 17 |
+
self.layers.append(SSMBlock(config))
|
| 18 |
+
else:
|
| 19 |
+
self.layers.append(SparseMoELayer(config))
|
| 20 |
+
|
| 21 |
+
self.final_norm = nn.LayerNorm(config.d_model)
|
| 22 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 23 |
+
self.lm_head.weight = self.embedding.weight # Weight tying
|
| 24 |
+
|
| 25 |
+
# Use -100 as ignore_index (PyTorch standard for label masking)
|
| 26 |
+
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
|
| 27 |
+
self.gradient_checkpointing = config.gradient_checkpointing
|
| 28 |
+
|
| 29 |
+
# Initialize embeddings with smaller scale
|
| 30 |
+
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
|
| 31 |
+
|
| 32 |
+
def resize_token_embeddings(self, new_vocab_size: int):
|
| 33 |
+
"""Resize embedding and lm_head for new tokens. New embeddings initialized from mean of existing."""
|
| 34 |
+
old_vocab_size = self.embedding.num_embeddings
|
| 35 |
+
if new_vocab_size == old_vocab_size:
|
| 36 |
+
return
|
| 37 |
+
old_weight = self.embedding.weight.data
|
| 38 |
+
mean_embed = old_weight.mean(dim=0)
|
| 39 |
+
self.embedding = nn.Embedding(new_vocab_size, self.config.d_model)
|
| 40 |
+
self.embedding.weight.data[:old_vocab_size] = old_weight
|
| 41 |
+
self.embedding.weight.data[old_vocab_size:] = mean_embed.unsqueeze(0).expand(
|
| 42 |
+
new_vocab_size - old_vocab_size, -1
|
| 43 |
+
)
|
| 44 |
+
self.lm_head = nn.Linear(self.config.d_model, new_vocab_size, bias=False)
|
| 45 |
+
self.lm_head.weight = self.embedding.weight # Re-tie weights
|
| 46 |
+
self.config.vocab_size = new_vocab_size
|
| 47 |
+
|
| 48 |
+
def _init_weights(self, module):
|
| 49 |
+
"""Apply proper weight initialization"""
|
| 50 |
+
if isinstance(module, nn.Linear):
|
| 51 |
+
nn.init.xavier_uniform_(module.weight, gain=0.5)
|
| 52 |
+
if module.bias is not None:
|
| 53 |
+
nn.init.zeros_(module.bias)
|
| 54 |
+
elif isinstance(module, nn.Embedding):
|
| 55 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 56 |
+
elif isinstance(module, nn.LayerNorm):
|
| 57 |
+
nn.init.ones_(module.weight)
|
| 58 |
+
nn.init.zeros_(module.bias)
|
| 59 |
+
|
| 60 |
+
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
|
| 61 |
+
x = self.embedding(input_ids)
|
| 62 |
+
total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
| 63 |
+
|
| 64 |
+
for i, layer in enumerate(self.layers):
|
| 65 |
+
if self.gradient_checkpointing and self.training:
|
| 66 |
+
# Checkpoint ALL layers for maximum memory savings
|
| 67 |
+
if isinstance(layer, SparseMoELayer):
|
| 68 |
+
def moe_forward(module, inp):
|
| 69 |
+
return module(inp)
|
| 70 |
+
x, aux_loss = torch.utils.checkpoint.checkpoint(
|
| 71 |
+
moe_forward, layer, x, use_reentrant=False
|
| 72 |
+
)
|
| 73 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 74 |
+
else:
|
| 75 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 76 |
+
layer, x, use_reentrant=False
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
if isinstance(layer, SparseMoELayer):
|
| 80 |
+
x, aux_loss = layer(x)
|
| 81 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 82 |
+
else:
|
| 83 |
+
x = layer(x)
|
| 84 |
+
|
| 85 |
+
x = self.final_norm(x)
|
| 86 |
+
logits = self.lm_head(x)
|
| 87 |
+
|
| 88 |
+
if labels is not None:
|
| 89 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 90 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 91 |
+
ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
|
| 92 |
+
shift_labels.view(-1))
|
| 93 |
+
|
| 94 |
+
# Scale down aux loss to prevent it from dominating
|
| 95 |
+
total_loss = ce_loss + 0.01 * total_aux_loss
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
"loss": total_loss,
|
| 99 |
+
"ce_loss": ce_loss,
|
| 100 |
+
"aux_loss": total_aux_loss,
|
| 101 |
+
"logits": logits
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
return {"logits": logits}
|
aetheris/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .expert import Expert
|
| 2 |
+
from .ssm import SSMBlock, selective_scan_native
|
| 3 |
+
from .moe import SparseMoELayer
|
aetheris/modules/expert.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class Expert(nn.Module):
|
| 6 |
+
"""Memory-efficient Feed-Forward Network expert with proper initialization."""
|
| 7 |
+
def __init__(self, d_model: int, d_ff: int):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.w1 = nn.Linear(d_model, d_ff, bias=False)
|
| 10 |
+
self.w2 = nn.Linear(d_ff, d_model, bias=False)
|
| 11 |
+
self.act = nn.GELU()
|
| 12 |
+
|
| 13 |
+
# Proper initialization to prevent NaN
|
| 14 |
+
nn.init.xavier_uniform_(self.w1.weight, gain=0.5)
|
| 15 |
+
nn.init.xavier_uniform_(self.w2.weight, gain=0.5)
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
orig_dtype = x.dtype
|
| 19 |
+
# Force float32 for internal computation to prevent overflow in half precision
|
| 20 |
+
x = x.to(torch.float32)
|
| 21 |
+
|
| 22 |
+
# Cast weights to float32 for calculation
|
| 23 |
+
# This is necessary because the module weights might be float16
|
| 24 |
+
w1_weight = self.w1.weight.to(torch.float32)
|
| 25 |
+
w2_weight = self.w2.weight.to(torch.float32)
|
| 26 |
+
|
| 27 |
+
h = F.linear(x, w1_weight)
|
| 28 |
+
h = self.act(h)
|
| 29 |
+
out = F.linear(h, w2_weight)
|
| 30 |
+
|
| 31 |
+
# Clamp to avoid Inf when casting back to float16
|
| 32 |
+
if orig_dtype == torch.float16:
|
| 33 |
+
out = torch.clamp(out, min=-65500.0, max=65500.0)
|
| 34 |
+
|
| 35 |
+
return out.to(orig_dtype)
|
aetheris/modules/moe.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from ..config import AetherisConfig
|
| 5 |
+
from .expert import Expert
|
| 6 |
+
|
| 7 |
+
class SparseMoELayer(nn.Module):
|
| 8 |
+
"""Memory-optimized Sparse MoE with efficient routing."""
|
| 9 |
+
def __init__(self, config: AetherisConfig):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.d_model = config.d_model
|
| 12 |
+
self.num_experts = config.num_experts
|
| 13 |
+
self.top_k = config.top_k
|
| 14 |
+
self.load_balancing_coef = config.load_balancing_coef
|
| 15 |
+
self.z_loss_coef = config.router_z_loss_coef
|
| 16 |
+
|
| 17 |
+
self.gate = nn.Linear(config.d_model, config.num_experts, bias=False)
|
| 18 |
+
self.experts = nn.ModuleList([Expert(config.d_model, config.d_ff)
|
| 19 |
+
for _ in range(config.num_experts)])
|
| 20 |
+
self.norm = nn.LayerNorm(config.d_model)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 23 |
+
B, L, D = x.shape
|
| 24 |
+
x_norm = self.norm(x)
|
| 25 |
+
flat_x = x_norm.view(-1, D)
|
| 26 |
+
|
| 27 |
+
# Routing Logits with stability
|
| 28 |
+
gate_logits = self.gate(flat_x)
|
| 29 |
+
|
| 30 |
+
# Clamp logits to prevent overflow
|
| 31 |
+
gate_logits = torch.clamp(gate_logits, min=-10.0, max=10.0)
|
| 32 |
+
|
| 33 |
+
# Z-Loss for stability
|
| 34 |
+
z_loss = torch.mean(torch.logsumexp(gate_logits, dim=-1)**2) * self.z_loss_coef
|
| 35 |
+
|
| 36 |
+
if self.training:
|
| 37 |
+
# Reduce noise for stability
|
| 38 |
+
gate_logits = gate_logits + torch.randn_like(gate_logits) * 1e-3
|
| 39 |
+
|
| 40 |
+
gate_probs = F.softmax(gate_logits, dim=-1)
|
| 41 |
+
gate_weights, expert_indices = torch.topk(gate_probs, self.top_k, dim=-1)
|
| 42 |
+
|
| 43 |
+
# Normalize weights for stability
|
| 44 |
+
gate_weights = gate_weights / (gate_weights.sum(dim=-1, keepdim=True) + 1e-8)
|
| 45 |
+
|
| 46 |
+
# Load balancing loss
|
| 47 |
+
# Use only the top-1 expert for load balancing calculation to keep it simple and consistent
|
| 48 |
+
expert_mask = F.one_hot(expert_indices[:, 0], num_classes=self.num_experts).float()
|
| 49 |
+
fraction_routed = expert_mask.mean(dim=0)
|
| 50 |
+
mean_prob = gate_probs.mean(dim=0)
|
| 51 |
+
|
| 52 |
+
aux_loss = (self.num_experts * torch.sum(fraction_routed * mean_prob)) * self.load_balancing_coef
|
| 53 |
+
total_aux_loss = aux_loss + z_loss
|
| 54 |
+
|
| 55 |
+
# Efficient dispatch with in-place operations
|
| 56 |
+
# Accumulate in float32 to prevent overflow during aggregation
|
| 57 |
+
final_output = torch.zeros_like(flat_x, dtype=torch.float32)
|
| 58 |
+
|
| 59 |
+
# Iterate over all k selected experts
|
| 60 |
+
for k_idx in range(self.top_k):
|
| 61 |
+
for i, expert in enumerate(self.experts):
|
| 62 |
+
# Find tokens routed to expert 'i' at the k-th position
|
| 63 |
+
mask = (expert_indices[:, k_idx] == i)
|
| 64 |
+
if not mask.any():
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
expert_input = flat_x[mask]
|
| 68 |
+
expert_out = expert(expert_input)
|
| 69 |
+
|
| 70 |
+
# Apply weights
|
| 71 |
+
weights = gate_weights[mask, k_idx].unsqueeze(1)
|
| 72 |
+
|
| 73 |
+
# Cast to float32 for accumulation
|
| 74 |
+
expert_out = expert_out.to(torch.float32)
|
| 75 |
+
weights = weights.to(torch.float32)
|
| 76 |
+
|
| 77 |
+
# Accumulate output (add to existing results from other experts)
|
| 78 |
+
final_output[mask] += expert_out * weights
|
| 79 |
+
|
| 80 |
+
# Cast back to original dtype
|
| 81 |
+
final_output = final_output.to(flat_x.dtype)
|
| 82 |
+
|
| 83 |
+
return x + final_output.view(B, L, D), total_aux_loss
|
aetheris/modules/ssm.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from ..config import AetherisConfig
|
| 5 |
+
|
| 6 |
+
# Try to import CUDA selective scan kernel
|
| 7 |
+
try:
|
| 8 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
| 9 |
+
HAS_CUDA_SSM = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
HAS_CUDA_SSM = False
|
| 12 |
+
|
| 13 |
+
@torch.jit.ignore
|
| 14 |
+
def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
| 15 |
+
B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""Fallback pure-Python scan (slow, O(L) sequential)."""
|
| 17 |
+
B_size, L, D_inner = u.shape
|
| 18 |
+
D_state = A.shape[-1]
|
| 19 |
+
original_dtype = u.dtype
|
| 20 |
+
|
| 21 |
+
h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=torch.float32)
|
| 22 |
+
ys = []
|
| 23 |
+
|
| 24 |
+
u = u.float()
|
| 25 |
+
delta = delta.float()
|
| 26 |
+
A = A.float()
|
| 27 |
+
B = B.float()
|
| 28 |
+
C = C.float()
|
| 29 |
+
D = D.float()
|
| 30 |
+
|
| 31 |
+
for l in range(L):
|
| 32 |
+
dt = delta[:, l, :].unsqueeze(-1)
|
| 33 |
+
dA = torch.exp(dt * A)
|
| 34 |
+
B_l = B[:, l, :].unsqueeze(1)
|
| 35 |
+
dB = dt * B_l
|
| 36 |
+
u_t = u[:, l, :].unsqueeze(-1)
|
| 37 |
+
h = dA * h + dB * u_t
|
| 38 |
+
C_l = C[:, l, :].unsqueeze(1)
|
| 39 |
+
y_t = torch.sum(h * C_l, dim=-1)
|
| 40 |
+
ys.append(y_t)
|
| 41 |
+
|
| 42 |
+
y = torch.stack(ys, dim=1)
|
| 43 |
+
y = y + u * D
|
| 44 |
+
return y.to(dtype=original_dtype)
|
| 45 |
+
|
| 46 |
+
class SSMBlock(nn.Module):
|
| 47 |
+
"""State Space Model block with optional CUDA-accelerated selective scan."""
|
| 48 |
+
def __init__(self, config: AetherisConfig):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.d_model = config.d_model
|
| 51 |
+
self.d_state = config.ssm_d_state
|
| 52 |
+
self.d_inner = config.d_inner
|
| 53 |
+
|
| 54 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False)
|
| 55 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
|
| 56 |
+
self.conv_d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3,
|
| 57 |
+
padding=2, groups=self.d_inner, bias=False)
|
| 58 |
+
self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=False)
|
| 59 |
+
|
| 60 |
+
self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
|
| 61 |
+
self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
|
| 62 |
+
self.delta_proj = nn.Linear(self.d_inner, self.d_inner, bias=False)
|
| 63 |
+
|
| 64 |
+
self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state) * 0.1 - 4.0)
|
| 65 |
+
self.D = nn.Parameter(torch.ones(self.d_inner) * 0.1)
|
| 66 |
+
|
| 67 |
+
self.act = nn.SiLU()
|
| 68 |
+
self.norm = nn.LayerNorm(config.d_model)
|
| 69 |
+
|
| 70 |
+
nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5)
|
| 71 |
+
nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5)
|
| 72 |
+
nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.5)
|
| 73 |
+
nn.init.xavier_uniform_(self.B_proj.weight, gain=0.5)
|
| 74 |
+
nn.init.xavier_uniform_(self.C_proj.weight, gain=0.5)
|
| 75 |
+
nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5)
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
B, L, D = x.shape
|
| 79 |
+
x_norm = self.norm(x)
|
| 80 |
+
|
| 81 |
+
xz = self.in_proj(x_norm)
|
| 82 |
+
x_in, z_gate = xz.chunk(2, dim=-1)
|
| 83 |
+
x_conv = self.conv_d(x_in.transpose(1, 2))
|
| 84 |
+
x_conv = x_conv[:, :, :-2].transpose(1, 2)
|
| 85 |
+
x_conv = self.act(x_conv)
|
| 86 |
+
|
| 87 |
+
B_ssm = self.B_proj(x_conv)
|
| 88 |
+
C_ssm = self.C_proj(x_conv)
|
| 89 |
+
|
| 90 |
+
# A is (D_inner, D_state) — clamped and negated
|
| 91 |
+
A = -torch.exp(torch.clamp(self.A_log, min=-10.0, max=2.0))
|
| 92 |
+
|
| 93 |
+
if HAS_CUDA_SSM and x.is_cuda:
|
| 94 |
+
# CUDA kernel expects float32 — cast inputs and cast output back
|
| 95 |
+
original_dtype = x_conv.dtype
|
| 96 |
+
delta_raw = self.delta_proj(x_conv)
|
| 97 |
+
y_ssm = selective_scan_fn(
|
| 98 |
+
x_conv.transpose(1, 2).contiguous().float(), # (B, D_inner, L)
|
| 99 |
+
delta_raw.transpose(1, 2).contiguous().float(), # (B, D_inner, L)
|
| 100 |
+
A.contiguous().float(), # (D_inner, D_state)
|
| 101 |
+
B_ssm.transpose(1, 2).contiguous().float(), # (B, D_state, L)
|
| 102 |
+
C_ssm.transpose(1, 2).contiguous().float(), # (B, D_state, L)
|
| 103 |
+
self.D.float(), # (D_inner,)
|
| 104 |
+
z=None,
|
| 105 |
+
delta_bias=None,
|
| 106 |
+
delta_softplus=True,
|
| 107 |
+
return_last_state=False,
|
| 108 |
+
)
|
| 109 |
+
y_ssm = y_ssm.to(dtype=original_dtype).transpose(1, 2) # Back to (B, L, D_inner)
|
| 110 |
+
else:
|
| 111 |
+
# Fallback: pure Python sequential scan
|
| 112 |
+
delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4
|
| 113 |
+
A_batched = A.unsqueeze(0).expand(B, -1, -1)
|
| 114 |
+
y_ssm = selective_scan_native(x_conv, delta, A_batched, B_ssm, C_ssm, self.D)
|
| 115 |
+
|
| 116 |
+
y_gate = F.silu(self.gate_proj(x_norm)) * y_ssm
|
| 117 |
+
output = self.out_proj(y_gate)
|
| 118 |
+
|
| 119 |
+
return x + output
|
aetheris/trainer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .trainer import Trainer
|
aetheris/trainer/trainer.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats
|
| 5 |
+
from aetheris.data import get_tokenizer
|
| 6 |
+
|
| 7 |
+
class Trainer:
|
| 8 |
+
def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None, grad_accum_steps=1):
|
| 9 |
+
self.model = model
|
| 10 |
+
self.optimizer = optimizer
|
| 11 |
+
self.scaler = scaler
|
| 12 |
+
self.config = config
|
| 13 |
+
self.device = device
|
| 14 |
+
self.checkpoint_dir = checkpoint_dir
|
| 15 |
+
self.logger = logger
|
| 16 |
+
self.grad_accum_steps = grad_accum_steps
|
| 17 |
+
self.scheduler = None # Set by CLI before train_epoch()
|
| 18 |
+
|
| 19 |
+
self.model.to(self.device)
|
| 20 |
+
|
| 21 |
+
def validate(self, val_loader, global_step):
|
| 22 |
+
self.model.eval()
|
| 23 |
+
total_loss = 0
|
| 24 |
+
total_items = 0
|
| 25 |
+
num_batches = 100 # Validate on 100 batches to save time
|
| 26 |
+
|
| 27 |
+
print(f"\n[Validation] Starting validation at step {global_step}...")
|
| 28 |
+
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
for i, batch in enumerate(val_loader):
|
| 31 |
+
if i >= num_batches:
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
input_ids, labels = batch
|
| 35 |
+
input_ids = input_ids.to(self.device, non_blocking=True)
|
| 36 |
+
labels = labels.to(self.device, non_blocking=True)
|
| 37 |
+
|
| 38 |
+
# Auto-cast context — bf16 on Ampere+, fp16 fallback
|
| 39 |
+
autocast_dtype = torch.bfloat16
|
| 40 |
+
|
| 41 |
+
use_autocast = True if self.config.torch_dtype != torch.float32 else False
|
| 42 |
+
|
| 43 |
+
if use_autocast:
|
| 44 |
+
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
| 45 |
+
output = self.model(input_ids, labels)
|
| 46 |
+
else:
|
| 47 |
+
output = self.model(input_ids, labels)
|
| 48 |
+
|
| 49 |
+
total_loss += output["loss"].item()
|
| 50 |
+
total_items += 1
|
| 51 |
+
|
| 52 |
+
avg_loss = total_loss / total_items if total_items > 0 else 0
|
| 53 |
+
perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
| 54 |
+
|
| 55 |
+
print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
|
| 56 |
+
self.model.train()
|
| 57 |
+
return avg_loss
|
| 58 |
+
|
| 59 |
+
def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
|
| 60 |
+
print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps} (Accum={self.grad_accum_steps})\n{'='*70}", flush=True)
|
| 61 |
+
self.model.train()
|
| 62 |
+
global_step = start_step
|
| 63 |
+
running_loss = 0
|
| 64 |
+
|
| 65 |
+
print("Initializing data iterator...")
|
| 66 |
+
train_iter = iter(train_loader)
|
| 67 |
+
|
| 68 |
+
print("Fetching first batch...")
|
| 69 |
+
|
| 70 |
+
# Zero gradients initially
|
| 71 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 72 |
+
|
| 73 |
+
while global_step < total_steps:
|
| 74 |
+
step_start = time.time()
|
| 75 |
+
|
| 76 |
+
# Removed periodic cache clearing for performance
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
batch = next(train_iter)
|
| 80 |
+
if global_step == start_step:
|
| 81 |
+
print(f"✓ First batch loaded! Starting training loop...", flush=True)
|
| 82 |
+
except StopIteration:
|
| 83 |
+
train_iter = iter(train_loader)
|
| 84 |
+
batch = next(train_iter)
|
| 85 |
+
|
| 86 |
+
data_time = time.time() - step_start
|
| 87 |
+
input_ids, labels = batch
|
| 88 |
+
input_ids = input_ids.to(self.device, non_blocking=True)
|
| 89 |
+
labels = labels.to(self.device, non_blocking=True)
|
| 90 |
+
|
| 91 |
+
gpu_start = time.time()
|
| 92 |
+
# Determine autocast dtype — bf16 on Ampere+ (no NaN from range overflow)
|
| 93 |
+
autocast_dtype = torch.bfloat16
|
| 94 |
+
|
| 95 |
+
# Check if we should use autocast (skip if model uses float32)
|
| 96 |
+
use_autocast = True
|
| 97 |
+
if self.config.torch_dtype == torch.float32:
|
| 98 |
+
use_autocast = False
|
| 99 |
+
|
| 100 |
+
if use_autocast:
|
| 101 |
+
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
| 102 |
+
output = self.model(input_ids, labels)
|
| 103 |
+
# Scale loss for accumulation
|
| 104 |
+
loss = output["loss"] / self.grad_accum_steps
|
| 105 |
+
else:
|
| 106 |
+
output = self.model(input_ids, labels)
|
| 107 |
+
loss = output["loss"] / self.grad_accum_steps
|
| 108 |
+
|
| 109 |
+
# NaN loss detection — skip batch entirely to prevent corruption
|
| 110 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 111 |
+
nan_count = getattr(self, '_nan_count', 0) + 1
|
| 112 |
+
self._nan_count = nan_count
|
| 113 |
+
print(f"WARNING: NaN/Inf loss at step {global_step} (count={nan_count}), skipping batch", flush=True)
|
| 114 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 115 |
+
global_step += 1
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
loss.backward()
|
| 119 |
+
if self.device.type == 'cuda':
|
| 120 |
+
torch.cuda.synchronize()
|
| 121 |
+
gpu_time = time.time() - gpu_start
|
| 122 |
+
|
| 123 |
+
# Gradient Accumulation Step
|
| 124 |
+
if (global_step + 1) % self.grad_accum_steps == 0:
|
| 125 |
+
# Gradient clipping
|
| 126 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
| 127 |
+
|
| 128 |
+
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
| 129 |
+
print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update", flush=True)
|
| 130 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 131 |
+
else:
|
| 132 |
+
self.optimizer.step()
|
| 133 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 134 |
+
|
| 135 |
+
# Step LR scheduler
|
| 136 |
+
if self.scheduler is not None:
|
| 137 |
+
self.scheduler.step()
|
| 138 |
+
|
| 139 |
+
global_step += 1
|
| 140 |
+
running_loss += (loss.item() * self.grad_accum_steps) # Un-scale for reporting
|
| 141 |
+
|
| 142 |
+
# Per-step progress file for monitoring (cheap I/O)
|
| 143 |
+
if global_step <= 20 or global_step % 100 == 0:
|
| 144 |
+
total_elapsed = time.time() - step_start
|
| 145 |
+
with open("/workspace/training_progress.log", "a") as pf:
|
| 146 |
+
pf.write(f"step={global_step} loss={loss.item() * self.grad_accum_steps:.4f} total={total_elapsed:.1f}s data={data_time:.1f}s gpu={gpu_time:.1f}s\n")
|
| 147 |
+
|
| 148 |
+
if global_step % 10 == 0:
|
| 149 |
+
avg_loss = running_loss / 10
|
| 150 |
+
t_diff = time.time() - step_start
|
| 151 |
+
if self.device.type == 'cuda':
|
| 152 |
+
mem = torch.cuda.memory_allocated() / 1e9
|
| 153 |
+
max_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 154 |
+
mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
|
| 155 |
+
else:
|
| 156 |
+
mem_str = "CPU Mode"
|
| 157 |
+
|
| 158 |
+
tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
|
| 159 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 160 |
+
msg = (f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
|
| 161 |
+
f"LR: {current_lr:.2e} | {mem_str} | {tokens_per_sec:.0f} tok/s")
|
| 162 |
+
print(msg, flush=True)
|
| 163 |
+
# Write progress to file (bypasses stdout buffering)
|
| 164 |
+
with open("/workspace/training_progress.log", "a") as pf:
|
| 165 |
+
pf.write(msg + "\n")
|
| 166 |
+
running_loss = 0
|
| 167 |
+
|
| 168 |
+
if global_step % 500 == 0:
|
| 169 |
+
save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
|
| 170 |
+
with open("/workspace/training_progress.log", "a") as pf:
|
| 171 |
+
pf.write(f" [Checkpoint saved at step {global_step}]\n")
|
| 172 |
+
|
| 173 |
+
if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
|
| 174 |
+
self.validate(val_loader, global_step)
|
| 175 |
+
|
| 176 |
+
return global_step
|
aetheris/utils.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
def save_checkpoint(model, optimizer, scaler, step, stage, checkpoint_dir, checkpoint_name="checkpoint_current.pth"):
|
| 6 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 7 |
+
path = os.path.join(checkpoint_dir, checkpoint_name)
|
| 8 |
+
torch.save({
|
| 9 |
+
'step': step,
|
| 10 |
+
'stage': stage,
|
| 11 |
+
'model_state_dict': model.state_dict(),
|
| 12 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 13 |
+
'scaler_state_dict': scaler.state_dict()
|
| 14 |
+
}, path)
|
| 15 |
+
print(f" [Checkpoint] Saved at step {step}")
|
| 16 |
+
|
| 17 |
+
def load_latest_checkpoint(model, optimizer, scaler, device, checkpoint_dir, checkpoint_name="checkpoint_current.pth") -> Tuple[int, str]:
|
| 18 |
+
path = os.path.join(checkpoint_dir, checkpoint_name)
|
| 19 |
+
if not os.path.exists(path):
|
| 20 |
+
return 0, "Pre-Training"
|
| 21 |
+
|
| 22 |
+
print(f" [Checkpoint] Loading from {path}...")
|
| 23 |
+
ckpt = torch.load(path, map_location=device)
|
| 24 |
+
state = ckpt['model_state_dict']
|
| 25 |
+
|
| 26 |
+
# Handle vocab size mismatch (base checkpoint may have fewer tokens than model)
|
| 27 |
+
model_vocab = model.config.vocab_size
|
| 28 |
+
for key in ("embedding.weight", "lm_head.weight"):
|
| 29 |
+
if key in state and state[key].shape[0] < model_vocab:
|
| 30 |
+
old = state[key]
|
| 31 |
+
pad_size = model_vocab - old.shape[0]
|
| 32 |
+
mean_vec = old.mean(dim=0)
|
| 33 |
+
state[key] = torch.cat([old, mean_vec.unsqueeze(0).expand(pad_size, -1)])
|
| 34 |
+
print(f" [Checkpoint] Padded {key}: {old.shape[0]} → {model_vocab}")
|
| 35 |
+
|
| 36 |
+
model.load_state_dict(state, strict=False)
|
| 37 |
+
|
| 38 |
+
if optimizer and 'optimizer_state_dict' in ckpt:
|
| 39 |
+
try:
|
| 40 |
+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
| 41 |
+
except (ValueError, KeyError):
|
| 42 |
+
print(" [Checkpoint] Optimizer state incompatible (vocab resize), using fresh optimizer")
|
| 43 |
+
if scaler and 'scaler_state_dict' in ckpt:
|
| 44 |
+
scaler.load_state_dict(ckpt['scaler_state_dict'])
|
| 45 |
+
return ckpt['step'], ckpt['stage']
|
| 46 |
+
|
| 47 |
+
def calculate_model_stats(model):
|
| 48 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 49 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 50 |
+
return {
|
| 51 |
+
'total_params': total_params,
|
| 52 |
+
'trainable_params': trainable_params,
|
| 53 |
+
'active_params': int(total_params * 0.6), # Approximation
|
| 54 |
+
'sparsity_ratio': 0.6 # Approximation
|
| 55 |
+
}
|
config.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_ssm_layers: true
|
| 2 |
+
d_ff: 3072
|
| 3 |
+
d_inner: 2048
|
| 4 |
+
d_model: 1024
|
| 5 |
+
dtype: float16
|
| 6 |
+
gradient_checkpointing: true
|
| 7 |
+
load_balancing_coef: 0.01
|
| 8 |
+
max_seq_len: 2048
|
| 9 |
+
n_layer: 24
|
| 10 |
+
num_experts: 4
|
| 11 |
+
router_z_loss_coef: 0.001
|
| 12 |
+
ssm_d_state: 16
|
| 13 |
+
ssm_expand: 2
|
| 14 |
+
top_k: 1
|
| 15 |
+
use_cpu_offload: false
|
| 16 |
+
use_flash_attention: false
|
| 17 |
+
vocab_size: 261019
|
pytorch_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9133520b370ce0ebab902748e74c0e60898f0ffe2c2f0d54f66f9412f40e9921
|
| 3 |
+
size 2886684406
|