rcgalbo commited on
Commit
8b21693
·
verified ·
1 Parent(s): 022db6b

Upload Aetheris model (Stage 2 best, 722M params, loss=2.73)

Browse files
README.md CHANGED
@@ -1,80 +1,65 @@
1
  ---
2
- language:
3
- - multilingual
4
- - en
5
- - es
6
- - hi
7
- - zh
8
- - ar
9
- - sw
10
- - tr
11
- - ja
12
- - id
13
- - te
14
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  tags:
16
- - mamba
17
- - moe
18
- - ssm
19
- - multilingual
20
- - distillation
21
- - aya
22
- library_name: aetheris
23
  pipeline_tag: text-generation
24
  ---
25
 
26
  # Aetheris — Hybrid Mamba-MoE Multilingual Model
27
 
28
- **Aetheris** is a ~800M parameter hybrid SSM/MoE language model distilled from
29
  [CohereLabs/tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global) (3.35B).
30
-
31
- Built by [Wayy Research](https://github.com/Wayy-Research).
32
 
33
  ## Architecture
34
-
35
- - **Type**: Hybrid Mamba (SSM) + Mixture of Experts (MoE)
36
- - **Layers**: 24 (interleaved: even=SSM, odd=MoE)
37
  - **Hidden dim**: 1024
38
- - **Experts**: 4 per MoE layer, top-1 routing
39
- - **SSM state dim**: 16
40
- - **Vocab size**: 256,000 (shared with tiny-aya-global)
41
- - **Parameters**: ~800M
42
 
43
  ## Training
 
 
 
 
 
44
 
45
- 3-stage MambaInLlama distillation pipeline:
46
-
47
- | Stage | Method | Data | Steps |
48
- |-------|--------|------|-------|
49
- | 1 | CKA-guided Layer Alignment | ClimbMix | 10,000 |
50
- | 2 | KL Distillation (T=2.0, alpha=0.7) | ClimbMix | 20,000 |
51
- | 3 | Supervised Fine-Tuning | aya_collection | 5,000 |
52
-
53
- Key research findings applied:
54
- - SSM 10x LR boost (compensates 27x gradient imbalance)
55
- - SVD split for MoE expert initialization (CKA=0.097 diversity)
56
- - Per-language KL tracking for multilingual equity
57
 
58
- ## Current Checkpoint
 
 
 
 
59
 
60
- - **Stage**: 2 (kl-distillation)
61
- - **Step**: 18000
62
- - **Loss**: 3.4199
63
- - **Updated**: 2026-03-13T01:45:14.154527+00:00
64
-
65
- ## Languages
66
-
67
- Supports 70+ languages inherited from tiny-aya-global. Core evaluation
68
- languages: English, Spanish, Hindi, Chinese, Arabic, Swahili, Turkish,
69
- Japanese, Indonesian, Telugu.
70
-
71
- ## Citation
72
-
73
- ```bibtex
74
- @misc{aetheris2026,
75
- title={Aetheris: Hybrid Mamba-MoE Multilingual Model via Knowledge Distillation},
76
- author={Wayy Research},
77
- year={2026},
78
- url={https://huggingface.co/wayyresearch/aetheris}
79
- }
80
  ```
 
 
 
 
 
1
  ---
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ - es
6
+ - fr
7
+ - de
8
+ - zh
9
+ - ja
10
+ - ko
11
+ - ar
12
+ - hi
13
+ - tr
14
+ - sw
15
+ - id
16
+ - pt
17
+ - ru
18
  tags:
19
+ - multilingual
20
+ - mamba
21
+ - moe
22
+ - distillation
23
+ - aya
 
 
24
  pipeline_tag: text-generation
25
  ---
26
 
27
  # Aetheris — Hybrid Mamba-MoE Multilingual Model
28
 
29
+ **Aetheris** is a ~720M parameter hybrid SSM-MoE language model distilled from
30
  [CohereLabs/tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global) (3.35B).
31
+ It supports **67 languages** with 4.6x compression.
 
32
 
33
  ## Architecture
34
+ - **Type**: Hybrid Mamba-MoE (interleaved SSM + Sparse MoE layers)
35
+ - **Layers**: 24 (12 SSM + 12 MoE)
 
36
  - **Hidden dim**: 1024
37
+ - **Experts**: 4 (top-1 routing)
38
+ - **Vocab**: 261,019 tokens (Aya tokenizer)
39
+ - **Parameters**: 722M
 
40
 
41
  ## Training
42
+ - **Stage 1**: CKA-guided layer alignment (10K steps)
43
+ - **Stage 2**: KL divergence distillation, T=2.0, alpha=0.7 (20K steps, best loss=2.73)
44
+ - **Stage 3**: SFT fine-tuning (pending)
45
+ - **Teacher**: CohereLabs/tiny-aya-global (3.35B)
46
+ - **Data**: ClimbMix (NVIDIA)
47
 
48
+ ## Usage
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ ```python
51
+ import torch, yaml, sys
52
+ sys.path.insert(0, ".")
53
+ from aetheris.config import AetherisConfig
54
+ from aetheris.model import HybridMambaMoE
55
 
56
+ config = AetherisConfig.from_yaml("config.yaml")
57
+ model = HybridMambaMoE(config)
58
+ sd = torch.load("pytorch_model.pt", map_location="cpu")
59
+ model.load_state_dict(sd)
60
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ```
62
+
63
+ ## Wayy Research
64
+ *People for research, research for people.*
65
+ Buffalo, NY — Est. 2024
aetheris/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import HybridMambaMoE
2
+ from .config import AetherisConfig
aetheris/api/schemas.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any
2
+ from pydantic import BaseModel, Field
3
+ import time
4
+
5
+ class ChatMessage(BaseModel):
6
+ role: str
7
+ content: str
8
+
9
+ class ChatCompletionRequest(BaseModel):
10
+ model: str
11
+ messages: List[ChatMessage]
12
+ temperature: Optional[float] = 1.0
13
+ top_p: Optional[float] = 1.0
14
+ n: Optional[int] = 1
15
+ stream: Optional[bool] = False
16
+ stop: Optional[Union[str, List[str]]] = None
17
+ max_tokens: Optional[int] = None
18
+ presence_penalty: Optional[float] = 0.0
19
+ frequency_penalty: Optional[float] = 0.0
20
+ logit_bias: Optional[Dict[str, float]] = None
21
+ user: Optional[str] = None
22
+
23
+ class ChatCompletionChoice(BaseModel):
24
+ index: int
25
+ message: ChatMessage
26
+ finish_reason: Optional[str] = None
27
+
28
+ class ChatCompletionResponse(BaseModel):
29
+ id: str
30
+ object: str = "chat.completion"
31
+ created: int = Field(default_factory=lambda: int(time.time()))
32
+ model: str
33
+ choices: List[ChatCompletionChoice]
34
+ usage: Optional[Dict[str, int]] = None
35
+
36
+ class ChatCompletionChunkDelta(BaseModel):
37
+ role: Optional[str] = None
38
+ content: Optional[str] = None
39
+
40
+ class ChatCompletionChunkChoice(BaseModel):
41
+ index: int
42
+ delta: ChatCompletionChunkDelta
43
+ finish_reason: Optional[str] = None
44
+
45
+ class ChatCompletionChunk(BaseModel):
46
+ id: str
47
+ object: str = "chat.completion.chunk"
48
+ created: int = Field(default_factory=lambda: int(time.time()))
49
+ model: str
50
+ choices: List[ChatCompletionChunkChoice]
51
+
52
+ class CompletionRequest(BaseModel):
53
+ model: str
54
+ prompt: Union[str, List[str]]
55
+ suffix: Optional[str] = None
56
+ max_tokens: Optional[int] = 16
57
+ temperature: Optional[float] = 1.0
58
+ top_p: Optional[float] = 1.0
59
+ n: Optional[int] = 1
60
+ stream: Optional[bool] = False
61
+ logprobs: Optional[int] = None
62
+ echo: Optional[bool] = False
63
+ stop: Optional[Union[str, List[str]]] = None
64
+ presence_penalty: Optional[float] = 0.0
65
+ frequency_penalty: Optional[float] = 0.0
66
+ best_of: Optional[int] = 1
67
+ logit_bias: Optional[Dict[str, float]] = None
68
+ user: Optional[str] = None
69
+
70
+ class CompletionChoice(BaseModel):
71
+ text: str
72
+ index: int
73
+ logprobs: Optional[Any] = None
74
+ finish_reason: Optional[str] = None
75
+
76
+ class CompletionResponse(BaseModel):
77
+ id: str
78
+ object: str = "text_completion"
79
+ created: int = Field(default_factory=lambda: int(time.time()))
80
+ model: str
81
+ choices: List[CompletionChoice]
82
+ usage: Optional[Dict[str, int]] = None
83
+
84
+ class ModelCard(BaseModel):
85
+ id: str
86
+ object: str = "model"
87
+ created: int = Field(default_factory=lambda: int(time.time()))
88
+ owned_by: str = "aetheris"
89
+
90
+ class ModelList(BaseModel):
91
+ object: str = "list"
92
+ data: List[ModelCard]
aetheris/api/server.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import uuid
3
+ import json
4
+ import asyncio
5
+ from typing import AsyncGenerator
6
+ from fastapi import FastAPI, HTTPException, Request
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from sse_starlette.sse import EventSourceResponse
9
+ from aetheris.api.schemas import (
10
+ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk,
11
+ ChatCompletionChoice, ChatMessage, ChatCompletionChunkChoice, ChatCompletionChunkDelta,
12
+ CompletionRequest, CompletionResponse, CompletionChoice,
13
+ ModelList, ModelCard
14
+ )
15
+ from aetheris.inference import InferenceEngine
16
+
17
+ app = FastAPI(title="Aetheris API", version="0.1.0")
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Global engine instance
28
+ engine: InferenceEngine = None
29
+
30
+ def get_engine():
31
+ global engine
32
+ if engine is None:
33
+ # Defaults, ideally loaded from config/env
34
+ engine = InferenceEngine()
35
+ return engine
36
+
37
+ @app.on_event("startup")
38
+ async def startup_event():
39
+ get_engine()
40
+
41
+ @app.get("/")
42
+ async def root():
43
+ return {"status": "running", "message": "Aetheris API is active. Use /v1/chat/completions for inference."}
44
+
45
+ @app.get("/v1/models", response_model=ModelList)
46
+ async def list_models():
47
+ return ModelList(data=[ModelCard(id="aetheris-hybrid-mamba-moe")])
48
+
49
+ @app.post("/v1/chat/completions")
50
+ async def chat_completions(request: ChatCompletionRequest):
51
+ engine = get_engine()
52
+
53
+ # Simple prompt construction from messages
54
+ prompt = ""
55
+ for msg in request.messages:
56
+ prompt += f"{msg.role}: {msg.content}\n"
57
+ prompt += "assistant: "
58
+
59
+ request_id = f"chatcmpl-{uuid.uuid4()}"
60
+ created_time = int(time.time())
61
+
62
+ if request.stream:
63
+ async def event_generator():
64
+ yield json.dumps(ChatCompletionChunk(
65
+ id=request_id,
66
+ created=created_time,
67
+ model=request.model,
68
+ choices=[ChatCompletionChunkChoice(
69
+ index=0,
70
+ delta=ChatCompletionChunkDelta(role="assistant"),
71
+ finish_reason=None
72
+ )]
73
+ ).model_dump())
74
+
75
+ # Offload synchronous generation to a thread to avoid blocking the event loop
76
+ queue = asyncio.Queue()
77
+ loop = asyncio.get_running_loop()
78
+ import threading
79
+ stop_event = threading.Event()
80
+
81
+ def producer():
82
+ try:
83
+ # Run the synchronous generator
84
+ for token in engine.generate(
85
+ prompt=prompt,
86
+ max_new_tokens=request.max_tokens or 100,
87
+ temperature=request.temperature,
88
+ top_p=request.top_p,
89
+ repetition_penalty=1.0 + request.frequency_penalty,
90
+ stream=True
91
+ ):
92
+ if stop_event.is_set():
93
+ break
94
+ # Schedule the put() coroutine on the main loop
95
+ asyncio.run_coroutine_threadsafe(queue.put(token), loop)
96
+ except Exception as e:
97
+ print(f"Generation error: {e}")
98
+ finally:
99
+ # Signal done
100
+ asyncio.run_coroutine_threadsafe(queue.put(None), loop)
101
+
102
+ thread = threading.Thread(target=producer, daemon=True)
103
+ thread.start()
104
+
105
+ try:
106
+ while True:
107
+ token = await queue.get()
108
+ if token is None:
109
+ break
110
+
111
+ yield json.dumps(ChatCompletionChunk(
112
+ id=request_id,
113
+ created=created_time,
114
+ model=request.model,
115
+ choices=[ChatCompletionChunkChoice(
116
+ index=0,
117
+ delta=ChatCompletionChunkDelta(content=token),
118
+ finish_reason=None
119
+ )]
120
+ ).model_dump())
121
+
122
+ yield json.dumps(ChatCompletionChunk(
123
+ id=request_id,
124
+ created=created_time,
125
+ model=request.model,
126
+ choices=[ChatCompletionChunkChoice(
127
+ index=0,
128
+ delta=ChatCompletionChunkDelta(),
129
+ finish_reason="stop"
130
+ )]
131
+ ).model_dump())
132
+
133
+ yield "[DONE]"
134
+ finally:
135
+ stop_event.set()
136
+
137
+ return EventSourceResponse(event_generator())
138
+
139
+ else:
140
+ generated_text = engine.generate_full(
141
+ prompt=prompt,
142
+ max_new_tokens=request.max_tokens or 100,
143
+ temperature=request.temperature,
144
+ top_p=request.top_p,
145
+ repetition_penalty=1.0 + request.frequency_penalty
146
+ )
147
+
148
+ return ChatCompletionResponse(
149
+ id=request_id,
150
+ created=created_time,
151
+ model=request.model,
152
+ choices=[ChatCompletionChoice(
153
+ index=0,
154
+ message=ChatMessage(role="assistant", content=generated_text),
155
+ finish_reason="stop"
156
+ )],
157
+ usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)} # Approximated
158
+ )
159
+
160
+ @app.post("/v1/completions")
161
+ async def completions(request: CompletionRequest):
162
+ engine = get_engine()
163
+
164
+ prompt = request.prompt
165
+ if isinstance(prompt, list):
166
+ prompt = prompt[0] # Handle single prompt for now
167
+
168
+ request_id = f"cmpl-{uuid.uuid4()}"
169
+ created_time = int(time.time())
170
+
171
+ if request.stream:
172
+ # Streaming for completions not fully implemented to match OpenAI exactly in this demo,
173
+ # but logic is similar to chat.
174
+ # For simplicity, returning non-streaming for now or basic stream.
175
+ pass # TODO: Implement streaming for completions
176
+
177
+ generated_text = engine.generate_full(
178
+ prompt=prompt,
179
+ max_new_tokens=request.max_tokens or 16,
180
+ temperature=request.temperature,
181
+ top_p=request.top_p,
182
+ repetition_penalty=1.0 + request.frequency_penalty
183
+ )
184
+
185
+ return CompletionResponse(
186
+ id=request_id,
187
+ created=created_time,
188
+ model=request.model,
189
+ choices=[CompletionChoice(
190
+ text=generated_text,
191
+ index=0,
192
+ logprobs=None,
193
+ finish_reason="length" # or stop
194
+ )],
195
+ usage={"prompt_tokens": len(prompt), "completion_tokens": len(generated_text), "total_tokens": len(prompt) + len(generated_text)}
196
+ )
aetheris/cli/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
aetheris/cli/main.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import torch
4
+ import os
5
+ import math
6
+ import torch.nn.functional as F
7
+ from aetheris.config import AetherisConfig
8
+ from aetheris.model import HybridMambaMoE
9
+ from aetheris.data import create_streaming_loader, get_tokenizer
10
+ from aetheris.utils import load_latest_checkpoint, calculate_model_stats
11
+ from aetheris.trainer import Trainer
12
+
13
+ def train_command(args):
14
+ print(f"\n{'='*70}")
15
+ print(f"Aetheris Training")
16
+ print(f"Config: {args.config}")
17
+
18
+ if args.hf_token:
19
+ print(f"Using Hugging Face token: {args.hf_token[:10]}...")
20
+ from huggingface_hub import login
21
+ login(token=args.hf_token)
22
+
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ if device.type == 'cuda':
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+ torch.backends.cudnn.benchmark = True
28
+ torch.cuda.empty_cache()
29
+
30
+ config = AetherisConfig.from_yaml(args.config)
31
+
32
+ # Add special tokens if using VoxLex config (vocab_size > 50257)
33
+ add_special = config.vocab_size > 50257
34
+ tokenizer = get_tokenizer(add_special_tokens=add_special)
35
+
36
+ print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
37
+ print(f"Model Size: d_model={config.d_model}, layers={config.n_layer}")
38
+ print(f"Vocab Size: {config.vocab_size} | Max Seq Len: {config.max_seq_len}")
39
+ print(f"{'='*70}\n")
40
+
41
+ model = HybridMambaMoE(config).to(device)
42
+
43
+ # Apply weight initialization BEFORE resize (resize copies old weights)
44
+ print("Applying proper weight initialization...")
45
+ model.apply(model._init_weights)
46
+
47
+ # Resize embeddings if tokenizer has special tokens (AFTER init)
48
+ if len(tokenizer) > model.config.vocab_size:
49
+ print(f"Resizing embeddings: {model.config.vocab_size} → {len(tokenizer)}")
50
+ model.resize_token_embeddings(len(tokenizer))
51
+ elif len(tokenizer) < model.config.vocab_size:
52
+ print(f"Resizing embeddings: {model.config.vocab_size} (config) with {len(tokenizer)} tokenizer tokens")
53
+ model.resize_token_embeddings(config.vocab_size)
54
+
55
+ # Calculate model stats
56
+ stats = calculate_model_stats(model)
57
+ print(f"Total Parameters: {stats['total_params']:,}")
58
+ print(f"Trainable Parameters: {stats['trainable_params']:,}")
59
+
60
+ # Use lower learning rate for stability
61
+ lr = args.lr if args.lr else 1e-4
62
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01,
63
+ betas=(0.9, 0.95), eps=1e-8)
64
+ # PyTorch 2.1 uses torch.cuda.amp.GradScaler; 2.3+ uses torch.amp.GradScaler
65
+ try:
66
+ scaler = torch.amp.GradScaler('cuda' if device.type == 'cuda' else 'cpu', init_scale=2**10)
67
+ except (TypeError, AttributeError):
68
+ scaler = torch.cuda.amp.GradScaler(init_scale=2**10)
69
+
70
+ if args.resume:
71
+ # Resume: load model + optimizer + scaler state
72
+ start_step, current_stage = load_latest_checkpoint(model, optimizer, scaler, device, args.checkpoint_dir, args.checkpoint_name)
73
+ else:
74
+ # Fine-tune: load model weights only, fresh optimizer
75
+ start_step, current_stage = load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
76
+ if start_step > 0:
77
+ print(f" Loaded base weights (was at step {start_step}), resetting to step 0 for fine-tuning")
78
+ start_step = 0
79
+ current_stage = "Pre-Training"
80
+
81
+ if args.compile:
82
+ print("Compiling model with torch.compile()...")
83
+ model = torch.compile(model)
84
+
85
+ trainer = Trainer(model, optimizer, scaler, config, device, args.checkpoint_dir, grad_accum_steps=args.accumulate_grad_batches)
86
+
87
+ # Resolve dataset names
88
+ pretrain_dataset = args.pretrain_dataset or "cerebras/SlimPajama-627B"
89
+ sft_dataset = args.sft_dataset or "OpenAssistant/oasst1"
90
+
91
+ # --- STAGE 1: PRE-TRAINING ---
92
+ if current_stage == "Pre-Training" or start_step == 0:
93
+ print(f"\n=== STAGE 1: Pre-Training on {pretrain_dataset} ===")
94
+
95
+ # Build LR scheduler for pretraining (adjust for gradient accumulation)
96
+ warmup_steps = args.warmup_steps if args.warmup_steps else 1000
97
+ effective_steps = max(1, args.pretrain_steps // args.accumulate_grad_batches)
98
+ effective_warmup = max(1, warmup_steps // args.accumulate_grad_batches)
99
+ scheduler = _build_scheduler(optimizer, effective_steps, effective_warmup)
100
+ trainer.scheduler = scheduler
101
+
102
+ pt_loader = create_streaming_loader(pretrain_dataset, "train",
103
+ tokenizer, config, args.batch_size, mode="pretrain",
104
+ hf_token=args.hf_token, start_step=start_step)
105
+
106
+ pt_val_loader = create_streaming_loader(pretrain_dataset, "validation",
107
+ tokenizer, config, args.batch_size, mode="pretrain",
108
+ hf_token=args.hf_token)
109
+
110
+ start_step = trainer.train_epoch(pt_loader, total_steps=args.pretrain_steps,
111
+ start_step=start_step, stage_name="Pre-Training",
112
+ val_loader=pt_val_loader)
113
+ current_stage = "SFT"
114
+ start_step = 0
115
+
116
+ # --- STAGE 2: SFT ---
117
+ print(f"\n=== STAGE 2: SFT on {sft_dataset} ===")
118
+ sft_lr = args.sft_lr if args.sft_lr else 5e-5
119
+ for param_group in optimizer.param_groups:
120
+ param_group['lr'] = sft_lr
121
+
122
+ # Build LR scheduler for SFT (adjust for gradient accumulation)
123
+ sft_warmup = args.sft_warmup_steps if args.sft_warmup_steps else 200
124
+ effective_sft_steps = max(1, args.sft_steps // args.accumulate_grad_batches)
125
+ effective_sft_warmup = max(1, sft_warmup // args.accumulate_grad_batches)
126
+ scheduler = _build_scheduler(optimizer, effective_sft_steps, effective_sft_warmup)
127
+ trainer.scheduler = scheduler
128
+
129
+ sft_loader = create_streaming_loader(sft_dataset, "train",
130
+ tokenizer, config, args.batch_size, mode="sft",
131
+ hf_token=args.hf_token, start_step=start_step)
132
+
133
+ sft_val_loader = create_streaming_loader(sft_dataset, "validation",
134
+ tokenizer, config, args.batch_size, mode="sft",
135
+ hf_token=args.hf_token)
136
+
137
+ trainer.train_epoch(sft_loader, total_steps=args.sft_steps,
138
+ start_step=start_step, stage_name="SFT",
139
+ val_loader=sft_val_loader)
140
+
141
+ print("\nTraining Complete!")
142
+
143
+
144
+ def _build_scheduler(optimizer, total_steps, warmup_steps):
145
+ """Cosine annealing with linear warmup. LR multiplier: 0→1 (warmup) → 0.1 (cosine)."""
146
+ def lr_lambda(current_step):
147
+ if current_step < warmup_steps:
148
+ return float(current_step) / float(max(1, warmup_steps))
149
+ progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
150
+ return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
151
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
152
+
153
+
154
+ @torch.no_grad()
155
+ def generate_command(args):
156
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157
+ config = AetherisConfig.from_yaml(args.config)
158
+
159
+ add_special = config.vocab_size > 50257
160
+ tokenizer = get_tokenizer(add_special_tokens=add_special)
161
+
162
+ model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
163
+
164
+ # Resize if needed
165
+ if len(tokenizer) != config.vocab_size:
166
+ model.resize_token_embeddings(config.vocab_size)
167
+
168
+ load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
169
+ model.eval()
170
+
171
+ prompt = args.prompt
172
+ max_new_tokens = args.max_new_tokens
173
+ temperature = args.temperature
174
+ top_k = args.top_k
175
+ top_p = args.top_p
176
+ repetition_penalty = args.repetition_penalty
177
+
178
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
179
+ generated_ids = input_ids.clone()
180
+ history_ids = set(input_ids[0].tolist())
181
+
182
+ print("-" * 50)
183
+ print(f"Prompt: {prompt}")
184
+ print("Generated Continuation:")
185
+
186
+ for step in range(max_new_tokens):
187
+ use_autocast = True
188
+ if config.torch_dtype == torch.float32:
189
+ use_autocast = False
190
+
191
+ if use_autocast:
192
+ with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
193
+ outputs = model(generated_ids)
194
+ logits = outputs['logits']
195
+ next_token_logits = logits[:, -1, :]
196
+ else:
197
+ outputs = model(generated_ids)
198
+ logits = outputs['logits']
199
+ next_token_logits = logits[:, -1, :]
200
+
201
+ # Repetition penalty
202
+ for token_id in history_ids:
203
+ if token_id < next_token_logits.size(-1):
204
+ logit = next_token_logits[0, token_id].item()
205
+ if logit > 0:
206
+ next_token_logits[0, token_id] = logit / repetition_penalty
207
+ else:
208
+ next_token_logits[0, token_id] = logit * repetition_penalty
209
+
210
+ # Temperature
211
+ if temperature > 0:
212
+ next_token_logits = next_token_logits / temperature
213
+
214
+ # Top-p / Top-k
215
+ if top_p < 1.0:
216
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
217
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
218
+ sorted_indices_to_remove = cumulative_probs > top_p
219
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
220
+ sorted_indices_to_remove[..., 0] = False
221
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
222
+ next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
223
+ elif top_k > 0:
224
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
225
+ next_token_logits = torch.full_like(next_token_logits, float('-inf'))
226
+ next_token_logits.scatter_(1, top_k_indices, top_k_logits)
227
+
228
+ # Sample
229
+ next_token_probs = F.softmax(next_token_logits, dim=-1)
230
+ next_token = torch.multinomial(next_token_probs, num_samples=1)
231
+ next_token_item = next_token.item()
232
+
233
+ if next_token_item == tokenizer.eos_token_id:
234
+ break
235
+
236
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
237
+ history_ids.add(next_token_item)
238
+
239
+ new_token_text = tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
240
+ print(new_token_text, end="", flush=True)
241
+
242
+ print("\n" + "-" * 50)
243
+
244
+ def info_command(args):
245
+ config = AetherisConfig.from_yaml(args.config)
246
+ model = HybridMambaMoE(config)
247
+
248
+ total_params = 0
249
+ dense_params = 0
250
+ expert_params = 0
251
+
252
+ for name, param in model.named_parameters():
253
+ numel = param.numel()
254
+ total_params += numel
255
+
256
+ if 'experts' in name:
257
+ expert_params += numel
258
+ else:
259
+ dense_params += numel
260
+
261
+ single_expert_size = expert_params / config.num_experts if config.num_experts > 0 else 0
262
+ active_per_token_params = dense_params + (single_expert_size * config.top_k)
263
+
264
+ def format_count(count):
265
+ return f"{count / 1_000_000:.2f}M"
266
+
267
+ print("=" * 50)
268
+ print("Hybrid Mamba-MoE Model Parameter Analysis")
269
+ print("=" * 50)
270
+ print(f"Total Model Layers (N_Layer): {config.n_layer}")
271
+ print(f"MoE Experts per Layer: {config.num_experts}")
272
+ print(f"Active Experts (Top-K): {config.top_k}")
273
+ print("-" * 50)
274
+ print(f"Total Parameters (Checkpoint Size): {format_count(total_params)}")
275
+ print(f"Dense (Always Active) Parameters: {format_count(dense_params)}")
276
+ print(f"Expert-Only Parameters: {format_count(expert_params)}")
277
+ print("-" * 50)
278
+ print(f"**Active Parameters (Per-Token Compute Load): {format_count(active_per_token_params)}**")
279
+ print(" (This is the 'Dense' parameters + the K active expert parameters)")
280
+ print("=" * 50)
281
+
282
+
283
+ def main():
284
+ parser = argparse.ArgumentParser(description="Aetheris CLI")
285
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
286
+
287
+ # Train Command
288
+ train_parser = subparsers.add_parser("train", help="Train the model")
289
+ train_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
290
+ train_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
291
+ train_parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
292
+ train_parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
293
+ train_parser.add_argument("--pretrain_steps", type=int, default=50000, help="Number of pretraining steps")
294
+ train_parser.add_argument("--sft_steps", type=int, default=1000, help="Number of SFT steps")
295
+ train_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name to load from")
296
+ train_parser.add_argument("--compile", action="store_true", help="Compile model with torch.compile for speed")
297
+ train_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Gradient accumulation steps")
298
+ # Custom dataset args
299
+ train_parser.add_argument("--pretrain-dataset", type=str, default=None,
300
+ help="Pretraining dataset: local JSONL path or HuggingFace dataset name")
301
+ train_parser.add_argument("--sft-dataset", type=str, default=None,
302
+ help="SFT dataset: local JSONL path or HuggingFace dataset name")
303
+ # Learning rate args
304
+ train_parser.add_argument("--lr", type=float, default=None, help="Peak learning rate for pretraining (default: 1e-4)")
305
+ train_parser.add_argument("--sft-lr", type=float, default=None, help="Peak learning rate for SFT (default: 5e-5)")
306
+ train_parser.add_argument("--warmup-steps", type=int, default=None, help="Warmup steps for pretraining (default: 1000)")
307
+ train_parser.add_argument("--sft-warmup-steps", type=int, default=None, help="Warmup steps for SFT (default: 200)")
308
+ train_parser.add_argument("--resume", action="store_true", help="Resume from checkpoint step (default: start from 0)")
309
+
310
+ # Generate Command
311
+ gen_parser = subparsers.add_parser("generate", help="Generate text")
312
+ gen_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
313
+ gen_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
314
+ gen_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
315
+ gen_parser.add_argument("--prompt", type=str, default="The quick brown fox", help="Prompt for generation")
316
+ gen_parser.add_argument("--max_new_tokens", type=int, default=100, help="Max new tokens to generate")
317
+ gen_parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
318
+ gen_parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling")
319
+ gen_parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
320
+ gen_parser.add_argument("--repetition_penalty", type=float, default=3.0, help="Repetition penalty")
321
+
322
+ # Serve Command
323
+ serve_parser = subparsers.add_parser("serve", help="Start the API server")
324
+ serve_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
325
+ serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind")
326
+ serve_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
327
+ serve_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
328
+ serve_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
329
+
330
+ # Info Command
331
+ info_parser = subparsers.add_parser("info", help="Show model info")
332
+ info_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
333
+
334
+ args = parser.parse_args()
335
+
336
+ if args.command == "train":
337
+ train_command(args)
338
+ elif args.command == "generate":
339
+ generate_command(args)
340
+ elif args.command == "serve":
341
+ import uvicorn
342
+ from aetheris.api.server import app, get_engine
343
+
344
+ engine = get_engine()
345
+ from aetheris.inference import InferenceEngine
346
+ import aetheris.api.server
347
+
348
+ aetheris.api.server.engine = InferenceEngine(
349
+ config_path=args.config,
350
+ checkpoint_dir=args.checkpoint_dir,
351
+ checkpoint_name=args.checkpoint_name
352
+ )
353
+
354
+ uvicorn.run(app, host=args.host, port=args.port)
355
+
356
+ elif args.command == "info":
357
+ info_command(args)
358
+ else:
359
+ parser.print_help()
360
+
361
+ if __name__ == "__main__":
362
+ main()
aetheris/config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import yaml
3
+ import torch
4
+ from typing import Optional
5
+
6
+ @dataclass
7
+ class AetherisConfig:
8
+ # Model dimensions
9
+ vocab_size: int = 50257
10
+ d_model: int = 768
11
+ n_layer: int = 24
12
+ num_experts: int = 4
13
+ top_k: int = 1
14
+ d_ff: int = 2304 # d_model * 3
15
+
16
+ # SSM parameters
17
+ ssm_d_state: int = 16
18
+ ssm_expand: int = 2
19
+ d_inner: Optional[int] = None # Will be d_model * ssm_expand if None
20
+
21
+ # Training parameters
22
+ load_balancing_coef: float = 1e-2
23
+ router_z_loss_coef: float = 1e-3
24
+ max_seq_len: int = 512
25
+ dtype: str = "float16" # "float16", "float32", "bfloat16"
26
+
27
+ # Optimization settings
28
+ use_cpu_offload: bool = False
29
+ gradient_checkpointing: bool = True
30
+ checkpoint_ssm_layers: bool = True
31
+ use_flash_attention: bool = False
32
+
33
+ def __post_init__(self):
34
+ if self.d_inner is None:
35
+ self.d_inner = self.d_model * self.ssm_expand
36
+ if self.d_ff is None:
37
+ self.d_ff = self.d_model * 3
38
+
39
+ @property
40
+ def torch_dtype(self):
41
+ if self.dtype == "float16":
42
+ return torch.float16
43
+ elif self.dtype == "float32":
44
+ return torch.float32
45
+ elif self.dtype == "bfloat16":
46
+ return torch.bfloat16
47
+ else:
48
+ raise ValueError(f"Unsupported dtype: {self.dtype}")
49
+
50
+ @classmethod
51
+ def from_yaml(cls, path: str):
52
+ with open(path, 'r') as f:
53
+ config_dict = yaml.safe_load(f)
54
+ return cls(**config_dict)
55
+
56
+ def to_yaml(self, path: str):
57
+ with open(path, 'w') as f:
58
+ yaml.dump(self.__dict__, f)
aetheris/data.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, IterableDataset
3
+ from transformers import AutoTokenizer
4
+ from datasets import load_dataset
5
+ import json
6
+ import random
7
+ from typing import Dict, Iterator, List, Optional
8
+ import os
9
+
10
+ VOXLEX_SPECIAL_TOKENS = [
11
+ "<tool_call>", "</tool_call>",
12
+ "<tool_result>", "</tool_result>",
13
+ "<legal_cite>", "</legal_cite>",
14
+ ]
15
+
16
+
17
+ def get_tokenizer(model_name: str = "gpt2", add_special_tokens: bool = False):
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+ if add_special_tokens:
22
+ num_added = tokenizer.add_special_tokens(
23
+ {"additional_special_tokens": VOXLEX_SPECIAL_TOKENS}
24
+ )
25
+ if num_added > 0:
26
+ print(f" Added {num_added} special tokens → vocab_size={len(tokenizer)}")
27
+ return tokenizer
28
+
29
+
30
+ class StreamingDataset(IterableDataset):
31
+ def __init__(self, dataset, tokenizer, max_seq_len, mode="pretrain", buffer_size=100, skip_samples=0):
32
+ self.dataset = dataset
33
+ self.tokenizer = tokenizer
34
+ self.max_seq_len = max_seq_len
35
+ self.mode = mode
36
+ self.buffer_size = buffer_size
37
+ self.skip_samples = skip_samples
38
+
39
+ def _find_assistant_spans(self, text: str) -> List[tuple]:
40
+ """Find character spans of assistant responses in SFT text."""
41
+ spans = []
42
+ search_from = 0
43
+ while True:
44
+ start = text.find("<|assistant|>", search_from)
45
+ if start == -1:
46
+ break
47
+ content_start = start + len("<|assistant|>")
48
+ # End at next role tag or end of text
49
+ end = len(text)
50
+ for tag in ["<|user|>", "<|system|>", "<|tool|>", "<|endoftext|>"]:
51
+ pos = text.find(tag, content_start)
52
+ if pos != -1:
53
+ end = min(end, pos)
54
+ spans.append((content_start, end))
55
+ search_from = end
56
+ return spans
57
+
58
+ def _prepare_sft_example(self, example):
59
+ """Prepare SFT example with label masking — loss only on assistant tokens."""
60
+ if 'messages' in example:
61
+ # Build text with role tags
62
+ text = ""
63
+ for msg in example['messages']:
64
+ role = msg.get('role', '')
65
+ content = msg.get('content', '')
66
+ text += f"<|{role}|>{content}"
67
+ text += self.tokenizer.eos_token
68
+ elif 'text' in example:
69
+ text = example['text']
70
+ else:
71
+ return None
72
+
73
+ if len(text) < 10:
74
+ return None
75
+
76
+ # Pre-truncate to avoid slow tokenization of very long texts
77
+ max_chars = self.max_seq_len * 5
78
+ if len(text) > max_chars:
79
+ text = text[:max_chars]
80
+
81
+ enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
82
+ return_tensors="pt")
83
+ input_ids = enc['input_ids'][0]
84
+
85
+ if len(input_ids) < 2:
86
+ return None
87
+
88
+ # Build labels: -100 for non-assistant tokens
89
+ labels = torch.full_like(input_ids, -100)
90
+ assistant_spans = self._find_assistant_spans(text)
91
+
92
+ for char_start, char_end in assistant_spans:
93
+ # Map character offsets to token positions
94
+ in_span = False
95
+ for tok_idx in range(len(input_ids)):
96
+ token_span = enc.token_to_chars(0, tok_idx)
97
+ if token_span is None:
98
+ # Special token (e.g. <tool_call>) — include if neighbors are in span
99
+ if in_span:
100
+ labels[tok_idx] = input_ids[tok_idx]
101
+ continue
102
+ tok_start, tok_end = token_span
103
+ # Token overlaps with assistant span
104
+ if tok_end > char_start and tok_start < char_end:
105
+ labels[tok_idx] = input_ids[tok_idx]
106
+ in_span = True
107
+ else:
108
+ in_span = False
109
+
110
+ # Also train on eos token at the end
111
+ if input_ids[-1] == self.tokenizer.eos_token_id:
112
+ labels[-1] = input_ids[-1]
113
+
114
+ # Pad to max_seq_len
115
+ if len(input_ids) < self.max_seq_len:
116
+ pad_len = self.max_seq_len - len(input_ids)
117
+ input_ids = torch.cat([
118
+ input_ids,
119
+ torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
120
+ ])
121
+ labels = torch.cat([
122
+ labels,
123
+ torch.full((pad_len,), -100, dtype=torch.long)
124
+ ])
125
+
126
+ return input_ids, labels
127
+
128
+ def _prepare_pretrain_example(self, example):
129
+ """Prepare pretraining example — loss on all non-pad tokens."""
130
+ text = example.get('text', '')
131
+ if len(text) < 10:
132
+ return None
133
+
134
+ # Pre-truncate text to avoid tokenizing 100K+ char documents
135
+ # GPT-2 averages ~4 chars per token; use 5x max_seq_len as safe limit
136
+ max_chars = self.max_seq_len * 5
137
+ if len(text) > max_chars:
138
+ text = text[:max_chars]
139
+
140
+ enc = self.tokenizer(text, truncation=True, max_length=self.max_seq_len,
141
+ return_tensors="pt")
142
+ input_ids = enc['input_ids'][0]
143
+
144
+ if len(input_ids) < 2:
145
+ return None
146
+
147
+ labels = input_ids.clone()
148
+
149
+ if len(input_ids) < self.max_seq_len:
150
+ pad_len = self.max_seq_len - len(input_ids)
151
+ input_ids = torch.cat([
152
+ input_ids,
153
+ torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
154
+ ])
155
+ labels = torch.cat([
156
+ labels,
157
+ torch.full((pad_len,), -100, dtype=torch.long)
158
+ ])
159
+
160
+ return input_ids, labels
161
+
162
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
163
+ iterator = iter(self.dataset)
164
+ buffer = []
165
+
166
+ for example in iterator:
167
+ if self.mode == "pretrain":
168
+ result = self._prepare_pretrain_example(example)
169
+ else:
170
+ result = self._prepare_sft_example(example)
171
+
172
+ if result is None:
173
+ continue
174
+
175
+ buffer.append(result)
176
+
177
+ if len(buffer) >= self.buffer_size:
178
+ random.shuffle(buffer)
179
+ for _ in range(self.buffer_size // 2):
180
+ item = buffer.pop()
181
+ if self.skip_samples > 0:
182
+ self.skip_samples -= 1
183
+ continue
184
+ yield item
185
+
186
+ # Yield remaining
187
+ random.shuffle(buffer)
188
+ while buffer:
189
+ item = buffer.pop()
190
+ if self.skip_samples > 0:
191
+ self.skip_samples -= 1
192
+ continue
193
+ yield item
194
+
195
+
196
+ def _load_jsonl_dataset(path: str):
197
+ """Load a local JSONL file as a streaming iterable (no memory materialization)."""
198
+ from datasets import IterableDataset
199
+
200
+ def gen():
201
+ with open(path, 'r') as f:
202
+ for line in f:
203
+ line = line.strip()
204
+ if line:
205
+ yield json.loads(line)
206
+
207
+ return IterableDataset.from_generator(gen)
208
+
209
+
210
+ def create_streaming_loader(dataset_name, split, tokenizer, config, batch_size,
211
+ mode="pretrain", hf_token=None, start_step=0):
212
+ # Support local JSONL files
213
+ if os.path.isfile(dataset_name) and dataset_name.endswith('.jsonl'):
214
+ print(f" Loading local dataset: {dataset_name}")
215
+ raw_dataset = _load_jsonl_dataset(dataset_name)
216
+ else:
217
+ raw_dataset = load_dataset(dataset_name, split=split, streaming=True,
218
+ trust_remote_code=True, token=hf_token)
219
+
220
+ # Calculate samples to skip: start_step * batch_size
221
+ skip_samples = start_step * batch_size
222
+ if skip_samples > 0:
223
+ print(f" [Loader] Resuming: Fast-forwarding dataset by {skip_samples} samples...")
224
+
225
+ stream_ds = StreamingDataset(raw_dataset, tokenizer, config.max_seq_len,
226
+ mode=mode, skip_samples=skip_samples)
227
+
228
+ # num_workers=0 avoids 4x data duplication with IterableDataset
229
+ # (each worker iterates the full dataset without sharding logic)
230
+ return DataLoader(stream_ds, batch_size=batch_size, pin_memory=True,
231
+ num_workers=0)
aetheris/inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, List, Generator
4
+ from aetheris.config import AetherisConfig
5
+ from aetheris.model import HybridMambaMoE
6
+ from aetheris.data import get_tokenizer
7
+ from aetheris.utils import load_latest_checkpoint
8
+
9
+ class InferenceEngine:
10
+ def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None):
11
+ self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
12
+ self.config = AetherisConfig.from_yaml(config_path)
13
+ self.tokenizer = get_tokenizer()
14
+
15
+ self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype)
16
+
17
+ # Load checkpoint
18
+ # Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None
19
+ load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name)
20
+ self.model.eval()
21
+
22
+ def generate(self,
23
+ prompt: str,
24
+ max_new_tokens: int = 100,
25
+ temperature: float = 0.8,
26
+ top_k: int = 0,
27
+ top_p: float = 0.9,
28
+ repetition_penalty: float = 1.0,
29
+ stream: bool = False) -> Generator[str, None, None] | str:
30
+
31
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
32
+ generated_ids = input_ids.clone()
33
+ history_ids = set(input_ids[0].tolist())
34
+
35
+ def token_generator():
36
+ nonlocal generated_ids
37
+ for _ in range(max_new_tokens):
38
+ # Check if we should use autocast (skip if model uses float32)
39
+ use_autocast = True
40
+ if self.config.torch_dtype == torch.float32:
41
+ use_autocast = False
42
+
43
+ if use_autocast:
44
+ with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype):
45
+ outputs = self.model(generated_ids)
46
+ logits = outputs['logits']
47
+ next_token_logits = logits[:, -1, :]
48
+ else:
49
+ outputs = self.model(generated_ids)
50
+ logits = outputs['logits']
51
+ next_token_logits = logits[:, -1, :]
52
+
53
+ # Repetition penalty
54
+ for token_id in history_ids:
55
+ if token_id < next_token_logits.size(-1):
56
+ logit = next_token_logits[0, token_id].item()
57
+ if logit > 0:
58
+ next_token_logits[0, token_id] = logit / repetition_penalty
59
+ else:
60
+ next_token_logits[0, token_id] = logit * repetition_penalty
61
+
62
+ # Temperature
63
+ if temperature > 0:
64
+ next_token_logits = next_token_logits / temperature
65
+
66
+ # Top-p / Top-k
67
+ if top_p < 1.0:
68
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
69
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
70
+ sorted_indices_to_remove = cumulative_probs > top_p
71
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
72
+ sorted_indices_to_remove[..., 0] = False
73
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
74
+ next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
75
+ elif top_k > 0:
76
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
77
+ next_token_logits = torch.full_like(next_token_logits, float('-inf'))
78
+ next_token_logits.scatter_(1, top_k_indices, top_k_logits)
79
+
80
+ # Sample
81
+ next_token_probs = F.softmax(next_token_logits, dim=-1)
82
+ next_token = torch.multinomial(next_token_probs, num_samples=1)
83
+ next_token_item = next_token.item()
84
+
85
+ if next_token_item == self.tokenizer.eos_token_id:
86
+ break
87
+
88
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
89
+ history_ids.add(next_token_item)
90
+
91
+ new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
92
+ yield new_token_text
93
+
94
+ if stream:
95
+ return token_generator()
96
+ else:
97
+ return "".join(list(token_generator()))
98
+
99
+ def generate_full(self,
100
+ prompt: str,
101
+ max_new_tokens: int = 100,
102
+ temperature: float = 0.8,
103
+ top_k: int = 0,
104
+ top_p: float = 0.9,
105
+ repetition_penalty: float = 1.0) -> str:
106
+ return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)
aetheris/model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint
4
+ from typing import Dict, Any, List
5
+ from .config import AetherisConfig
6
+ from .modules import SSMBlock, SparseMoELayer
7
+
8
+ class HybridMambaMoE(nn.Module):
9
+ def __init__(self, config: AetherisConfig):
10
+ super().__init__()
11
+ self.config = config
12
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
13
+
14
+ self.layers = nn.ModuleList()
15
+ for i in range(config.n_layer):
16
+ if i % 2 == 0:
17
+ self.layers.append(SSMBlock(config))
18
+ else:
19
+ self.layers.append(SparseMoELayer(config))
20
+
21
+ self.final_norm = nn.LayerNorm(config.d_model)
22
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
23
+ self.lm_head.weight = self.embedding.weight # Weight tying
24
+
25
+ # Use -100 as ignore_index (PyTorch standard for label masking)
26
+ self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
27
+ self.gradient_checkpointing = config.gradient_checkpointing
28
+
29
+ # Initialize embeddings with smaller scale
30
+ nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
31
+
32
+ def resize_token_embeddings(self, new_vocab_size: int):
33
+ """Resize embedding and lm_head for new tokens. New embeddings initialized from mean of existing."""
34
+ old_vocab_size = self.embedding.num_embeddings
35
+ if new_vocab_size == old_vocab_size:
36
+ return
37
+ old_weight = self.embedding.weight.data
38
+ mean_embed = old_weight.mean(dim=0)
39
+ self.embedding = nn.Embedding(new_vocab_size, self.config.d_model)
40
+ self.embedding.weight.data[:old_vocab_size] = old_weight
41
+ self.embedding.weight.data[old_vocab_size:] = mean_embed.unsqueeze(0).expand(
42
+ new_vocab_size - old_vocab_size, -1
43
+ )
44
+ self.lm_head = nn.Linear(self.config.d_model, new_vocab_size, bias=False)
45
+ self.lm_head.weight = self.embedding.weight # Re-tie weights
46
+ self.config.vocab_size = new_vocab_size
47
+
48
+ def _init_weights(self, module):
49
+ """Apply proper weight initialization"""
50
+ if isinstance(module, nn.Linear):
51
+ nn.init.xavier_uniform_(module.weight, gain=0.5)
52
+ if module.bias is not None:
53
+ nn.init.zeros_(module.bias)
54
+ elif isinstance(module, nn.Embedding):
55
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
56
+ elif isinstance(module, nn.LayerNorm):
57
+ nn.init.ones_(module.weight)
58
+ nn.init.zeros_(module.bias)
59
+
60
+ def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
61
+ x = self.embedding(input_ids)
62
+ total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
63
+
64
+ for i, layer in enumerate(self.layers):
65
+ if self.gradient_checkpointing and self.training:
66
+ # Checkpoint ALL layers for maximum memory savings
67
+ if isinstance(layer, SparseMoELayer):
68
+ def moe_forward(module, inp):
69
+ return module(inp)
70
+ x, aux_loss = torch.utils.checkpoint.checkpoint(
71
+ moe_forward, layer, x, use_reentrant=False
72
+ )
73
+ total_aux_loss = total_aux_loss + aux_loss
74
+ else:
75
+ x = torch.utils.checkpoint.checkpoint(
76
+ layer, x, use_reentrant=False
77
+ )
78
+ else:
79
+ if isinstance(layer, SparseMoELayer):
80
+ x, aux_loss = layer(x)
81
+ total_aux_loss = total_aux_loss + aux_loss
82
+ else:
83
+ x = layer(x)
84
+
85
+ x = self.final_norm(x)
86
+ logits = self.lm_head(x)
87
+
88
+ if labels is not None:
89
+ shift_logits = logits[..., :-1, :].contiguous()
90
+ shift_labels = labels[..., 1:].contiguous()
91
+ ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
92
+ shift_labels.view(-1))
93
+
94
+ # Scale down aux loss to prevent it from dominating
95
+ total_loss = ce_loss + 0.01 * total_aux_loss
96
+
97
+ return {
98
+ "loss": total_loss,
99
+ "ce_loss": ce_loss,
100
+ "aux_loss": total_aux_loss,
101
+ "logits": logits
102
+ }
103
+
104
+ return {"logits": logits}
aetheris/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .expert import Expert
2
+ from .ssm import SSMBlock, selective_scan_native
3
+ from .moe import SparseMoELayer
aetheris/modules/expert.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class Expert(nn.Module):
6
+ """Memory-efficient Feed-Forward Network expert with proper initialization."""
7
+ def __init__(self, d_model: int, d_ff: int):
8
+ super().__init__()
9
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
10
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
11
+ self.act = nn.GELU()
12
+
13
+ # Proper initialization to prevent NaN
14
+ nn.init.xavier_uniform_(self.w1.weight, gain=0.5)
15
+ nn.init.xavier_uniform_(self.w2.weight, gain=0.5)
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ orig_dtype = x.dtype
19
+ # Force float32 for internal computation to prevent overflow in half precision
20
+ x = x.to(torch.float32)
21
+
22
+ # Cast weights to float32 for calculation
23
+ # This is necessary because the module weights might be float16
24
+ w1_weight = self.w1.weight.to(torch.float32)
25
+ w2_weight = self.w2.weight.to(torch.float32)
26
+
27
+ h = F.linear(x, w1_weight)
28
+ h = self.act(h)
29
+ out = F.linear(h, w2_weight)
30
+
31
+ # Clamp to avoid Inf when casting back to float16
32
+ if orig_dtype == torch.float16:
33
+ out = torch.clamp(out, min=-65500.0, max=65500.0)
34
+
35
+ return out.to(orig_dtype)
aetheris/modules/moe.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..config import AetherisConfig
5
+ from .expert import Expert
6
+
7
+ class SparseMoELayer(nn.Module):
8
+ """Memory-optimized Sparse MoE with efficient routing."""
9
+ def __init__(self, config: AetherisConfig):
10
+ super().__init__()
11
+ self.d_model = config.d_model
12
+ self.num_experts = config.num_experts
13
+ self.top_k = config.top_k
14
+ self.load_balancing_coef = config.load_balancing_coef
15
+ self.z_loss_coef = config.router_z_loss_coef
16
+
17
+ self.gate = nn.Linear(config.d_model, config.num_experts, bias=False)
18
+ self.experts = nn.ModuleList([Expert(config.d_model, config.d_ff)
19
+ for _ in range(config.num_experts)])
20
+ self.norm = nn.LayerNorm(config.d_model)
21
+
22
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
23
+ B, L, D = x.shape
24
+ x_norm = self.norm(x)
25
+ flat_x = x_norm.view(-1, D)
26
+
27
+ # Routing Logits with stability
28
+ gate_logits = self.gate(flat_x)
29
+
30
+ # Clamp logits to prevent overflow
31
+ gate_logits = torch.clamp(gate_logits, min=-10.0, max=10.0)
32
+
33
+ # Z-Loss for stability
34
+ z_loss = torch.mean(torch.logsumexp(gate_logits, dim=-1)**2) * self.z_loss_coef
35
+
36
+ if self.training:
37
+ # Reduce noise for stability
38
+ gate_logits = gate_logits + torch.randn_like(gate_logits) * 1e-3
39
+
40
+ gate_probs = F.softmax(gate_logits, dim=-1)
41
+ gate_weights, expert_indices = torch.topk(gate_probs, self.top_k, dim=-1)
42
+
43
+ # Normalize weights for stability
44
+ gate_weights = gate_weights / (gate_weights.sum(dim=-1, keepdim=True) + 1e-8)
45
+
46
+ # Load balancing loss
47
+ # Use only the top-1 expert for load balancing calculation to keep it simple and consistent
48
+ expert_mask = F.one_hot(expert_indices[:, 0], num_classes=self.num_experts).float()
49
+ fraction_routed = expert_mask.mean(dim=0)
50
+ mean_prob = gate_probs.mean(dim=0)
51
+
52
+ aux_loss = (self.num_experts * torch.sum(fraction_routed * mean_prob)) * self.load_balancing_coef
53
+ total_aux_loss = aux_loss + z_loss
54
+
55
+ # Efficient dispatch with in-place operations
56
+ # Accumulate in float32 to prevent overflow during aggregation
57
+ final_output = torch.zeros_like(flat_x, dtype=torch.float32)
58
+
59
+ # Iterate over all k selected experts
60
+ for k_idx in range(self.top_k):
61
+ for i, expert in enumerate(self.experts):
62
+ # Find tokens routed to expert 'i' at the k-th position
63
+ mask = (expert_indices[:, k_idx] == i)
64
+ if not mask.any():
65
+ continue
66
+
67
+ expert_input = flat_x[mask]
68
+ expert_out = expert(expert_input)
69
+
70
+ # Apply weights
71
+ weights = gate_weights[mask, k_idx].unsqueeze(1)
72
+
73
+ # Cast to float32 for accumulation
74
+ expert_out = expert_out.to(torch.float32)
75
+ weights = weights.to(torch.float32)
76
+
77
+ # Accumulate output (add to existing results from other experts)
78
+ final_output[mask] += expert_out * weights
79
+
80
+ # Cast back to original dtype
81
+ final_output = final_output.to(flat_x.dtype)
82
+
83
+ return x + final_output.view(B, L, D), total_aux_loss
aetheris/modules/ssm.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..config import AetherisConfig
5
+
6
+ # Try to import CUDA selective scan kernel
7
+ try:
8
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
9
+ HAS_CUDA_SSM = True
10
+ except ImportError:
11
+ HAS_CUDA_SSM = False
12
+
13
+ @torch.jit.ignore
14
+ def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
15
+ B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
16
+ """Fallback pure-Python scan (slow, O(L) sequential)."""
17
+ B_size, L, D_inner = u.shape
18
+ D_state = A.shape[-1]
19
+ original_dtype = u.dtype
20
+
21
+ h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=torch.float32)
22
+ ys = []
23
+
24
+ u = u.float()
25
+ delta = delta.float()
26
+ A = A.float()
27
+ B = B.float()
28
+ C = C.float()
29
+ D = D.float()
30
+
31
+ for l in range(L):
32
+ dt = delta[:, l, :].unsqueeze(-1)
33
+ dA = torch.exp(dt * A)
34
+ B_l = B[:, l, :].unsqueeze(1)
35
+ dB = dt * B_l
36
+ u_t = u[:, l, :].unsqueeze(-1)
37
+ h = dA * h + dB * u_t
38
+ C_l = C[:, l, :].unsqueeze(1)
39
+ y_t = torch.sum(h * C_l, dim=-1)
40
+ ys.append(y_t)
41
+
42
+ y = torch.stack(ys, dim=1)
43
+ y = y + u * D
44
+ return y.to(dtype=original_dtype)
45
+
46
+ class SSMBlock(nn.Module):
47
+ """State Space Model block with optional CUDA-accelerated selective scan."""
48
+ def __init__(self, config: AetherisConfig):
49
+ super().__init__()
50
+ self.d_model = config.d_model
51
+ self.d_state = config.ssm_d_state
52
+ self.d_inner = config.d_inner
53
+
54
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False)
55
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
56
+ self.conv_d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3,
57
+ padding=2, groups=self.d_inner, bias=False)
58
+ self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=False)
59
+
60
+ self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
61
+ self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
62
+ self.delta_proj = nn.Linear(self.d_inner, self.d_inner, bias=False)
63
+
64
+ self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state) * 0.1 - 4.0)
65
+ self.D = nn.Parameter(torch.ones(self.d_inner) * 0.1)
66
+
67
+ self.act = nn.SiLU()
68
+ self.norm = nn.LayerNorm(config.d_model)
69
+
70
+ nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5)
71
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5)
72
+ nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.5)
73
+ nn.init.xavier_uniform_(self.B_proj.weight, gain=0.5)
74
+ nn.init.xavier_uniform_(self.C_proj.weight, gain=0.5)
75
+ nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ B, L, D = x.shape
79
+ x_norm = self.norm(x)
80
+
81
+ xz = self.in_proj(x_norm)
82
+ x_in, z_gate = xz.chunk(2, dim=-1)
83
+ x_conv = self.conv_d(x_in.transpose(1, 2))
84
+ x_conv = x_conv[:, :, :-2].transpose(1, 2)
85
+ x_conv = self.act(x_conv)
86
+
87
+ B_ssm = self.B_proj(x_conv)
88
+ C_ssm = self.C_proj(x_conv)
89
+
90
+ # A is (D_inner, D_state) — clamped and negated
91
+ A = -torch.exp(torch.clamp(self.A_log, min=-10.0, max=2.0))
92
+
93
+ if HAS_CUDA_SSM and x.is_cuda:
94
+ # CUDA kernel expects float32 — cast inputs and cast output back
95
+ original_dtype = x_conv.dtype
96
+ delta_raw = self.delta_proj(x_conv)
97
+ y_ssm = selective_scan_fn(
98
+ x_conv.transpose(1, 2).contiguous().float(), # (B, D_inner, L)
99
+ delta_raw.transpose(1, 2).contiguous().float(), # (B, D_inner, L)
100
+ A.contiguous().float(), # (D_inner, D_state)
101
+ B_ssm.transpose(1, 2).contiguous().float(), # (B, D_state, L)
102
+ C_ssm.transpose(1, 2).contiguous().float(), # (B, D_state, L)
103
+ self.D.float(), # (D_inner,)
104
+ z=None,
105
+ delta_bias=None,
106
+ delta_softplus=True,
107
+ return_last_state=False,
108
+ )
109
+ y_ssm = y_ssm.to(dtype=original_dtype).transpose(1, 2) # Back to (B, L, D_inner)
110
+ else:
111
+ # Fallback: pure Python sequential scan
112
+ delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4
113
+ A_batched = A.unsqueeze(0).expand(B, -1, -1)
114
+ y_ssm = selective_scan_native(x_conv, delta, A_batched, B_ssm, C_ssm, self.D)
115
+
116
+ y_gate = F.silu(self.gate_proj(x_norm)) * y_ssm
117
+ output = self.out_proj(y_gate)
118
+
119
+ return x + output
aetheris/trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .trainer import Trainer
aetheris/trainer/trainer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import os
4
+ from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats
5
+ from aetheris.data import get_tokenizer
6
+
7
+ class Trainer:
8
+ def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None, grad_accum_steps=1):
9
+ self.model = model
10
+ self.optimizer = optimizer
11
+ self.scaler = scaler
12
+ self.config = config
13
+ self.device = device
14
+ self.checkpoint_dir = checkpoint_dir
15
+ self.logger = logger
16
+ self.grad_accum_steps = grad_accum_steps
17
+ self.scheduler = None # Set by CLI before train_epoch()
18
+
19
+ self.model.to(self.device)
20
+
21
+ def validate(self, val_loader, global_step):
22
+ self.model.eval()
23
+ total_loss = 0
24
+ total_items = 0
25
+ num_batches = 100 # Validate on 100 batches to save time
26
+
27
+ print(f"\n[Validation] Starting validation at step {global_step}...")
28
+
29
+ with torch.no_grad():
30
+ for i, batch in enumerate(val_loader):
31
+ if i >= num_batches:
32
+ break
33
+
34
+ input_ids, labels = batch
35
+ input_ids = input_ids.to(self.device, non_blocking=True)
36
+ labels = labels.to(self.device, non_blocking=True)
37
+
38
+ # Auto-cast context — bf16 on Ampere+, fp16 fallback
39
+ autocast_dtype = torch.bfloat16
40
+
41
+ use_autocast = True if self.config.torch_dtype != torch.float32 else False
42
+
43
+ if use_autocast:
44
+ with torch.cuda.amp.autocast(dtype=autocast_dtype):
45
+ output = self.model(input_ids, labels)
46
+ else:
47
+ output = self.model(input_ids, labels)
48
+
49
+ total_loss += output["loss"].item()
50
+ total_items += 1
51
+
52
+ avg_loss = total_loss / total_items if total_items > 0 else 0
53
+ perplexity = torch.exp(torch.tensor(avg_loss)).item()
54
+
55
+ print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
56
+ self.model.train()
57
+ return avg_loss
58
+
59
+ def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
60
+ print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps} (Accum={self.grad_accum_steps})\n{'='*70}", flush=True)
61
+ self.model.train()
62
+ global_step = start_step
63
+ running_loss = 0
64
+
65
+ print("Initializing data iterator...")
66
+ train_iter = iter(train_loader)
67
+
68
+ print("Fetching first batch...")
69
+
70
+ # Zero gradients initially
71
+ self.optimizer.zero_grad(set_to_none=True)
72
+
73
+ while global_step < total_steps:
74
+ step_start = time.time()
75
+
76
+ # Removed periodic cache clearing for performance
77
+
78
+ try:
79
+ batch = next(train_iter)
80
+ if global_step == start_step:
81
+ print(f"✓ First batch loaded! Starting training loop...", flush=True)
82
+ except StopIteration:
83
+ train_iter = iter(train_loader)
84
+ batch = next(train_iter)
85
+
86
+ data_time = time.time() - step_start
87
+ input_ids, labels = batch
88
+ input_ids = input_ids.to(self.device, non_blocking=True)
89
+ labels = labels.to(self.device, non_blocking=True)
90
+
91
+ gpu_start = time.time()
92
+ # Determine autocast dtype — bf16 on Ampere+ (no NaN from range overflow)
93
+ autocast_dtype = torch.bfloat16
94
+
95
+ # Check if we should use autocast (skip if model uses float32)
96
+ use_autocast = True
97
+ if self.config.torch_dtype == torch.float32:
98
+ use_autocast = False
99
+
100
+ if use_autocast:
101
+ with torch.cuda.amp.autocast(dtype=autocast_dtype):
102
+ output = self.model(input_ids, labels)
103
+ # Scale loss for accumulation
104
+ loss = output["loss"] / self.grad_accum_steps
105
+ else:
106
+ output = self.model(input_ids, labels)
107
+ loss = output["loss"] / self.grad_accum_steps
108
+
109
+ # NaN loss detection — skip batch entirely to prevent corruption
110
+ if torch.isnan(loss) or torch.isinf(loss):
111
+ nan_count = getattr(self, '_nan_count', 0) + 1
112
+ self._nan_count = nan_count
113
+ print(f"WARNING: NaN/Inf loss at step {global_step} (count={nan_count}), skipping batch", flush=True)
114
+ self.optimizer.zero_grad(set_to_none=True)
115
+ global_step += 1
116
+ continue
117
+
118
+ loss.backward()
119
+ if self.device.type == 'cuda':
120
+ torch.cuda.synchronize()
121
+ gpu_time = time.time() - gpu_start
122
+
123
+ # Gradient Accumulation Step
124
+ if (global_step + 1) % self.grad_accum_steps == 0:
125
+ # Gradient clipping
126
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
127
+
128
+ if torch.isnan(grad_norm) or torch.isinf(grad_norm):
129
+ print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update", flush=True)
130
+ self.optimizer.zero_grad(set_to_none=True)
131
+ else:
132
+ self.optimizer.step()
133
+ self.optimizer.zero_grad(set_to_none=True)
134
+
135
+ # Step LR scheduler
136
+ if self.scheduler is not None:
137
+ self.scheduler.step()
138
+
139
+ global_step += 1
140
+ running_loss += (loss.item() * self.grad_accum_steps) # Un-scale for reporting
141
+
142
+ # Per-step progress file for monitoring (cheap I/O)
143
+ if global_step <= 20 or global_step % 100 == 0:
144
+ total_elapsed = time.time() - step_start
145
+ with open("/workspace/training_progress.log", "a") as pf:
146
+ pf.write(f"step={global_step} loss={loss.item() * self.grad_accum_steps:.4f} total={total_elapsed:.1f}s data={data_time:.1f}s gpu={gpu_time:.1f}s\n")
147
+
148
+ if global_step % 10 == 0:
149
+ avg_loss = running_loss / 10
150
+ t_diff = time.time() - step_start
151
+ if self.device.type == 'cuda':
152
+ mem = torch.cuda.memory_allocated() / 1e9
153
+ max_mem = torch.cuda.max_memory_allocated() / 1e9
154
+ mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
155
+ else:
156
+ mem_str = "CPU Mode"
157
+
158
+ tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
159
+ current_lr = self.optimizer.param_groups[0]['lr']
160
+ msg = (f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
161
+ f"LR: {current_lr:.2e} | {mem_str} | {tokens_per_sec:.0f} tok/s")
162
+ print(msg, flush=True)
163
+ # Write progress to file (bypasses stdout buffering)
164
+ with open("/workspace/training_progress.log", "a") as pf:
165
+ pf.write(msg + "\n")
166
+ running_loss = 0
167
+
168
+ if global_step % 500 == 0:
169
+ save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
170
+ with open("/workspace/training_progress.log", "a") as pf:
171
+ pf.write(f" [Checkpoint saved at step {global_step}]\n")
172
+
173
+ if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
174
+ self.validate(val_loader, global_step)
175
+
176
+ return global_step
aetheris/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import Tuple
4
+
5
+ def save_checkpoint(model, optimizer, scaler, step, stage, checkpoint_dir, checkpoint_name="checkpoint_current.pth"):
6
+ os.makedirs(checkpoint_dir, exist_ok=True)
7
+ path = os.path.join(checkpoint_dir, checkpoint_name)
8
+ torch.save({
9
+ 'step': step,
10
+ 'stage': stage,
11
+ 'model_state_dict': model.state_dict(),
12
+ 'optimizer_state_dict': optimizer.state_dict(),
13
+ 'scaler_state_dict': scaler.state_dict()
14
+ }, path)
15
+ print(f" [Checkpoint] Saved at step {step}")
16
+
17
+ def load_latest_checkpoint(model, optimizer, scaler, device, checkpoint_dir, checkpoint_name="checkpoint_current.pth") -> Tuple[int, str]:
18
+ path = os.path.join(checkpoint_dir, checkpoint_name)
19
+ if not os.path.exists(path):
20
+ return 0, "Pre-Training"
21
+
22
+ print(f" [Checkpoint] Loading from {path}...")
23
+ ckpt = torch.load(path, map_location=device)
24
+ state = ckpt['model_state_dict']
25
+
26
+ # Handle vocab size mismatch (base checkpoint may have fewer tokens than model)
27
+ model_vocab = model.config.vocab_size
28
+ for key in ("embedding.weight", "lm_head.weight"):
29
+ if key in state and state[key].shape[0] < model_vocab:
30
+ old = state[key]
31
+ pad_size = model_vocab - old.shape[0]
32
+ mean_vec = old.mean(dim=0)
33
+ state[key] = torch.cat([old, mean_vec.unsqueeze(0).expand(pad_size, -1)])
34
+ print(f" [Checkpoint] Padded {key}: {old.shape[0]} → {model_vocab}")
35
+
36
+ model.load_state_dict(state, strict=False)
37
+
38
+ if optimizer and 'optimizer_state_dict' in ckpt:
39
+ try:
40
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
41
+ except (ValueError, KeyError):
42
+ print(" [Checkpoint] Optimizer state incompatible (vocab resize), using fresh optimizer")
43
+ if scaler and 'scaler_state_dict' in ckpt:
44
+ scaler.load_state_dict(ckpt['scaler_state_dict'])
45
+ return ckpt['step'], ckpt['stage']
46
+
47
+ def calculate_model_stats(model):
48
+ total_params = sum(p.numel() for p in model.parameters())
49
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
50
+ return {
51
+ 'total_params': total_params,
52
+ 'trainable_params': trainable_params,
53
+ 'active_params': int(total_params * 0.6), # Approximation
54
+ 'sparsity_ratio': 0.6 # Approximation
55
+ }
config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_ssm_layers: true
2
+ d_ff: 3072
3
+ d_inner: 2048
4
+ d_model: 1024
5
+ dtype: float16
6
+ gradient_checkpointing: true
7
+ load_balancing_coef: 0.01
8
+ max_seq_len: 2048
9
+ n_layer: 24
10
+ num_experts: 4
11
+ router_z_loss_coef: 0.001
12
+ ssm_d_state: 16
13
+ ssm_expand: 2
14
+ top_k: 1
15
+ use_cpu_offload: false
16
+ use_flash_attention: false
17
+ vocab_size: 261019
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9133520b370ce0ebab902748e74c0e60898f0ffe2c2f0d54f66f9412f40e9921
3
+ size 2886684406