Debito commited on
Commit
fc54e43
·
verified ·
1 Parent(s): f74f0af

Upload 4 files

Browse files
training/data_loader.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # training/data_loader.py
3
+ # =============================================================================
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from typing import List, Dict, Iterator
7
+ import json
8
+ import random
9
+ from core.tokenizer import MambaTokenizer
10
+ from core.preprocess import TextPreprocessor
11
+
12
+ class MambaDataset(Dataset):
13
+ """Dataset for Mamba training"""
14
+
15
+ def __init__(self, data_path: str, tokenizer: MambaTokenizer,
16
+ preprocessor: TextPreprocessor, config):
17
+ self.config = config
18
+ self.tokenizer = tokenizer
19
+ self.preprocessor = preprocessor
20
+ self.max_length = config.max_seq_len
21
+
22
+ # Load data
23
+ self.data = self._load_data(data_path)
24
+
25
+ def _load_data(self, data_path: str) -> List[str]:
26
+ """Load training data from file"""
27
+ data = []
28
+
29
+ try:
30
+ if data_path.endswith('.json'):
31
+ with open(data_path, 'r') as f:
32
+ raw_data = json.load(f)
33
+ if isinstance(raw_data, list):
34
+ data = [item['text'] if isinstance(item, dict) else str(item)
35
+ for item in raw_data]
36
+ else:
37
+ data = [raw_data['text']]
38
+
39
+ elif data_path.endswith('.txt'):
40
+ with open(data_path, 'r') as f:
41
+ content = f.read()
42
+ # Split into chunks
43
+ data = self.preprocessor.chunk_text(content, self.max_length)
44
+
45
+ print(f"Loaded {len(data)} training examples")
46
+
47
+ except Exception as e:
48
+ print(f"Error loading data: {e}")
49
+ # Create dummy data for testing
50
+ data = [f"This is example text number {i}." for i in range(1000)]
51
+
52
+ return data
53
+
54
+ def __len__(self) -> int:
55
+ return len(self.data)
56
+
57
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
58
+ """Get a training example"""
59
+ text = self.data[idx]
60
+
61
+ # Preprocess text
62
+ clean_text = self.preprocessor.clean_text(text)
63
+
64
+ # Tokenize
65
+ encoded = self.tokenizer.encode(clean_text, max_length=self.max_length)
66
+
67
+ # Create input and target sequences
68
+ input_ids = encoded['input_ids'].squeeze(0) # [seq_len]
69
+
70
+ # For language modeling, target is input shifted by 1
71
+ target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
72
+
73
+ return {
74
+ 'input_ids': input_ids[:-1], # [seq_len-1]
75
+ 'target_ids': target_ids[:-1], # [seq_len-1]
76
+ 'attention_mask': encoded['attention_mask'].squeeze(0)[:-1]
77
+ }
78
+
79
+ class DomainSpecificDataset(Dataset):
80
+ """Dataset for domain-specific specialist training"""
81
+
82
+ def __init__(self, domain_data: Dict[str, List[str]], domain_id: int,
83
+ tokenizer: MambaTokenizer, preprocessor: TextPreprocessor, config):
84
+ self.domain_id = domain_id
85
+ self.tokenizer = tokenizer
86
+ self.preprocessor = preprocessor
87
+ self.config = config
88
+
89
+ # Get domain-specific data
90
+ domain_name = f"domain_{domain_id}"
91
+ self.data = domain_data.get(domain_name, [])
92
+
93
+ if not self.data:
94
+ # Create synthetic domain data for testing
95
+ self.data = [f"Domain {domain_id} specific text example {i}."
96
+ for i in range(100)]
97
+
98
+ def __len__(self) -> int:
99
+ return len(self.data)
100
+
101
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
102
+ """Get domain-specific training example"""
103
+ text = self.data[idx]
104
+
105
+ # Preprocess and tokenize
106
+ clean_text = self.preprocessor.clean_text(text)
107
+ encoded = self.tokenizer.encode(clean_text, max_length=self.config.max_seq_len)
108
+
109
+ input_ids = encoded['input_ids'].squeeze(0)
110
+ target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
111
+
112
+ return {
113
+ 'input_ids': input_ids[:-1],
114
+ 'target_ids': target_ids[:-1],
115
+ 'attention_mask': encoded['attention_mask'].squeeze(0)[:-1],
116
+ 'domain_id': self.domain_id
117
+ }
118
+
119
+ def create_data_loaders(config, tokenizer: MambaTokenizer,
120
+ preprocessor: TextPreprocessor) -> Dict[str, DataLoader]:
121
+ """Create data loaders for training"""
122
+
123
+ # Main training dataset
124
+ train_dataset = MambaDataset(
125
+ data_path=getattr(config, 'train_data_path', 'train_data.txt'),
126
+ tokenizer=tokenizer,
127
+ preprocessor=preprocessor,
128
+ config=config
129
+ )
130
+
131
+ train_loader = DataLoader(
132
+ train_dataset,
133
+ batch_size=config.batch_size,
134
+ shuffle=True,
135
+ num_workers=4,
136
+ pin_memory=True
137
+ )
138
+
139
+ # Domain-specific datasets for specialist training
140
+ domain_loaders = {}
141
+
142
+ # Load domain-specific data (placeholder)
143
+ domain_data = {} # Should load actual domain-specific datasets
144
+
145
+ for domain_id in range(config.num_specialists):
146
+ domain_dataset = DomainSpecificDataset(
147
+ domain_data=domain_data,
148
+ domain_id=domain_id,
149
+ tokenizer=tokenizer,
150
+ preprocessor=preprocessor,
151
+ config=config
152
+ )
153
+
154
+ domain_loader = DataLoader(
155
+ domain_dataset,
156
+ batch_size=config.batch_size,
157
+ shuffle=True,
158
+ num_workers=2
159
+ )
160
+
161
+ domain_loaders[domain_id] = domain_loader
162
+
163
+ return {
164
+ 'main': train_loader,
165
+ 'domains': domain_loaders
166
+ }
training/loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # training/loss.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Dict, Optional
8
+
9
+ class MambaLoss(nn.Module):
10
+ """Loss functions for Mamba training"""
11
+
12
+ def __init__(self, config, vocab_size: int):
13
+ super().__init__()
14
+ self.config = config
15
+ self.vocab_size = vocab_size
16
+
17
+ # Primary loss
18
+ self.lm_loss = nn.CrossEntropyLoss(ignore_index=-100)
19
+
20
+ # Auxiliary losses
21
+ self.diversity_weight = 0.01
22
+ self.specialist_balance_weight = 0.001
23
+
24
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor,
25
+ specialist_weights: Optional[Dict] = None) -> Dict[str, torch.Tensor]:
26
+ """
27
+ Compute total loss
28
+
29
+ Args:
30
+ logits: [batch, seq_len, vocab_size]
31
+ targets: [batch, seq_len]
32
+ specialist_weights: Dict of specialist activation weights
33
+
34
+ Returns:
35
+ Dict with loss components
36
+ """
37
+ losses = {}
38
+
39
+ # Primary language modeling loss
40
+ lm_loss = self.lm_loss(
41
+ logits.view(-1, logits.size(-1)),
42
+ targets.view(-1)
43
+ )
44
+ losses['lm_loss'] = lm_loss
45
+
46
+ # Diversity loss to encourage specialist specialization
47
+ if specialist_weights is not None:
48
+ diversity_loss = self._compute_diversity_loss(specialist_weights)
49
+ losses['diversity_loss'] = diversity_loss
50
+
51
+ # Balance loss to prevent specialist dominance
52
+ balance_loss = self._compute_balance_loss(specialist_weights)
53
+ losses['balance_loss'] = balance_loss
54
+ else:
55
+ losses['diversity_loss'] = torch.tensor(0.0, device=logits.device)
56
+ losses['balance_loss'] = torch.tensor(0.0, device=logits.device)
57
+
58
+ # Total loss
59
+ total_loss = (
60
+ lm_loss +
61
+ self.diversity_weight * losses['diversity_loss'] +
62
+ self.specialist_balance_weight * losses['balance_loss']
63
+ )
64
+ losses['total_loss'] = total_loss
65
+
66
+ return losses
67
+
68
+ def _compute_diversity_loss(self, specialist_weights: Dict) -> torch.Tensor:
69
+ """Encourage specialists to be diverse"""
70
+ if len(specialist_weights) < 2:
71
+ return torch.tensor(0.0)
72
+
73
+ # Convert weights to tensor
74
+ weights = torch.stack(list(specialist_weights.values()))
75
+
76
+ # Compute pairwise similarities
77
+ normalized_weights = F.normalize(weights, dim=-1)
78
+ similarity_matrix = torch.mm(normalized_weights, normalized_weights.t())
79
+
80
+ # Penalize high similarities (encourage diversity)
81
+ diversity_loss = similarity_matrix.triu(diagonal=1).mean()
82
+
83
+ return diversity_loss
84
+
85
+ def _compute_balance_loss(self, specialist_weights: Dict) -> torch.Tensor:
86
+ """Encourage balanced specialist usage"""
87
+ if not specialist_weights:
88
+ return torch.tensor(0.0)
89
+
90
+ # Get activation frequencies
91
+ activations = torch.stack(list(specialist_weights.values()))
92
+
93
+ # Compute variance in activations (lower is more balanced)
94
+ balance_loss = activations.var()
95
+
96
+ return balance_loss
training/optimizer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # training/optimizer.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+ import math
8
+ from typing import Dict, List
9
+
10
+ class MambaOptimizer:
11
+ """Optimizer setup for Mamba models"""
12
+
13
+ def __init__(self, model, config):
14
+ self.config = config
15
+ self.model = model
16
+
17
+ # Separate parameters that should and shouldn't have weight decay
18
+ decay_params = []
19
+ no_decay_params = []
20
+
21
+ for name, param in model.named_parameters():
22
+ if param.requires_grad:
23
+ # Don't apply weight decay to biases and layer norms
24
+ if 'bias' in name or 'norm' in name or 'embedding' in name:
25
+ no_decay_params.append(param)
26
+ else:
27
+ decay_params.append(param)
28
+
29
+ # Create parameter groups
30
+ param_groups = [
31
+ {'params': decay_params, 'weight_decay': config.weight_decay},
32
+ {'params': no_decay_params, 'weight_decay': 0.0}
33
+ ]
34
+
35
+ # Initialize optimizer
36
+ self.optimizer = optim.AdamW(
37
+ param_groups,
38
+ lr=config.learning_rate,
39
+ betas=(0.9, 0.95),
40
+ eps=1e-8
41
+ )
42
+
43
+ # Learning rate scheduler
44
+ self.scheduler = self._create_scheduler()
45
+
46
+ def _create_scheduler(self):
47
+ """Create learning rate scheduler with warmup"""
48
+ def lr_lambda(step):
49
+ if step < self.config.warmup_steps:
50
+ # Linear warmup
51
+ return step / self.config.warmup_steps
52
+ else:
53
+ # Cosine decay
54
+ progress = (step - self.config.warmup_steps) / (self.config.max_steps - self.config.warmup_steps)
55
+ return 0.5 * (1 + math.cos(math.pi * progress))
56
+
57
+ return LambdaLR(self.optimizer, lr_lambda)
58
+
59
+ def step(self):
60
+ """Optimizer step with gradient clipping"""
61
+ # Gradient clipping
62
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
63
+
64
+ # Optimizer step
65
+ self.optimizer.step()
66
+ self.scheduler.step()
67
+
68
+ return self.scheduler.get_last_lr()[0]
69
+
70
+ def zero_grad(self):
71
+ """Zero gradients"""
72
+ self.optimizer.zero_grad()
73
+
training/trainer.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # training/trainer.py
3
+ # =============================================================================
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ from typing import Dict, List, Optional
8
+ import time
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ from core.config import MambaConfig
13
+ from routing.tlm_manager import TLMManager
14
+ from routing.aggregator import AttentionAggregator
15
+ from training.optimizer import MambaOptimizer
16
+ from training.loss import MambaLoss
17
+ from training.data_loader import create_data_loaders
18
+ from core.tokenizer import MambaTokenizer
19
+ from core.preprocess import TextPreprocessor
20
+
21
+ class MambaSwarmTrainer:
22
+ """Multi-phase trainer for Mamba swarm architecture"""
23
+
24
+ def __init__(self, config: MambaConfig):
25
+ self.config = config
26
+ self.device = config.device
27
+
28
+ # Initialize components
29
+ self.tokenizer = MambaTokenizer(config)
30
+ self.preprocessor = TextPreprocessor(config)
31
+
32
+ # Initialize TLM manager and aggregator
33
+ self.tlm_manager = TLMManager(config)
34
+ self.aggregator = AttentionAggregator(config)
35
+ self.aggregator.to(self.device)
36
+
37
+ # Initialize loss function
38
+ self.loss_fn = MambaLoss(config, config.vocab_size)
39
+
40
+ # Create data loaders
41
+ self.data_loaders = create_data_loaders(config, self.tokenizer, self.preprocessor)
42
+
43
+ # Training state
44
+ self.global_step = 0
45
+ self.phase = "foundation" # foundation, specialists, aggregator, end_to_end
46
+
47
+ # Setup logging
48
+ self.setup_logging()
49
+
50
+ def setup_logging(self):
51
+ """Setup training logging"""
52
+ logging.basicConfig(
53
+ level=logging.INFO,
54
+ format='%(asctime)s - %(levelname)s - %(message)s',
55
+ handlers=[
56
+ logging.FileHandler('training.log'),
57
+ logging.StreamHandler()
58
+ ]
59
+ )
60
+ self.logger = logging.getLogger(__name__)
61
+
62
+ def train_foundation_phase(self, num_steps: int = 10000):
63
+ """Phase 1: Train shared foundation weights"""
64
+ self.logger.info("Starting foundation training phase...")
65
+ self.phase = "foundation"
66
+
67
+ # Get a reference specialist for foundation training
68
+ reference_specialist = list(self.tlm_manager.specialists.values())[0]
69
+ optimizer = MambaOptimizer(reference_specialist.model, self.config)
70
+
71
+ reference_specialist.model.train()
72
+
73
+ for step in range(num_steps):
74
+ batch = next(iter(self.data_loaders['main']))
75
+
76
+ # Move to device
77
+ input_ids = batch['input_ids'].to(self.device)
78
+ target_ids = batch['target_ids'].to(self.device)
79
+
80
+ # Forward pass
81
+ logits, loss = reference_specialist.model(input_ids, target_ids)
82
+
83
+ # Backward pass
84
+ optimizer.zero_grad()
85
+ loss.backward()
86
+ lr = optimizer.step()
87
+
88
+ self.global_step += 1
89
+
90
+ if step % 100 == 0:
91
+ self.logger.info(f"Foundation step {step}, loss: {loss.item():.4f}, lr: {lr:.6f}")
92
+
93
+ # Copy foundation weights to all specialists
94
+ self._copy_foundation_weights(reference_specialist)
95
+
96
+ self.logger.info("Foundation training phase completed!")
97
+
98
+ def _copy_foundation_weights(self, reference_specialist):
99
+ """Copy foundation weights to all specialists"""
100
+ reference_state = reference_specialist.model.state_dict()
101
+
102
+ for specialist in self.tlm_manager.specialists.values():
103
+ if specialist != reference_specialist:
104
+ # Copy shared layers (first half of the model)
105
+ specialist_state = specialist.model.state_dict()
106
+
107
+ for name, param in reference_state.items():
108
+ if 'layers.' in name:
109
+ # Extract layer number
110
+ layer_num = int(name.split('.')[1])
111
+ if layer_num < self.config.n_layers // 2: # Share first half
112
+ specialist_state[name] = param.clone()
113
+ elif 'embedding' in name: # Share embeddings
114
+ specialist_state[name] = param.clone()
115
+
116
+ specialist.model.load_state_dict(specialist_state)
117
+
118
+ def train_specialists_phase(self, num_steps: int = 5000):
119
+ """Phase 2: Train domain specialists in parallel"""
120
+ self.logger.info("Starting specialist training phase...")
121
+ self.phase = "specialists"
122
+
123
+ # Create optimizers for each specialist
124
+ specialist_optimizers = {}
125
+ for specialist_id, specialist in self.tlm_manager.specialists.items():
126
+ specialist_optimizers[specialist_id] = MambaOptimizer(
127
+ specialist.model, self.config
128
+ )
129
+ specialist.model.train()
130
+
131
+ # Train specialists in parallel (simplified - could use actual parallel training)
132
+ for step in range(num_steps):
133
+ total_loss = 0.0
134
+
135
+ # Train each specialist on its domain data
136
+ for specialist_id in range(min(10, self.config.num_specialists)): # Limit for demo
137
+ if specialist_id in self.data_loaders['domains']:
138
+ try:
139
+ batch = next(iter(self.data_loaders['domains'][specialist_id]))
140
+
141
+ # Move to device
142
+ input_ids = batch['input_ids'].to(self.device)
143
+ target_ids = batch['target_ids'].to(self.device)
144
+
145
+ # Get specialist and optimizer
146
+ specialist = self.tlm_manager.specialists[specialist_id]
147
+ optimizer = specialist_optimizers[specialist_id]
148
+
149
+ # Forward pass
150
+ logits, loss = specialist.model(input_ids, target_ids)
151
+
152
+ # Backward pass
153
+ optimizer.zero_grad()
154
+ loss.backward()
155
+ optimizer.step()
156
+
157
+ total_loss += loss.item()
158
+
159
+ except Exception as e:
160
+ self.logger.warning(f"Error training specialist {specialist_id}: {e}")
161
+ continue
162
+
163
+ self.global_step += 1
164
+
165
+ if step % 100 == 0:
166
+ avg_loss = total_loss / min(10, self.config.num_specialists)
167
+ self.logger.info(f"Specialists step {step}, avg loss: {avg_loss:.4f}")
168
+
169
+ self.logger.info("Specialist training phase completed!")
170
+
171
+ def train_aggregator_phase(self, num_steps: int = 3000):
172
+ """Phase 3: Train aggregator to combine specialist outputs"""
173
+ self.logger.info("Starting aggregator training phase...")
174
+ self.phase = "aggregator"
175
+
176
+ # Freeze specialist models
177
+ for specialist in self.tlm_manager.specialists.values():
178
+ specialist.model.eval()
179
+ for param in specialist.model.parameters():
180
+ param.requires_grad = False
181
+
182
+ # Create optimizer for aggregator
183
+ aggregator_optimizer = MambaOptimizer(self.aggregator, self.config)
184
+ self.aggregator.train()
185
+
186
+ for step in range(num_steps):
187
+ try:
188
+ batch = next(iter(self.data_loaders['main']))
189
+
190
+ # Simulate specialist outputs (simplified for demo)
191
+ specialist_outputs = self._simulate_specialist_outputs(batch)
192
+
193
+ # Get target text for comparison
194
+ target_ids = batch['target_ids'].to(self.device)
195
+
196
+ # Forward pass through aggregator
197
+ logits = self.aggregator(specialist_outputs)
198
+
199
+ # Compute loss
200
+ loss_dict = self.loss_fn(logits, target_ids)
201
+ loss = loss_dict['total_loss']
202
+
203
+ # Backward pass
204
+ aggregator_optimizer.zero_grad()
205
+ loss.backward()
206
+ aggregator_optimizer.step()
207
+
208
+ self.global_step += 1
209
+
210
+ if step % 100 == 0:
211
+ self.logger.info(f"Aggregator step {step}, loss: {loss.item():.4f}")
212
+
213
+ except Exception as e:
214
+ self.logger.warning(f"Error in aggregator training step {step}: {e}")
215
+ continue
216
+
217
+ self.logger.info("Aggregator training phase completed!")
218
+
219
+ def _simulate_specialist_outputs(self, batch) -> Dict[int, List[Dict]]:
220
+ """Simulate specialist outputs for aggregator training"""
221
+ # This is a simplified simulation - in real training, you'd run
222
+ # the text through the router and specialists
223
+
224
+ input_ids = batch['input_ids'].to(self.device)
225
+
226
+ # Simulate 3 chunks with 2-3 specialists each
227
+ specialist_outputs = {}
228
+
229
+ for chunk_id in range(3):
230
+ chunk_results = []
231
+
232
+ # Simulate 2-3 specialists working on this chunk
233
+ for i in range(2 + chunk_id % 2):
234
+ specialist_id = (chunk_id * 3 + i) % self.config.num_specialists
235
+
236
+ if specialist_id in self.tlm_manager.specialists:
237
+ specialist = self.tlm_manager.specialists[specialist_id]
238
+
239
+ # Get encoding from specialist
240
+ with torch.no_grad():
241
+ encoding = specialist.encode(input_ids[:1]) # Single sample
242
+
243
+ chunk_results.append({
244
+ 'chunk_id': chunk_id,
245
+ 'specialist_id': specialist_id,
246
+ 'confidence': 0.8 + 0.2 * torch.rand(1).item(),
247
+ 'encoding': encoding[0], # Remove batch dim
248
+ 'domain': f'domain_{specialist_id}'
249
+ })
250
+
251
+ specialist_outputs[chunk_id] = chunk_results
252
+
253
+ return specialist_outputs
254
+
255
+ def train_end_to_end_phase(self, num_steps: int = 2000):
256
+ """Phase 4: End-to-end fine-tuning of the entire system"""
257
+ self.logger.info("Starting end-to-end training phase...")
258
+ self.phase = "end_to_end"
259
+
260
+ # Unfreeze all parameters
261
+ for specialist in self.tlm_manager.specialists.values():
262
+ specialist.model.train()
263
+ for param in specialist.model.parameters():
264
+ param.requires_grad = True
265
+
266
+ self.aggregator.train()
267
+
268
+ # Create system-wide optimizer with lower learning rate
269
+ all_params = []
270
+
271
+ # Add specialist parameters
272
+ for specialist in self.tlm_manager.specialists.values():
273
+ all_params.extend(specialist.model.parameters())
274
+
275
+ # Add aggregator parameters
276
+ all_params.extend(self.aggregator.parameters())
277
+
278
+ # Create optimizer with reduced learning rate
279
+ end_to_end_config = self.config
280
+ end_to_end_config.learning_rate = self.config.learning_rate * 0.1
281
+
282
+ system_optimizer = torch.optim.AdamW(
283
+ all_params,
284
+ lr=end_to_end_config.learning_rate,
285
+ weight_decay=end_to_end_config.weight_decay
286
+ )
287
+
288
+ for step in range(num_steps):
289
+ try:
290
+ batch = next(iter(self.data_loaders['main']))
291
+
292
+ # Full system forward pass (simplified)
293
+ specialist_outputs = self._simulate_specialist_outputs(batch)
294
+ logits = self.aggregator(specialist_outputs)
295
+
296
+ # Compute loss
297
+ target_ids = batch['target_ids'].to(self.device)
298
+ loss_dict = self.loss_fn(logits, target_ids)
299
+ loss = loss_dict['total_loss']
300
+
301
+ # Backward pass
302
+ system_optimizer.zero_grad()
303
+ loss.backward()
304
+ torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
305
+ system_optimizer.step()
306
+
307
+ self.global_step += 1
308
+
309
+ if step % 100 == 0:
310
+ self.logger.info(f"End-to-end step {step}, loss: {loss.item():.4f}")
311
+
312
+ except Exception as e:
313
+ self.logger.warning(f"Error in end-to-end training step {step}: {e}")
314
+ continue
315
+
316
+ self.logger.info("End-to-end training phase completed!")
317
+
318
+ def full_training_pipeline(self):
319
+ """Run the complete 4-phase training pipeline"""
320
+ self.logger.info("Starting full Mamba swarm training pipeline...")
321
+
322
+ start_time = time.time()
323
+
324
+ try:
325
+ # Phase 1: Foundation training
326
+ self.train_foundation_phase(num_steps=1000) # Reduced for demo
327
+
328
+ # Phase 2: Specialist training
329
+ self.train_specialists_phase(num_steps=500) # Reduced for demo
330
+
331
+ # Phase 3: Aggregator training
332
+ self.train_aggregator_phase(num_steps=300) # Reduced for demo
333
+
334
+ # Phase 4: End-to-end fine-tuning
335
+ self.train_end_to_end_phase(num_steps=200) # Reduced for demo
336
+
337
+ total_time = time.time() - start_time
338
+ self.logger.info(f"Training completed in {total_time:.2f} seconds!")
339
+
340
+ except Exception as e:
341
+ self.logger.error(f"Training failed: {e}")
342
+ raise
343
+
344
+ def save_checkpoint(self, checkpoint_path: str):
345
+ """Save training checkpoint"""
346
+ checkpoint = {
347
+ 'global_step': self.global_step,
348
+ 'phase': self.phase,
349
+ 'config': self.config.__dict__,
350
+ 'aggregator_state': self.aggregator.state_dict(),
351
+ 'specialist_states': {}
352
+ }
353
+
354
+ # Save specialist states
355
+ for specialist_id, specialist in self.tlm_manager.specialists.items():
356
+ checkpoint['specialist_states'][specialist_id] = specialist.model.state_dict()
357
+
358
+ torch.save(checkpoint, checkpoint_path)
359
+ self.logger.info(f"Checkpoint saved to {checkpoint_path}")
360
+
361
+ def load_checkpoint(self, checkpoint_path: str):
362
+ """Load training checkpoint"""
363
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
364
+
365
+ self.global_step = checkpoint['global_step']
366
+ self.phase = checkpoint['phase']
367
+
368
+ # Load aggregator state
369
+ self.aggregator.load_state_dict(checkpoint['aggregator_state'])
370
+
371
+ # Load specialist states
372
+ for specialist_id, state_dict in checkpoint['specialist_states'].items():
373
+ if specialist_id in self.tlm_manager.specialists:
374
+ self.tlm_manager.specialists[specialist_id].model.load_state_dict(state_dict)
375
+
376
+ self.logger.info(f"Checkpoint loaded from {checkpoint_path}")
377
+
378
+ def evaluate(self, eval_steps: int = 100) -> Dict[str, float]:
379
+ """Evaluate the trained model"""
380
+ self.logger.info("Starting evaluation...")
381
+
382
+ # Set models to eval mode
383
+ for specialist in self.tlm_manager.specialists.values():
384
+ specialist.model.eval()
385
+ self.aggregator.eval()
386
+
387
+ total_loss = 0.0
388
+ num_steps = 0
389
+
390
+ with torch.no_grad():
391
+ for step in range(eval_steps):
392
+ try:
393
+ batch = next(iter(self.data_loaders['main']))
394
+
395
+ # Forward pass
396
+ specialist_outputs = self._simulate_specialist_outputs(batch)
397
+ logits = self.aggregator(specialist_outputs)
398
+
399
+ # Compute loss
400
+ target_ids = batch['target_ids'].to(self.device)
401
+ loss_dict = self.loss_fn(logits, target_ids)
402
+
403
+ total_loss += loss_dict['total_loss'].item()
404
+ num_steps += 1
405
+
406
+ except Exception as e:
407
+ self.logger.warning(f"Error in evaluation step {step}: {e}")
408
+ continue
409
+
410
+ avg_loss = total_loss / max(num_steps, 1)
411
+ perplexity = torch.exp(torch.tensor(avg_loss)).item()
412
+
413
+ results = {
414
+ 'eval_loss': avg_loss,
415
+ 'perplexity': perplexity,
416
+ 'num_steps': num_steps
417
+ }
418
+
419
+ self.logger.info(f"Evaluation results: {results}")
420
+ return results