Mango-Metrics-NLM commited on
Commit
c8b77b5
·
0 Parent(s):

feat: Phi-3.5-MoE multi-agent model repository

Browse files

- Core Phi-3.5-MoE integration and multi-agent training system
- Multi-agent dataset loading and tokenization
- CPU-optimized LoRA training with Hugging Face
- Configuration and deployment scripts
- Git LFS configuration for model files

Files changed (35) hide show
  1. .gitattributes +8 -0
  2. .gitignore +37 -0
  3. MULTI_AGENT_TRAINING_README.md +3 -0
  4. deploy_multi_agent_training.py +384 -0
  5. multi_agent_training_config.yaml +181 -0
  6. src/models/multi_agent_moe_integration.py +393 -0
  7. src/models/phi35_moe_integration.py +397 -0
  8. src/multi_agent_datasets/__init__.py +1 -0
  9. src/multi_agent_datasets/multi_agent_loader.py +515 -0
  10. src/multi_agent_tokenization/__init__.py +1 -0
  11. src/multi_agent_tokenization/agent_tokenizer.py +402 -0
  12. src/multi_agent_training/__init__.py +1 -0
  13. src/multi_agent_training/core_framework.py +232 -0
  14. src/multi_agent_training/data_processor.py +582 -0
  15. src/multi_agent_training/distillation_loss.py +232 -0
  16. src/multi_agent_training/gradient_descent/__init__.py +49 -0
  17. src/multi_agent_training/gradient_descent/backpropagation.py +337 -0
  18. src/multi_agent_training/gradient_descent/loss_functions.py +379 -0
  19. src/multi_agent_training/gradient_descent/model_wrapper.py +342 -0
  20. src/multi_agent_training/gradient_descent/monitoring.py +548 -0
  21. src/multi_agent_training/gradient_descent/optimizers.py +319 -0
  22. src/multi_agent_training/gradient_descent/schedulers.py +338 -0
  23. src/multi_agent_training/gradient_descent/training_loop.py +570 -0
  24. src/multi_agent_training/lora_trainer.py +699 -0
  25. src/multi_agent_training/modular_distillation_loss.py +263 -0
  26. src/multi_agent_training/modular_trainer.py +229 -0
  27. src/multi_agent_training/multi_agent_trainer.py +547 -0
  28. src/multi_agent_training/specialized/curiosity_module.py +516 -0
  29. src/multi_agent_training/specialized/empathy_module.py +502 -0
  30. src/multi_agent_training/specialized/ethics_module.py +293 -0
  31. src/multi_agent_training/specialized/generators/ethics_generator.py +1244 -0
  32. src/multi_agent_training/specialized/generators/reasoning_generator.py +581 -0
  33. src/multi_agent_training/specialized/memory_module.py +317 -0
  34. src/multi_agent_training/specialized/reasoning_module.py +404 -0
  35. src/multi_agent_training/synthetic_data_generator.py +162 -0
.gitattributes ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
3
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.pt filter=lfs diff=lfs merge=lfs -text
6
+ *.zip filter=lfs diff=lfs merge=lfs -text
7
+ *.json filter=lfs diff=lfs merge=lfs -text
8
+ *.md filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large files that shouldn't be in model repository
2
+ imported_aws_backup/
3
+ artifacts/
4
+ models/gradient_descent_ensemble/
5
+ models/improved_ensemble/
6
+ models/expanded_agents_1gb/
7
+
8
+ # Python cache
9
+ __pycache__/
10
+ *.pyc
11
+ *.pyo
12
+ *.pyd
13
+ .Python
14
+ *.so
15
+
16
+ # Virtual environments
17
+ venv/
18
+ env/
19
+ ENV/
20
+
21
+ # IDE files
22
+ .vscode/
23
+ .idea/
24
+ *.swp
25
+ *.swo
26
+
27
+ # OS files
28
+ .DS_Store
29
+ Thumbs.db
30
+
31
+ # Logs
32
+ *.log
33
+ logs/
34
+
35
+ # Temporary files
36
+ *.tmp
37
+ *.temp
MULTI_AGENT_TRAINING_README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:291d3f54b6a84723e0fe0b2bfa96a8d135d67546c2e2c278d5d3ea8f8a308ed7
3
+ size 13132
deploy_multi_agent_training.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-Agent Training Deployment Script
4
+
5
+ This script provides comprehensive deployment capabilities for the multi-agent
6
+ training system, including Docker container management, environment setup,
7
+ and training execution.
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import json
13
+ import yaml
14
+ import argparse
15
+ import subprocess
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Dict, List, Optional, Any
19
+ from dataclasses import dataclass
20
+
21
+ # Add src to path for imports
22
+ sys.path.append('src')
23
+
24
+ from training.multi_agent_trainer import MultiAgentTrainingConfig, MultiAgentTrainingPipeline
25
+ from datasets.multi_agent_loader import MultiAgentDatasetConfig
26
+
27
+ @dataclass
28
+ class DeploymentConfig:
29
+ """Configuration for deployment"""
30
+ project_root: str
31
+ dataset_path: str
32
+ model_repo_id: str
33
+ dataset_repo_id: Optional[str] = None
34
+ agents_file: Optional[str] = None
35
+ config_file: Optional[str] = None
36
+ docker_image_name: str = "phi35moe-cpu:latest"
37
+ output_dir: str = "./outputs"
38
+ logs_dir: str = "./logs"
39
+ max_steps: int = 50
40
+ balance_agents: bool = True
41
+ push_to_hub: bool = True
42
+ dry_run: bool = False
43
+
44
+ class MultiAgentTrainingDeployment:
45
+ """
46
+ Comprehensive deployment manager for multi-agent training
47
+ """
48
+
49
+ def __init__(self, config: DeploymentConfig):
50
+ self.config = config
51
+ self.setup_logging()
52
+
53
+ def setup_logging(self):
54
+ """Setup logging configuration"""
55
+ logging.basicConfig(
56
+ level=logging.INFO,
57
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
58
+ handlers=[
59
+ logging.StreamHandler(),
60
+ logging.FileHandler(os.path.join(self.config.logs_dir, 'deployment.log'))
61
+ ]
62
+ )
63
+ self.logger = logging.getLogger(__name__)
64
+
65
+ # Create logs directory
66
+ os.makedirs(self.config.logs_dir, exist_ok=True)
67
+
68
+ def validate_environment(self) -> bool:
69
+ """Validate deployment environment"""
70
+ self.logger.info("Validating deployment environment")
71
+
72
+ # Check required environment variables
73
+ required_env_vars = ["HF_TOKEN"]
74
+ missing_vars = []
75
+
76
+ for var in required_env_vars:
77
+ if not os.getenv(var):
78
+ missing_vars.append(var)
79
+
80
+ if missing_vars:
81
+ self.logger.error(f"Missing required environment variables: {missing_vars}")
82
+ return False
83
+
84
+ # Check Docker availability
85
+ try:
86
+ result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
87
+ if result.returncode != 0:
88
+ self.logger.error("Docker not available")
89
+ return False
90
+ except FileNotFoundError:
91
+ self.logger.error("Docker not installed")
92
+ return False
93
+
94
+ # Check dataset path
95
+ if not os.path.exists(self.config.dataset_path):
96
+ self.logger.error(f"Dataset path not found: {self.config.dataset_path}")
97
+ return False
98
+
99
+ # Check agents file if specified
100
+ if self.config.agents_file and not os.path.exists(self.config.agents_file):
101
+ self.logger.error(f"Agents file not found: {self.config.agents_file}")
102
+ return False
103
+
104
+ self.logger.info("Environment validation passed")
105
+ return True
106
+
107
+ def build_docker_image(self) -> bool:
108
+ """Build Docker image for training"""
109
+ self.logger.info("Building Docker image")
110
+
111
+ dockerfile_path = "docker/multi_agent_training/Dockerfile.cpu"
112
+ if not os.path.exists(dockerfile_path):
113
+ self.logger.error(f"Dockerfile not found: {dockerfile_path}")
114
+ return False
115
+
116
+ try:
117
+ cmd = [
118
+ "docker", "build",
119
+ "-f", dockerfile_path,
120
+ "-t", self.config.docker_image_name,
121
+ "docker/multi_agent_training/"
122
+ ]
123
+
124
+ self.logger.info(f"Running command: {' '.join(cmd)}")
125
+
126
+ if not self.config.dry_run:
127
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True)
128
+ self.logger.info("Docker image built successfully")
129
+ else:
130
+ self.logger.info("Dry run: Would build Docker image")
131
+
132
+ return True
133
+
134
+ except subprocess.CalledProcessError as e:
135
+ self.logger.error(f"Docker build failed: {e}")
136
+ self.logger.error(f"Error output: {e.stderr}")
137
+ return False
138
+
139
+ def create_training_command(self) -> List[str]:
140
+ """Create training command for Docker execution"""
141
+ cmd = [
142
+ "python", "/app/train_lora_cpu_multiagent.py",
143
+ "--dataset_path", self.config.dataset_path,
144
+ "--hub_repo_id", self.config.model_repo_id,
145
+ "--output_dir", "/app/outputs",
146
+ "--max_steps", str(self.config.max_steps),
147
+ "--logging_steps", "5",
148
+ "--save_steps", "50",
149
+ "--eval_steps", "25"
150
+ ]
151
+
152
+ if self.config.balance_agents:
153
+ cmd.append("--balance_agents")
154
+
155
+ if self.config.push_to_hub:
156
+ cmd.append("--push_to_hub")
157
+
158
+ if self.config.agents_file:
159
+ cmd.extend(["--agents_file", self.config.agents_file])
160
+
161
+ if self.config.dataset_repo_id:
162
+ cmd.extend(["--push_dataset_repo", self.config.dataset_repo_id])
163
+
164
+ return cmd
165
+
166
+ def run_training(self) -> bool:
167
+ """Run training in Docker container"""
168
+ self.logger.info("Starting training in Docker container")
169
+
170
+ # Create training command
171
+ training_cmd = self.create_training_command()
172
+
173
+ # Prepare Docker run command
174
+ docker_cmd = [
175
+ "docker", "run", "--rm", "-it",
176
+ "-e", f"HF_TOKEN={os.getenv('HF_TOKEN')}",
177
+ "-v", f"{os.path.abspath(self.config.dataset_path)}:{self.config.dataset_path}:ro",
178
+ "-v", f"{os.path.abspath(self.config.output_dir)}:/app/outputs",
179
+ "-v", f"{os.path.abspath(self.config.logs_dir)}:/app/logs",
180
+ self.config.docker_image_name,
181
+ "bash", "-lc"
182
+ ]
183
+
184
+ # Create full command with training script
185
+ full_cmd = " ".join(training_cmd)
186
+ docker_cmd.append(full_cmd)
187
+
188
+ self.logger.info(f"Running command: {' '.join(docker_cmd[:-1])} '{full_cmd}'")
189
+
190
+ try:
191
+ if not self.config.dry_run:
192
+ result = subprocess.run(docker_cmd, check=True)
193
+ self.logger.info("Training completed successfully")
194
+ else:
195
+ self.logger.info("Dry run: Would execute training")
196
+
197
+ return True
198
+
199
+ except subprocess.CalledProcessError as e:
200
+ self.logger.error(f"Training failed: {e}")
201
+ return False
202
+
203
+ def generate_deployment_report(self) -> Dict[str, Any]:
204
+ """Generate deployment report"""
205
+ report = {
206
+ "deployment_config": {
207
+ "project_root": self.config.project_root,
208
+ "dataset_path": self.config.dataset_path,
209
+ "model_repo_id": self.config.model_repo_id,
210
+ "dataset_repo_id": self.config.dataset_repo_id,
211
+ "docker_image_name": self.config.docker_image_name,
212
+ "max_steps": self.config.max_steps,
213
+ "balance_agents": self.config.balance_agents,
214
+ "push_to_hub": self.config.push_to_hub
215
+ },
216
+ "environment": {
217
+ "hf_token_set": bool(os.getenv("HF_TOKEN")),
218
+ "docker_available": self._check_docker_available(),
219
+ "dataset_exists": os.path.exists(self.config.dataset_path),
220
+ "agents_file_exists": os.path.exists(self.config.agents_file) if self.config.agents_file else None
221
+ },
222
+ "deployment_status": "completed" if not self.config.dry_run else "dry_run"
223
+ }
224
+
225
+ return report
226
+
227
+ def _check_docker_available(self) -> bool:
228
+ """Check if Docker is available"""
229
+ try:
230
+ result = subprocess.run(["docker", "--version"], capture_output=True, text=True)
231
+ return result.returncode == 0
232
+ except FileNotFoundError:
233
+ return False
234
+
235
+ def save_deployment_report(self, report: Dict[str, Any]):
236
+ """Save deployment report to file"""
237
+ report_file = os.path.join(self.config.logs_dir, "deployment_report.json")
238
+ with open(report_file, 'w') as f:
239
+ json.dump(report, f, indent=2)
240
+
241
+ self.logger.info(f"Deployment report saved to {report_file}")
242
+
243
+ def deploy(self) -> bool:
244
+ """Execute complete deployment process"""
245
+ self.logger.info("Starting multi-agent training deployment")
246
+
247
+ try:
248
+ # Validate environment
249
+ if not self.validate_environment():
250
+ return False
251
+
252
+ # Build Docker image
253
+ if not self.build_docker_image():
254
+ return False
255
+
256
+ # Run training
257
+ if not self.run_training():
258
+ return False
259
+
260
+ # Generate and save report
261
+ report = self.generate_deployment_report()
262
+ self.save_deployment_report(report)
263
+
264
+ self.logger.info("Deployment completed successfully")
265
+ return True
266
+
267
+ except Exception as e:
268
+ self.logger.error(f"Deployment failed: {e}")
269
+ return False
270
+
271
+ def load_config_from_file(config_file: str) -> Dict[str, Any]:
272
+ """Load configuration from YAML file"""
273
+ with open(config_file, 'r') as f:
274
+ return yaml.safe_load(f)
275
+
276
+ def create_deployment_config_from_yaml(config_data: Dict[str, Any],
277
+ dataset_path: str,
278
+ model_repo_id: str) -> DeploymentConfig:
279
+ """Create deployment configuration from YAML data"""
280
+ training_config = config_data.get("training", {})
281
+ multi_agent_config = config_data.get("multi_agent", {})
282
+ huggingface_config = config_data.get("huggingface", {})
283
+
284
+ return DeploymentConfig(
285
+ project_root=os.getcwd(),
286
+ dataset_path=dataset_path,
287
+ model_repo_id=model_repo_id,
288
+ dataset_repo_id=huggingface_config.get("push_dataset_repo"),
289
+ agents_file=config_data.get("dataset", {}).get("agents_file"),
290
+ config_file=None, # Will be set if provided
291
+ docker_image_name="phi35moe-cpu:latest",
292
+ output_dir=training_config.get("output_dir", "./outputs"),
293
+ logs_dir=training_config.get("logging", {}).get("logging_dir", "./logs"),
294
+ max_steps=training_config.get("max_steps", 50),
295
+ balance_agents=multi_agent_config.get("balance_agents", True),
296
+ push_to_hub=huggingface_config.get("push_to_hub", True)
297
+ )
298
+
299
+ def main():
300
+ """Main deployment function"""
301
+ parser = argparse.ArgumentParser(description="Deploy Multi-Agent Training System")
302
+
303
+ # Required arguments
304
+ parser.add_argument("--dataset_path", required=True,
305
+ help="Path to multi-agent dataset")
306
+ parser.add_argument("--model_repo_id", required=True,
307
+ help="Hugging Face model repository ID")
308
+
309
+ # Optional arguments
310
+ parser.add_argument("--dataset_repo_id", default="",
311
+ help="Optional dataset repository ID")
312
+ parser.add_argument("--agents_file", default="",
313
+ help="Optional agents YAML file")
314
+ parser.add_argument("--config_file", default="",
315
+ help="Optional configuration YAML file")
316
+ parser.add_argument("--docker_image_name", default="phi35moe-cpu:latest",
317
+ help="Docker image name")
318
+ parser.add_argument("--output_dir", default="./outputs",
319
+ help="Output directory")
320
+ parser.add_argument("--logs_dir", default="./logs",
321
+ help="Logs directory")
322
+ parser.add_argument("--max_steps", type=int, default=50,
323
+ help="Maximum training steps")
324
+ parser.add_argument("--balance_agents", action="store_true",
325
+ help="Balance dataset across agents")
326
+ parser.add_argument("--push_to_hub", action="store_true",
327
+ help="Push model to Hugging Face Hub")
328
+ parser.add_argument("--dry_run", action="store_true",
329
+ help="Perform dry run without actual execution")
330
+ parser.add_argument("--log_level", default="INFO",
331
+ help="Logging level")
332
+
333
+ args = parser.parse_args()
334
+
335
+ # Setup logging
336
+ logging.basicConfig(level=getattr(logging, args.log_level.upper()))
337
+
338
+ try:
339
+ # Load configuration from file if provided
340
+ if args.config_file and os.path.exists(args.config_file):
341
+ config_data = load_config_from_file(args.config_file)
342
+ deployment_config = create_deployment_config_from_yaml(
343
+ config_data, args.dataset_path, args.model_repo_id
344
+ )
345
+ else:
346
+ # Create configuration from command line arguments
347
+ deployment_config = DeploymentConfig(
348
+ project_root=os.getcwd(),
349
+ dataset_path=args.dataset_path,
350
+ model_repo_id=args.model_repo_id,
351
+ dataset_repo_id=args.dataset_repo_id if args.dataset_repo_id else None,
352
+ agents_file=args.agents_file if args.agents_file else None,
353
+ config_file=args.config_file if args.config_file else None,
354
+ docker_image_name=args.docker_image_name,
355
+ output_dir=args.output_dir,
356
+ logs_dir=args.logs_dir,
357
+ max_steps=args.max_steps,
358
+ balance_agents=args.balance_agents,
359
+ push_to_hub=args.push_to_hub,
360
+ dry_run=args.dry_run
361
+ )
362
+
363
+ # Create deployment manager
364
+ deployment = MultiAgentTrainingDeployment(deployment_config)
365
+
366
+ # Execute deployment
367
+ success = deployment.deploy()
368
+
369
+ if success:
370
+ print("✅ Deployment completed successfully!")
371
+ print(f"📁 Outputs: {deployment_config.output_dir}")
372
+ print(f"📊 Logs: {deployment_config.logs_dir}")
373
+ if deployment_config.push_to_hub:
374
+ print(f"🚀 Model: https://huggingface.co/{deployment_config.model_repo_id}")
375
+ else:
376
+ print("❌ Deployment failed!")
377
+ sys.exit(1)
378
+
379
+ except Exception as e:
380
+ print(f"❌ Deployment error: {e}")
381
+ sys.exit(1)
382
+
383
+ if __name__ == "__main__":
384
+ main()
multi_agent_training_config.yaml ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-Agent Training Configuration
2
+ # Comprehensive configuration for multi-agent LoRA training
3
+
4
+ # Model Configuration
5
+ model:
6
+ base_model: "microsoft/Phi-3.5-MoE-instruct"
7
+ model_cache_dir: "/app/cache"
8
+ trust_remote_code: true
9
+ max_seq_length: 2048
10
+
11
+ # Training Configuration
12
+ training:
13
+ output_dir: "/app/outputs"
14
+ max_steps: 50
15
+ num_train_epochs: 1
16
+ per_device_train_batch_size: 1
17
+ per_device_eval_batch_size: 1
18
+ gradient_accumulation_steps: 8
19
+ learning_rate: 2e-5
20
+ lr_scheduler_type: "cosine"
21
+ warmup_steps: 0
22
+
23
+ # CPU Optimization
24
+ use_cpu: true
25
+ bf16: false
26
+ fp16: false
27
+ gradient_checkpointing: true
28
+ dataloader_num_workers: 0
29
+ remove_unused_columns: false
30
+
31
+ # LoRA Configuration
32
+ lora:
33
+ r: 8
34
+ alpha: 16
35
+ dropout: 0.05
36
+ target_modules: "all-linear"
37
+ bias: "none"
38
+
39
+ # Multi-Agent Configuration
40
+ multi_agent:
41
+ agent_prefix: "<|agent:"
42
+ agent_suffix: "|>"
43
+ balance_agents: true
44
+ balance_cap: null # null means no cap
45
+
46
+ # Agent-Expert Mapping
47
+ agent_expert_mapping:
48
+ SWE: "code"
49
+ SQE: "code"
50
+ DevOps: "code"
51
+ Architect: "reasoning"
52
+ Security: "reasoning"
53
+ Product_Manager: "general"
54
+ VP_Product: "general"
55
+ Philosopher: "reasoning"
56
+ EvalOps: "general"
57
+ CostOptimizer: "general"
58
+ Discovery: "general"
59
+
60
+ # Dataset Configuration
61
+ dataset:
62
+ dataset_path: "/data/datasets/phi_moe_chat"
63
+ agents_file: "/data/datasets/phi_moe_chat/agents.yaml"
64
+ validation_split: 0.1
65
+ seed: 42
66
+
67
+ # Dataset patterns supported
68
+ patterns:
69
+ - "single_folder" # Single folder with agent field in rows
70
+ - "subfolders" # Per-agent subfolders
71
+
72
+ # Logging and Monitoring
73
+ logging:
74
+ logging_steps: 5
75
+ save_steps: 50
76
+ eval_steps: 25
77
+ save_total_limit: 1
78
+ logging_dir: "/app/logs"
79
+ report_to: "none" # Options: wandb, tensorboard, none
80
+
81
+ # Log levels
82
+ log_level: "INFO"
83
+
84
+ # Hugging Face Hub Configuration
85
+ huggingface:
86
+ hub_repo_id: null # Set to your repo ID
87
+ push_to_hub: false
88
+ push_dataset_repo: null # Optional dataset repo
89
+ hub_token: null # Set via environment variable HF_TOKEN
90
+
91
+ # Performance Configuration
92
+ performance:
93
+ # Memory optimization
94
+ gradient_checkpointing: true
95
+ use_cache: false
96
+
97
+ # CPU optimization
98
+ attn_implementation: "eager"
99
+ torch_dtype: "float32"
100
+
101
+ # Training optimization
102
+ optim: "adamw_torch"
103
+ weight_decay: 0.01
104
+ max_grad_norm: 1.0
105
+
106
+ # Validation Configuration
107
+ validation:
108
+ validate_dataset: true
109
+ dry_run: false
110
+
111
+ # Dataset validation
112
+ check_agent_field: true
113
+ check_text_content: true
114
+ validate_jsonl_format: true
115
+
116
+ # Deployment Configuration
117
+ deployment:
118
+ # Docker configuration
119
+ docker:
120
+ base_image: "python:3.12-slim-bookworm"
121
+ working_dir: "/app"
122
+ cache_dir: "/app/cache"
123
+ outputs_dir: "/app/outputs"
124
+ logs_dir: "/app/logs"
125
+
126
+ # Model serving
127
+ serving:
128
+ enable_gradio: true
129
+ gradio_port: 7860
130
+ max_concurrent_requests: 10
131
+
132
+ # Environment Configuration
133
+ environment:
134
+ # Python environment
135
+ python_version: "3.12"
136
+ pip_index_url: "https://download.pytorch.org/whl/cpu"
137
+
138
+ # System requirements
139
+ system_deps:
140
+ - "git"
141
+ - "git-lfs"
142
+ - "build-essential"
143
+ - "curl"
144
+ - "wget"
145
+ - "ca-certificates"
146
+
147
+ # Environment variables
148
+ env_vars:
149
+ PYTHONDONTWRITEBYTECODE: "1"
150
+ PYTHONUNBUFFERED: "1"
151
+ HF_HOME: "/app/cache"
152
+ TRANSFORMERS_CACHE: "/app/cache/transformers"
153
+ HF_DATASETS_CACHE: "/app/cache/datasets"
154
+
155
+ # Development Configuration
156
+ development:
157
+ # Testing
158
+ run_tests: true
159
+ test_coverage: true
160
+
161
+ # Code quality
162
+ lint_code: true
163
+ format_code: true
164
+
165
+ # Documentation
166
+ generate_docs: true
167
+ update_readme: true
168
+
169
+ # Security Configuration
170
+ security:
171
+ # Token management
172
+ secure_token_storage: true
173
+ token_rotation: false
174
+
175
+ # Model security
176
+ trust_remote_code: true
177
+ verify_checksums: true
178
+
179
+ # Data security
180
+ encrypt_sensitive_data: false
181
+ secure_data_transmission: true
src/models/multi_agent_moe_integration.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-Agent MoE Integration
4
+
5
+ This module provides integration between the multi-agent training system and the existing
6
+ MoE framework, allowing for seamless combination of agent-specific conditioning and
7
+ expert specialization.
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import logging
13
+ from typing import Dict, List, Optional, Tuple, Any, Union
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM
19
+
20
+ from .phi35_moe_integration import EnhancedMoEFramework, Phi35MoEConfig, Phi35MoEExpert
21
+ from ..multi_agent_tokenization.agent_tokenizer import AgentTokenManager, AgentTokenConfig
22
+ from ..multi_agent_training.multi_agent_trainer import MultiAgentTrainingConfig
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ @dataclass
27
+ class MultiAgentMoEConfig:
28
+ """Configuration for multi-agent MoE integration"""
29
+ # Base MoE configuration
30
+ moe_config: Phi35MoEConfig
31
+
32
+ # Multi-agent configuration
33
+ agent_prefix: str = "<|agent:"
34
+ agent_suffix: str = "|>"
35
+ agents_file: Optional[str] = None
36
+
37
+ # Integration settings
38
+ enable_agent_conditioning: bool = True
39
+ enable_expert_routing: bool = True
40
+ hybrid_mode: bool = True # Use both agent tokens and expert routing
41
+
42
+ # Model paths
43
+ base_model_path: str = "microsoft/Phi-3.5-MoE-instruct"
44
+ lora_adapter_path: Optional[str] = None
45
+
46
+ # Agent-Expert mapping
47
+ agent_expert_mapping: Optional[Dict[str, str]] = None
48
+
49
+ class MultiAgentMoEExpert(Phi35MoEExpert):
50
+ """
51
+ Enhanced MoE expert with multi-agent support
52
+ """
53
+
54
+ def __init__(self, expert_id: str, specialization: str, config: Phi35MoEConfig,
55
+ agent_manager: Optional[AgentTokenManager] = None):
56
+ super().__init__(expert_id, specialization, config)
57
+ self.agent_manager = agent_manager
58
+ self.supported_agents: List[str] = []
59
+
60
+ def add_agent_support(self, agent: str):
61
+ """Add agent support to this expert"""
62
+ if agent not in self.supported_agents:
63
+ self.supported_agents.append(agent)
64
+ logger.info(f"Added agent '{agent}' support to expert '{self.expert_id}'")
65
+
66
+ def format_agent_prompt(self, agent: str, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
67
+ """Format messages with agent conditioning"""
68
+ if not self.agent_manager or agent not in self.supported_agents:
69
+ return messages
70
+
71
+ # Add agent token to system message or create one
72
+ formatted_messages = messages.copy()
73
+
74
+ # Check if there's a system message
75
+ has_system = any(msg.get("role") == "system" for msg in formatted_messages)
76
+
77
+ if not has_system:
78
+ # Add system message with agent token
79
+ agent_token = f"{self.agent_manager.config.agent_prefix}{agent}{self.agent_manager.config.agent_suffix}"
80
+ system_message = {
81
+ "role": "system",
82
+ "content": f"You are a {agent} agent specialized in {self.specialization}."
83
+ }
84
+ formatted_messages.insert(0, system_message)
85
+ else:
86
+ # Update existing system message
87
+ for msg in formatted_messages:
88
+ if msg.get("role") == "system":
89
+ agent_token = f"{self.agent_manager.config.agent_prefix}{agent}{self.agent_manager.config.agent_suffix}"
90
+ msg["content"] = f"{agent_token}\n{msg['content']}"
91
+ break
92
+
93
+ return formatted_messages
94
+
95
+ async def generate_response(self, messages: List[Dict[str, str]],
96
+ agent: Optional[str] = None, **kwargs) -> Dict[str, Any]:
97
+ """Generate response with optional agent conditioning"""
98
+ if agent and agent in self.supported_agents:
99
+ messages = self.format_agent_prompt(agent, messages)
100
+
101
+ # Call parent method
102
+ return await super().generate_response(messages, **kwargs)
103
+
104
+ class MultiAgentMoERouter:
105
+ """
106
+ Enhanced router that considers both agent and content for expert selection
107
+ """
108
+
109
+ def __init__(self, agent_manager: Optional[AgentTokenManager] = None):
110
+ self.agent_manager = agent_manager
111
+ self.agent_expert_mapping: Dict[str, str] = {}
112
+ self.expert_specializations = {
113
+ "code": ["programming", "software", "development", "coding", "algorithm", "python", "javascript", "java", "function", "code"],
114
+ "math": ["mathematics", "calculation", "equation", "formula", "statistics", "derivative", "integral", "algebra", "calculus", "math", "solve", "calculate"],
115
+ "reasoning": ["logic", "analysis", "reasoning", "problem-solving", "critical", "explain", "why", "how", "because"],
116
+ "multilingual": ["translation", "language", "multilingual", "localization", "translate", "spanish", "french", "german"],
117
+ "general": ["general", "conversation", "assistance", "help", "hello", "hi", "what", "who", "when", "where"]
118
+ }
119
+
120
+ def set_agent_expert_mapping(self, mapping: Dict[str, str]):
121
+ """Set mapping from agents to preferred experts"""
122
+ self.agent_expert_mapping = mapping
123
+ logger.info(f"Set agent-expert mapping: {mapping}")
124
+
125
+ def extract_agent_from_messages(self, messages: List[Dict[str, str]]) -> Optional[str]:
126
+ """Extract agent from messages"""
127
+ if not self.agent_manager:
128
+ return None
129
+
130
+ # Look for agent token in system message
131
+ for msg in messages:
132
+ if msg.get("role") == "system":
133
+ content = msg.get("content", "")
134
+ agent = self.agent_manager.extract_agent_from_text(content)
135
+ if agent:
136
+ return agent
137
+
138
+ return None
139
+
140
+ def route_experts(self, messages: List[Dict[str, str]], available_experts: List[MultiAgentMoEExpert]) -> List[MultiAgentMoEExpert]:
141
+ """Route to appropriate experts considering both agent and content"""
142
+ # Extract agent
143
+ agent = self.extract_agent_from_messages(messages)
144
+
145
+ # Get content for analysis
146
+ content = ""
147
+ for msg in messages:
148
+ if msg.get("role") in ["user", "assistant"]:
149
+ content += " " + msg.get("content", "")
150
+
151
+ content_lower = content.lower()
152
+
153
+ # First, try agent-based routing
154
+ if agent and agent in self.agent_expert_mapping:
155
+ preferred_expert_type = self.agent_expert_mapping[agent]
156
+ agent_experts = [exp for exp in available_experts if exp.specialization == preferred_expert_type]
157
+ if agent_experts:
158
+ logger.debug(f"Routing agent '{agent}' to {preferred_expert_type} expert")
159
+ return agent_experts
160
+
161
+ # Fall back to content-based routing
162
+ for specialization, keywords in self.expert_specializations.items():
163
+ if any(keyword in content_lower for keyword in keywords):
164
+ content_experts = [exp for exp in available_experts if exp.specialization == specialization]
165
+ if content_experts:
166
+ logger.debug(f"Routing based on content to {specialization} expert")
167
+ return content_experts
168
+
169
+ # Default to general expert
170
+ general_experts = [exp for exp in available_experts if exp.specialization == "general"]
171
+ if general_experts:
172
+ logger.debug("Routing to general expert")
173
+ return general_experts
174
+
175
+ # Return all experts if no specific routing
176
+ return available_experts
177
+
178
+ class MultiAgentMoEFramework(EnhancedMoEFramework):
179
+ """
180
+ Enhanced MoE framework with multi-agent support
181
+ """
182
+
183
+ def __init__(self, config: MultiAgentMoEConfig):
184
+ super().__init__(config.moe_config)
185
+ self.multi_agent_config = config
186
+ self.agent_manager: Optional[AgentTokenManager] = None
187
+ self.agent_expert_mapping: Dict[str, str] = {}
188
+ self.agents: List[str] = []
189
+
190
+ async def initialize_agents(self, agents: List[str], agent_expert_mapping: Optional[Dict[str, str]] = None):
191
+ """Initialize multi-agent support"""
192
+ self.agents = agents
193
+
194
+ # Create agent token manager
195
+ agent_config = AgentTokenConfig(
196
+ agent_prefix=self.multi_agent_config.agent_prefix,
197
+ agent_suffix=self.multi_agent_config.agent_suffix
198
+ )
199
+ self.agent_manager = AgentTokenManager(agent_config)
200
+
201
+ # Set agent-expert mapping
202
+ if agent_expert_mapping:
203
+ self.agent_expert_mapping = agent_expert_mapping
204
+ else:
205
+ # Default mapping based on agent names
206
+ self.agent_expert_mapping = self._create_default_mapping(agents)
207
+
208
+ # Update router with agent manager
209
+ if hasattr(self, 'router') and isinstance(self.router, MultiAgentMoERouter):
210
+ self.router.agent_manager = self.agent_manager
211
+ self.router.set_agent_expert_mapping(self.agent_expert_mapping)
212
+
213
+ logger.info(f"Initialized multi-agent support for {len(agents)} agents")
214
+ logger.info(f"Agent-expert mapping: {self.agent_expert_mapping}")
215
+
216
+ def _create_default_mapping(self, agents: List[str]) -> Dict[str, str]:
217
+ """Create default agent-expert mapping"""
218
+ mapping = {}
219
+
220
+ for agent in agents:
221
+ agent_lower = agent.lower()
222
+ if any(keyword in agent_lower for keyword in ["swe", "developer", "programmer", "engineer"]):
223
+ mapping[agent] = "code"
224
+ elif any(keyword in agent_lower for keyword in ["sqa", "tester", "qa", "quality"]):
225
+ mapping[agent] = "code"
226
+ elif any(keyword in agent_lower for keyword in ["devops", "ops", "deployment"]):
227
+ mapping[agent] = "code"
228
+ elif any(keyword in agent_lower for keyword in ["architect", "design", "system"]):
229
+ mapping[agent] = "reasoning"
230
+ elif any(keyword in agent_lower for keyword in ["security", "sec", "cyber"]):
231
+ mapping[agent] = "reasoning"
232
+ elif any(keyword in agent_lower for keyword in ["math", "analyst", "data"]):
233
+ mapping[agent] = "math"
234
+ elif any(keyword in agent_lower for keyword in ["translate", "localization", "lang"]):
235
+ mapping[agent] = "multilingual"
236
+ else:
237
+ mapping[agent] = "general"
238
+
239
+ return mapping
240
+
241
+ async def initialize_experts(self, expert_configs: List[Dict[str, str]]):
242
+ """Initialize experts with multi-agent support"""
243
+ # Call parent method
244
+ await super().initialize_experts(expert_configs)
245
+
246
+ # Enhance experts with agent support
247
+ for expert_id, expert in self.experts.items():
248
+ if isinstance(expert, MultiAgentMoEExpert):
249
+ # Add agent support based on mapping
250
+ for agent, expert_type in self.agent_expert_mapping.items():
251
+ if expert.specialization == expert_type:
252
+ expert.add_agent_support(agent)
253
+
254
+ async def process_query(self, query: str, agent: Optional[str] = None,
255
+ system_message: Optional[str] = None,
256
+ use_multiple_experts: bool = True, **kwargs) -> Dict[str, Any]:
257
+ """Process query with optional agent conditioning"""
258
+ # Format messages with agent context
259
+ messages = [{"role": "user", "content": query}]
260
+
261
+ if system_message:
262
+ messages.insert(0, {"role": "system", "content": system_message})
263
+
264
+ # Add agent conditioning if specified
265
+ if agent and self.agent_manager:
266
+ agent_token = f"{self.agent_manager.config.agent_prefix}{agent}{self.agent_manager.config.agent_suffix}"
267
+ if messages[0].get("role") == "system":
268
+ messages[0]["content"] = f"{agent_token}\n{messages[0]['content']}"
269
+ else:
270
+ messages.insert(0, {"role": "system", "content": f"{agent_token}\nYou are a {agent} agent."})
271
+
272
+ # Process with enhanced framework
273
+ return await super().process_query(
274
+ query=query,
275
+ system_message=system_message,
276
+ use_multiple_experts=use_multiple_experts,
277
+ **kwargs
278
+ )
279
+
280
+ def get_agent_statistics(self) -> Dict[str, Any]:
281
+ """Get statistics about agent support"""
282
+ stats = {
283
+ "total_agents": len(self.agents),
284
+ "agents": self.agents,
285
+ "agent_expert_mapping": self.agent_expert_mapping,
286
+ "expert_agent_support": {}
287
+ }
288
+
289
+ for expert_id, expert in self.experts.items():
290
+ if isinstance(expert, MultiAgentMoEExpert):
291
+ stats["expert_agent_support"][expert_id] = {
292
+ "specialization": expert.specialization,
293
+ "supported_agents": expert.supported_agents
294
+ }
295
+
296
+ return stats
297
+
298
+ def save_agent_configuration(self, output_dir: str):
299
+ """Save agent configuration for deployment"""
300
+ os.makedirs(output_dir, exist_ok=True)
301
+
302
+ config_data = {
303
+ "agents": self.agents,
304
+ "agent_expert_mapping": self.agent_expert_mapping,
305
+ "agent_token_config": {
306
+ "agent_prefix": self.multi_agent_config.agent_prefix,
307
+ "agent_suffix": self.multi_agent_config.agent_suffix
308
+ },
309
+ "expert_agent_support": self.get_agent_statistics()["expert_agent_support"]
310
+ }
311
+
312
+ config_file = os.path.join(output_dir, "agent_config.json")
313
+ with open(config_file, 'w') as f:
314
+ json.dump(config_data, f, indent=2)
315
+
316
+ logger.info(f"Saved agent configuration to {config_file}")
317
+ return config_file
318
+
319
+ class MultiAgentMoEAdapter:
320
+ """
321
+ Adapter for integrating multi-agent MoE with existing systems
322
+ """
323
+
324
+ def __init__(self, multi_agent_moe: MultiAgentMoEFramework):
325
+ self.multi_agent_moe = multi_agent_moe
326
+
327
+ async def process_agent_query(self, agent: str, query: str, **kwargs) -> Dict[str, Any]:
328
+ """Process query for specific agent"""
329
+ return await self.multi_agent_moe.process_query(
330
+ query=query,
331
+ agent=agent,
332
+ **kwargs
333
+ )
334
+
335
+ def get_agent_capabilities(self, agent: str) -> Dict[str, Any]:
336
+ """Get capabilities for specific agent"""
337
+ if agent not in self.multi_agent_moe.agents:
338
+ return {"error": f"Agent '{agent}' not found"}
339
+
340
+ expert_type = self.multi_agent_moe.agent_expert_mapping.get(agent, "general")
341
+
342
+ return {
343
+ "agent": agent,
344
+ "expert_type": expert_type,
345
+ "supported": True,
346
+ "capabilities": self.multi_agent_moe.expert_specializations.get(expert_type, [])
347
+ }
348
+
349
+ def list_available_agents(self) -> List[Dict[str, Any]]:
350
+ """List all available agents and their capabilities"""
351
+ agents_info = []
352
+
353
+ for agent in self.multi_agent_moe.agents:
354
+ expert_type = self.multi_agent_moe.agent_expert_mapping.get(agent, "general")
355
+ agents_info.append({
356
+ "agent": agent,
357
+ "expert_type": expert_type,
358
+ "capabilities": self.multi_agent_moe.expert_specializations.get(expert_type, [])
359
+ })
360
+
361
+ return agents_info
362
+
363
+ # Example usage and testing
364
+ if __name__ == "__main__":
365
+ # Configure logging
366
+ logging.basicConfig(level=logging.INFO)
367
+
368
+ # Example configuration
369
+ moe_config = Phi35MoEConfig()
370
+ multi_agent_config = MultiAgentMoEConfig(
371
+ moe_config=moe_config,
372
+ agent_prefix="<|agent:",
373
+ agent_suffix="|>"
374
+ )
375
+
376
+ # Create framework
377
+ framework = MultiAgentMoEFramework(multi_agent_config)
378
+
379
+ # Example agents
380
+ agents = ["SWE", "SQE", "DevOps", "Architect", "Security"]
381
+
382
+ # Example agent-expert mapping
383
+ agent_expert_mapping = {
384
+ "SWE": "code",
385
+ "SQE": "code",
386
+ "DevOps": "code",
387
+ "Architect": "reasoning",
388
+ "Security": "reasoning"
389
+ }
390
+
391
+ print("Multi-agent MoE framework ready")
392
+ print(f"Agents: {agents}")
393
+ print(f"Agent-expert mapping: {agent_expert_mapping}")
src/models/phi35_moe_integration.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Phi-3.5-MoE Integration Module
4
+
5
+ This module integrates Microsoft's Phi-3.5-MoE-instruct model as a baseline
6
+ for expanding our current MoE (Mixture of Experts) framework.
7
+
8
+ Based on: https://huggingface.co/microsoft/Phi-3.5-MoE-instruct
9
+ - 41.9B parameters
10
+ - 128K context length
11
+ - Multilingual support
12
+ - Strong reasoning capabilities (code, math, logic)
13
+ """
14
+
15
+ import torch
16
+ import logging
17
+ from typing import Dict, List, Optional, Any, Union
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ pipeline,
22
+ Pipeline
23
+ )
24
+ from pathlib import Path
25
+ import json
26
+ import asyncio
27
+ from dataclasses import dataclass
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ @dataclass
32
+ class Phi35MoEConfig:
33
+ """Configuration for Phi-3.5-MoE integration"""
34
+ model_name: str = "microsoft/Phi-3.5-MoE-instruct"
35
+ device_map: str = "auto"
36
+ torch_dtype: str = "auto"
37
+ trust_remote_code: bool = False
38
+ max_new_tokens: int = 500
39
+ temperature: float = 0.7
40
+ top_p: float = 0.9
41
+ do_sample: bool = True
42
+ context_length: int = 128000 # 128K context length
43
+ vocabulary_size: int = 32064
44
+
45
+ class Phi35MoEExpert:
46
+ """Individual expert using Phi-3.5-MoE model"""
47
+
48
+ def __init__(self, config: Phi35MoEConfig, expert_id: str, specialization: str):
49
+ self.config = config
50
+ self.expert_id = expert_id
51
+ self.specialization = specialization
52
+ self.model = None
53
+ self.tokenizer = None
54
+ self.pipeline = None
55
+ self.is_loaded = False
56
+
57
+ async def load_model(self):
58
+ """Load the Phi-3.5-MoE model and tokenizer"""
59
+ try:
60
+ logger.info(f"Loading Phi-3.5-MoE model for expert {self.expert_id}")
61
+
62
+ # Load model
63
+ self.model = AutoModelForCausalLM.from_pretrained(
64
+ self.config.model_name,
65
+ device_map=self.config.device_map,
66
+ torch_dtype=self.config.torch_dtype,
67
+ trust_remote_code=self.config.trust_remote_code,
68
+ )
69
+
70
+ # Load tokenizer
71
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
72
+
73
+ # Create pipeline
74
+ self.pipeline = pipeline(
75
+ "text-generation",
76
+ model=self.model,
77
+ tokenizer=self.tokenizer,
78
+ device_map=self.config.device_map,
79
+ )
80
+
81
+ self.is_loaded = True
82
+ logger.info(f"Successfully loaded Phi-3.5-MoE model for expert {self.expert_id}")
83
+
84
+ except Exception as e:
85
+ logger.error(f"Failed to load Phi-3.5-MoE model for expert {self.expert_id}: {e}")
86
+ raise
87
+
88
+ async def generate_response(
89
+ self,
90
+ messages: List[Dict[str, str]],
91
+ **generation_kwargs
92
+ ) -> Dict[str, Any]:
93
+ """Generate response using Phi-3.5-MoE model"""
94
+ if not self.is_loaded:
95
+ await self.load_model()
96
+
97
+ try:
98
+ # Set default generation parameters
99
+ generation_args = {
100
+ "max_new_tokens": self.config.max_new_tokens,
101
+ "temperature": self.config.temperature,
102
+ "top_p": self.config.top_p,
103
+ "do_sample": self.config.do_sample,
104
+ "return_full_text": False,
105
+ **generation_kwargs
106
+ }
107
+
108
+ # Generate response
109
+ response = self.pipeline(messages, **generation_args)
110
+
111
+ return {
112
+ "expert_id": self.expert_id,
113
+ "specialization": self.specialization,
114
+ "response": response[0]["generated_text"] if response else "",
115
+ "model": self.config.model_name,
116
+ "success": True
117
+ }
118
+
119
+ except Exception as e:
120
+ logger.error(f"Error generating response for expert {self.expert_id}: {e}")
121
+ return {
122
+ "expert_id": self.expert_id,
123
+ "specialization": self.specialization,
124
+ "response": "",
125
+ "error": str(e),
126
+ "success": False
127
+ }
128
+
129
+ class Phi35MoERouter:
130
+ """Router for selecting appropriate Phi-3.5-MoE experts"""
131
+
132
+ def __init__(self):
133
+ self.expert_specializations = {
134
+ "code": ["programming", "software", "development", "coding", "algorithm", "python", "javascript", "java", "function", "code"],
135
+ "math": ["mathematics", "calculation", "equation", "formula", "statistics", "derivative", "integral", "algebra", "calculus", "math", "solve", "calculate"],
136
+ "reasoning": ["logic", "analysis", "reasoning", "problem-solving", "critical", "explain", "why", "how", "because"],
137
+ "multilingual": ["translation", "language", "multilingual", "localization", "translate", "spanish", "french", "german"],
138
+ "general": ["general", "conversation", "assistance", "help", "hello", "hi", "what", "who", "when", "where"]
139
+ }
140
+
141
+ def route_query(self, query: str, available_experts: List[Phi35MoEExpert]) -> List[Phi35MoEExpert]:
142
+ """Route query to appropriate experts based on content analysis"""
143
+ query_lower = query.lower()
144
+ selected_experts = []
145
+
146
+ # Score each expert based on query content
147
+ expert_scores = {}
148
+ for expert in available_experts:
149
+ score = 0
150
+ for keyword in self.expert_specializations.get(expert.specialization, []):
151
+ if keyword in query_lower:
152
+ score += 1
153
+ expert_scores[expert.expert_id] = score
154
+
155
+ # Select experts with highest scores
156
+ sorted_experts = sorted(expert_scores.items(), key=lambda x: x[1], reverse=True)
157
+
158
+ # Always include at least one expert (preferably general)
159
+ if sorted_experts and sorted_experts[0][1] > 0:
160
+ # Select experts with score > 0
161
+ for expert_id, score in sorted_experts:
162
+ if score > 0:
163
+ expert = next((e for e in available_experts if e.expert_id == expert_id), None)
164
+ if expert:
165
+ selected_experts.append(expert)
166
+ else:
167
+ # Fallback to general expert
168
+ general_expert = next((e for e in available_experts if e.specialization == "general"), None)
169
+ if general_expert:
170
+ selected_experts.append(general_expert)
171
+
172
+ return selected_experts[:3] # Limit to top 3 experts
173
+
174
+ class EnhancedMoEFramework:
175
+ """Enhanced MoE framework using Phi-3.5-MoE as baseline"""
176
+
177
+ def __init__(self, config: Optional[Phi35MoEConfig] = None):
178
+ self.config = config or Phi35MoEConfig()
179
+ self.experts: Dict[str, Phi35MoEExpert] = {}
180
+ self.router = Phi35MoERouter()
181
+ self.is_initialized = False
182
+
183
+ async def initialize_experts(self, expert_configs: List[Dict[str, str]]):
184
+ """Initialize multiple Phi-3.5-MoE experts with different specializations"""
185
+ logger.info("Initializing Phi-3.5-MoE experts...")
186
+
187
+ for expert_config in expert_configs:
188
+ expert_id = expert_config["expert_id"]
189
+ specialization = expert_config["specialization"]
190
+
191
+ expert = Phi35MoEExpert(
192
+ config=self.config,
193
+ expert_id=expert_id,
194
+ specialization=specialization
195
+ )
196
+
197
+ self.experts[expert_id] = expert
198
+
199
+ self.is_initialized = True
200
+ logger.info(f"Initialized {len(self.experts)} Phi-3.5-MoE experts")
201
+
202
+ async def process_query(
203
+ self,
204
+ query: str,
205
+ system_message: Optional[str] = None,
206
+ use_multiple_experts: bool = True
207
+ ) -> Dict[str, Any]:
208
+ """Process query using the enhanced MoE framework"""
209
+ if not self.is_initialized:
210
+ raise RuntimeError("MoE framework not initialized. Call initialize_experts() first.")
211
+
212
+ # Prepare messages
213
+ messages = []
214
+ if system_message:
215
+ messages.append({"role": "system", "content": system_message})
216
+ messages.append({"role": "user", "content": query})
217
+
218
+ if use_multiple_experts:
219
+ # Route to multiple experts
220
+ selected_experts = self.router.route_query(query, list(self.experts.values()))
221
+
222
+ # Generate responses from selected experts
223
+ tasks = []
224
+ for expert in selected_experts:
225
+ task = expert.generate_response(messages)
226
+ tasks.append(task)
227
+
228
+ # Wait for all responses
229
+ responses = await asyncio.gather(*tasks, return_exceptions=True)
230
+
231
+ # Process responses
232
+ successful_responses = []
233
+ for response in responses:
234
+ if isinstance(response, dict) and response.get("success", False):
235
+ successful_responses.append(response)
236
+
237
+ # Combine responses
238
+ combined_response = self._combine_responses(successful_responses)
239
+
240
+ return {
241
+ "query": query,
242
+ "responses": successful_responses,
243
+ "combined_response": combined_response,
244
+ "num_experts_used": len(successful_responses),
245
+ "success": len(successful_responses) > 0
246
+ }
247
+ else:
248
+ # Use single expert (general)
249
+ general_expert = self.experts.get("general")
250
+ if not general_expert:
251
+ general_expert = list(self.experts.values())[0] # Fallback to first expert
252
+
253
+ response = await general_expert.generate_response(messages)
254
+ return {
255
+ "query": query,
256
+ "response": response,
257
+ "success": response.get("success", False)
258
+ }
259
+
260
+ def _combine_responses(self, responses: List[Dict[str, Any]]) -> str:
261
+ """Combine multiple expert responses into a coherent answer"""
262
+ if not responses:
263
+ return "No responses available."
264
+
265
+ if len(responses) == 1:
266
+ return responses[0]["response"]
267
+
268
+ # Combine responses with expert attribution
269
+ combined = "Based on analysis from multiple experts:\n\n"
270
+ for i, response in enumerate(responses, 1):
271
+ expert_id = response.get("expert_id", f"Expert {i}")
272
+ specialization = response.get("specialization", "general")
273
+ expert_response = response.get("response", "")
274
+
275
+ combined += f"**{expert_id} ({specialization}):**\n{expert_response}\n\n"
276
+
277
+ return combined.strip()
278
+
279
+ async def benchmark_performance(self, test_queries: List[str]) -> Dict[str, Any]:
280
+ """Benchmark the enhanced MoE framework performance"""
281
+ logger.info("Starting Phi-3.5-MoE framework benchmark...")
282
+
283
+ results = {
284
+ "total_queries": len(test_queries),
285
+ "successful_queries": 0,
286
+ "failed_queries": 0,
287
+ "average_response_time": 0,
288
+ "expert_usage_stats": {},
289
+ "detailed_results": []
290
+ }
291
+
292
+ total_time = 0
293
+
294
+ for i, query in enumerate(test_queries):
295
+ start_time = asyncio.get_event_loop().time()
296
+
297
+ try:
298
+ result = await self.process_query(query)
299
+ end_time = asyncio.get_event_loop().time()
300
+ response_time = end_time - start_time
301
+ total_time += response_time
302
+
303
+ if result.get("success", False):
304
+ results["successful_queries"] += 1
305
+ else:
306
+ results["failed_queries"] += 1
307
+
308
+ # Track expert usage
309
+ if "responses" in result:
310
+ for response in result["responses"]:
311
+ expert_id = response.get("expert_id", "unknown")
312
+ results["expert_usage_stats"][expert_id] = results["expert_usage_stats"].get(expert_id, 0) + 1
313
+
314
+ results["detailed_results"].append({
315
+ "query": query,
316
+ "success": result.get("success", False),
317
+ "response_time": response_time,
318
+ "num_experts_used": result.get("num_experts_used", 1)
319
+ })
320
+
321
+ except Exception as e:
322
+ logger.error(f"Error processing query {i}: {e}")
323
+ results["failed_queries"] += 1
324
+ results["detailed_results"].append({
325
+ "query": query,
326
+ "success": False,
327
+ "error": str(e),
328
+ "response_time": 0
329
+ })
330
+
331
+ results["average_response_time"] = total_time / len(test_queries) if test_queries else 0
332
+ results["success_rate"] = results["successful_queries"] / results["total_queries"] if test_queries else 0
333
+
334
+ logger.info(f"Benchmark completed. Success rate: {results['success_rate']:.2%}")
335
+ return results
336
+
337
+ # Example usage and testing
338
+ async def main():
339
+ """Example usage of the enhanced MoE framework"""
340
+
341
+ # Initialize configuration
342
+ config = Phi35MoEConfig(
343
+ max_new_tokens=300,
344
+ temperature=0.7
345
+ )
346
+
347
+ # Initialize framework
348
+ moe_framework = EnhancedMoEFramework(config)
349
+
350
+ # Define expert configurations
351
+ expert_configs = [
352
+ {"expert_id": "code_expert", "specialization": "code"},
353
+ {"expert_id": "math_expert", "specialization": "math"},
354
+ {"expert_id": "reasoning_expert", "specialization": "reasoning"},
355
+ {"expert_id": "general_expert", "specialization": "general"}
356
+ ]
357
+
358
+ # Initialize experts
359
+ await moe_framework.initialize_experts(expert_configs)
360
+
361
+ # Test queries
362
+ test_queries = [
363
+ "How do I implement a binary search algorithm in Python?",
364
+ "What is the derivative of x^2 + 3x + 1?",
365
+ "Explain the logical reasoning behind the Monty Hall problem",
366
+ "Hello, how are you today?"
367
+ ]
368
+
369
+ # Process queries
370
+ for query in test_queries:
371
+ print(f"\n{'='*50}")
372
+ print(f"Query: {query}")
373
+ print(f"{'='*50}")
374
+
375
+ result = await moe_framework.process_query(query)
376
+
377
+ if result.get("success", False):
378
+ if "combined_response" in result:
379
+ print(f"Combined Response:\n{result['combined_response']}")
380
+ else:
381
+ print(f"Response:\n{result['response']['response']}")
382
+ else:
383
+ print("Failed to generate response")
384
+
385
+ # Run benchmark
386
+ print(f"\n{'='*50}")
387
+ print("Running Performance Benchmark...")
388
+ print(f"{'='*50}")
389
+
390
+ benchmark_results = await moe_framework.benchmark_performance(test_queries)
391
+
392
+ print(f"Success Rate: {benchmark_results['success_rate']:.2%}")
393
+ print(f"Average Response Time: {benchmark_results['average_response_time']:.2f}s")
394
+ print(f"Expert Usage Stats: {benchmark_results['expert_usage_stats']}")
395
+
396
+ if __name__ == "__main__":
397
+ asyncio.run(main())
src/multi_agent_datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Multi-agent dataset loading package
src/multi_agent_datasets/multi_agent_loader.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-Agent Dataset Loader
4
+
5
+ This module provides comprehensive support for loading and processing multi-agent datasets
6
+ with two supported patterns:
7
+ A) Single folder with JSONLs that include an "agent" field
8
+ B) Per-agent subfolders (agent name == folder name)
9
+
10
+ Supports agent balancing, dataset validation, and integration with existing training pipelines.
11
+ """
12
+
13
+ import os
14
+ import json
15
+ import yaml
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Dict, List, Optional, Union, Tuple, Any
19
+ from collections import Counter, defaultdict
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
24
+ from transformers import AutoTokenizer
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ @dataclass
29
+ class MultiAgentDatasetConfig:
30
+ """Configuration for multi-agent dataset loading"""
31
+ dataset_path: str
32
+ agents_file: Optional[str] = None
33
+ agent_prefix: str = "<|agent:"
34
+ agent_suffix: str = "|>"
35
+ balance_agents: bool = False
36
+ balance_cap: Optional[int] = None
37
+ max_seq_length: int = 2048
38
+ validation_split: float = 0.1
39
+ seed: int = 42
40
+
41
+ class MultiAgentDatasetLoader:
42
+ """
43
+ Multi-agent dataset loader supporting two patterns:
44
+ 1. Single folder with JSONLs containing 'agent' field
45
+ 2. Per-agent subfolders with agent name == folder name
46
+ """
47
+
48
+ def __init__(self, config: MultiAgentDatasetConfig):
49
+ self.config = config
50
+ self.agents = []
51
+ self.dataset_stats = {}
52
+
53
+ def read_agents_yaml(self, path: str) -> List[str]:
54
+ """Read agents list from YAML file"""
55
+ yml_path = os.path.join(path, "agents.yaml")
56
+ if os.path.isfile(yml_path):
57
+ try:
58
+ with open(yml_path, "r") as f:
59
+ obj = yaml.safe_load(f) or {}
60
+ agents = [str(a) for a in obj.get("agents", [])]
61
+ logger.info(f"Loaded {len(agents)} agents from YAML: {agents}")
62
+ return agents
63
+ except Exception as e:
64
+ logger.warning(f"Failed to read agents.yaml: {e}")
65
+ return []
66
+
67
+ def list_agent_subdirs(self, path: str) -> List[Tuple[str, str]]:
68
+ """List agent subdirectories with train/test.jsonl files"""
69
+ items = []
70
+ if not os.path.isdir(path):
71
+ return items
72
+
73
+ for name in sorted(os.listdir(path)):
74
+ subdir_path = os.path.join(path, name)
75
+ if os.path.isdir(subdir_path):
76
+ train_file = os.path.join(subdir_path, "train.jsonl")
77
+ test_file = os.path.join(subdir_path, "test.jsonl")
78
+ if os.path.isfile(train_file) or os.path.isfile(test_file):
79
+ items.append((name, subdir_path))
80
+ logger.debug(f"Found agent subdirectory: {name}")
81
+
82
+ return items
83
+
84
+ def load_single_folder_dataset(self, dataset_path: str) -> DatasetDict:
85
+ """Load dataset from single folder with agent field in rows"""
86
+ data_files = {}
87
+
88
+ train_file = os.path.join(dataset_path, "train.jsonl")
89
+ test_file = os.path.join(dataset_path, "test.jsonl")
90
+
91
+ if os.path.isfile(train_file):
92
+ data_files["train"] = train_file
93
+ if os.path.isfile(test_file):
94
+ data_files["test"] = test_file
95
+
96
+ if not data_files:
97
+ raise FileNotFoundError(f"No dataset files found in {dataset_path}")
98
+
99
+ logger.info(f"Loading single folder dataset from {data_files}")
100
+ dataset = load_dataset("json", data_files=data_files)
101
+
102
+ # Validate that agent field exists
103
+ for split_name, split_data in dataset.items():
104
+ if "agent" not in split_data.column_names:
105
+ raise ValueError(f"Agent field not found in {split_name} split")
106
+
107
+ return dataset
108
+
109
+ def load_subfolder_dataset(self, dataset_path: str) -> DatasetDict:
110
+ """Load dataset from per-agent subfolders"""
111
+ subdirs = self.list_agent_subdirs(dataset_path)
112
+ if not subdirs:
113
+ raise FileNotFoundError(f"No agent subdirectories found in {dataset_path}")
114
+
115
+ parts_train, parts_test = [], []
116
+
117
+ for agent_name, agent_dir in subdirs:
118
+ train_file = os.path.join(agent_dir, "train.jsonl")
119
+ test_file = os.path.join(agent_dir, "test.jsonl")
120
+
121
+ def add_agent_field(example):
122
+ example["agent"] = agent_name
123
+ return example
124
+
125
+ if os.path.isfile(train_file):
126
+ logger.debug(f"Loading train data for agent: {agent_name}")
127
+ train_data = load_dataset("json", data_files={"train": train_file})["train"]
128
+ train_data = train_data.map(add_agent_field)
129
+ parts_train.append(train_data)
130
+
131
+ if os.path.isfile(test_file):
132
+ logger.debug(f"Loading test data for agent: {agent_name}")
133
+ test_data = load_dataset("json", data_files={"test": test_file})["test"]
134
+ test_data = test_data.map(add_agent_field)
135
+ parts_test.append(test_data)
136
+
137
+ dataset_dict = {}
138
+ if parts_train:
139
+ dataset_dict["train"] = concatenate_datasets(parts_train)
140
+ if parts_test:
141
+ dataset_dict["test"] = concatenate_datasets(parts_test)
142
+
143
+ if not dataset_dict:
144
+ raise ValueError("No data splits found in agent subdirectories")
145
+
146
+ return DatasetDict(dataset_dict)
147
+
148
+ def load_multiagent_dataset(self) -> DatasetDict:
149
+ """
150
+ Load multi-agent dataset supporting both patterns:
151
+ - Single folder with 'agent' field in rows
152
+ - Per-agent subfolders
153
+ """
154
+ dataset_path = self.config.dataset_path
155
+
156
+ # Try single folder pattern first
157
+ if os.path.isfile(os.path.join(dataset_path, "train.jsonl")):
158
+ logger.info("Loading dataset using single folder pattern")
159
+ return self.load_single_folder_dataset(dataset_path)
160
+
161
+ # Try subfolder pattern
162
+ logger.info("Loading dataset using subfolder pattern")
163
+ return self.load_subfolder_dataset(dataset_path)
164
+
165
+ def infer_agents_from_dataset(self, dataset: DatasetDict) -> List[str]:
166
+ """Infer agent list from dataset"""
167
+ agents = set()
168
+
169
+ for split_name, split_data in dataset.items():
170
+ if "agent" in split_data.column_names:
171
+ agent_values = [a for a in set(split_data["agent"]) if a is not None]
172
+ agents.update(agent_values)
173
+ logger.debug(f"Found agents in {split_name}: {agent_values}")
174
+
175
+ agents_list = sorted(list(agents))
176
+ logger.info(f"Inferred {len(agents_list)} agents from dataset: {agents_list}")
177
+ return agents_list
178
+
179
+ def resolve_agents_list(self, dataset: DatasetDict) -> List[str]:
180
+ """Resolve agents list from YAML file or dataset inference"""
181
+ agents = []
182
+
183
+ # Try to load from agents file first
184
+ if self.config.agents_file and os.path.isfile(self.config.agents_file):
185
+ try:
186
+ with open(self.config.agents_file, "r") as f:
187
+ obj = yaml.safe_load(f) or {}
188
+ agents = [str(a) for a in obj.get("agents", [])]
189
+ logger.info(f"Loaded agents from file: {agents}")
190
+ except Exception as e:
191
+ logger.warning(f"Failed to load agents from file: {e}")
192
+
193
+ # Fall back to dataset inference
194
+ if not agents:
195
+ agents = self.infer_agents_from_dataset(dataset)
196
+
197
+ self.agents = agents
198
+ return agents
199
+
200
+ def balance_by_agent(self, dataset: Dataset, agent_col: str = "agent") -> Dataset:
201
+ """
202
+ Balance dataset by upsampling minority agents to the max count
203
+ """
204
+ if agent_col not in dataset.column_names:
205
+ logger.warning(f"Agent column '{agent_col}' not found, skipping balancing")
206
+ return dataset
207
+
208
+ counts = Counter(dataset[agent_col])
209
+ if not counts:
210
+ logger.warning("No agent counts found, skipping balancing")
211
+ return dataset
212
+
213
+ max_count = max(counts.values())
214
+ if self.config.balance_cap:
215
+ max_count = min(max_count, self.config.balance_cap)
216
+
217
+ logger.info(f"Balancing agents. Current counts: {dict(counts)}")
218
+ logger.info(f"Target count per agent: {max_count}")
219
+
220
+ parts = []
221
+ for agent, count in counts.items():
222
+ agent_subset = dataset.filter(lambda x: x[agent_col] == agent)
223
+ parts.append(agent_subset)
224
+
225
+ # Calculate how many additional samples we need
226
+ needed = max_count - count
227
+ if needed > 0:
228
+ agent_subset_len = len(agent_subset)
229
+ if agent_subset_len == 0:
230
+ logger.warning(f"Agent '{agent}' has zero samples, cannot upsample.")
231
+ continue
232
+ # Calculate repetitions needed
233
+ reps = needed // agent_subset_len
234
+ remainder = needed % agent_subset_len
235
+
236
+ # Add full repetitions
237
+ for _ in range(reps):
238
+ parts.append(agent_subset)
239
+
240
+ # Add remainder samples
241
+ if remainder > 0:
242
+ remainder_subset = agent_subset.shuffle(seed=self.config.seed).select(range(remainder))
243
+ parts.append(remainder_subset)
244
+
245
+ balanced_dataset = concatenate_datasets(parts).shuffle(seed=self.config.seed)
246
+
247
+ # Log final counts
248
+ final_counts = Counter(balanced_dataset[agent_col])
249
+ logger.info(f"Balanced dataset counts: {dict(final_counts)}")
250
+
251
+ return balanced_dataset
252
+
253
+ def apply_agent_prefix(self, dataset: Dataset, tokenizer: AutoTokenizer) -> Dataset:
254
+ """
255
+ Apply agent prefix to dataset text using chat template or direct text
256
+ """
257
+ def add_agent_prefix(example):
258
+ agent = example.get("agent", None)
259
+ prefix = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}\n" if agent else ""
260
+
261
+ # Handle different text formats
262
+ if "messages" in example and example["messages"] is not None:
263
+ # Use chat template if available
264
+ try:
265
+ text = tokenizer.apply_chat_template(
266
+ example["messages"],
267
+ tokenize=False,
268
+ add_generation_prompt=False
269
+ )
270
+ example["text"] = prefix + text
271
+ except Exception as e:
272
+ logger.warning(f"Failed to apply chat template: {e}")
273
+ # Fallback to simple concatenation
274
+ text = str(example["messages"])
275
+ example["text"] = prefix + text
276
+
277
+ elif "text" in example and example["text"] is not None:
278
+ example["text"] = prefix + example["text"]
279
+
280
+ else:
281
+ # Handle prompt/response format
282
+ prompt = example.get("prompt", "")
283
+ response = example.get("response", "")
284
+ example["text"] = prefix + prompt + ("\n" if response else "") + response
285
+
286
+ return example
287
+
288
+ # Get original features to preserve
289
+ original_features = list(dataset.features)
290
+ features_to_remove = [f for f in original_features if f not in ["text", "agent"]]
291
+
292
+ logger.info("Applying agent prefixes to dataset")
293
+ processed_dataset = dataset.map(
294
+ add_agent_prefix,
295
+ remove_columns=features_to_remove,
296
+ desc="Adding agent prefixes"
297
+ )
298
+
299
+ return processed_dataset
300
+
301
+ def validate_dataset(self, dataset: DatasetDict) -> Dict[str, Any]:
302
+ """Validate dataset and return statistics"""
303
+ stats = {
304
+ "total_samples": 0,
305
+ "agents": {},
306
+ "splits": {},
307
+ "validation_errors": []
308
+ }
309
+
310
+ for split_name, split_data in dataset.items():
311
+ split_stats = {
312
+ "samples": len(split_data),
313
+ "agents": {},
314
+ "columns": split_data.column_names
315
+ }
316
+
317
+ stats["total_samples"] += len(split_data)
318
+
319
+ # Validate required columns
320
+ if "agent" not in split_data.column_names:
321
+ stats["validation_errors"].append(f"Missing 'agent' column in {split_name}")
322
+
323
+ # Count agents
324
+ if "agent" in split_data.column_names:
325
+ agent_counts = Counter(split_data["agent"])
326
+ split_stats["agents"] = dict(agent_counts)
327
+
328
+ # Update global agent counts
329
+ for agent, count in agent_counts.items():
330
+ if agent not in stats["agents"]:
331
+ stats["agents"][agent] = 0
332
+ stats["agents"][agent] += count
333
+
334
+ stats["splits"][split_name] = split_stats
335
+
336
+ self.dataset_stats = stats
337
+ logger.info(f"Dataset validation complete. Stats: {stats}")
338
+
339
+ return stats
340
+
341
+ def load_and_process(self, tokenizer: AutoTokenizer) -> Tuple[DatasetDict, List[str], Dict[str, Any]]:
342
+ """
343
+ Complete dataset loading and processing pipeline
344
+ """
345
+ logger.info(f"Loading multi-agent dataset from {self.config.dataset_path}")
346
+
347
+ # Load dataset
348
+ dataset = self.load_multiagent_dataset()
349
+
350
+ # Resolve agents list
351
+ agents = self.resolve_agents_list(dataset)
352
+
353
+ # Validate dataset
354
+ stats = self.validate_dataset(dataset)
355
+
356
+ # Apply agent prefixes
357
+ if "train" in dataset:
358
+ dataset["train"] = self.apply_agent_prefix(dataset["train"], tokenizer)
359
+ if "test" in dataset:
360
+ dataset["test"] = self.apply_agent_prefix(dataset["test"], tokenizer)
361
+
362
+ # Balance agents if requested
363
+ if self.config.balance_agents and "train" in dataset:
364
+ dataset["train"] = self.balance_by_agent(dataset["train"])
365
+
366
+ logger.info(f"Dataset processing complete. Loaded {len(agents)} agents with {stats['total_samples']} total samples")
367
+
368
+ return dataset, agents, stats
369
+
370
+ class MultiAgentDatasetValidator:
371
+ """Validator for multi-agent datasets"""
372
+
373
+ @staticmethod
374
+ def validate_jsonl_file(file_path: str) -> List[str]:
375
+ """Validate JSONL file format and content"""
376
+ errors = []
377
+
378
+ if not os.path.isfile(file_path):
379
+ errors.append(f"File not found: {file_path}")
380
+ return errors
381
+
382
+ try:
383
+ with open(file_path, 'r') as f:
384
+ for line_num, line in enumerate(f, 1):
385
+ line = line.strip()
386
+ if not line:
387
+ continue
388
+
389
+ try:
390
+ data = json.loads(line)
391
+
392
+ # Check required fields
393
+ if not isinstance(data, dict):
394
+ errors.append(f"Line {line_num}: Not a JSON object")
395
+ continue
396
+
397
+ # Check for agent field
398
+ if "agent" not in data:
399
+ errors.append(f"Line {line_num}: Missing 'agent' field")
400
+
401
+ # Check for text content
402
+ has_text = any(field in data for field in ["text", "messages", "prompt"])
403
+ if not has_text:
404
+ errors.append(f"Line {line_num}: No text content found")
405
+
406
+ except json.JSONDecodeError as e:
407
+ errors.append(f"Line {line_num}: JSON decode error - {e}")
408
+
409
+ except Exception as e:
410
+ errors.append(f"File read error: {e}")
411
+
412
+ return errors
413
+
414
+ @staticmethod
415
+ def validate_dataset_structure(dataset_path: str) -> Dict[str, Any]:
416
+ """Validate complete dataset structure"""
417
+ validation_result = {
418
+ "valid": True,
419
+ "errors": [],
420
+ "warnings": [],
421
+ "structure": {}
422
+ }
423
+
424
+ if not os.path.isdir(dataset_path):
425
+ validation_result["valid"] = False
426
+ validation_result["errors"].append(f"Dataset path is not a directory: {dataset_path}")
427
+ return validation_result
428
+
429
+ # Check for single folder pattern
430
+ train_file = os.path.join(dataset_path, "train.jsonl")
431
+ test_file = os.path.join(dataset_path, "test.jsonl")
432
+
433
+ if os.path.isfile(train_file):
434
+ validation_result["structure"]["pattern"] = "single_folder"
435
+ validation_result["structure"]["files"] = []
436
+
437
+ if os.path.isfile(train_file):
438
+ validation_result["structure"]["files"].append("train.jsonl")
439
+ errors = MultiAgentDatasetValidator.validate_jsonl_file(train_file)
440
+ validation_result["errors"].extend(errors)
441
+
442
+ if os.path.isfile(test_file):
443
+ validation_result["structure"]["files"].append("test.jsonl")
444
+ errors = MultiAgentDatasetValidator.validate_jsonl_file(test_file)
445
+ validation_result["errors"].extend(errors)
446
+
447
+ else:
448
+ # Check for subfolder pattern
449
+ validation_result["structure"]["pattern"] = "subfolders"
450
+ validation_result["structure"]["agents"] = []
451
+
452
+ for item in os.listdir(dataset_path):
453
+ item_path = os.path.join(dataset_path, item)
454
+ if os.path.isdir(item_path):
455
+ agent_train = os.path.join(item_path, "train.jsonl")
456
+ agent_test = os.path.join(item_path, "test.jsonl")
457
+
458
+ if os.path.isfile(agent_train) or os.path.isfile(agent_test):
459
+ validation_result["structure"]["agents"].append(item)
460
+
461
+ if os.path.isfile(agent_train):
462
+ errors = MultiAgentDatasetValidator.validate_jsonl_file(agent_train)
463
+ validation_result["errors"].extend([f"{item}/train.jsonl: {e}" for e in errors])
464
+
465
+ if os.path.isfile(agent_test):
466
+ errors = MultiAgentDatasetValidator.validate_jsonl_file(agent_test)
467
+ validation_result["errors"].extend([f"{item}/test.jsonl: {e}" for e in errors])
468
+
469
+ # Check for agents.yaml
470
+ agents_yaml = os.path.join(dataset_path, "agents.yaml")
471
+ if os.path.isfile(agents_yaml):
472
+ validation_result["structure"]["has_agents_yaml"] = True
473
+ try:
474
+ with open(agents_yaml, 'r') as f:
475
+ yaml.safe_load(f)
476
+ except Exception as e:
477
+ validation_result["warnings"].append(f"Invalid agents.yaml: {e}")
478
+ else:
479
+ validation_result["structure"]["has_agents_yaml"] = False
480
+
481
+ validation_result["valid"] = len(validation_result["errors"]) == 0
482
+
483
+ return validation_result
484
+
485
+ # Example usage and testing
486
+ if __name__ == "__main__":
487
+ # Configure logging
488
+ logging.basicConfig(level=logging.INFO)
489
+
490
+ # Example configuration
491
+ config = MultiAgentDatasetConfig(
492
+ dataset_path="/path/to/dataset",
493
+ balance_agents=True,
494
+ balance_cap=1000
495
+ )
496
+
497
+ # Create loader
498
+ loader = MultiAgentDatasetLoader(config)
499
+
500
+ # Example tokenizer (would be loaded from actual model)
501
+ from transformers import AutoTokenizer
502
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
503
+
504
+ try:
505
+ # Load and process dataset
506
+ dataset, agents, stats = loader.load_and_process(tokenizer)
507
+
508
+ print(f"Loaded dataset with {len(agents)} agents:")
509
+ for agent in agents:
510
+ print(f" - {agent}")
511
+
512
+ print(f"Dataset stats: {stats}")
513
+
514
+ except Exception as e:
515
+ print(f"Error loading dataset: {e}")
src/multi_agent_tokenization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Agent tokenization package
src/multi_agent_tokenization/agent_tokenizer.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Agent Token Management System
4
+
5
+ This module provides comprehensive agent token management for multi-agent training,
6
+ including special token handling, embedding management, and integration with
7
+ existing tokenization systems.
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import logging
13
+ from typing import Dict, List, Optional, Tuple, Any, Union
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ from transformers import AutoTokenizer, PreTrainedTokenizer
19
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ @dataclass
24
+ class AgentTokenConfig:
25
+ """Configuration for agent token management"""
26
+ agent_prefix: str = "<|agent:"
27
+ agent_suffix: str = "|>"
28
+ special_tokens: Optional[Dict[str, str]] = None
29
+ resize_embeddings: bool = True
30
+ save_tokens: bool = True
31
+ tokens_file: str = "agent_tokens.json"
32
+
33
+ class AgentTokenManager:
34
+ """
35
+ Manages agent-specific tokens and their integration with tokenizers
36
+ """
37
+
38
+ def __init__(self, config: AgentTokenConfig):
39
+ self.config = config
40
+ self.agent_tokens: Dict[str, str] = {}
41
+ self.token_ids: Dict[str, int] = {}
42
+ self.original_vocab_size: Optional[int] = None
43
+
44
+ def generate_agent_tokens(self, agents: List[str]) -> List[str]:
45
+ """Generate agent tokens for given agent list"""
46
+ tokens = []
47
+ for agent in agents:
48
+ token = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}"
49
+ tokens.append(token)
50
+ self.agent_tokens[agent] = token
51
+
52
+ logger.info(f"Generated {len(tokens)} agent tokens: {tokens}")
53
+ return tokens
54
+
55
+ def add_agent_tokens_to_tokenizer(self, tokenizer: PreTrainedTokenizer, agents: List[str]) -> Tuple[PreTrainedTokenizer, List[str]]:
56
+ """
57
+ Add agent tokens to tokenizer and return updated tokenizer with token list
58
+ """
59
+ if not agents:
60
+ logger.warning("No agents provided, skipping token addition")
61
+ return tokenizer, []
62
+
63
+ # Generate agent tokens
64
+ agent_tokens = self.generate_agent_tokens(agents)
65
+
66
+ # Check which tokens need to be added
67
+ existing_tokens = set(tokenizer.get_vocab().keys())
68
+ tokens_to_add = [token for token in agent_tokens if token not in existing_tokens]
69
+
70
+ if not tokens_to_add:
71
+ logger.info("All agent tokens already exist in tokenizer")
72
+ return tokenizer, agent_tokens
73
+
74
+ # Store original vocab size
75
+ self.original_vocab_size = len(tokenizer)
76
+
77
+ # Add special tokens
78
+ logger.info(f"Adding {len(tokens_to_add)} new agent tokens to tokenizer")
79
+ tokenizer.add_special_tokens({
80
+ "additional_special_tokens": tokens_to_add
81
+ })
82
+
83
+ # Update token IDs mapping
84
+ for agent, token in self.agent_tokens.items():
85
+ if token in tokenizer.get_vocab():
86
+ self.token_ids[agent] = tokenizer.convert_tokens_to_ids(token)
87
+
88
+ logger.info(f"Added agent tokens. New vocab size: {len(tokenizer)}")
89
+ return tokenizer, agent_tokens
90
+
91
+ def resize_model_embeddings(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer) -> torch.nn.Module:
92
+ """
93
+ Resize model embeddings to accommodate new agent tokens
94
+ """
95
+ if not self.config.resize_embeddings:
96
+ logger.info("Embedding resize disabled, skipping")
97
+ return model
98
+
99
+ if self.original_vocab_size is None:
100
+ logger.warning("Original vocab size not set, cannot resize embeddings")
101
+ return model
102
+
103
+ new_vocab_size = len(tokenizer)
104
+ if new_vocab_size == self.original_vocab_size:
105
+ logger.info("Vocab size unchanged, no embedding resize needed")
106
+ return model
107
+
108
+ logger.info(f"Resizing model embeddings from {self.original_vocab_size} to {new_vocab_size}")
109
+
110
+ # Resize embeddings
111
+ model.resize_token_embeddings(new_vocab_size)
112
+
113
+ # Initialize new embeddings (copy from unk token or use random initialization)
114
+ if hasattr(model, 'get_input_embeddings'):
115
+ embeddings = model.get_input_embeddings()
116
+ if hasattr(embeddings, 'weight'):
117
+ with torch.no_grad():
118
+ # Initialize new embeddings with small random values
119
+ new_embeddings = embeddings.weight[self.original_vocab_size:]
120
+ torch.nn.init.normal_(new_embeddings, mean=0.0, std=0.02)
121
+
122
+ logger.info("Model embeddings resized successfully")
123
+ return model
124
+
125
+ def format_agent_prompt(self, agent: str, text: str) -> str:
126
+ """Format text with agent token prefix"""
127
+ if agent not in self.agent_tokens:
128
+ logger.warning(f"Agent '{agent}' not found in token mappings")
129
+ return text
130
+
131
+ agent_token = self.agent_tokens[agent]
132
+ return f"{agent_token}\n{text}"
133
+
134
+ def extract_agent_from_text(self, text: str) -> Optional[str]:
135
+ """Extract agent name from text if it starts with agent token"""
136
+ for agent, token in self.agent_tokens.items():
137
+ if text.startswith(token):
138
+ return agent
139
+ return None
140
+
141
+ def get_agent_token_id(self, agent: str) -> Optional[int]:
142
+ """Get token ID for agent token"""
143
+ return self.token_ids.get(agent)
144
+
145
+ def save_agent_tokens(self, output_dir: str) -> str:
146
+ """Save agent tokens to file"""
147
+ if not self.config.save_tokens:
148
+ return ""
149
+
150
+ os.makedirs(output_dir, exist_ok=True)
151
+ tokens_file = os.path.join(output_dir, self.config.tokens_file)
152
+
153
+ tokens_data = {
154
+ "agent_tokens": self.agent_tokens,
155
+ "token_ids": self.token_ids,
156
+ "config": {
157
+ "agent_prefix": self.config.agent_prefix,
158
+ "agent_suffix": self.config.agent_suffix,
159
+ "original_vocab_size": self.original_vocab_size
160
+ }
161
+ }
162
+
163
+ with open(tokens_file, 'w') as f:
164
+ json.dump(tokens_data, f, indent=2)
165
+
166
+ logger.info(f"Saved agent tokens to {tokens_file}")
167
+ return tokens_file
168
+
169
+ def load_agent_tokens(self, tokens_file: str) -> bool:
170
+ """Load agent tokens from file"""
171
+ if not os.path.isfile(tokens_file):
172
+ logger.warning(f"Agent tokens file not found: {tokens_file}")
173
+ return False
174
+
175
+ try:
176
+ with open(tokens_file, 'r') as f:
177
+ tokens_data = json.load(f)
178
+
179
+ self.agent_tokens = tokens_data.get("agent_tokens", {})
180
+ self.token_ids = tokens_data.get("token_ids", {})
181
+
182
+ config_data = tokens_data.get("config", {})
183
+ self.original_vocab_size = config_data.get("original_vocab_size")
184
+
185
+ logger.info(f"Loaded {len(self.agent_tokens)} agent tokens from {tokens_file}")
186
+ return True
187
+
188
+ except Exception as e:
189
+ logger.error(f"Failed to load agent tokens: {e}")
190
+ return False
191
+
192
+ def get_agent_statistics(self) -> Dict[str, Any]:
193
+ """Get statistics about agent tokens"""
194
+ return {
195
+ "total_agents": len(self.agent_tokens),
196
+ "agents": list(self.agent_tokens.keys()),
197
+ "token_ids": self.token_ids,
198
+ "original_vocab_size": self.original_vocab_size,
199
+ "config": {
200
+ "agent_prefix": self.config.agent_prefix,
201
+ "agent_suffix": self.config.agent_suffix
202
+ }
203
+ }
204
+
205
+ class AgentTokenizer:
206
+ """
207
+ Enhanced tokenizer wrapper that integrates agent token management
208
+ """
209
+
210
+ def __init__(self, tokenizer: PreTrainedTokenizer, agent_manager: AgentTokenManager):
211
+ self.tokenizer = tokenizer
212
+ self.agent_manager = agent_manager
213
+
214
+ def tokenize_agent_text(self, agent: str, text: str, **kwargs) -> Dict[str, Any]:
215
+ """Tokenize text with agent prefix"""
216
+ formatted_text = self.agent_manager.format_agent_prompt(agent, text)
217
+ return self.tokenizer(formatted_text, **kwargs)
218
+
219
+ def decode_agent_tokens(self, token_ids: Union[List[int], torch.Tensor], **kwargs) -> str:
220
+ """Decode token IDs back to text"""
221
+ return self.tokenizer.decode(token_ids, **kwargs)
222
+
223
+ def get_agent_attention_mask(self, input_ids: torch.Tensor, agent: str) -> torch.Tensor:
224
+ """Get attention mask with special handling for agent tokens"""
225
+ attention_mask = torch.ones_like(input_ids)
226
+
227
+ # Find agent token position
228
+ agent_token_id = self.agent_manager.get_agent_token_id(agent)
229
+ if agent_token_id is not None:
230
+ # Ensure agent token is attended to
231
+ agent_positions = (input_ids == agent_token_id)
232
+ attention_mask[agent_positions] = 1
233
+
234
+ return attention_mask
235
+
236
+ def __getattr__(self, name):
237
+ """Delegate unknown attributes to underlying tokenizer"""
238
+ return getattr(self.tokenizer, name)
239
+
240
+ class AgentTokenValidator:
241
+ """Validator for agent token configurations"""
242
+
243
+ @staticmethod
244
+ def validate_agent_tokens(agents: List[str], config: AgentTokenConfig) -> Dict[str, Any]:
245
+ """Validate agent token configuration"""
246
+ validation_result = {
247
+ "valid": True,
248
+ "errors": [],
249
+ "warnings": [],
250
+ "tokens": {}
251
+ }
252
+
253
+ if not agents:
254
+ validation_result["warnings"].append("No agents provided")
255
+ return validation_result
256
+
257
+ # Check for duplicate agents
258
+ if len(agents) != len(set(agents)):
259
+ validation_result["errors"].append("Duplicate agents found")
260
+ validation_result["valid"] = False
261
+
262
+ # Generate and validate tokens
263
+ manager = AgentTokenManager(config)
264
+ tokens = manager.generate_agent_tokens(agents)
265
+
266
+ # Check for token conflicts
267
+ token_set = set(tokens)
268
+ if len(token_set) != len(tokens):
269
+ validation_result["errors"].append("Duplicate tokens generated")
270
+ validation_result["valid"] = False
271
+
272
+ # Check token length
273
+ for agent, token in zip(agents, tokens):
274
+ if len(token) > 50: # Reasonable limit
275
+ validation_result["warnings"].append(f"Long token for agent '{agent}': {token}")
276
+
277
+ validation_result["tokens"] = dict(zip(agents, tokens))
278
+
279
+ return validation_result
280
+
281
+ @staticmethod
282
+ def validate_tokenizer_compatibility(tokenizer: PreTrainedTokenizer, agents: List[str], config: AgentTokenConfig) -> Dict[str, Any]:
283
+ """Validate tokenizer compatibility with agent tokens"""
284
+ validation_result = {
285
+ "compatible": True,
286
+ "errors": [],
287
+ "warnings": [],
288
+ "existing_tokens": [],
289
+ "new_tokens": []
290
+ }
291
+
292
+ if not agents:
293
+ return validation_result
294
+
295
+ # Generate tokens
296
+ manager = AgentTokenManager(config)
297
+ tokens = manager.generate_agent_tokens(agents)
298
+
299
+ # Check existing vocabulary
300
+ vocab = tokenizer.get_vocab()
301
+ for agent, token in zip(agents, tokens):
302
+ if token in vocab:
303
+ validation_result["existing_tokens"].append(agent)
304
+ else:
305
+ validation_result["new_tokens"].append(agent)
306
+
307
+ # Check for potential conflicts
308
+ for token in tokens:
309
+ if token in vocab:
310
+ # Check if it's already a special token
311
+ if hasattr(tokenizer, 'special_tokens_map'):
312
+ special_tokens = tokenizer.special_tokens_map
313
+ if token not in special_tokens.values():
314
+ validation_result["warnings"].append(f"Token '{token}' exists in vocab but not as special token")
315
+
316
+ return validation_result
317
+
318
+ # Integration with existing MoE framework
319
+ class MoEAgentTokenIntegration:
320
+ """
321
+ Integration layer between agent tokens and MoE framework
322
+ """
323
+
324
+ def __init__(self, agent_manager: AgentTokenManager):
325
+ self.agent_manager = agent_manager
326
+ self.agent_to_expert_mapping: Dict[str, str] = {}
327
+
328
+ def map_agent_to_expert(self, agent: str, expert: str):
329
+ """Map agent to MoE expert specialization"""
330
+ self.agent_to_expert_mapping[agent] = expert
331
+ logger.info(f"Mapped agent '{agent}' to expert '{expert}'")
332
+
333
+ def get_expert_for_agent(self, agent: str) -> Optional[str]:
334
+ """Get expert specialization for agent"""
335
+ return self.agent_to_expert_mapping.get(agent)
336
+
337
+ def format_moe_prompt(self, agent: str, text: str, expert: Optional[str] = None) -> str:
338
+ """Format prompt for MoE framework with agent and expert context"""
339
+ # Start with agent token
340
+ formatted_text = self.agent_manager.format_agent_prompt(agent, text)
341
+
342
+ # Add expert context if available
343
+ if expert:
344
+ expert_context = f"\n<|expert:{expert}|>\n"
345
+ formatted_text = formatted_text.replace("\n", expert_context, 1)
346
+
347
+ return formatted_text
348
+
349
+ def extract_agent_and_expert(self, text: str) -> Tuple[Optional[str], Optional[str]]:
350
+ """Extract both agent and expert from formatted text"""
351
+ agent = self.agent_manager.extract_agent_from_text(text)
352
+
353
+ # Extract expert if present
354
+ expert = None
355
+ if "<|expert:" in text and "|>" in text:
356
+ start = text.find("<|expert:") + 9
357
+ end = text.find("|>", start)
358
+ if end > start:
359
+ expert = text[start:end]
360
+
361
+ return agent, expert
362
+
363
+ # Example usage and testing
364
+ if __name__ == "__main__":
365
+ # Configure logging
366
+ logging.basicConfig(level=logging.INFO)
367
+
368
+ # Example configuration
369
+ config = AgentTokenConfig(
370
+ agent_prefix="<|agent:",
371
+ agent_suffix="|>",
372
+ resize_embeddings=True
373
+ )
374
+
375
+ # Example agents
376
+ agents = ["SWE", "SQE", "DevOps", "Architect", "Security"]
377
+
378
+ # Create agent manager
379
+ manager = AgentTokenManager(config)
380
+
381
+ # Generate tokens
382
+ tokens = manager.generate_agent_tokens(agents)
383
+ print(f"Generated tokens: {tokens}")
384
+
385
+ # Example tokenizer (would be loaded from actual model)
386
+ from transformers import AutoTokenizer
387
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
388
+
389
+ # Add tokens to tokenizer
390
+ updated_tokenizer, agent_tokens = manager.add_agent_tokens_to_tokenizer(tokenizer, agents)
391
+
392
+ print(f"Updated tokenizer vocab size: {len(updated_tokenizer)}")
393
+ print(f"Agent token IDs: {manager.token_ids}")
394
+
395
+ # Test formatting
396
+ test_text = "How do I implement a binary search?"
397
+ formatted = manager.format_agent_prompt("SWE", test_text)
398
+ print(f"Formatted prompt: {formatted}")
399
+
400
+ # Test extraction
401
+ extracted_agent = manager.extract_agent_from_text(formatted)
402
+ print(f"Extracted agent: {extracted_agent}")
src/multi_agent_training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Multi-agent training package
src/multi_agent_training/core_framework.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core Training Framework for MangoMAS Local
3
+
4
+ This module provides the foundation for specialized training modules,
5
+ allowing for modular training of different cognitive capabilities.
6
+ """
7
+
8
+ import logging
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import torch
14
+ import yaml
15
+
16
+ from .lora_trainer import LoRADistillationTrainer
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class TrainingModuleConfig:
23
+ """Configuration for a specialized training module."""
24
+
25
+ name: str
26
+ module_type: str
27
+ enabled: bool = True
28
+ loss_weight: float = 1.0
29
+ learning_rate: Optional[float] = None
30
+ batch_size: Optional[int] = None
31
+ data_path: Optional[str] = None
32
+ module_config: Dict[str, Any] = field(default_factory=dict)
33
+
34
+
35
+ class SpecializedTrainingModule(ABC):
36
+ """
37
+ Abstract base class for specialized training modules.
38
+ Each cognitive capability (reasoning, memory, etc.) should implement this interface.
39
+ """
40
+
41
+ def __init__(self, config: TrainingModuleConfig, tokenizer):
42
+ """
43
+ Initialize the specialized training module.
44
+
45
+ Args:
46
+ config: Module configuration
47
+ tokenizer: Tokenizer for text processing
48
+ """
49
+ self.config = config
50
+ self.tokenizer = tokenizer
51
+ self.name = config.name
52
+ self.enabled = config.enabled
53
+ self.loss_weight = config.loss_weight
54
+ self.device = torch.device(
55
+ "cuda"
56
+ if torch.cuda.is_available()
57
+ else "mps" if torch.backends.mps.is_available() else "cpu"
58
+ )
59
+
60
+ logger.info(f"Initialized {self.name} training module")
61
+ logger.info(f"Module config: {self.config}")
62
+
63
+ @abstractmethod
64
+ def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
65
+ """
66
+ Prepare a batch of data for this specific training module.
67
+
68
+ Args:
69
+ batch: The input batch from the dataloader
70
+
71
+ Returns:
72
+ Processed batch ready for the module
73
+ """
74
+ pass
75
+
76
+ @abstractmethod
77
+ def compute_loss(
78
+ self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
79
+ ) -> torch.Tensor:
80
+ """
81
+ Compute the specialized loss for this module.
82
+
83
+ Args:
84
+ student_outputs: Outputs from the student model
85
+ teacher_outputs: Outputs from the teacher model
86
+ batch: The processed input batch
87
+
88
+ Returns:
89
+ Loss tensor for this module
90
+ """
91
+ pass
92
+
93
+ @abstractmethod
94
+ def get_metrics(self) -> Dict[str, float]:
95
+ """
96
+ Get metrics specific to this training module.
97
+
98
+ Returns:
99
+ Dictionary of metric names and values
100
+ """
101
+ pass
102
+
103
+
104
+ class ModularTrainingManager:
105
+ """
106
+ Training manager that orchestrates multiple specialized training modules.
107
+ """
108
+
109
+ def __init__(self, config_path: str):
110
+ """
111
+ Initialize the modular training manager.
112
+
113
+ Args:
114
+ config_path: Path to the training configuration file
115
+ """
116
+ with open(config_path, "r") as f:
117
+ self.config = yaml.safe_load(f)
118
+
119
+ # Set up core components
120
+ self.base_trainer = LoRADistillationTrainer(config_path)
121
+ self.tokenizer = self.base_trainer.tokenizer
122
+ self.student_model = self.base_trainer.student_model
123
+ self.teacher_model = self.base_trainer.teacher_manager.model
124
+
125
+ # Initialize modules
126
+ self.modules = self._initialize_modules()
127
+
128
+ logger.info(
129
+ f"Initialized ModularTrainingManager with {len(self.modules)} modules"
130
+ )
131
+
132
+ def _initialize_modules(self) -> List[SpecializedTrainingModule]:
133
+ """
134
+ Initialize all specialized training modules based on configuration.
135
+
136
+ Returns:
137
+ List of initialized training modules
138
+ """
139
+ modules = []
140
+ module_configs = self.config.get("specialized_modules", [])
141
+
142
+ for module_config in module_configs:
143
+ if not module_config.get("enabled", True):
144
+ logger.info(f"Skipping disabled module: {module_config.get('name')}")
145
+ continue
146
+
147
+ try:
148
+ # Convert to proper config object
149
+ config_obj = TrainingModuleConfig(**module_config)
150
+
151
+ # Import the module dynamically
152
+ module_type = config_obj.module_type
153
+ module_class = self._import_module_class(module_type)
154
+
155
+ # Initialize the module
156
+ module = module_class(config_obj, self.tokenizer)
157
+ modules.append(module)
158
+
159
+ logger.info(f"Successfully loaded module: {config_obj.name}")
160
+ except Exception as e:
161
+ logger.error(
162
+ f"Failed to load module {module_config.get('name')}: {str(e)}"
163
+ )
164
+
165
+ return modules
166
+
167
+ def _import_module_class(self, module_type: str) -> type:
168
+ """
169
+ Dynamically import a module class based on its type.
170
+
171
+ Args:
172
+ module_type: The module type identifier
173
+
174
+ Returns:
175
+ The module class
176
+ """
177
+ if module_type == "reasoning":
178
+ from .specialized.reasoning_module import ReasoningTrainingModule
179
+
180
+ return ReasoningTrainingModule
181
+ elif module_type == "memory":
182
+ from .specialized.memory_module import MemoryTrainingModule
183
+
184
+ return MemoryTrainingModule
185
+ elif module_type == "ethics":
186
+ from .specialized.ethics_module import EthicsTrainingModule
187
+
188
+ return EthicsTrainingModule
189
+ elif module_type == "empathy":
190
+ from .specialized.empathy_module import EmpathyTrainingModule
191
+
192
+ return EmpathyTrainingModule
193
+ elif module_type == "curiosity":
194
+ from .specialized.curiosity_module import CuriosityTrainingModule
195
+
196
+ return CuriosityTrainingModule
197
+ else:
198
+ raise ValueError(f"Unknown module type: {module_type}")
199
+
200
+ def train(self, agent_type: str = None) -> Dict[str, Any]:
201
+ """
202
+ Train the model using all enabled specialized modules.
203
+
204
+ Args:
205
+ agent_type: Optional agent type for specialized training
206
+
207
+ Returns:
208
+ Training metrics and results
209
+ """
210
+ # Delegate to base trainer for core training functionality
211
+ # but integrate specialized module losses
212
+ logger.info(f"Starting modular training for agent: {agent_type or 'all'}")
213
+
214
+ # TODO: Implement the full training loop integrating all modules
215
+
216
+ # This is a placeholder until we implement the full integration
217
+ return self.base_trainer.train(agent_type)
218
+
219
+ def evaluate(self, agent_type: str = None) -> Dict[str, Any]:
220
+ """
221
+ Evaluate the model using all enabled specialized modules.
222
+
223
+ Args:
224
+ agent_type: Optional agent type for specialized evaluation
225
+
226
+ Returns:
227
+ Evaluation metrics and results
228
+ """
229
+ # TODO: Implement evaluation using specialized modules
230
+
231
+ # This is a placeholder until we implement the full integration
232
+ return self.base_trainer.evaluate(agent_type)
src/multi_agent_training/data_processor.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Processor for MangoMAS Local Training
3
+
4
+ This module processes the original MangoMAS datasets from JSONL format
5
+ into training-ready datasets with proper splits and preprocessing.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Dict, List, Tuple
13
+
14
+ import yaml
15
+ from sklearn.model_selection import train_test_split
16
+ from tqdm import tqdm
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class MangoMASDataProcessor:
22
+ """Process MangoMAS datasets for local training."""
23
+
24
+ def __init__(
25
+ self,
26
+ input_dir,
27
+ output_dir=None,
28
+ min_length: int = 10,
29
+ max_length: int = 2048,
30
+ config_path: str = None,
31
+ ):
32
+ """Initialize with input/output directories or config path for flexibility."""
33
+ # Support both test interface and config-driven approach
34
+ if config_path is not None:
35
+ # Config-driven initialization (original functionality)
36
+ with open(config_path, "r") as f:
37
+ self.config = yaml.safe_load(f)
38
+ self.data_config = self.config["data"]
39
+ self.agents_config = self.config["agents"]
40
+ self.input_dir = Path(input_dir) if input_dir else None
41
+ self.output_dir = Path(output_dir) if output_dir else Path("/Volumes/Mango_MAS/data/processed")
42
+ self.min_length = self.data_config.get("preprocessing", {}).get(
43
+ "min_length", min_length
44
+ )
45
+ self.max_length = self.data_config.get("preprocessing", {}).get(
46
+ "max_length", max_length
47
+ )
48
+ else:
49
+ # Direct initialization (test interface)
50
+ self.input_dir = Path(input_dir)
51
+ self.output_dir = Path(output_dir) if output_dir else Path("/Volumes/Mango_MAS/data/processed")
52
+ self.min_length = min_length
53
+ self.max_length = max_length
54
+ self.config = None
55
+ self.data_config = None
56
+ self.agents_config = None
57
+
58
+ logging.basicConfig(level=logging.INFO)
59
+
60
+ def process_datasets(
61
+ self, input_dir: str, output_dir: str = "/Volumes/Mango_MAS/data/processed"
62
+ ) -> None:
63
+ """
64
+ Process all agent datasets from input directory.
65
+
66
+ Args:
67
+ input_dir: Directory containing original JSONL files
68
+ output_dir: Directory to save processed datasets
69
+ """
70
+ input_path = Path(input_dir)
71
+ output_path = Path(output_dir)
72
+ output_path.mkdir(parents=True, exist_ok=True)
73
+
74
+ # Define dataset mappings
75
+ datasets = {
76
+ "infrastructure": input_path
77
+ / "infrastructure_agent_synthetic_prompts.jsonl",
78
+ "devsecops": input_path / "devsecops_agent_synthetic_prompts.jsonl",
79
+ "risk_assessment": input_path
80
+ / "risk_assessment_agent_synthetic_prompts.jsonl",
81
+ }
82
+
83
+ for agent_type, file_path in datasets.items():
84
+ if file_path.exists():
85
+ logger.info(f"Processing {agent_type} dataset from {file_path}")
86
+ self._process_single_dataset(file_path, output_path, agent_type)
87
+ else:
88
+ logger.warning(f"Dataset file not found: {file_path}")
89
+
90
+ def _process_single_dataset(
91
+ self, input_file: Path, output_dir: Path, agent_type: str
92
+ ) -> None:
93
+ """Process a single agent dataset."""
94
+ # Load data
95
+ data = self._load_jsonl(input_file)
96
+ logger.info(f"Loaded {len(data)} samples for {agent_type}")
97
+
98
+ # Clean and preprocess
99
+ cleaned_data = self._clean_data(data)
100
+ logger.info(f"After cleaning: {len(cleaned_data)} samples")
101
+
102
+ # Convert to training format
103
+ training_data = self._convert_to_training_format(cleaned_data, agent_type)
104
+
105
+ # Create splits
106
+ train_data, val_data, test_data = self._create_splits(training_data)
107
+
108
+ # Save processed datasets
109
+ self._save_datasets(train_data, val_data, test_data, output_dir, agent_type)
110
+
111
+ logger.info(
112
+ f"Saved {agent_type} dataset: "
113
+ f"{len(train_data)} train, {len(val_data)} val, {len(test_data)} test"
114
+ )
115
+
116
+ def _load_jsonl(self, file_path: Path) -> List[Dict]:
117
+ """Load data from JSONL file."""
118
+ data = []
119
+ with open(file_path, "r", encoding="utf-8") as f:
120
+ for line_num, line in enumerate(f, 1):
121
+ try:
122
+ data.append(json.loads(line.strip()))
123
+ except json.JSONDecodeError as e:
124
+ logger.warning(f"Skipping invalid JSON on line {line_num}: {e}")
125
+ return data
126
+
127
+ def _clean_data(self, data: List[Dict]) -> List[Dict]:
128
+ """Clean and validate the data."""
129
+ cleaned = []
130
+
131
+ for item in tqdm(data, desc="Cleaning data"):
132
+ # Check required fields
133
+ if not all(key in item for key in ["instruction", "input", "output"]):
134
+ continue
135
+
136
+ # Check text lengths
137
+ input_text = f"{item['instruction']} {item['input']}"
138
+ output_text = item["output"]
139
+
140
+ if (
141
+ len(input_text) < self.data_config["preprocessing"]["min_length"]
142
+ or len(input_text) > self.data_config["preprocessing"]["max_length"]
143
+ ):
144
+ continue
145
+
146
+ if (
147
+ len(output_text) < self.data_config["preprocessing"]["min_length"]
148
+ or len(output_text) > self.data_config["preprocessing"]["max_length"]
149
+ ):
150
+ continue
151
+
152
+ cleaned.append(item)
153
+
154
+ # Remove duplicates if configured
155
+ if self.data_config["preprocessing"]["remove_duplicates"]:
156
+ cleaned = self._remove_duplicates(cleaned)
157
+
158
+ return cleaned
159
+
160
+ def _remove_duplicates(self, data: List[Dict]) -> List[Dict]:
161
+ """Remove duplicate entries based on input text."""
162
+ seen_inputs = set()
163
+ unique_data = []
164
+
165
+ for item in data:
166
+ input_text = f"{item['instruction']} {item['input']}"
167
+ if input_text not in seen_inputs:
168
+ seen_inputs.add(input_text)
169
+ unique_data.append(item)
170
+
171
+ logger.info(f"Removed {len(data) - len(unique_data)} duplicates")
172
+ return unique_data
173
+
174
+ def _validate_sample(self, sample: Dict) -> bool:
175
+ """Validate a single sample for required fields and length constraints."""
176
+ # Check required fields
177
+ required_fields = ["instruction", "input", "output", "agent_type"]
178
+ if not all(key in sample for key in required_fields):
179
+ return False
180
+
181
+ # Check text lengths
182
+ combined_text = f"{sample['instruction']} {sample['input']} {sample['output']}"
183
+ if len(combined_text) < self.min_length or len(combined_text) > self.max_length:
184
+ return False
185
+
186
+ return True
187
+
188
+ def _clean_text(self, text: str) -> str:
189
+ """Clean text by normalizing whitespace and removing extra spaces."""
190
+ import re
191
+
192
+ # Remove extra whitespace and normalize
193
+ cleaned = re.sub(r"\s+", " ", text.strip())
194
+ return cleaned
195
+
196
+ def _format_conversation(self, sample: Dict) -> Dict:
197
+ """Format sample into conversation format suitable for training."""
198
+ # Create conversation text
199
+ if sample.get("input", "").strip():
200
+ conversation_text = f"Human: {sample['instruction']}\n{sample['input']}\n\nAssistant: {sample['output']}"
201
+ else:
202
+ conversation_text = (
203
+ f"Human: {sample['instruction']}\n\nAssistant: {sample['output']}"
204
+ )
205
+
206
+ return {
207
+ "text": conversation_text,
208
+ "agent_type": sample["agent_type"],
209
+ "instruction": sample["instruction"],
210
+ "input": sample["input"],
211
+ "output": sample["output"],
212
+ }
213
+
214
+ def _split_dataset(
215
+ self,
216
+ data: List[Dict],
217
+ train_ratio: float = 0.8,
218
+ val_ratio: float = 0.1,
219
+ test_ratio: float = 0.1,
220
+ ) -> Tuple[List[Dict], List[Dict], List[Dict]]:
221
+ """Split dataset into train/validation/test sets."""
222
+ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
223
+ raise ValueError(
224
+ f"Split ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}"
225
+ )
226
+
227
+ if not data:
228
+ return [], [], []
229
+
230
+ # Use sklearn for consistent splitting
231
+ from sklearn.model_selection import train_test_split
232
+
233
+ # First split: train vs (val + test)
234
+ if len(data) == 1:
235
+ return data, [], []
236
+
237
+ train_data, temp_data = train_test_split(
238
+ data, test_size=(val_ratio + test_ratio), random_state=42, shuffle=True
239
+ )
240
+
241
+ # Second split: val vs test
242
+ if temp_data and val_ratio > 0 and test_ratio > 0:
243
+ val_ratio_normalized = val_ratio / (val_ratio + test_ratio)
244
+ val_data, test_data = train_test_split(
245
+ temp_data,
246
+ test_size=(1 - val_ratio_normalized),
247
+ random_state=42,
248
+ shuffle=True,
249
+ )
250
+ elif val_ratio > 0:
251
+ val_data, test_data = temp_data, []
252
+ else:
253
+ val_data, test_data = [], temp_data
254
+
255
+ return train_data, val_data, test_data
256
+
257
+ def _calculate_stats(self, data: List[Dict]) -> Dict:
258
+ """Calculate statistics for the dataset."""
259
+ if not data:
260
+ return {
261
+ "total_samples": 0,
262
+ "avg_length": 0,
263
+ "min_length": 0,
264
+ "max_length": 0,
265
+ "agent_distribution": {},
266
+ }
267
+
268
+ lengths = [len(item.get("text", "")) for item in data]
269
+ agent_counts = {}
270
+
271
+ for item in data:
272
+ agent = item.get("agent_type", "unknown")
273
+ agent_counts[agent] = agent_counts.get(agent, 0) + 1
274
+
275
+ return {
276
+ "total_samples": len(data),
277
+ "avg_length": sum(lengths) / len(lengths),
278
+ "min_length": min(lengths),
279
+ "max_length": max(lengths),
280
+ "agent_distribution": agent_counts,
281
+ }
282
+
283
+ def _load_agent_data(self, agent_type: str) -> List[Dict]:
284
+ """Load data for a specific agent type."""
285
+ if not self.input_dir:
286
+ return []
287
+
288
+ # Look for files matching the agent type. We intentionally call glob even
289
+ # if the directory may not exist in test environments, since tests patch
290
+ # pathlib.Path.glob.
291
+ pattern = f"*{agent_type}*.jsonl"
292
+ matching_files = list(self.input_dir.glob(pattern))
293
+
294
+ data = []
295
+ for file_path in matching_files:
296
+ file_data = self._load_jsonl(file_path)
297
+ data.extend(file_data)
298
+
299
+ return data
300
+
301
+ def _save_jsonl(self, data: List[Dict], output_path: Path) -> None:
302
+ """Save data to JSONL file."""
303
+ output_path.parent.mkdir(parents=True, exist_ok=True)
304
+ with open(output_path, "w", encoding="utf-8") as f:
305
+ for item in data:
306
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
307
+
308
+ def _save_stats(self, stats: Dict, output_path: Path) -> None:
309
+ """Save statistics to JSON file."""
310
+ output_path.parent.mkdir(parents=True, exist_ok=True)
311
+ with open(output_path, "w", encoding="utf-8") as f:
312
+ json.dump(stats, f, indent=2, ensure_ascii=False)
313
+
314
+ def process_agent(
315
+ self,
316
+ agent_type: str,
317
+ train_ratio: float = 0.8,
318
+ val_ratio: float = 0.1,
319
+ test_ratio: float = 0.1,
320
+ ) -> None:
321
+ """Process data for a single agent type."""
322
+ # Load data
323
+ data = self._load_agent_data(agent_type)
324
+
325
+ if not data:
326
+ raise ValueError(f"No valid data found for agent type: {agent_type}")
327
+
328
+ # Validate and clean data
329
+ valid_data = []
330
+ for sample in data:
331
+ if self._validate_sample(sample):
332
+ formatted = self._format_conversation(sample)
333
+ valid_data.append(formatted)
334
+
335
+ if not valid_data:
336
+ raise ValueError(
337
+ f"No valid data found after processing for agent type: {agent_type}"
338
+ )
339
+
340
+ # Remove duplicates
341
+ unique_data = self._remove_duplicates(valid_data)
342
+
343
+ # Split dataset
344
+ train_data, val_data, test_data = self._split_dataset(
345
+ unique_data, train_ratio, val_ratio, test_ratio
346
+ )
347
+
348
+ # Save datasets
349
+ self.output_dir.mkdir(parents=True, exist_ok=True)
350
+ self._save_jsonl(train_data, self.output_dir / f"{agent_type}_train.jsonl")
351
+ self._save_jsonl(val_data, self.output_dir / f"{agent_type}_val.jsonl")
352
+ self._save_jsonl(test_data, self.output_dir / f"{agent_type}_test.jsonl")
353
+
354
+ # Save statistics
355
+ stats = self._calculate_stats(unique_data)
356
+ self._save_stats(stats, self.output_dir / f"{agent_type}_stats.json")
357
+
358
+ logger.info(
359
+ f"Processed {agent_type}: {len(train_data)} train, {len(val_data)} val, {len(test_data)} test samples"
360
+ )
361
+
362
+ def _convert_to_training_format(
363
+ self, data: List[Dict], agent_type: str
364
+ ) -> List[Dict]:
365
+ """Convert to format suitable for training."""
366
+ training_data = []
367
+
368
+ for item in data:
369
+ # Create conversation format suitable for language modeling
370
+ conversation = {
371
+ "messages": [
372
+ {
373
+ "role": "system",
374
+ "content": f"You are a {agent_type.replace('_', ' ')} specialist. "
375
+ f"Provide expert recommendations and analysis.",
376
+ },
377
+ {
378
+ "role": "user",
379
+ "content": f"{item['instruction']}\n\n{item['input']}",
380
+ },
381
+ {"role": "assistant", "content": item["output"]},
382
+ ],
383
+ "metadata": item.get("metadata", {}),
384
+ "agent_type": agent_type,
385
+ }
386
+ training_data.append(conversation)
387
+
388
+ return training_data
389
+
390
+ def _create_splits(
391
+ self, data: List[Dict]
392
+ ) -> Tuple[List[Dict], List[Dict], List[Dict]]:
393
+ """Create train/validation/test splits."""
394
+ train_size = self.data_config["train_split"]
395
+ val_size = self.data_config["validation_split"]
396
+ test_size = self.data_config["test_split"]
397
+
398
+ # Normalize splits to sum to 1
399
+ total = train_size + val_size + test_size
400
+ train_size /= total
401
+ val_size /= total
402
+ test_size /= total
403
+
404
+ # First split: train vs (val + test)
405
+ train_data, temp_data = train_test_split(
406
+ data, test_size=(val_size + test_size), random_state=42, shuffle=True
407
+ )
408
+
409
+ # Second split: val vs test
410
+ val_ratio = val_size / (val_size + test_size)
411
+ val_data, test_data = train_test_split(
412
+ temp_data, test_size=(1 - val_ratio), random_state=42, shuffle=True
413
+ )
414
+
415
+ return train_data, val_data, test_data
416
+
417
+ def _save_datasets(
418
+ self,
419
+ train_data: List[Dict],
420
+ val_data: List[Dict],
421
+ test_data: List[Dict],
422
+ output_dir: Path,
423
+ agent_type: str,
424
+ ) -> None:
425
+ """Save processed datasets to files."""
426
+ datasets = {"train": train_data, "validation": val_data, "test": test_data}
427
+
428
+ for split_name, split_data in datasets.items():
429
+ output_file = output_dir / f"{agent_type}_{split_name}.jsonl"
430
+
431
+ with open(output_file, "w", encoding="utf-8") as f:
432
+ for item in split_data:
433
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
434
+
435
+ logger.info(f"Saved {len(split_data)} samples to {output_file}")
436
+
437
+ def create_combined_dataset(self, output_dir: str = "/Volumes/Mango_MAS/data/processed") -> None:
438
+ """Create combined dataset with all agent types for multi-task training."""
439
+ output_path = Path(output_dir)
440
+
441
+ # Collect all processed data
442
+ all_train_data = []
443
+ all_val_data = []
444
+ all_test_data = []
445
+
446
+ for agent_type in self.agents_config.keys():
447
+ for split in ["train", "validation", "test"]:
448
+ file_path = output_path / f"{agent_type}_{split}.jsonl"
449
+ if file_path.exists():
450
+ data = self._load_jsonl(file_path)
451
+
452
+ if split == "train":
453
+ all_train_data.extend(data)
454
+ elif split == "validation":
455
+ all_val_data.extend(data)
456
+ else:
457
+ all_test_data.extend(data)
458
+
459
+ # Shuffle combined datasets
460
+ import random
461
+
462
+ random.seed(42)
463
+ random.shuffle(all_train_data)
464
+ random.shuffle(all_val_data)
465
+ random.shuffle(all_test_data)
466
+
467
+ # Save combined datasets
468
+ combined_datasets = {
469
+ "train": all_train_data,
470
+ "validation": all_val_data,
471
+ "test": all_test_data,
472
+ }
473
+
474
+ for split_name, split_data in combined_datasets.items():
475
+ output_file = output_path / f"combined_{split_name}.jsonl"
476
+
477
+ with open(output_file, "w", encoding="utf-8") as f:
478
+ for item in split_data:
479
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
480
+
481
+ logger.info(
482
+ f"Saved combined {split_name} dataset: {len(split_data)} samples"
483
+ )
484
+
485
+ def generate_statistics(self, output_dir: str = "/Volumes/Mango_MAS/data/processed") -> Dict:
486
+ """Generate statistics about the processed datasets."""
487
+ output_path = Path(output_dir)
488
+ stats = {}
489
+
490
+ for agent_type in list(self.agents_config.keys()) + ["combined"]:
491
+ agent_stats = {}
492
+
493
+ for split in ["train", "validation", "test"]:
494
+ file_path = output_path / f"{agent_type}_{split}.jsonl"
495
+
496
+ if file_path.exists():
497
+ data = self._load_jsonl(file_path)
498
+
499
+ # Calculate statistics
500
+ lengths = []
501
+ for item in data:
502
+ if "messages" in item:
503
+ # Calculate total text length
504
+ total_length = sum(
505
+ len(msg["content"]) for msg in item["messages"]
506
+ )
507
+ lengths.append(total_length)
508
+
509
+ agent_stats[split] = {
510
+ "count": len(data),
511
+ "avg_length": sum(lengths) / len(lengths) if lengths else 0,
512
+ "min_length": min(lengths) if lengths else 0,
513
+ "max_length": max(lengths) if lengths else 0,
514
+ }
515
+
516
+ stats[agent_type] = agent_stats
517
+
518
+ # Save statistics
519
+ stats_file = output_path / "dataset_statistics.json"
520
+ with open(stats_file, "w") as f:
521
+ json.dump(stats, f, indent=2)
522
+
523
+ logger.info(f"Generated dataset statistics: {stats_file}")
524
+ return stats
525
+
526
+
527
+ def main():
528
+ parser = argparse.ArgumentParser(
529
+ description="Process MangoMAS datasets for local training"
530
+ )
531
+ parser.add_argument(
532
+ "--input_dir",
533
+ type=str,
534
+ default="/Users/iancruickshank/Documents/Model/mangomas-datasets/agents/",
535
+ help="Directory containing original JSONL files",
536
+ )
537
+ parser.add_argument(
538
+ "--output_dir",
539
+ type=str,
540
+ default="/Volumes/Mango_MAS/data/processed",
541
+ help="Directory to save processed datasets",
542
+ )
543
+ parser.add_argument(
544
+ "--config",
545
+ type=str,
546
+ default="config/training/distillation.yaml",
547
+ help="Path to configuration file",
548
+ )
549
+ parser.add_argument(
550
+ "--create_combined",
551
+ action="store_true",
552
+ help="Create combined multi-agent dataset",
553
+ )
554
+
555
+ args = parser.parse_args()
556
+
557
+ # Initialize processor
558
+ processor = MangoMASDataProcessor(args.config)
559
+
560
+ # Process datasets
561
+ processor.process_datasets(args.input_dir, args.output_dir)
562
+
563
+ # Create combined dataset if requested
564
+ if args.create_combined:
565
+ processor.create_combined_dataset(args.output_dir)
566
+
567
+ # Generate statistics
568
+ stats = processor.generate_statistics(args.output_dir)
569
+
570
+ print("\nDataset Statistics:")
571
+ print("=" * 50)
572
+ for agent_type, agent_stats in stats.items():
573
+ print(f"\n{agent_type.upper()}:")
574
+ for split, split_stats in agent_stats.items():
575
+ print(
576
+ f" {split}: {split_stats['count']} samples, "
577
+ f"avg length: {split_stats['avg_length']:.0f} chars"
578
+ )
579
+
580
+
581
+ if __name__ == "__main__":
582
+ main()
src/multi_agent_training/distillation_loss.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Distillation Loss Implementation for MangoMAS Local
3
+
4
+ This module implements custom loss functions for knowledge distillation,
5
+ balancing task-specific loss with knowledge transfer from teacher models.
6
+ """
7
+
8
+ import logging
9
+ from typing import Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class DistillationLoss:
18
+ """
19
+ Custom loss function for knowledge distillation combining task loss
20
+ and distillation loss with configurable temperature and alpha parameters.
21
+ """
22
+
23
+ def __init__(self, alpha: float = 0.5, temperature: float = 3.0):
24
+ """
25
+ Initialize distillation loss with parameters.
26
+
27
+ Args:
28
+ alpha: Balance between task loss and distillation loss (0.0-1.0)
29
+ temperature: Temperature for softmax in knowledge distillation
30
+ """
31
+ self.alpha = alpha
32
+ self.temperature = temperature
33
+
34
+ if not 0.0 <= alpha <= 1.0:
35
+ raise ValueError(f"Alpha must be between 0 and 1, got {alpha}")
36
+ if temperature <= 0.0:
37
+ raise ValueError(f"Temperature must be positive, got {temperature}")
38
+
39
+ # Add required loss attributes for tests
40
+ import torch.nn as nn
41
+
42
+ self.task_loss = nn.CrossEntropyLoss(ignore_index=-100)
43
+ self.kl_loss = nn.KLDivLoss(reduction="batchmean")
44
+
45
+ logger.info(
46
+ f"Initialized DistillationLoss with alpha={alpha}, temperature={temperature}"
47
+ )
48
+
49
+ def compute_loss(
50
+ self,
51
+ student_logits: torch.Tensor,
52
+ teacher_logits: torch.Tensor,
53
+ labels: torch.Tensor,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ ) -> Tuple[torch.Tensor, dict]:
56
+ """
57
+ Compute the combined distillation loss.
58
+
59
+ Args:
60
+ student_logits: Logits from student model [batch_size, seq_len, vocab_size]
61
+ teacher_logits: Logits from teacher model [batch_size, seq_len, vocab_size]
62
+ labels: Target labels [batch_size, seq_len]
63
+ attention_mask: Attention mask for padding tokens [batch_size, seq_len]
64
+
65
+ Returns:
66
+ Tuple of (total_loss, loss_dict) where loss_dict contains individual losses
67
+ """
68
+ # Task-specific loss (standard cross-entropy)
69
+ task_loss = self._compute_task_loss(student_logits, labels, attention_mask)
70
+
71
+ # Knowledge distillation loss
72
+ distill_loss = self._compute_distillation_loss(
73
+ student_logits, teacher_logits, attention_mask
74
+ )
75
+
76
+ # Combined loss
77
+ # NOTE: alpha is treated as the weight for the distillation loss to match
78
+ # unit-test expectations (alpha=1.0 => pure distillation loss).
79
+ total_loss = self.alpha * distill_loss + (1.0 - self.alpha) * task_loss
80
+
81
+ loss_dict = {
82
+ "total_loss": total_loss.item(),
83
+ "task_loss": task_loss.item(),
84
+ "distillation_loss": distill_loss.item(),
85
+ "alpha": self.alpha,
86
+ "temperature": self.temperature,
87
+ }
88
+
89
+ return total_loss, loss_dict
90
+
91
+ def _compute_task_loss(
92
+ self,
93
+ logits: torch.Tensor,
94
+ labels: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ ) -> torch.Tensor:
97
+ """Compute standard cross-entropy task loss."""
98
+ # Reshape for cross entropy: [batch_size * seq_len, vocab_size]
99
+ shift_logits = logits[..., :-1, :].contiguous()
100
+ shift_labels = labels[..., 1:].contiguous()
101
+
102
+ # Flatten the tokens
103
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
104
+ flat_labels = shift_labels.view(-1)
105
+
106
+ # Apply attention mask if provided
107
+ if attention_mask is not None:
108
+ shift_mask = attention_mask[..., 1:].contiguous()
109
+ flat_mask = shift_mask.view(-1)
110
+ # Only compute loss for non-padded tokens
111
+ valid_indices = flat_mask.bool()
112
+ if valid_indices.sum() == 0:
113
+ return torch.tensor(0.0, device=logits.device)
114
+ flat_logits = flat_logits[valid_indices]
115
+ flat_labels = flat_labels[valid_indices]
116
+
117
+ # Compute cross entropy loss, ignoring padding tokens
118
+ task_loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100)
119
+
120
+ return task_loss
121
+
122
+ def _compute_distillation_loss(
123
+ self,
124
+ student_logits: torch.Tensor,
125
+ teacher_logits: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ ) -> torch.Tensor:
128
+ eps = 1e-8
129
+ # Apply temperature scaling
130
+ student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
131
+ teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
132
+ # Optionally mask out padding tokens
133
+ if attention_mask is not None:
134
+ mask = attention_mask[..., 1:].contiguous().view(-1)
135
+ student_log_probs = (
136
+ student_log_probs[..., :-1, :]
137
+ .contiguous()
138
+ .view(-1, student_log_probs.size(-1))[mask.bool()]
139
+ )
140
+ teacher_log_probs = (
141
+ teacher_log_probs[..., :-1, :]
142
+ .contiguous()
143
+ .view(-1, teacher_log_probs.size(-1))[mask.bool()]
144
+ )
145
+ else:
146
+ student_log_probs = (
147
+ student_log_probs[..., :-1, :]
148
+ .contiguous()
149
+ .view(-1, student_log_probs.size(-1))
150
+ )
151
+ teacher_log_probs = (
152
+ teacher_log_probs[..., :-1, :]
153
+ .contiguous()
154
+ .view(-1, teacher_log_probs.size(-1))
155
+ )
156
+ # KLDivLoss expects log-probabilities for both input and target if log_target=True
157
+ if student_log_probs.shape[0] == 0:
158
+ return torch.tensor(0.0, device=student_logits.device)
159
+ distill_loss = F.kl_div(
160
+ student_log_probs, teacher_log_probs, reduction="batchmean", log_target=True
161
+ ) * (self.temperature**2)
162
+ return distill_loss
163
+
164
+ def update_alpha(self, new_alpha: float) -> None:
165
+ """Update the alpha parameter during training."""
166
+ if not 0.0 <= new_alpha <= 1.0:
167
+ raise ValueError(f"Alpha must be between 0.0 and 1.0, got {new_alpha}")
168
+ self.alpha = new_alpha
169
+ logger.info(f"Updated alpha to {new_alpha}")
170
+
171
+ def update_temperature(self, new_temperature: float) -> None:
172
+ """Update the temperature parameter during training."""
173
+ if new_temperature <= 0.0:
174
+ raise ValueError(f"Temperature must be positive, got {new_temperature}")
175
+ self.temperature = new_temperature
176
+ logger.info(f"Updated temperature to {new_temperature}")
177
+
178
+ def __call__(self, student_logits, teacher_logits, labels, attention_mask=None):
179
+ total_loss, _ = self.compute_loss(
180
+ student_logits, teacher_logits, labels, attention_mask
181
+ )
182
+ return total_loss
183
+
184
+
185
+ class AdaptiveDistillationLoss(DistillationLoss):
186
+ """
187
+ Adaptive distillation loss that adjusts alpha based on training progress.
188
+ Starts with more focus on distillation, gradually shifting to task loss.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ initial_alpha: float = 0.5,
194
+ final_alpha: float = 0.1,
195
+ temperature: float = 3.0,
196
+ warmup_steps: int = 1000,
197
+ ):
198
+ if not 0.0 <= initial_alpha <= 1.0:
199
+ raise ValueError("initial_alpha must be between 0 and 1")
200
+ if not 0.0 <= final_alpha <= 1.0:
201
+ raise ValueError("final_alpha must be between 0 and 1")
202
+ if initial_alpha < final_alpha:
203
+ raise ValueError(
204
+ f"initial_alpha must be >= final_alpha, got {initial_alpha} < {final_alpha}"
205
+ )
206
+ super().__init__(alpha=initial_alpha, temperature=temperature)
207
+ self.initial_alpha = initial_alpha
208
+ self.final_alpha = final_alpha
209
+ self.current_alpha = initial_alpha
210
+ self.warmup_steps = warmup_steps
211
+ self.current_step = 0
212
+
213
+ def update_alpha(self, current_epoch: int, total_epochs: int):
214
+ """Update alpha based on current epoch and total epochs."""
215
+ if total_epochs <= 0:
216
+ raise ValueError("total_epochs must be positive")
217
+ if current_epoch < 0:
218
+ raise ValueError("current_epoch must be non-negative")
219
+ if current_epoch >= total_epochs:
220
+ self.current_alpha = self.final_alpha
221
+ elif current_epoch <= 0:
222
+ self.current_alpha = self.initial_alpha
223
+ else:
224
+ progress = current_epoch / total_epochs
225
+ self.current_alpha = (
226
+ self.initial_alpha - (self.initial_alpha - self.final_alpha) * progress
227
+ )
228
+ self.alpha = self.current_alpha
229
+
230
+ def get_alpha(self):
231
+ """Return the current alpha value."""
232
+ return self.current_alpha
src/multi_agent_training/gradient_descent/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Descent and Backpropagation Training System
3
+ ==================================================
4
+
5
+ This module provides a comprehensive implementation of gradient descent optimization
6
+ algorithms and backpropagation for neural network training, specifically designed
7
+ for the MangoMAS multi-agent system.
8
+
9
+ Key Components:
10
+ - Optimizers: SGD, Adam, AdamW with proper mathematical implementations
11
+ - Backpropagation: Chain rule-based gradient computation
12
+ - Training Loop: Complete training orchestration with monitoring
13
+ - Loss Functions: Various loss implementations for different tasks
14
+ - Monitoring: Comprehensive gradient and training metrics tracking
15
+
16
+ Usage:
17
+ from src.training.gradient_descent import GradientDescentTrainer
18
+ trainer = GradientDescentTrainer()
19
+ results = trainer.train_agent(agent_spec)
20
+ """
21
+
22
+ from .optimizers import SGD, Adam, AdamW, Optimizer
23
+ from .backpropagation import BackpropagationEngine
24
+ from .training_loop import GradientDescentTrainer
25
+ from .loss_functions import CrossEntropyLoss, KLDivergenceLoss, LossFunction
26
+ from .monitoring import GradientMonitor, TrainingMonitor
27
+ from .model_wrapper import ModelWrapper
28
+ from .schedulers import LinearScheduler, CosineScheduler, StepScheduler
29
+
30
+ __version__ = "1.0.0"
31
+ __author__ = "MangoMAS Team"
32
+
33
+ __all__ = [
34
+ "SGD",
35
+ "Adam",
36
+ "AdamW",
37
+ "Optimizer",
38
+ "BackpropagationEngine",
39
+ "GradientDescentTrainer",
40
+ "CrossEntropyLoss",
41
+ "KLDivergenceLoss",
42
+ "LossFunction",
43
+ "GradientMonitor",
44
+ "TrainingMonitor",
45
+ "ModelWrapper",
46
+ "LinearScheduler",
47
+ "CosineScheduler",
48
+ "StepScheduler"
49
+ ]
src/multi_agent_training/gradient_descent/backpropagation.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Backpropagation Engine
3
+ =====================
4
+
5
+ This module implements the backpropagation algorithm for neural network training,
6
+ including support for transformer architectures and LoRA adapters.
7
+
8
+ The implementation follows the mathematical chain rule for gradient computation
9
+ and includes proper handling of various activation functions and layer types.
10
+ """
11
+
12
+ import logging
13
+ import math
14
+ from typing import Dict, List, Any
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class BackpropagationEngine:
22
+ """
23
+ Backpropagation engine for neural network training
24
+
25
+ Implements the chain rule for gradient computation:
26
+ ∂L/∂w = ∂L/∂y * ∂y/∂w
27
+
28
+ Supports various architectures including transformers and LoRA adapters.
29
+ """
30
+
31
+ def __init__(self, model: nn.Module, device: torch.device = None):
32
+ self.model = model
33
+ self.device = device or torch.device('cpu')
34
+ self.gradient_norms = []
35
+ self.gradient_stats = {}
36
+
37
+ # Move model to device
38
+ self.model.to(self.device)
39
+
40
+ logger.info(f"Initialized backpropagation engine on device: {self.device}")
41
+
42
+ def compute_gradients(self, loss: torch.Tensor, retain_graph: bool = False) -> Dict[str, torch.Tensor]:
43
+ """
44
+ Compute gradients using backpropagation
45
+
46
+ Args:
47
+ loss: The computed loss tensor
48
+ retain_graph: Whether to retain the computational graph
49
+
50
+ Returns:
51
+ Dictionary of parameter gradients
52
+ """
53
+ # Zero existing gradients
54
+ self.model.zero_grad()
55
+
56
+ # Backward pass
57
+ loss.backward(retain_graph=retain_graph)
58
+
59
+ # Collect gradients
60
+ gradients = {}
61
+ for name, param in self.model.named_parameters():
62
+ if param.grad is not None:
63
+ gradients[name] = param.grad.clone()
64
+
65
+ # Compute gradient statistics
66
+ self._compute_gradient_stats(gradients)
67
+
68
+ return gradients
69
+
70
+ def _compute_gradient_stats(self, gradients: Dict[str, torch.Tensor]):
71
+ """Compute gradient statistics for monitoring"""
72
+ total_norm = 0.0
73
+ param_count = 0
74
+
75
+ for name, grad in gradients.items():
76
+ if grad is not None:
77
+ param_norm = grad.data.norm(2)
78
+ total_norm += param_norm.item() ** 2
79
+ param_count += 1
80
+
81
+ # Store individual gradient stats
82
+ self.gradient_stats[name] = {
83
+ 'norm': param_norm.item(),
84
+ 'mean': grad.data.mean().item(),
85
+ 'std': grad.data.std().item(),
86
+ 'max': grad.data.max().item(),
87
+ 'min': grad.data.min().item()
88
+ }
89
+
90
+ # Compute total gradient norm
91
+ total_norm = total_norm ** (1. / 2)
92
+ self.gradient_norms.append(total_norm)
93
+
94
+ logger.debug(f"Total gradient norm: {total_norm:.6f}")
95
+
96
+ def apply_gradient_clipping(self, max_norm: float = 1.0, norm_type: float = 2.0) -> float:
97
+ """
98
+ Apply gradient clipping to prevent exploding gradients
99
+
100
+ Args:
101
+ max_norm: Maximum gradient norm
102
+ norm_type: Type of norm to use (2.0 for L2 norm)
103
+
104
+ Returns:
105
+ Actual gradient norm before clipping
106
+ """
107
+ # Compute total gradient norm
108
+ total_norm = 0.0
109
+ for param in self.model.parameters():
110
+ if param.grad is not None:
111
+ param_norm = param.grad.data.norm(norm_type)
112
+ total_norm += param_norm.item() ** norm_type
113
+
114
+ total_norm = total_norm ** (1. / norm_type)
115
+
116
+ # Apply clipping if necessary
117
+ if total_norm > max_norm:
118
+ clip_coef = max_norm / (total_norm + 1e-6)
119
+ for param in self.model.parameters():
120
+ if param.grad is not None:
121
+ param.grad.data.mul_(clip_coef)
122
+
123
+ logger.debug(f"Gradient clipping applied: {total_norm:.6f} -> {max_norm:.6f}")
124
+
125
+ return total_norm
126
+
127
+ def compute_parameter_updates(self, gradients: Dict[str, torch.Tensor],
128
+ learning_rate: float) -> Dict[str, torch.Tensor]:
129
+ """
130
+ Compute parameter updates from gradients
131
+
132
+ Args:
133
+ gradients: Dictionary of parameter gradients
134
+ learning_rate: Learning rate for updates
135
+
136
+ Returns:
137
+ Dictionary of parameter updates
138
+ """
139
+ updates = {}
140
+
141
+ for name, param in self.model.named_parameters():
142
+ if name in gradients and gradients[name] is not None:
143
+ # Simple gradient descent update
144
+ update = -learning_rate * gradients[name]
145
+ updates[name] = update
146
+
147
+ return updates
148
+
149
+ def apply_parameter_updates(self, updates: Dict[str, torch.Tensor]):
150
+ """
151
+ Apply parameter updates to the model
152
+
153
+ Args:
154
+ updates: Dictionary of parameter updates
155
+ """
156
+ for name, param in self.model.named_parameters():
157
+ if name in updates:
158
+ param.data.add_(updates[name])
159
+
160
+ def compute_hessian_diagonal(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
161
+ """
162
+ Compute diagonal elements of the Hessian matrix for second-order optimization
163
+
164
+ Args:
165
+ loss: The computed loss tensor
166
+
167
+ Returns:
168
+ Dictionary of Hessian diagonal elements
169
+ """
170
+ hessian_diag = {}
171
+
172
+ # First backward pass
173
+ self.model.zero_grad()
174
+ loss.backward(retain_graph=True)
175
+
176
+ # Compute second-order gradients
177
+ for name, param in self.model.named_parameters():
178
+ if param.grad is not None:
179
+ # Compute second derivatives
180
+ grad_grad = torch.autograd.grad(
181
+ param.grad, param, retain_graph=True, create_graph=True
182
+ )[0]
183
+ hessian_diag[name] = grad_grad.diag()
184
+
185
+ return hessian_diag
186
+
187
+ def compute_gradient_flow(self) -> Dict[str, float]:
188
+ """
189
+ Compute gradient flow statistics for monitoring training health
190
+
191
+ Returns:
192
+ Dictionary of gradient flow statistics
193
+ """
194
+ flow_stats = {}
195
+
196
+ for name, param in self.model.named_parameters():
197
+ if param.grad is not None:
198
+ # Compute gradient-to-parameter ratio
199
+ grad_norm = param.grad.data.norm(2)
200
+ param_norm = param.data.norm(2)
201
+
202
+ if param_norm > 0:
203
+ flow_ratio = grad_norm / param_norm
204
+ flow_stats[name] = flow_ratio.item()
205
+
206
+ return flow_stats
207
+
208
+ def detect_gradient_anomalies(self) -> Dict[str, Any]:
209
+ """
210
+ Detect gradient anomalies that might indicate training issues
211
+
212
+ Returns:
213
+ Dictionary of detected anomalies
214
+ """
215
+ anomalies = {
216
+ 'exploding_gradients': False,
217
+ 'vanishing_gradients': False,
218
+ 'gradient_imbalance': False,
219
+ 'nan_gradients': False
220
+ }
221
+
222
+ if not self.gradient_norms:
223
+ return anomalies
224
+
225
+ current_norm = self.gradient_norms[-1]
226
+
227
+ # Check for exploding gradients
228
+ if current_norm > 10.0:
229
+ anomalies['exploding_gradients'] = True
230
+ logger.warning(f"Exploding gradients detected: norm={current_norm:.6f}")
231
+
232
+ # Check for vanishing gradients
233
+ if current_norm < 1e-6:
234
+ anomalies['vanishing_gradients'] = True
235
+ logger.warning(f"Vanishing gradients detected: norm={current_norm:.6f}")
236
+
237
+ # Check for NaN gradients
238
+ for name, stats in self.gradient_stats.items():
239
+ if math.isnan(stats['norm']) or math.isnan(stats['mean']):
240
+ anomalies['nan_gradients'] = True
241
+ logger.warning(f"NaN gradients detected in parameter: {name}")
242
+ break
243
+
244
+ # Check for gradient imbalance
245
+ if len(self.gradient_stats) > 1:
246
+ norms = [stats['norm'] for stats in self.gradient_stats.values()]
247
+ if max(norms) / min(norms) > 1000:
248
+ anomalies['gradient_imbalance'] = True
249
+ logger.warning("Gradient imbalance detected between parameters")
250
+
251
+ return anomalies
252
+
253
+ def get_gradient_summary(self) -> Dict[str, Any]:
254
+ """
255
+ Get comprehensive gradient summary for monitoring
256
+
257
+ Returns:
258
+ Dictionary of gradient summary statistics
259
+ """
260
+ if not self.gradient_norms:
261
+ return {}
262
+
263
+ return {
264
+ 'total_norm': self.gradient_norms[-1],
265
+ 'norm_history': self.gradient_norms[-10:], # Last 10 steps
266
+ 'parameter_stats': self.gradient_stats,
267
+ 'anomalies': self.detect_gradient_anomalies(),
268
+ 'flow_stats': self.compute_gradient_flow()
269
+ }
270
+
271
+
272
+ class LoRABackpropagationEngine(BackpropagationEngine):
273
+ """
274
+ Specialized backpropagation engine for LoRA (Low-Rank Adaptation) models
275
+
276
+ Only computes gradients for LoRA adapter parameters, not the base model.
277
+ """
278
+
279
+ def __init__(self, model: nn.Module, lora_params: List[str], device: torch.device = None):
280
+ super().__init__(model, device)
281
+ self.lora_params = lora_params
282
+
283
+ # Freeze base model parameters
284
+ for name, param in self.model.named_parameters():
285
+ if name not in self.lora_params:
286
+ param.requires_grad = False
287
+
288
+ logger.info(f"Initialized LoRA backpropagation engine with {len(lora_params)} LoRA parameters")
289
+
290
+ def compute_gradients(self, loss: torch.Tensor, retain_graph: bool = False) -> Dict[str, torch.Tensor]:
291
+ """
292
+ Compute gradients only for LoRA parameters
293
+
294
+ Args:
295
+ loss: The computed loss tensor
296
+ retain_graph: Whether to retain the computational graph
297
+
298
+ Returns:
299
+ Dictionary of LoRA parameter gradients
300
+ """
301
+ # Zero existing gradients
302
+ self.model.zero_grad()
303
+
304
+ # Backward pass
305
+ loss.backward(retain_graph=retain_graph)
306
+
307
+ # Collect only LoRA gradients
308
+ gradients = {}
309
+ for name, param in self.model.named_parameters():
310
+ if name in self.lora_params and param.grad is not None:
311
+ gradients[name] = param.grad.clone()
312
+
313
+ # Compute gradient statistics
314
+ self._compute_gradient_stats(gradients)
315
+
316
+ return gradients
317
+
318
+ def get_lora_parameter_info(self) -> Dict[str, Any]:
319
+ """
320
+ Get information about LoRA parameters
321
+
322
+ Returns:
323
+ Dictionary of LoRA parameter information
324
+ """
325
+ info = {}
326
+
327
+ for name, param in self.model.named_parameters():
328
+ if name in self.lora_params:
329
+ info[name] = {
330
+ 'shape': list(param.shape),
331
+ 'numel': param.numel(),
332
+ 'requires_grad': param.requires_grad,
333
+ 'device': str(param.device),
334
+ 'dtype': str(param.dtype)
335
+ }
336
+
337
+ return info
src/multi_agent_training/gradient_descent/loss_functions.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss Functions
3
+ =============
4
+
5
+ This module implements various loss functions for neural network training,
6
+ including cross-entropy, KL divergence, and custom loss functions for
7
+ the MangoMAS multi-agent system.
8
+ """
9
+
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from typing import Dict, Optional, Any
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class LossFunction(ABC):
20
+ """Abstract base class for all loss functions"""
21
+
22
+ def __init__(self, reduction: str = 'mean'):
23
+ self.reduction = reduction
24
+
25
+ @abstractmethod
26
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
27
+ """Compute the loss"""
28
+ pass
29
+
30
+ def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
31
+ """Call the loss function"""
32
+ return self.forward(predictions, targets)
33
+
34
+
35
+ class CrossEntropyLoss(LossFunction):
36
+ """
37
+ Cross-entropy loss for classification tasks
38
+
39
+ Mathematical formulation:
40
+ L = -∑(y_i * log(ŷ_i))
41
+
42
+ Where y_i is the true label and ŷ_i is the predicted probability.
43
+ """
44
+
45
+ def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0,
46
+ weight: Optional[torch.Tensor] = None):
47
+ super().__init__(reduction)
48
+ self.label_smoothing = label_smoothing
49
+ self.weight = weight
50
+
51
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Compute cross-entropy loss
54
+
55
+ Args:
56
+ predictions: Model predictions (logits)
57
+ targets: True labels
58
+
59
+ Returns:
60
+ Computed loss
61
+ """
62
+ # Apply softmax to get probabilities
63
+ if predictions.dim() > 1 and predictions.size(1) > 1:
64
+ # Multi-class classification
65
+ log_probs = F.log_softmax(predictions, dim=1)
66
+
67
+ if targets.dim() == 1:
68
+ # Targets are class indices
69
+ loss = F.nll_loss(log_probs, targets, weight=self.weight,
70
+ reduction=self.reduction, label_smoothing=self.label_smoothing)
71
+ else:
72
+ # Targets are one-hot encoded
73
+ loss = -(targets * log_probs).sum(dim=1)
74
+ if self.reduction == 'mean':
75
+ loss = loss.mean()
76
+ elif self.reduction == 'sum':
77
+ loss = loss.sum()
78
+ else:
79
+ # Binary classification
80
+ loss = F.binary_cross_entropy_with_logits(predictions, targets.float(),
81
+ weight=self.weight, reduction=self.reduction)
82
+
83
+ return loss
84
+
85
+
86
+ class KLDivergenceLoss(LossFunction):
87
+ """
88
+ Kullback-Leibler divergence loss for distribution matching
89
+
90
+ Mathematical formulation:
91
+ KL(P||Q) = ∑ P(x) * log(P(x)/Q(x))
92
+ """
93
+
94
+ def __init__(self, reduction: str = 'mean', log_target: bool = False):
95
+ super().__init__(reduction)
96
+ self.log_target = log_target
97
+
98
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
99
+ """
100
+ Compute KL divergence loss
101
+
102
+ Args:
103
+ predictions: Predicted distribution (log probabilities)
104
+ targets: Target distribution (probabilities or log probabilities)
105
+
106
+ Returns:
107
+ Computed KL divergence loss
108
+ """
109
+ if self.log_target:
110
+ # Both predictions and targets are in log space
111
+ loss = F.kl_div(predictions, targets, reduction=self.reduction, log_target=True)
112
+ else:
113
+ # Convert predictions to log space, targets are probabilities
114
+ log_predictions = F.log_softmax(predictions, dim=-1)
115
+ loss = F.kl_div(log_predictions, targets, reduction=self.reduction, log_target=False)
116
+
117
+ return loss
118
+
119
+
120
+ class MSELoss(LossFunction):
121
+ """
122
+ Mean Squared Error loss for regression tasks
123
+
124
+ Mathematical formulation:
125
+ L = (1/n) * ∑(y_i - ŷ_i)²
126
+ """
127
+
128
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
129
+ """
130
+ Compute MSE loss
131
+
132
+ Args:
133
+ predictions: Model predictions
134
+ targets: True values
135
+
136
+ Returns:
137
+ Computed MSE loss
138
+ """
139
+ loss = F.mse_loss(predictions, targets, reduction=self.reduction)
140
+ return loss
141
+
142
+
143
+ class HuberLoss(LossFunction):
144
+ """
145
+ Huber loss (smooth L1 loss) for robust regression
146
+
147
+ Mathematical formulation:
148
+ L = { 0.5 * (y - ŷ)², if |y - ŷ| < δ
149
+ { δ * (|y - ŷ| - 0.5 * δ), otherwise
150
+ """
151
+
152
+ def __init__(self, reduction: str = 'mean', delta: float = 1.0):
153
+ super().__init__(reduction)
154
+ self.delta = delta
155
+
156
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
157
+ """
158
+ Compute Huber loss
159
+
160
+ Args:
161
+ predictions: Model predictions
162
+ targets: True values
163
+
164
+ Returns:
165
+ Computed Huber loss
166
+ """
167
+ loss = F.huber_loss(predictions, targets, reduction=self.reduction, delta=self.delta)
168
+ return loss
169
+
170
+
171
+ class FocalLoss(LossFunction):
172
+ """
173
+ Focal loss for addressing class imbalance
174
+
175
+ Mathematical formulation:
176
+ FL = -α(1-p_t)^γ * log(p_t)
177
+
178
+ Where p_t is the predicted probability for the true class.
179
+ """
180
+
181
+ def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'):
182
+ super().__init__(reduction)
183
+ self.alpha = alpha
184
+ self.gamma = gamma
185
+
186
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Compute focal loss
189
+
190
+ Args:
191
+ predictions: Model predictions (logits)
192
+ targets: True labels
193
+
194
+ Returns:
195
+ Computed focal loss
196
+ """
197
+ # Compute probabilities
198
+ probs = F.softmax(predictions, dim=1)
199
+
200
+ # Get probabilities for true classes
201
+ if targets.dim() == 1:
202
+ # Targets are class indices
203
+ pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
204
+ else:
205
+ # Targets are one-hot encoded
206
+ pt = (probs * targets).sum(dim=1)
207
+
208
+ # Compute focal loss
209
+ focal_weight = self.alpha * (1 - pt) ** self.gamma
210
+ loss = -focal_weight * torch.log(pt + 1e-8)
211
+
212
+ if self.reduction == 'mean':
213
+ loss = loss.mean()
214
+ elif self.reduction == 'sum':
215
+ loss = loss.sum()
216
+
217
+ return loss
218
+
219
+
220
+ class ContrastiveLoss(LossFunction):
221
+ """
222
+ Contrastive loss for learning representations
223
+
224
+ Mathematical formulation:
225
+ L = (1-y) * d² + y * max(0, margin - d)²
226
+
227
+ Where d is the distance between embeddings and y is the similarity label.
228
+ """
229
+
230
+ def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
231
+ super().__init__(reduction)
232
+ self.margin = margin
233
+
234
+ def forward(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor,
235
+ labels: torch.Tensor) -> torch.Tensor:
236
+ """
237
+ Compute contrastive loss
238
+
239
+ Args:
240
+ embeddings1: First set of embeddings
241
+ embeddings2: Second set of embeddings
242
+ labels: Similarity labels (1 for similar, 0 for dissimilar)
243
+
244
+ Returns:
245
+ Computed contrastive loss
246
+ """
247
+ # Compute Euclidean distance
248
+ distance = F.pairwise_distance(embeddings1, embeddings2)
249
+
250
+ # Compute contrastive loss
251
+ positive_loss = labels.float() * distance.pow(2)
252
+ negative_loss = (1 - labels.float()) * F.relu(self.margin - distance).pow(2)
253
+
254
+ loss = positive_loss + negative_loss
255
+
256
+ if self.reduction == 'mean':
257
+ loss = loss.mean()
258
+ elif self.reduction == 'sum':
259
+ loss = loss.sum()
260
+
261
+ return loss
262
+
263
+
264
+ class CustomAgentLoss(LossFunction):
265
+ """
266
+ Custom loss function for MangoMAS agents
267
+
268
+ Combines multiple loss components to optimize agent performance.
269
+ """
270
+
271
+ def __init__(self, task_loss_weight: float = 1.0,
272
+ consistency_loss_weight: float = 0.1,
273
+ regularization_weight: float = 0.01):
274
+ super().__init__('mean')
275
+ self.task_loss_weight = task_loss_weight
276
+ self.consistency_loss_weight = consistency_loss_weight
277
+ self.regularization_weight = regularization_weight
278
+
279
+ # Initialize component losses
280
+ self.task_loss = CrossEntropyLoss()
281
+ self.consistency_loss = MSELoss()
282
+
283
+ def forward(self, predictions: torch.Tensor, targets: torch.Tensor,
284
+ model_outputs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
285
+ """
286
+ Compute custom agent loss
287
+
288
+ Args:
289
+ predictions: Model predictions
290
+ targets: True labels
291
+ model_outputs: Additional model outputs for consistency loss
292
+
293
+ Returns:
294
+ Computed custom loss
295
+ """
296
+ # Task-specific loss
297
+ task_loss = self.task_loss(predictions, targets)
298
+
299
+ # Consistency loss (if model outputs provided)
300
+ consistency_loss = torch.tensor(0.0, device=predictions.device)
301
+ if model_outputs is not None and 'hidden_states' in model_outputs:
302
+ # Encourage consistent hidden representations
303
+ hidden_states = model_outputs['hidden_states']
304
+ if len(hidden_states) > 1:
305
+ # Compute consistency between consecutive hidden states
306
+ for i in range(len(hidden_states) - 1):
307
+ consistency_loss += self.consistency_loss(hidden_states[i], hidden_states[i+1])
308
+ consistency_loss /= (len(hidden_states) - 1)
309
+
310
+ # Regularization loss (L2 penalty)
311
+ regularization_loss = torch.tensor(0.0, device=predictions.device)
312
+ # This would be computed from model parameters in practice
313
+
314
+ # Combine losses
315
+ total_loss = (self.task_loss_weight * task_loss +
316
+ self.consistency_loss_weight * consistency_loss +
317
+ self.regularization_weight * regularization_loss)
318
+
319
+ return total_loss
320
+
321
+
322
+ class LossFunctionFactory:
323
+ """Factory class for creating loss functions"""
324
+
325
+ @staticmethod
326
+ def create_loss_function(loss_type: str, **kwargs) -> LossFunction:
327
+ """Create a loss function instance"""
328
+ loss_functions = {
329
+ 'cross_entropy': CrossEntropyLoss,
330
+ 'kl_divergence': KLDivergenceLoss,
331
+ 'mse': MSELoss,
332
+ 'huber': HuberLoss,
333
+ 'focal': FocalLoss,
334
+ 'contrastive': ContrastiveLoss,
335
+ 'custom_agent': CustomAgentLoss
336
+ }
337
+
338
+ if loss_type.lower() not in loss_functions:
339
+ raise ValueError(f"Unknown loss function type: {loss_type}")
340
+
341
+ loss_class = loss_functions[loss_type.lower()]
342
+ return loss_class(**kwargs)
343
+
344
+ @staticmethod
345
+ def get_default_config(loss_type: str) -> Dict[str, Any]:
346
+ """Get default configuration for loss function"""
347
+ configs = {
348
+ 'cross_entropy': {
349
+ 'reduction': 'mean',
350
+ 'label_smoothing': 0.0
351
+ },
352
+ 'kl_divergence': {
353
+ 'reduction': 'mean',
354
+ 'log_target': False
355
+ },
356
+ 'mse': {
357
+ 'reduction': 'mean'
358
+ },
359
+ 'huber': {
360
+ 'reduction': 'mean',
361
+ 'delta': 1.0
362
+ },
363
+ 'focal': {
364
+ 'alpha': 1.0,
365
+ 'gamma': 2.0,
366
+ 'reduction': 'mean'
367
+ },
368
+ 'contrastive': {
369
+ 'margin': 1.0,
370
+ 'reduction': 'mean'
371
+ },
372
+ 'custom_agent': {
373
+ 'task_loss_weight': 1.0,
374
+ 'consistency_loss_weight': 0.1,
375
+ 'regularization_weight': 0.01
376
+ }
377
+ }
378
+
379
+ return configs.get(loss_type.lower(), {})
src/multi_agent_training/gradient_descent/model_wrapper.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Wrapper
3
+ =============
4
+
5
+ This module provides a wrapper for neural network models to integrate
6
+ with the gradient descent training system, including support for LoRA
7
+ adapters and the MangoMAS agent system.
8
+ """
9
+
10
+ import logging
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing import Dict, List, Optional, Any
14
+ from pathlib import Path
15
+ import json
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ModelWrapper:
21
+ """
22
+ Wrapper for neural network models to integrate with gradient descent training
23
+
24
+ Provides a unified interface for different model types and handles
25
+ LoRA adapter integration for the MangoMAS system.
26
+ """
27
+
28
+ def __init__(self, model: nn.Module, model_type: str = 'transformer',
29
+ lora_config: Optional[Dict[str, Any]] = None):
30
+ self.model = model
31
+ self.model_type = model_type
32
+ self.lora_config = lora_config or {}
33
+ self.lora_params = []
34
+
35
+ # Initialize LoRA if configured
36
+ if lora_config:
37
+ self._setup_lora()
38
+
39
+ logger.info(f"Initialized ModelWrapper for {model_type} model")
40
+
41
+ def _setup_lora(self):
42
+ """Setup LoRA adapters for the model"""
43
+ if not self.lora_config:
44
+ return
45
+
46
+ # Extract LoRA parameters
47
+ for name, param in self.model.named_parameters():
48
+ if 'lora' in name.lower() or 'adapter' in name.lower():
49
+ self.lora_params.append(name)
50
+ param.requires_grad = True
51
+ else:
52
+ param.requires_grad = False
53
+
54
+ logger.info(f"Setup LoRA with {len(self.lora_params)} adapter parameters")
55
+
56
+ def forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
57
+ """
58
+ Forward pass through the model
59
+
60
+ Args:
61
+ inputs: Input tensor
62
+ **kwargs: Additional arguments
63
+
64
+ Returns:
65
+ Model output tensor
66
+ """
67
+ return self.model(inputs, **kwargs)
68
+
69
+ def get_trainable_parameters(self) -> List[torch.Tensor]:
70
+ """
71
+ Get list of trainable parameters
72
+
73
+ Returns:
74
+ List of trainable parameter tensors
75
+ """
76
+ if self.lora_params:
77
+ # Return only LoRA parameters
78
+ return [param for name, param in self.model.named_parameters()
79
+ if name in self.lora_params and param.requires_grad]
80
+ else:
81
+ # Return all trainable parameters
82
+ return [param for param in self.model.parameters() if param.requires_grad]
83
+
84
+ def get_parameter_info(self) -> Dict[str, Any]:
85
+ """
86
+ Get information about model parameters
87
+
88
+ Returns:
89
+ Dictionary of parameter information
90
+ """
91
+ info = {
92
+ 'total_parameters': sum(p.numel() for p in self.model.parameters()),
93
+ 'trainable_parameters': sum(p.numel() for p in self.get_trainable_parameters()),
94
+ 'lora_parameters': len(self.lora_params),
95
+ 'parameter_details': {}
96
+ }
97
+
98
+ for name, param in self.model.named_parameters():
99
+ info['parameter_details'][name] = {
100
+ 'shape': list(param.shape),
101
+ 'numel': param.numel(),
102
+ 'requires_grad': param.requires_grad,
103
+ 'is_lora': name in self.lora_params
104
+ }
105
+
106
+ return info
107
+
108
+ def save_model(self, save_path: str, metadata: Optional[Dict[str, Any]] = None):
109
+ """
110
+ Save the model and metadata
111
+
112
+ Args:
113
+ save_path: Path to save the model
114
+ metadata: Additional metadata to save
115
+ """
116
+ save_path = Path(save_path)
117
+ save_path.mkdir(parents=True, exist_ok=True)
118
+
119
+ # Save model state
120
+ model_path = save_path / 'model.pt'
121
+ torch.save(self.model.state_dict(), model_path)
122
+
123
+ # Save metadata
124
+ if metadata is None:
125
+ metadata = {}
126
+
127
+ metadata.update({
128
+ 'model_type': self.model_type,
129
+ 'lora_config': self.lora_config,
130
+ 'lora_params': self.lora_params,
131
+ 'parameter_info': self.get_parameter_info()
132
+ })
133
+
134
+ metadata_path = save_path / 'metadata.json'
135
+ with open(metadata_path, 'w') as f:
136
+ json.dump(metadata, f, indent=2)
137
+
138
+ logger.info(f"Model saved to {save_path}")
139
+
140
+ def load_model(self, load_path: str):
141
+ """
142
+ Load the model from saved files
143
+
144
+ Args:
145
+ load_path: Path to load the model from
146
+ """
147
+ load_path = Path(load_path)
148
+
149
+ # Load model state
150
+ model_path = load_path / 'model.pt'
151
+ if model_path.exists():
152
+ state_dict = torch.load(model_path, map_location='cpu')
153
+ self.model.load_state_dict(state_dict)
154
+ logger.info(f"Model loaded from {model_path}")
155
+
156
+ # Load metadata
157
+ metadata_path = load_path / 'metadata.json'
158
+ if metadata_path.exists():
159
+ with open(metadata_path, 'r') as f:
160
+ metadata = json.load(f)
161
+
162
+ self.model_type = metadata.get('model_type', self.model_type)
163
+ self.lora_config = metadata.get('lora_config', self.lora_config)
164
+ self.lora_params = metadata.get('lora_params', self.lora_params)
165
+
166
+ logger.info(f"Metadata loaded from {metadata_path}")
167
+
168
+ def to(self, device: torch.device):
169
+ """Move model to device"""
170
+ self.model.to(device)
171
+ return self
172
+
173
+ def train(self):
174
+ """Set model to training mode"""
175
+ self.model.train()
176
+ return self
177
+
178
+ def eval(self):
179
+ """Set model to evaluation mode"""
180
+ self.model.eval()
181
+ return self
182
+
183
+ def __call__(self, *args, **kwargs):
184
+ """Call the model"""
185
+ return self.forward(*args, **kwargs)
186
+
187
+
188
+ class LoRAModelWrapper(ModelWrapper):
189
+ """
190
+ Specialized wrapper for LoRA (Low-Rank Adaptation) models
191
+
192
+ Provides enhanced functionality for LoRA adapter management
193
+ and integration with the MangoMAS system.
194
+ """
195
+
196
+ def __init__(self, base_model: nn.Module, lora_config: Dict[str, Any]):
197
+ super().__init__(base_model, 'lora_transformer', lora_config)
198
+ self.base_model = base_model
199
+ self.adapters = {}
200
+
201
+ # Initialize LoRA adapters
202
+ self._initialize_lora_adapters()
203
+
204
+ def _initialize_lora_adapters(self):
205
+ """Initialize LoRA adapters based on configuration"""
206
+ rank = self.lora_config.get('rank', 16)
207
+ alpha = self.lora_config.get('alpha', 32)
208
+ dropout = self.lora_config.get('dropout', 0.1)
209
+ target_modules = self.lora_config.get('target_modules', ['c_attn', 'c_proj'])
210
+
211
+ # Add LoRA adapters to target modules
212
+ for name, module in self.base_model.named_modules():
213
+ if any(target in name for target in target_modules):
214
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
215
+ # Add LoRA adapter
216
+ adapter = LoRAAdapter(module, rank, alpha, dropout)
217
+ self.adapters[name] = adapter
218
+
219
+ # Replace original module
220
+ self._replace_module(name, adapter)
221
+
222
+ logger.info(f"Initialized {len(self.adapters)} LoRA adapters")
223
+
224
+ def _replace_module(self, module_name: str, new_module: nn.Module):
225
+ """Replace a module in the model"""
226
+ parts = module_name.split('.')
227
+ parent = self.base_model
228
+
229
+ for part in parts[:-1]:
230
+ parent = getattr(parent, part)
231
+
232
+ setattr(parent, parts[-1], new_module)
233
+
234
+ def get_lora_parameters(self) -> List[torch.Tensor]:
235
+ """Get LoRA adapter parameters"""
236
+ lora_params = []
237
+ for adapter in self.adapters.values():
238
+ lora_params.extend(adapter.parameters())
239
+ return lora_params
240
+
241
+ def merge_adapters(self):
242
+ """Merge LoRA adapters into base model"""
243
+ for adapter in self.adapters.values():
244
+ adapter.merge()
245
+ logger.info("LoRA adapters merged into base model")
246
+
247
+ def unmerge_adapters(self):
248
+ """Unmerge LoRA adapters from base model"""
249
+ for adapter in self.adapters.values():
250
+ adapter.unmerge()
251
+ logger.info("LoRA adapters unmerged from base model")
252
+
253
+
254
+ class LoRAAdapter(nn.Module):
255
+ """
256
+ LoRA (Low-Rank Adaptation) adapter module
257
+
258
+ Implements the LoRA technique for efficient fine-tuning of large models.
259
+ """
260
+
261
+ def __init__(self, original_module: nn.Module, rank: int = 16,
262
+ alpha: float = 32, dropout: float = 0.1):
263
+ super().__init__()
264
+ self.original_module = original_module
265
+ self.rank = rank
266
+ self.alpha = alpha
267
+ self.dropout = dropout
268
+
269
+ # Get original module dimensions
270
+ if isinstance(original_module, nn.Linear):
271
+ in_features = original_module.in_features
272
+ out_features = original_module.out_features
273
+ elif isinstance(original_module, nn.Conv2d):
274
+ in_features = original_module.in_channels
275
+ out_features = original_module.out_channels
276
+ else:
277
+ raise ValueError(f"Unsupported module type: {type(original_module)}")
278
+
279
+ # Initialize LoRA matrices
280
+ self.lora_A = nn.Linear(in_features, rank, bias=False)
281
+ self.lora_B = nn.Linear(rank, out_features, bias=False)
282
+ self.dropout_layer = nn.Dropout(dropout)
283
+
284
+ # Initialize weights
285
+ nn.init.kaiming_uniform_(self.lora_A.weight)
286
+ nn.init.zeros_(self.lora_B.weight)
287
+
288
+ # Store original weights
289
+ self.original_weight = original_module.weight.data.clone()
290
+ self.merged = False
291
+
292
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
293
+ """Forward pass through LoRA adapter"""
294
+ if self.merged:
295
+ # Use merged weights
296
+ return F.linear(x, self.original_weight, self.original_module.bias)
297
+ else:
298
+ # Use LoRA adaptation
299
+ lora_output = self.lora_B(self.dropout_layer(self.lora_A(x)))
300
+ original_output = F.linear(x, self.original_weight, self.original_module.bias)
301
+ return original_output + (self.alpha / self.rank) * lora_output
302
+
303
+ def merge(self):
304
+ """Merge LoRA weights into original weights"""
305
+ if not self.merged:
306
+ lora_weight = (self.alpha / self.rank) * torch.mm(
307
+ self.lora_B.weight, self.lora_A.weight
308
+ )
309
+ self.original_weight += lora_weight
310
+ self.merged = True
311
+
312
+ def unmerge(self):
313
+ """Unmerge LoRA weights from original weights"""
314
+ if self.merged:
315
+ lora_weight = (self.alpha / self.rank) * torch.mm(
316
+ self.lora_B.weight, self.lora_A.weight
317
+ )
318
+ self.original_weight -= lora_weight
319
+ self.merged = False
320
+
321
+
322
+ class ModelFactory:
323
+ """Factory class for creating model wrappers"""
324
+
325
+ @staticmethod
326
+ def create_model_wrapper(model_type: str, model: nn.Module,
327
+ **kwargs) -> ModelWrapper:
328
+ """Create a model wrapper instance"""
329
+ if model_type.lower() == 'lora':
330
+ return LoRAModelWrapper(model, kwargs.get('lora_config', {}))
331
+ else:
332
+ return ModelWrapper(model, model_type, kwargs.get('lora_config'))
333
+
334
+ @staticmethod
335
+ def get_default_lora_config() -> Dict[str, Any]:
336
+ """Get default LoRA configuration"""
337
+ return {
338
+ 'rank': 16,
339
+ 'alpha': 32,
340
+ 'dropout': 0.1,
341
+ 'target_modules': ['c_attn', 'c_proj']
342
+ }
src/multi_agent_training/gradient_descent/monitoring.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Monitoring
3
+ ==================
4
+
5
+ This module provides comprehensive monitoring capabilities for gradient descent
6
+ training, including gradient tracking, loss monitoring, and performance metrics.
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import time
12
+ from typing import Dict, Optional, Any
13
+ import torch
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ from collections import defaultdict, deque
17
+ import json
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GradientMonitor:
23
+ """
24
+ Monitor gradient statistics during training
25
+
26
+ Tracks gradient norms, distributions, and anomalies to ensure
27
+ stable training and detect potential issues.
28
+ """
29
+
30
+ def __init__(self, max_history: int = 1000):
31
+ self.max_history = max_history
32
+ self.gradient_norms = deque(maxlen=max_history)
33
+ self.gradient_means = deque(maxlen=max_history)
34
+ self.gradient_stds = deque(maxlen=max_history)
35
+ self.gradient_maxs = deque(maxlen=max_history)
36
+ self.gradient_mins = deque(maxlen=max_history)
37
+
38
+ self.parameter_stats = defaultdict(lambda: {
39
+ 'norms': deque(maxlen=max_history),
40
+ 'means': deque(maxlen=max_history),
41
+ 'stds': deque(maxlen=max_history)
42
+ })
43
+
44
+ self.anomaly_count = 0
45
+ self.last_anomaly_time = None
46
+
47
+ logger.info("Initialized GradientMonitor")
48
+
49
+ def update(self, gradients: Dict[str, torch.Tensor]):
50
+ """
51
+ Update gradient statistics
52
+
53
+ Args:
54
+ gradients: Dictionary of parameter gradients
55
+ """
56
+ total_norm = 0.0
57
+ total_mean = 0.0
58
+ total_std = 0.0
59
+ total_max = float('-inf')
60
+ total_min = float('inf')
61
+ param_count = 0
62
+
63
+ for name, grad in gradients.items():
64
+ if grad is not None:
65
+ # Compute statistics for this parameter
66
+ param_norm = grad.data.norm(2).item()
67
+ param_mean = grad.data.mean().item()
68
+ param_std = grad.data.std().item()
69
+ param_max = grad.data.max().item()
70
+ param_min = grad.data.min().item()
71
+
72
+ # Update parameter-specific stats
73
+ self.parameter_stats[name]['norms'].append(param_norm)
74
+ self.parameter_stats[name]['means'].append(param_mean)
75
+ self.parameter_stats[name]['stds'].append(param_std)
76
+
77
+ # Accumulate global stats
78
+ total_norm += param_norm ** 2
79
+ total_mean += param_mean
80
+ total_std += param_std ** 2
81
+ total_max = max(total_max, param_max)
82
+ total_min = min(total_min, param_min)
83
+ param_count += 1
84
+
85
+ # Compute global statistics
86
+ if param_count > 0:
87
+ total_norm = math.sqrt(total_norm)
88
+ total_mean /= param_count
89
+ total_std = math.sqrt(total_std / param_count)
90
+
91
+ # Store global stats
92
+ self.gradient_norms.append(total_norm)
93
+ self.gradient_means.append(total_mean)
94
+ self.gradient_stds.append(total_std)
95
+ self.gradient_maxs.append(total_max)
96
+ self.gradient_mins.append(total_min)
97
+
98
+ def detect_anomalies(self) -> Dict[str, Any]:
99
+ """
100
+ Detect gradient anomalies
101
+
102
+ Returns:
103
+ Dictionary of detected anomalies
104
+ """
105
+ anomalies = {
106
+ 'exploding_gradients': False,
107
+ 'vanishing_gradients': False,
108
+ 'gradient_imbalance': False,
109
+ 'nan_gradients': False,
110
+ 'gradient_spikes': False
111
+ }
112
+
113
+ if len(self.gradient_norms) < 2:
114
+ return anomalies
115
+
116
+ current_norm = self.gradient_norms[-1]
117
+
118
+ # Check for exploding gradients
119
+ if current_norm > 10.0:
120
+ anomalies['exploding_gradients'] = True
121
+ self.anomaly_count += 1
122
+ self.last_anomaly_time = time.time()
123
+ logger.warning(f"Exploding gradients detected: norm={current_norm:.6f}")
124
+
125
+ # Check for vanishing gradients
126
+ if current_norm < 1e-6:
127
+ anomalies['vanishing_gradients'] = True
128
+ self.anomaly_count += 1
129
+ self.last_anomaly_time = time.time()
130
+ logger.warning(f"Vanishing gradients detected: norm={current_norm:.6f}")
131
+
132
+ # Check for gradient spikes
133
+ if len(self.gradient_norms) >= 10:
134
+ recent_norms = list(self.gradient_norms)[-10:]
135
+ avg_norm = np.mean(recent_norms[:-1])
136
+ if current_norm > 3 * avg_norm:
137
+ anomalies['gradient_spikes'] = True
138
+ logger.warning(f"Gradient spike detected: {current_norm:.6f} vs avg {avg_norm:.6f}")
139
+
140
+ # Check for NaN gradients
141
+ if math.isnan(current_norm) or math.isnan(self.gradient_means[-1]):
142
+ anomalies['nan_gradients'] = True
143
+ self.anomaly_count += 1
144
+ self.last_anomaly_time = time.time()
145
+ logger.warning("NaN gradients detected")
146
+
147
+ # Check for gradient imbalance between parameters
148
+ if len(self.parameter_stats) > 1:
149
+ param_norms = [stats['norms'][-1] for stats in self.parameter_stats.values()
150
+ if len(stats['norms']) > 0]
151
+ if param_norms and max(param_norms) / min(param_norms) > 1000:
152
+ anomalies['gradient_imbalance'] = True
153
+ logger.warning("Gradient imbalance detected between parameters")
154
+
155
+ return anomalies
156
+
157
+ def get_statistics(self) -> Dict[str, Any]:
158
+ """
159
+ Get comprehensive gradient statistics
160
+
161
+ Returns:
162
+ Dictionary of gradient statistics
163
+ """
164
+ if not self.gradient_norms:
165
+ return {}
166
+
167
+ stats = {
168
+ 'current_norm': self.gradient_norms[-1],
169
+ 'mean_norm': np.mean(self.gradient_norms),
170
+ 'std_norm': np.std(self.gradient_norms),
171
+ 'min_norm': min(self.gradient_norms),
172
+ 'max_norm': max(self.gradient_norms),
173
+ 'current_mean': self.gradient_means[-1],
174
+ 'current_std': self.gradient_stds[-1],
175
+ 'current_max': self.gradient_maxs[-1],
176
+ 'current_min': self.gradient_mins[-1],
177
+ 'anomaly_count': self.anomaly_count,
178
+ 'parameter_count': len(self.parameter_stats)
179
+ }
180
+
181
+ # Add parameter-specific statistics
182
+ param_stats = {}
183
+ for name, stats_dict in self.parameter_stats.items():
184
+ if stats_dict['norms']:
185
+ param_stats[name] = {
186
+ 'current_norm': stats_dict['norms'][-1],
187
+ 'mean_norm': np.mean(stats_dict['norms']),
188
+ 'std_norm': np.std(stats_dict['norms'])
189
+ }
190
+
191
+ stats['parameter_stats'] = param_stats
192
+
193
+ return stats
194
+
195
+ def plot_gradients(self, save_path: Optional[str] = None):
196
+ """
197
+ Plot gradient statistics
198
+
199
+ Args:
200
+ save_path: Path to save the plot
201
+ """
202
+ if not self.gradient_norms:
203
+ logger.warning("No gradient data to plot")
204
+ return
205
+
206
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
207
+
208
+ # Plot gradient norms
209
+ axes[0, 0].plot(self.gradient_norms)
210
+ axes[0, 0].set_title('Gradient Norms')
211
+ axes[0, 0].set_xlabel('Step')
212
+ axes[0, 0].set_ylabel('Norm')
213
+ axes[0, 0].grid(True)
214
+
215
+ # Plot gradient means
216
+ axes[0, 1].plot(self.gradient_means)
217
+ axes[0, 1].set_title('Gradient Means')
218
+ axes[0, 1].set_xlabel('Step')
219
+ axes[0, 1].set_ylabel('Mean')
220
+ axes[0, 1].grid(True)
221
+
222
+ # Plot gradient stds
223
+ axes[1, 0].plot(self.gradient_stds)
224
+ axes[1, 0].set_title('Gradient Standard Deviations')
225
+ axes[1, 0].set_xlabel('Step')
226
+ axes[1, 0].set_ylabel('Std')
227
+ axes[1, 0].grid(True)
228
+
229
+ # Plot gradient range
230
+ axes[1, 1].plot(self.gradient_maxs, label='Max')
231
+ axes[1, 1].plot(self.gradient_mins, label='Min')
232
+ axes[1, 1].set_title('Gradient Range')
233
+ axes[1, 1].set_xlabel('Step')
234
+ axes[1, 1].set_ylabel('Value')
235
+ axes[1, 1].legend()
236
+ axes[1, 1].grid(True)
237
+
238
+ plt.tight_layout()
239
+
240
+ if save_path:
241
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
242
+ logger.info(f"Gradient plots saved to {save_path}")
243
+
244
+ plt.show()
245
+
246
+
247
+ class TrainingMonitor:
248
+ """
249
+ Monitor training progress and performance
250
+
251
+ Tracks loss, accuracy, learning rates, and other training metrics
252
+ to provide comprehensive training insights.
253
+ """
254
+
255
+ def __init__(self, max_history: int = 1000):
256
+ self.max_history = max_history
257
+ self.losses = deque(maxlen=max_history)
258
+ self.accuracies = deque(maxlen=max_history)
259
+ self.learning_rates = deque(maxlen=max_history)
260
+ self.training_times = deque(maxlen=max_history)
261
+
262
+ self.epoch_metrics = defaultdict(lambda: {
263
+ 'loss': deque(maxlen=max_history),
264
+ 'accuracy': deque(maxlen=max_history),
265
+ 'learning_rate': deque(maxlen=max_history)
266
+ })
267
+
268
+ self.best_loss = float('inf')
269
+ self.best_accuracy = 0.0
270
+ self.training_start_time = time.time()
271
+
272
+ logger.info("Initialized TrainingMonitor")
273
+
274
+ def update(self, loss: float, accuracy: Optional[float] = None,
275
+ learning_rate: Optional[float] = None, epoch: Optional[int] = None):
276
+ """
277
+ Update training metrics
278
+
279
+ Args:
280
+ loss: Current loss value
281
+ accuracy: Current accuracy (optional)
282
+ learning_rate: Current learning rate (optional)
283
+ epoch: Current epoch (optional)
284
+ """
285
+ current_time = time.time()
286
+
287
+ # Update global metrics
288
+ self.losses.append(loss)
289
+ if accuracy is not None:
290
+ self.accuracies.append(accuracy)
291
+ if learning_rate is not None:
292
+ self.learning_rates.append(learning_rate)
293
+ self.training_times.append(current_time - self.training_start_time)
294
+
295
+ # Update epoch-specific metrics
296
+ if epoch is not None:
297
+ self.epoch_metrics[epoch]['loss'].append(loss)
298
+ if accuracy is not None:
299
+ self.epoch_metrics[epoch]['accuracy'].append(accuracy)
300
+ if learning_rate is not None:
301
+ self.epoch_metrics[epoch]['learning_rate'].append(learning_rate)
302
+
303
+ # Update best metrics
304
+ if loss < self.best_loss:
305
+ self.best_loss = loss
306
+ if accuracy is not None and accuracy > self.best_accuracy:
307
+ self.best_accuracy = accuracy
308
+
309
+ def get_statistics(self) -> Dict[str, Any]:
310
+ """
311
+ Get comprehensive training statistics
312
+
313
+ Returns:
314
+ Dictionary of training statistics
315
+ """
316
+ if not self.losses:
317
+ return {}
318
+
319
+ stats = {
320
+ 'current_loss': self.losses[-1],
321
+ 'best_loss': self.best_loss,
322
+ 'mean_loss': np.mean(self.losses),
323
+ 'std_loss': np.std(self.losses),
324
+ 'min_loss': min(self.losses),
325
+ 'max_loss': max(self.losses),
326
+ 'best_accuracy': self.best_accuracy,
327
+ 'total_steps': len(self.losses),
328
+ 'training_time': self.training_times[-1] if self.training_times else 0
329
+ }
330
+
331
+ if self.accuracies:
332
+ stats.update({
333
+ 'current_accuracy': self.accuracies[-1],
334
+ 'mean_accuracy': np.mean(self.accuracies),
335
+ 'std_accuracy': np.std(self.accuracies)
336
+ })
337
+
338
+ if self.learning_rates:
339
+ stats.update({
340
+ 'current_learning_rate': self.learning_rates[-1],
341
+ 'mean_learning_rate': np.mean(self.learning_rates),
342
+ 'min_learning_rate': min(self.learning_rates),
343
+ 'max_learning_rate': max(self.learning_rates)
344
+ })
345
+
346
+ return stats
347
+
348
+ def detect_convergence(self, patience: int = 10, threshold: float = 1e-4) -> bool:
349
+ """
350
+ Detect if training has converged
351
+
352
+ Args:
353
+ patience: Number of steps to wait for improvement
354
+ threshold: Minimum improvement threshold
355
+
356
+ Returns:
357
+ True if training has converged
358
+ """
359
+ if len(self.losses) < patience:
360
+ return False
361
+
362
+ recent_losses = list(self.losses)[-patience:]
363
+ best_recent = min(recent_losses)
364
+ improvement = self.best_loss - best_recent
365
+
366
+ return improvement < threshold
367
+
368
+ def plot_training_curves(self, save_path: Optional[str] = None):
369
+ """
370
+ Plot training curves
371
+
372
+ Args:
373
+ save_path: Path to save the plot
374
+ """
375
+ if not self.losses:
376
+ logger.warning("No training data to plot")
377
+ return
378
+
379
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
380
+
381
+ # Plot loss curve
382
+ axes[0, 0].plot(self.losses)
383
+ axes[0, 0].set_title('Training Loss')
384
+ axes[0, 0].set_xlabel('Step')
385
+ axes[0, 0].set_ylabel('Loss')
386
+ axes[0, 0].grid(True)
387
+
388
+ # Plot accuracy curve
389
+ if self.accuracies:
390
+ axes[0, 1].plot(self.accuracies)
391
+ axes[0, 1].set_title('Training Accuracy')
392
+ axes[0, 1].set_xlabel('Step')
393
+ axes[0, 1].set_ylabel('Accuracy')
394
+ axes[0, 1].grid(True)
395
+
396
+ # Plot learning rate curve
397
+ if self.learning_rates:
398
+ axes[1, 0].plot(self.learning_rates)
399
+ axes[1, 0].set_title('Learning Rate')
400
+ axes[1, 0].set_xlabel('Step')
401
+ axes[1, 0].set_ylabel('Learning Rate')
402
+ axes[1, 0].grid(True)
403
+
404
+ # Plot training time
405
+ if self.training_times:
406
+ axes[1, 1].plot(self.training_times)
407
+ axes[1, 1].set_title('Training Time')
408
+ axes[1, 1].set_xlabel('Step')
409
+ axes[1, 1].set_ylabel('Time (seconds)')
410
+ axes[1, 1].grid(True)
411
+
412
+ plt.tight_layout()
413
+
414
+ if save_path:
415
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
416
+ logger.info(f"Training curves saved to {save_path}")
417
+
418
+ plt.show()
419
+
420
+ def save_metrics(self, file_path: str):
421
+ """
422
+ Save training metrics to file
423
+
424
+ Args:
425
+ file_path: Path to save the metrics
426
+ """
427
+ metrics = {
428
+ 'losses': list(self.losses),
429
+ 'accuracies': list(self.accuracies),
430
+ 'learning_rates': list(self.learning_rates),
431
+ 'training_times': list(self.training_times),
432
+ 'best_loss': self.best_loss,
433
+ 'best_accuracy': self.best_accuracy,
434
+ 'statistics': self.get_statistics()
435
+ }
436
+
437
+ with open(file_path, 'w') as f:
438
+ json.dump(metrics, f, indent=2)
439
+
440
+ logger.info(f"Training metrics saved to {file_path}")
441
+
442
+ def load_metrics(self, file_path: str):
443
+ """
444
+ Load training metrics from file
445
+
446
+ Args:
447
+ file_path: Path to load the metrics from
448
+ """
449
+ with open(file_path, 'r') as f:
450
+ metrics = json.load(f)
451
+
452
+ self.losses = deque(metrics['losses'], maxlen=self.max_history)
453
+ self.accuracies = deque(metrics['accuracies'], maxlen=self.max_history)
454
+ self.learning_rates = deque(metrics['learning_rates'], maxlen=self.max_history)
455
+ self.training_times = deque(metrics['training_times'], maxlen=self.max_history)
456
+ self.best_loss = metrics['best_loss']
457
+ self.best_accuracy = metrics['best_accuracy']
458
+
459
+ logger.info(f"Training metrics loaded from {file_path}")
460
+
461
+
462
+ class PerformanceMonitor:
463
+ """
464
+ Monitor system performance during training
465
+
466
+ Tracks memory usage, compute time, and other system metrics
467
+ to optimize training efficiency.
468
+ """
469
+
470
+ def __init__(self):
471
+ self.memory_usage = deque(maxlen=1000)
472
+ self.compute_times = deque(maxlen=1000)
473
+ self.gpu_usage = deque(maxlen=1000)
474
+
475
+ self.step_times = []
476
+ self.forward_times = []
477
+ self.backward_times = []
478
+ self.optimizer_times = []
479
+
480
+ logger.info("Initialized PerformanceMonitor")
481
+
482
+ def update_memory(self, memory_mb: float):
483
+ """Update memory usage"""
484
+ self.memory_usage.append(memory_mb)
485
+
486
+ def update_compute_time(self, time_seconds: float):
487
+ """Update compute time"""
488
+ self.compute_times.append(time_seconds)
489
+
490
+ def update_gpu_usage(self, gpu_percent: float):
491
+ """Update GPU usage"""
492
+ self.gpu_usage.append(gpu_percent)
493
+
494
+ def time_step(self, step_name: str):
495
+ """Context manager for timing steps"""
496
+ return StepTimer(self, step_name)
497
+
498
+ def get_statistics(self) -> Dict[str, Any]:
499
+ """Get performance statistics"""
500
+ stats = {}
501
+
502
+ if self.memory_usage:
503
+ stats['memory'] = {
504
+ 'current_mb': self.memory_usage[-1],
505
+ 'mean_mb': np.mean(self.memory_usage),
506
+ 'max_mb': max(self.memory_usage)
507
+ }
508
+
509
+ if self.compute_times:
510
+ stats['compute'] = {
511
+ 'current_seconds': self.compute_times[-1],
512
+ 'mean_seconds': np.mean(self.compute_times),
513
+ 'total_seconds': sum(self.compute_times)
514
+ }
515
+
516
+ if self.gpu_usage:
517
+ stats['gpu'] = {
518
+ 'current_percent': self.gpu_usage[-1],
519
+ 'mean_percent': np.mean(self.gpu_usage),
520
+ 'max_percent': max(self.gpu_usage)
521
+ }
522
+
523
+ return stats
524
+
525
+
526
+ class StepTimer:
527
+ """Context manager for timing training steps"""
528
+
529
+ def __init__(self, monitor: PerformanceMonitor, step_name: str):
530
+ self.monitor = monitor
531
+ self.step_name = step_name
532
+ self.start_time = None
533
+
534
+ def __enter__(self):
535
+ self.start_time = time.time()
536
+ return self
537
+
538
+ def __exit__(self, exc_type, exc_val, exc_tb):
539
+ elapsed = time.time() - self.start_time
540
+
541
+ if self.step_name == 'forward':
542
+ self.monitor.forward_times.append(elapsed)
543
+ elif self.step_name == 'backward':
544
+ self.monitor.backward_times.append(elapsed)
545
+ elif self.step_name == 'optimizer':
546
+ self.monitor.optimizer_times.append(elapsed)
547
+ else:
548
+ self.monitor.step_times.append(elapsed)
src/multi_agent_training/gradient_descent/optimizers.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Descent Optimizers
3
+ ==========================
4
+
5
+ This module implements various gradient descent optimization algorithms
6
+ with proper mathematical formulations and PyTorch compatibility.
7
+
8
+ Algorithms implemented:
9
+ - SGD (Stochastic Gradient Descent)
10
+ - Adam (Adaptive Moment Estimation)
11
+ - AdamW (Adam with Decoupled Weight Decay)
12
+
13
+ All optimizers follow the mathematical principles and include proper
14
+ momentum, learning rate scheduling, and weight decay handling.
15
+ """
16
+
17
+ import math
18
+ import logging
19
+ from abc import ABC, abstractmethod
20
+ from typing import Dict, List, Tuple, Any
21
+ import torch
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class Optimizer(ABC):
27
+ """Abstract base class for all optimizers"""
28
+
29
+ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3, **kwargs):
30
+ self.params = params
31
+ self.lr = lr
32
+ self.step_count = 0
33
+ self.state = {}
34
+
35
+ @abstractmethod
36
+ def step(self, closure=None):
37
+ """Perform a single optimization step"""
38
+ pass
39
+
40
+ def zero_grad(self):
41
+ """Zero the gradients of all parameters"""
42
+ for param in self.params:
43
+ if param.grad is not None:
44
+ param.grad.zero_()
45
+
46
+ def state_dict(self):
47
+ """Return the state of the optimizer"""
48
+ return {
49
+ 'state': self.state,
50
+ 'param_groups': [{'params': self.params, 'lr': self.lr}],
51
+ 'step_count': self.step_count
52
+ }
53
+
54
+ def load_state_dict(self, state_dict):
55
+ """Load the state of the optimizer"""
56
+ self.state = state_dict['state']
57
+ self.lr = state_dict['param_groups'][0]['lr']
58
+ self.step_count = state_dict['step_count']
59
+
60
+
61
+ class SGD(Optimizer):
62
+ """
63
+ Stochastic Gradient Descent optimizer
64
+
65
+ Mathematical formulation:
66
+ θ_{t+1} = θ_t - α * ∇_θ J(θ_t)
67
+
68
+ With momentum:
69
+ v_{t+1} = μ * v_t + ∇_θ J(θ_t)
70
+ θ_{t+1} = θ_t - α * v_{t+1}
71
+ """
72
+
73
+ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3,
74
+ momentum: float = 0.0, weight_decay: float = 0.0,
75
+ dampening: float = 0.0, nesterov: bool = False):
76
+ super().__init__(params, lr)
77
+ self.momentum = momentum
78
+ self.weight_decay = weight_decay
79
+ self.dampening = dampening
80
+ self.nesterov = nesterov
81
+
82
+ # Initialize momentum buffers
83
+ for param in self.params:
84
+ if momentum > 0:
85
+ self.state[param] = {'momentum_buffer': torch.zeros_like(param)}
86
+
87
+ logger.info(f"Initialized SGD optimizer: lr={lr}, momentum={momentum}, weight_decay={weight_decay}")
88
+
89
+ def step(self, closure=None):
90
+ """Perform SGD optimization step"""
91
+ loss = None
92
+ if closure is not None:
93
+ loss = closure()
94
+
95
+ for param in self.params:
96
+ if param.grad is None:
97
+ continue
98
+
99
+ grad = param.grad.data
100
+
101
+ # Apply weight decay
102
+ if self.weight_decay != 0:
103
+ grad = grad.add(param.data, alpha=self.weight_decay)
104
+
105
+ # Apply momentum
106
+ if self.momentum != 0:
107
+ param_state = self.state[param]
108
+ if 'momentum_buffer' not in param_state:
109
+ param_state['momentum_buffer'] = torch.zeros_like(param.data)
110
+
111
+ momentum_buffer = param_state['momentum_buffer']
112
+ momentum_buffer.mul_(self.momentum).add_(grad, alpha=1 - self.dampening)
113
+
114
+ if self.nesterov:
115
+ grad = grad.add(momentum_buffer, alpha=self.momentum)
116
+ else:
117
+ grad = momentum_buffer
118
+
119
+ # Update parameters
120
+ param.data.add_(grad, alpha=-self.lr)
121
+
122
+ self.step_count += 1
123
+ return loss
124
+
125
+
126
+ class Adam(Optimizer):
127
+ """
128
+ Adam (Adaptive Moment Estimation) optimizer
129
+
130
+ Mathematical formulation:
131
+ m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
132
+ v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
133
+ m̂_t = m_t / (1 - β₁ᵗ)
134
+ v̂_t = v_t / (1 - β₂ᵗ)
135
+ θ_{t+1} = θ_t - α * m̂_t / (√v̂_t + ε)
136
+ """
137
+
138
+ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3,
139
+ betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8,
140
+ weight_decay: float = 0.0, amsgrad: bool = False):
141
+ super().__init__(params, lr)
142
+ self.betas = betas
143
+ self.eps = eps
144
+ self.weight_decay = weight_decay
145
+ self.amsgrad = amsgrad
146
+
147
+ # Initialize moment estimates
148
+ for param in self.params:
149
+ self.state[param] = {
150
+ 'step': 0,
151
+ 'exp_avg': torch.zeros_like(param.data),
152
+ 'exp_avg_sq': torch.zeros_like(param.data)
153
+ }
154
+ if amsgrad:
155
+ self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data)
156
+
157
+ logger.info(f"Initialized Adam optimizer: lr={lr}, betas={betas}, eps={eps}")
158
+
159
+ def step(self, closure=None):
160
+ """Perform Adam optimization step"""
161
+ loss = None
162
+ if closure is not None:
163
+ loss = closure()
164
+
165
+ for param in self.params:
166
+ if param.grad is None:
167
+ continue
168
+
169
+ grad = param.grad.data
170
+
171
+ # Apply weight decay
172
+ if self.weight_decay != 0:
173
+ grad = grad.add(param.data, alpha=self.weight_decay)
174
+
175
+ param_state = self.state[param]
176
+ exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq']
177
+ beta1, beta2 = self.betas
178
+
179
+ param_state['step'] += 1
180
+ bias_correction1 = 1 - beta1 ** param_state['step']
181
+ bias_correction2 = 1 - beta2 ** param_state['step']
182
+
183
+ # Update biased first moment estimate
184
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
185
+
186
+ # Update biased second raw moment estimate
187
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
188
+
189
+ if self.amsgrad:
190
+ # Maintains the maximum of all 2nd moment running avg. of squared gradients
191
+ torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq'])
192
+ # Use the max. for normalizing running avg. of squared gradients
193
+ denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
194
+ else:
195
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
196
+
197
+ # Update parameters
198
+ step_size = self.lr / bias_correction1
199
+ param.data.addcdiv_(exp_avg, denom, value=-step_size)
200
+
201
+ self.step_count += 1
202
+ return loss
203
+
204
+
205
+ class AdamW(Optimizer):
206
+ """
207
+ AdamW (Adam with Decoupled Weight Decay) optimizer
208
+
209
+ Mathematical formulation:
210
+ θ_t = θ_{t-1} - α * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1})
211
+
212
+ Where weight decay is applied directly to parameters, not gradients.
213
+ """
214
+
215
+ def __init__(self, params: List[torch.Tensor], lr: float = 1e-3,
216
+ betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8,
217
+ weight_decay: float = 0.01, amsgrad: bool = False):
218
+ super().__init__(params, lr)
219
+ self.betas = betas
220
+ self.eps = eps
221
+ self.weight_decay = weight_decay
222
+ self.amsgrad = amsgrad
223
+
224
+ # Initialize moment estimates
225
+ for param in self.params:
226
+ self.state[param] = {
227
+ 'step': 0,
228
+ 'exp_avg': torch.zeros_like(param.data),
229
+ 'exp_avg_sq': torch.zeros_like(param.data)
230
+ }
231
+ if amsgrad:
232
+ self.state[param]['max_exp_avg_sq'] = torch.zeros_like(param.data)
233
+
234
+ logger.info(f"Initialized AdamW optimizer: lr={lr}, betas={betas}, weight_decay={weight_decay}")
235
+
236
+ def step(self, closure=None):
237
+ """Perform AdamW optimization step"""
238
+ loss = None
239
+ if closure is not None:
240
+ loss = closure()
241
+
242
+ for param in self.params:
243
+ if param.grad is None:
244
+ continue
245
+
246
+ grad = param.grad.data
247
+ param_state = self.state[param]
248
+ exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq']
249
+ beta1, beta2 = self.betas
250
+
251
+ param_state['step'] += 1
252
+ bias_correction1 = 1 - beta1 ** param_state['step']
253
+ bias_correction2 = 1 - beta2 ** param_state['step']
254
+
255
+ # Update biased first moment estimate
256
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
257
+
258
+ # Update biased second raw moment estimate
259
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
260
+
261
+ if self.amsgrad:
262
+ # Maintains the maximum of all 2nd moment running avg. of squared gradients
263
+ torch.max(param_state['max_exp_avg_sq'], exp_avg_sq, out=param_state['max_exp_avg_sq'])
264
+ # Use the max. for normalizing running avg. of squared gradients
265
+ denom = (param_state['max_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
266
+ else:
267
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
268
+
269
+ # Update parameters with decoupled weight decay
270
+ step_size = self.lr / bias_correction1
271
+ param.data.mul_(1 - self.lr * self.weight_decay)
272
+ param.data.addcdiv_(exp_avg, denom, value=-step_size)
273
+
274
+ self.step_count += 1
275
+ return loss
276
+
277
+
278
+ class OptimizerFactory:
279
+ """Factory class for creating optimizers"""
280
+
281
+ @staticmethod
282
+ def create_optimizer(optimizer_type: str, params: List[torch.Tensor], **kwargs) -> Optimizer:
283
+ """Create an optimizer instance"""
284
+ optimizers = {
285
+ 'sgd': SGD,
286
+ 'adam': Adam,
287
+ 'adamw': AdamW
288
+ }
289
+
290
+ if optimizer_type.lower() not in optimizers:
291
+ raise ValueError(f"Unknown optimizer type: {optimizer_type}")
292
+
293
+ optimizer_class = optimizers[optimizer_type.lower()]
294
+ return optimizer_class(params, **kwargs)
295
+
296
+ @staticmethod
297
+ def get_default_config(optimizer_type: str) -> Dict[str, Any]:
298
+ """Get default configuration for optimizer"""
299
+ configs = {
300
+ 'sgd': {
301
+ 'lr': 1e-3,
302
+ 'momentum': 0.9,
303
+ 'weight_decay': 1e-4
304
+ },
305
+ 'adam': {
306
+ 'lr': 1e-3,
307
+ 'betas': (0.9, 0.999),
308
+ 'eps': 1e-8,
309
+ 'weight_decay': 1e-4
310
+ },
311
+ 'adamw': {
312
+ 'lr': 1e-3,
313
+ 'betas': (0.9, 0.999),
314
+ 'eps': 1e-8,
315
+ 'weight_decay': 0.01
316
+ }
317
+ }
318
+
319
+ return configs.get(optimizer_type.lower(), {})
src/multi_agent_training/gradient_descent/schedulers.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learning Rate Schedulers
3
+ =======================
4
+
5
+ This module implements various learning rate scheduling strategies for
6
+ gradient descent optimization, including linear, cosine, and step schedulers.
7
+
8
+ Learning rate scheduling is crucial for stable training and convergence.
9
+ """
10
+
11
+ import logging
12
+ import math
13
+ from abc import ABC, abstractmethod
14
+ from typing import Dict, Optional, Any
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class Scheduler(ABC):
20
+ """Abstract base class for all learning rate schedulers"""
21
+
22
+ def __init__(self, optimizer, **kwargs):
23
+ self.optimizer = optimizer
24
+ self.base_lr = optimizer.lr
25
+ self.current_lr = self.base_lr
26
+
27
+ @abstractmethod
28
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
29
+ """Update the learning rate"""
30
+ pass
31
+
32
+ def get_lr(self) -> float:
33
+ """Get the current learning rate"""
34
+ return self.current_lr
35
+
36
+ def state_dict(self) -> Dict[str, Any]:
37
+ """Return the state of the scheduler"""
38
+ return {
39
+ 'current_lr': self.current_lr,
40
+ 'base_lr': self.base_lr
41
+ }
42
+
43
+ def load_state_dict(self, state_dict: Dict[str, Any]):
44
+ """Load the state of the scheduler"""
45
+ self.current_lr = state_dict['current_lr']
46
+ self.base_lr = state_dict['base_lr']
47
+
48
+
49
+ class LinearScheduler(Scheduler):
50
+ """
51
+ Linear learning rate scheduler
52
+
53
+ Mathematical formulation:
54
+ lr_t = lr_0 * (1 - t/T)
55
+
56
+ Where lr_0 is the initial learning rate, t is the current step,
57
+ and T is the total number of steps.
58
+ """
59
+
60
+ def __init__(self, optimizer, total_steps: int, min_lr: float = 0.0):
61
+ super().__init__(optimizer)
62
+ self.total_steps = total_steps
63
+ self.min_lr = min_lr
64
+ self.current_step = 0
65
+
66
+ logger.info(f"Initialized LinearScheduler: base_lr={self.base_lr}, total_steps={total_steps}, min_lr={min_lr}")
67
+
68
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
69
+ """Update learning rate linearly"""
70
+ if self.current_step < self.total_steps:
71
+ # Linear decay
72
+ decay_factor = 1.0 - (self.current_step / self.total_steps)
73
+ self.current_lr = max(self.min_lr, self.base_lr * decay_factor)
74
+
75
+ # Update optimizer learning rate
76
+ self.optimizer.lr = self.current_lr
77
+
78
+ self.current_step += 1
79
+
80
+ logger.debug(f"LinearScheduler step {self.current_step}: lr={self.current_lr:.6f}")
81
+
82
+
83
+ class CosineScheduler(Scheduler):
84
+ """
85
+ Cosine annealing learning rate scheduler
86
+
87
+ Mathematical formulation:
88
+ lr_t = lr_min + (lr_max - lr_min) * (1 + cos(π * t/T)) / 2
89
+
90
+ Where lr_max is the maximum learning rate, lr_min is the minimum,
91
+ t is the current step, and T is the total number of steps.
92
+ """
93
+
94
+ def __init__(self, optimizer, total_steps: int, min_lr: float = 0.0,
95
+ max_lr: Optional[float] = None, warmup_steps: int = 0):
96
+ super().__init__(optimizer)
97
+ self.total_steps = total_steps
98
+ self.min_lr = min_lr
99
+ self.max_lr = max_lr or self.base_lr
100
+ self.warmup_steps = warmup_steps
101
+ self.current_step = 0
102
+
103
+ logger.info(f"Initialized CosineScheduler: max_lr={self.max_lr}, min_lr={min_lr}, total_steps={total_steps}")
104
+
105
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
106
+ """Update learning rate with cosine annealing"""
107
+ if self.current_step < self.total_steps:
108
+ if self.current_step < self.warmup_steps:
109
+ # Warmup phase
110
+ self.current_lr = self.min_lr + (self.max_lr - self.min_lr) * (self.current_step / self.warmup_steps)
111
+ else:
112
+ # Cosine annealing phase
113
+ progress = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
114
+ cosine_factor = (1 + math.cos(math.pi * progress)) / 2
115
+ self.current_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
116
+
117
+ # Update optimizer learning rate
118
+ self.optimizer.lr = self.current_lr
119
+
120
+ self.current_step += 1
121
+
122
+ logger.debug(f"CosineScheduler step {self.current_step}: lr={self.current_lr:.6f}")
123
+
124
+
125
+ class StepScheduler(Scheduler):
126
+ """
127
+ Step learning rate scheduler
128
+
129
+ Mathematical formulation:
130
+ lr_t = lr_0 * γ^⌊t/step_size⌋
131
+
132
+ Where γ is the decay factor and step_size is the number of steps
133
+ between each decay.
134
+ """
135
+
136
+ def __init__(self, optimizer, step_size: int, gamma: float = 0.1,
137
+ min_lr: float = 0.0):
138
+ super().__init__(optimizer)
139
+ self.step_size = step_size
140
+ self.gamma = gamma
141
+ self.min_lr = min_lr
142
+ self.current_step = 0
143
+
144
+ logger.info(f"Initialized StepScheduler: step_size={step_size}, gamma={gamma}, min_lr={min_lr}")
145
+
146
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
147
+ """Update learning rate with step decay"""
148
+ if self.current_step % self.step_size == 0 and self.current_step > 0:
149
+ # Decay learning rate
150
+ decay_factor = self.gamma ** (self.current_step // self.step_size)
151
+ self.current_lr = max(self.min_lr, self.base_lr * decay_factor)
152
+
153
+ # Update optimizer learning rate
154
+ self.optimizer.lr = self.current_lr
155
+
156
+ logger.debug(f"StepScheduler decay at step {self.current_step}: lr={self.current_lr:.6f}")
157
+
158
+ self.current_step += 1
159
+
160
+
161
+ class ExponentialScheduler(Scheduler):
162
+ """
163
+ Exponential learning rate scheduler
164
+
165
+ Mathematical formulation:
166
+ lr_t = lr_0 * γ^t
167
+
168
+ Where γ is the decay factor.
169
+ """
170
+
171
+ def __init__(self, optimizer, gamma: float = 0.95, min_lr: float = 0.0):
172
+ super().__init__(optimizer)
173
+ self.gamma = gamma
174
+ self.min_lr = min_lr
175
+ self.current_step = 0
176
+
177
+ logger.info(f"Initialized ExponentialScheduler: gamma={gamma}, min_lr={min_lr}")
178
+
179
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
180
+ """Update learning rate exponentially"""
181
+ self.current_lr = max(self.min_lr, self.base_lr * (self.gamma ** self.current_step))
182
+
183
+ # Update optimizer learning rate
184
+ self.optimizer.lr = self.current_lr
185
+
186
+ self.current_step += 1
187
+
188
+ logger.debug(f"ExponentialScheduler step {self.current_step}: lr={self.current_lr:.6f}")
189
+
190
+
191
+ class ReduceLROnPlateauScheduler(Scheduler):
192
+ """
193
+ Reduce learning rate when a metric has stopped improving
194
+
195
+ This scheduler monitors a metric and reduces the learning rate
196
+ when the metric stops improving for a specified number of epochs.
197
+ """
198
+
199
+ def __init__(self, optimizer, mode: str = 'min', factor: float = 0.1,
200
+ patience: int = 10, threshold: float = 1e-4,
201
+ min_lr: float = 0.0, metric_name: str = 'loss'):
202
+ super().__init__(optimizer)
203
+ self.mode = mode
204
+ self.factor = factor
205
+ self.patience = patience
206
+ self.threshold = threshold
207
+ self.min_lr = min_lr
208
+ self.metric_name = metric_name
209
+
210
+ self.best_metric = float('inf') if mode == 'min' else float('-inf')
211
+ self.patience_counter = 0
212
+ self.current_epoch = 0
213
+
214
+ logger.info(f"Initialized ReduceLROnPlateauScheduler: mode={mode}, factor={factor}, patience={patience}")
215
+
216
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
217
+ """Update learning rate based on metric improvement"""
218
+ if metrics is None or self.metric_name not in metrics:
219
+ return
220
+
221
+ current_metric = metrics[self.metric_name]
222
+ self.current_epoch = epoch or self.current_epoch
223
+
224
+ # Check if metric improved
225
+ if self.mode == 'min':
226
+ improved = current_metric < (self.best_metric - self.threshold)
227
+ else:
228
+ improved = current_metric > (self.best_metric + self.threshold)
229
+
230
+ if improved:
231
+ self.best_metric = current_metric
232
+ self.patience_counter = 0
233
+ else:
234
+ self.patience_counter += 1
235
+
236
+ # Reduce learning rate if patience exceeded
237
+ if self.patience_counter >= self.patience:
238
+ old_lr = self.current_lr
239
+ self.current_lr = max(self.min_lr, self.current_lr * self.factor)
240
+
241
+ # Update optimizer learning rate
242
+ self.optimizer.lr = self.current_lr
243
+
244
+ self.patience_counter = 0
245
+
246
+ logger.info(f"ReduceLROnPlateauScheduler reduced LR: {old_lr:.6f} -> {self.current_lr:.6f}")
247
+
248
+
249
+ class WarmupScheduler(Scheduler):
250
+ """
251
+ Warmup learning rate scheduler
252
+
253
+ Gradually increases learning rate from 0 to the target learning rate
254
+ over a specified number of warmup steps.
255
+ """
256
+
257
+ def __init__(self, optimizer, warmup_steps: int, target_lr: Optional[float] = None):
258
+ super().__init__(optimizer)
259
+ self.warmup_steps = warmup_steps
260
+ self.target_lr = target_lr or self.base_lr
261
+ self.current_step = 0
262
+
263
+ # Start with zero learning rate
264
+ self.current_lr = 0.0
265
+ self.optimizer.lr = self.current_lr
266
+
267
+ logger.info(f"Initialized WarmupScheduler: warmup_steps={warmup_steps}, target_lr={self.target_lr}")
268
+
269
+ def step(self, epoch: int = None, metrics: Dict[str, float] = None):
270
+ """Update learning rate with warmup"""
271
+ if self.current_step < self.warmup_steps:
272
+ # Linear warmup
273
+ self.current_lr = self.target_lr * (self.current_step / self.warmup_steps)
274
+
275
+ # Update optimizer learning rate
276
+ self.optimizer.lr = self.current_lr
277
+
278
+ logger.debug(f"WarmupScheduler step {self.current_step}: lr={self.current_lr:.6f}")
279
+
280
+ self.current_step += 1
281
+
282
+
283
+ class SchedulerFactory:
284
+ """Factory class for creating learning rate schedulers"""
285
+
286
+ @staticmethod
287
+ def create_scheduler(scheduler_type: str, optimizer, **kwargs) -> Scheduler:
288
+ """Create a scheduler instance"""
289
+ schedulers = {
290
+ 'linear': LinearScheduler,
291
+ 'cosine': CosineScheduler,
292
+ 'step': StepScheduler,
293
+ 'exponential': ExponentialScheduler,
294
+ 'plateau': ReduceLROnPlateauScheduler,
295
+ 'warmup': WarmupScheduler
296
+ }
297
+
298
+ if scheduler_type.lower() not in schedulers:
299
+ raise ValueError(f"Unknown scheduler type: {scheduler_type}")
300
+
301
+ scheduler_class = schedulers[scheduler_type.lower()]
302
+ return scheduler_class(optimizer, **kwargs)
303
+
304
+ @staticmethod
305
+ def get_default_config(scheduler_type: str) -> Dict[str, Any]:
306
+ """Get default configuration for scheduler"""
307
+ configs = {
308
+ 'linear': {
309
+ 'total_steps': 1000,
310
+ 'min_lr': 0.0
311
+ },
312
+ 'cosine': {
313
+ 'total_steps': 1000,
314
+ 'min_lr': 0.0,
315
+ 'warmup_steps': 100
316
+ },
317
+ 'step': {
318
+ 'step_size': 100,
319
+ 'gamma': 0.1,
320
+ 'min_lr': 0.0
321
+ },
322
+ 'exponential': {
323
+ 'gamma': 0.95,
324
+ 'min_lr': 0.0
325
+ },
326
+ 'plateau': {
327
+ 'mode': 'min',
328
+ 'factor': 0.1,
329
+ 'patience': 10,
330
+ 'threshold': 1e-4,
331
+ 'min_lr': 0.0
332
+ },
333
+ 'warmup': {
334
+ 'warmup_steps': 100
335
+ }
336
+ }
337
+
338
+ return configs.get(scheduler_type.lower(), {})
src/multi_agent_training/gradient_descent/training_loop.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Descent Training Loop
3
+ =============================
4
+
5
+ This module implements the main training loop that orchestrates gradient descent
6
+ optimization with backpropagation for the MangoMAS multi-agent system.
7
+
8
+ The training loop includes:
9
+ - Forward and backward passes
10
+ - Gradient computation and optimization
11
+ - Learning rate scheduling
12
+ - Comprehensive monitoring and logging
13
+ - Model checkpointing and validation
14
+ """
15
+
16
+ import logging
17
+ import time
18
+ import math
19
+ from typing import Dict, List, Optional, Tuple, Any
20
+ import torch
21
+ import torch.nn as nn
22
+ from pathlib import Path
23
+
24
+ from .optimizers import OptimizerFactory
25
+ from .backpropagation import BackpropagationEngine, LoRABackpropagationEngine
26
+ from .loss_functions import LossFunctionFactory
27
+ from .schedulers import SchedulerFactory
28
+ from .monitoring import GradientMonitor, TrainingMonitor, PerformanceMonitor
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class GradientDescentTrainer:
34
+ """
35
+ Main training class that orchestrates gradient descent optimization
36
+
37
+ This class provides a complete training pipeline with:
38
+ - Real gradient descent and backpropagation
39
+ - Comprehensive monitoring and logging
40
+ - Model checkpointing and validation
41
+ - Integration with MangoMAS agent system
42
+ """
43
+
44
+ def __init__(self,
45
+ optimizer_type: str = 'adam',
46
+ learning_rate: float = 1e-3,
47
+ scheduler_type: str = 'cosine',
48
+ loss_function_type: str = 'cross_entropy',
49
+ device: torch.device = None,
50
+ max_grad_norm: float = 1.0,
51
+ gradient_accumulation_steps: int = 1,
52
+ mixed_precision: bool = False,
53
+ **kwargs):
54
+
55
+ self.optimizer_type = optimizer_type
56
+ self.learning_rate = learning_rate
57
+ self.scheduler_type = scheduler_type
58
+ self.loss_function_type = loss_function_type
59
+ self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+ self.max_grad_norm = max_grad_norm
61
+ self.gradient_accumulation_steps = gradient_accumulation_steps
62
+ self.mixed_precision = mixed_precision
63
+
64
+ # Initialize components
65
+ self.optimizer = None
66
+ self.scheduler = None
67
+ self.loss_function = None
68
+ self.backprop_engine = None
69
+
70
+ # Monitoring
71
+ self.gradient_monitor = GradientMonitor()
72
+ self.training_monitor = TrainingMonitor()
73
+ self.performance_monitor = PerformanceMonitor()
74
+
75
+ # Training state
76
+ self.current_epoch = 0
77
+ self.current_step = 0
78
+ self.best_loss = float('inf')
79
+ self.training_start_time = None
80
+
81
+ # Configuration
82
+ self.config = {
83
+ 'optimizer_type': optimizer_type,
84
+ 'learning_rate': learning_rate,
85
+ 'scheduler_type': scheduler_type,
86
+ 'loss_function_type': loss_function_type,
87
+ 'max_grad_norm': max_grad_norm,
88
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
89
+ 'mixed_precision': mixed_precision,
90
+ **kwargs
91
+ }
92
+
93
+ logger.info(f"Initialized GradientDescentTrainer with config: {self.config}")
94
+
95
+ def setup_training(self, model: nn.Module, training_data: List[Dict[str, Any]]):
96
+ """
97
+ Setup training components
98
+
99
+ Args:
100
+ model: The neural network model to train
101
+ training_data: Training dataset
102
+ """
103
+ logger.info("Setting up training components...")
104
+
105
+ # Move model to device
106
+ model.to(self.device)
107
+
108
+ # Get trainable parameters
109
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
110
+ logger.info(f"Found {len(trainable_params)} trainable parameters")
111
+
112
+ # Initialize optimizer
113
+ optimizer_config = OptimizerFactory.get_default_config(self.optimizer_type)
114
+ optimizer_config.update({'lr': self.learning_rate})
115
+
116
+ self.optimizer = OptimizerFactory.create_optimizer(
117
+ self.optimizer_type, trainable_params, **optimizer_config
118
+ )
119
+
120
+ # Initialize scheduler
121
+ scheduler_config = SchedulerFactory.get_default_config(self.scheduler_type)
122
+ scheduler_config.update({'total_steps': len(training_data)})
123
+
124
+ self.scheduler = SchedulerFactory.create_scheduler(
125
+ self.scheduler_type, self.optimizer, **scheduler_config
126
+ )
127
+
128
+ # Initialize loss function
129
+ loss_config = LossFunctionFactory.get_default_config(self.loss_function_type)
130
+ self.loss_function = LossFunctionFactory.create_loss_function(
131
+ self.loss_function_type, **loss_config
132
+ )
133
+
134
+ # Initialize backpropagation engine
135
+ if hasattr(model, 'lora_params'):
136
+ # LoRA model
137
+ self.backprop_engine = LoRABackpropagationEngine(
138
+ model, model.lora_params, self.device
139
+ )
140
+ else:
141
+ # Standard model
142
+ self.backprop_engine = BackpropagationEngine(model, self.device)
143
+
144
+ logger.info("Training setup complete")
145
+
146
+ def train_epoch(self, model: nn.Module, training_data: List[Dict[str, Any]],
147
+ epoch: int) -> Dict[str, float]:
148
+ """
149
+ Train for one epoch using gradient descent and backpropagation
150
+
151
+ Args:
152
+ model: The neural network model
153
+ training_data: Training dataset
154
+ epoch: Current epoch number
155
+
156
+ Returns:
157
+ Dictionary of training metrics
158
+ """
159
+ logger.info(f"Starting epoch {epoch}")
160
+
161
+ model.train()
162
+ epoch_loss = 0.0
163
+ epoch_accuracy = 0.0
164
+ num_batches = 0
165
+
166
+ # Process training data in batches
167
+ batch_size = 32 # Default batch size
168
+ num_batches = math.ceil(len(training_data) / batch_size)
169
+
170
+ for batch_idx in range(num_batches):
171
+ start_idx = batch_idx * batch_size
172
+ end_idx = min(start_idx + batch_size, len(training_data))
173
+ batch_data = training_data[start_idx:end_idx]
174
+
175
+ # Process batch
176
+ batch_metrics = self.train_batch(model, batch_data, epoch, batch_idx)
177
+
178
+ epoch_loss += batch_metrics['loss']
179
+ epoch_accuracy += batch_metrics.get('accuracy', 0.0)
180
+
181
+ # Update step counter
182
+ self.current_step += 1
183
+
184
+ # Log progress
185
+ if batch_idx % 10 == 0:
186
+ logger.info(f"Epoch {epoch}, Batch {batch_idx}/{num_batches}, "
187
+ f"Loss: {batch_metrics['loss']:.4f}")
188
+
189
+ # Compute epoch averages
190
+ avg_loss = epoch_loss / num_batches
191
+ avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0
192
+
193
+ # Update monitors
194
+ self.training_monitor.update(
195
+ loss=avg_loss,
196
+ accuracy=avg_accuracy,
197
+ learning_rate=self.optimizer.lr,
198
+ epoch=epoch
199
+ )
200
+
201
+ # Update scheduler
202
+ self.scheduler.step(epoch=epoch, metrics={'loss': avg_loss})
203
+
204
+ logger.info(f"Epoch {epoch} complete - Loss: {avg_loss:.4f}, "
205
+ f"Accuracy: {avg_accuracy:.4f}, LR: {self.optimizer.lr:.6f}")
206
+
207
+ return {
208
+ 'loss': avg_loss,
209
+ 'accuracy': avg_accuracy,
210
+ 'learning_rate': self.optimizer.lr,
211
+ 'num_batches': num_batches
212
+ }
213
+
214
+ def train_batch(self, model: nn.Module, batch_data: List[Dict[str, Any]],
215
+ epoch: int, batch_idx: int) -> Dict[str, float]:
216
+ """
217
+ Train on a single batch using gradient descent and backpropagation
218
+
219
+ Args:
220
+ model: The neural network model
221
+ batch_data: Batch of training data
222
+ epoch: Current epoch number
223
+ batch_idx: Current batch index
224
+
225
+ Returns:
226
+ Dictionary of batch metrics
227
+ """
228
+ # Prepare batch data
229
+ inputs, targets = self._prepare_batch(batch_data)
230
+
231
+ # Forward pass
232
+ with self.performance_monitor.time_step('forward'):
233
+ outputs = model(inputs)
234
+
235
+ # Compute loss
236
+ loss = self.loss_function(outputs, targets)
237
+
238
+ # Scale loss for gradient accumulation
239
+ if self.gradient_accumulation_steps > 1:
240
+ loss = loss / self.gradient_accumulation_steps
241
+
242
+ # Backward pass
243
+ with self.performance_monitor.time_step('backward'):
244
+ loss.backward()
245
+
246
+ # Gradient accumulation
247
+ if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
248
+ # Apply gradient clipping
249
+ grad_norm = self.backprop_engine.apply_gradient_clipping(self.max_grad_norm)
250
+
251
+ # Get gradients for monitoring
252
+ gradients = self.backprop_engine.compute_gradients(loss, retain_graph=False)
253
+ self.gradient_monitor.update(gradients)
254
+
255
+ # Optimizer step
256
+ with self.performance_monitor.time_step('optimizer'):
257
+ self.optimizer.step()
258
+
259
+ # Zero gradients
260
+ self.optimizer.zero_grad()
261
+
262
+ # Update performance monitoring
263
+ self.performance_monitor.update_compute_time(time.time() - self.training_start_time)
264
+
265
+ # Compute accuracy (if applicable)
266
+ accuracy = self._compute_accuracy(outputs, targets)
267
+
268
+ return {
269
+ 'loss': loss.item() * self.gradient_accumulation_steps,
270
+ 'accuracy': accuracy,
271
+ 'grad_norm': grad_norm if 'grad_norm' in locals() else 0.0
272
+ }
273
+
274
+ def _prepare_batch(self, batch_data: List[Dict[str, Any]]) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ """
276
+ Prepare batch data for training
277
+
278
+ Args:
279
+ batch_data: Raw batch data
280
+
281
+ Returns:
282
+ Tuple of (inputs, targets) tensors
283
+ """
284
+ # Extract inputs and targets
285
+ inputs = []
286
+ targets = []
287
+
288
+ for item in batch_data:
289
+ # Convert text to tokens (simplified)
290
+ if 'instruction' in item and 'response' in item:
291
+ # For text generation tasks
292
+ input_text = item['instruction']
293
+ target_text = item['response']
294
+
295
+ # Simple tokenization (in practice, use proper tokenizer)
296
+ input_tokens = self._simple_tokenize(input_text)
297
+ target_tokens = self._simple_tokenize(target_text)
298
+
299
+ inputs.append(input_tokens)
300
+ targets.append(target_tokens)
301
+
302
+ # Convert to tensors
303
+ if inputs and targets:
304
+ # Pad sequences to same length
305
+ max_len = max(len(seq) for seq in inputs + targets)
306
+ inputs = [seq + [0] * (max_len - len(seq)) for seq in inputs]
307
+ targets = [seq + [0] * (max_len - len(seq)) for seq in targets]
308
+
309
+ inputs_tensor = torch.tensor(inputs, dtype=torch.long, device=self.device)
310
+ targets_tensor = torch.tensor(targets, dtype=torch.long, device=self.device)
311
+ else:
312
+ # Fallback: create dummy data
313
+ batch_size = len(batch_data)
314
+ seq_len = 128
315
+ inputs_tensor = torch.randint(0, 1000, (batch_size, seq_len), device=self.device)
316
+ targets_tensor = torch.randint(0, 1000, (batch_size, seq_len), device=self.device)
317
+
318
+ return inputs_tensor, targets_tensor
319
+
320
+ def _simple_tokenize(self, text: str) -> List[int]:
321
+ """
322
+ Simple tokenization for demonstration
323
+
324
+ Args:
325
+ text: Input text
326
+
327
+ Returns:
328
+ List of token IDs
329
+ """
330
+ # Simple character-based tokenization
331
+ tokens = []
332
+ for char in text[:100]: # Limit length
333
+ tokens.append(ord(char) % 1000) # Map to vocabulary
334
+ return tokens
335
+
336
+ def _compute_accuracy(self, outputs: torch.Tensor, targets: torch.Tensor) -> float:
337
+ """
338
+ Compute accuracy for the batch
339
+
340
+ Args:
341
+ outputs: Model outputs
342
+ targets: Target values
343
+
344
+ Returns:
345
+ Accuracy score
346
+ """
347
+ if outputs.dim() > 1 and outputs.size(1) > 1:
348
+ # Classification task
349
+ predictions = torch.argmax(outputs, dim=1)
350
+ if targets.dim() == 1:
351
+ correct = (predictions == targets).float().sum()
352
+ accuracy = correct / targets.size(0)
353
+ else:
354
+ # Multi-label case
355
+ accuracy = 0.0
356
+ else:
357
+ # Regression task - use a simple threshold
358
+ accuracy = 0.0
359
+
360
+ return accuracy.item() if isinstance(accuracy, torch.Tensor) else accuracy
361
+
362
+ def validate(self, model: nn.Module, validation_data: List[Dict[str, Any]]) -> Dict[str, float]:
363
+ """
364
+ Validate the model
365
+
366
+ Args:
367
+ model: The neural network model
368
+ validation_data: Validation dataset
369
+
370
+ Returns:
371
+ Dictionary of validation metrics
372
+ """
373
+ logger.info("Running validation...")
374
+
375
+ model.eval()
376
+ total_loss = 0.0
377
+ total_accuracy = 0.0
378
+ num_batches = 0
379
+
380
+ with torch.no_grad():
381
+ batch_size = 32
382
+ num_batches = math.ceil(len(validation_data) / batch_size)
383
+
384
+ for batch_idx in range(num_batches):
385
+ start_idx = batch_idx * batch_size
386
+ end_idx = min(start_idx + batch_size, len(validation_data))
387
+ batch_data = validation_data[start_idx:end_idx]
388
+
389
+ # Prepare batch
390
+ inputs, targets = self._prepare_batch(batch_data)
391
+
392
+ # Forward pass
393
+ outputs = model(inputs)
394
+
395
+ # Compute loss
396
+ loss = self.loss_function(outputs, targets)
397
+ total_loss += loss.item()
398
+
399
+ # Compute accuracy
400
+ accuracy = self._compute_accuracy(outputs, targets)
401
+ total_accuracy += accuracy
402
+
403
+ # Compute averages
404
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
405
+ avg_accuracy = total_accuracy / num_batches if num_batches > 0 else 0.0
406
+
407
+ logger.info(f"Validation - Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")
408
+
409
+ return {
410
+ 'val_loss': avg_loss,
411
+ 'val_accuracy': avg_accuracy
412
+ }
413
+
414
+ def train(self, model: nn.Module, training_data: List[Dict[str, Any]],
415
+ validation_data: Optional[List[Dict[str, Any]]] = None,
416
+ num_epochs: int = 10, save_dir: Optional[str] = None) -> Dict[str, Any]:
417
+ """
418
+ Complete training loop with gradient descent and backpropagation
419
+
420
+ Args:
421
+ model: The neural network model to train
422
+ training_data: Training dataset
423
+ validation_data: Validation dataset (optional)
424
+ num_epochs: Number of training epochs
425
+ save_dir: Directory to save checkpoints
426
+
427
+ Returns:
428
+ Dictionary of training results
429
+ """
430
+ logger.info(f"Starting training for {num_epochs} epochs")
431
+
432
+ # Setup training
433
+ self.setup_training(model, training_data)
434
+
435
+ # Initialize training state
436
+ self.training_start_time = time.time()
437
+ self.current_epoch = 0
438
+ self.current_step = 0
439
+
440
+ # Training history
441
+ training_history = []
442
+ validation_history = []
443
+
444
+ # Main training loop
445
+ for epoch in range(num_epochs):
446
+ self.current_epoch = epoch
447
+
448
+ # Train epoch
449
+ epoch_metrics = self.train_epoch(model, training_data, epoch)
450
+ training_history.append(epoch_metrics)
451
+
452
+ # Validation
453
+ if validation_data:
454
+ val_metrics = self.validate(model, validation_data)
455
+ validation_history.append(val_metrics)
456
+
457
+ # Update best loss
458
+ if val_metrics['val_loss'] < self.best_loss:
459
+ self.best_loss = val_metrics['val_loss']
460
+
461
+ # Save best model
462
+ if save_dir:
463
+ self.save_checkpoint(model, save_dir, epoch, val_metrics)
464
+
465
+ # Check for convergence
466
+ if self.training_monitor.detect_convergence():
467
+ logger.info("Training converged, stopping early")
468
+ break
469
+
470
+ # Log epoch summary
471
+ logger.info(f"Epoch {epoch} Summary:")
472
+ logger.info(f" Training Loss: {epoch_metrics['loss']:.4f}")
473
+ logger.info(f" Training Accuracy: {epoch_metrics['accuracy']:.4f}")
474
+ if validation_data:
475
+ logger.info(f" Validation Loss: {val_metrics['val_loss']:.4f}")
476
+ logger.info(f" Validation Accuracy: {val_metrics['val_accuracy']:.4f}")
477
+ logger.info(f" Learning Rate: {self.optimizer.lr:.6f}")
478
+
479
+ # Training complete
480
+ training_time = time.time() - self.training_start_time
481
+
482
+ # Get final statistics
483
+ gradient_stats = self.gradient_monitor.get_statistics()
484
+ training_stats = self.training_monitor.get_statistics()
485
+ performance_stats = self.performance_monitor.get_statistics()
486
+
487
+ results = {
488
+ 'training_history': training_history,
489
+ 'validation_history': validation_history,
490
+ 'final_metrics': {
491
+ 'best_loss': self.best_loss,
492
+ 'final_loss': training_history[-1]['loss'] if training_history else 0.0,
493
+ 'final_accuracy': training_history[-1]['accuracy'] if training_history else 0.0,
494
+ 'training_time': training_time,
495
+ 'total_steps': self.current_step,
496
+ 'total_epochs': self.current_epoch + 1
497
+ },
498
+ 'gradient_stats': gradient_stats,
499
+ 'training_stats': training_stats,
500
+ 'performance_stats': performance_stats,
501
+ 'config': self.config
502
+ }
503
+
504
+ logger.info("Training complete!")
505
+ logger.info(f"Final Loss: {results['final_metrics']['final_loss']:.4f}")
506
+ logger.info(f"Best Loss: {results['final_metrics']['best_loss']:.4f}")
507
+ logger.info(f"Training Time: {training_time:.2f} seconds")
508
+
509
+ return results
510
+
511
+ def save_checkpoint(self, model: nn.Module, save_dir: str, epoch: int,
512
+ metrics: Dict[str, float]):
513
+ """
514
+ Save model checkpoint
515
+
516
+ Args:
517
+ model: The neural network model
518
+ save_dir: Directory to save checkpoint
519
+ epoch: Current epoch
520
+ metrics: Training metrics
521
+ """
522
+ save_path = Path(save_dir)
523
+ save_path.mkdir(parents=True, exist_ok=True)
524
+
525
+ checkpoint = {
526
+ 'epoch': epoch,
527
+ 'model_state_dict': model.state_dict(),
528
+ 'optimizer_state_dict': self.optimizer.state_dict(),
529
+ 'scheduler_state_dict': self.scheduler.state_dict(),
530
+ 'best_loss': self.best_loss,
531
+ 'metrics': metrics,
532
+ 'config': self.config
533
+ }
534
+
535
+ checkpoint_path = save_path / f'checkpoint_epoch_{epoch}.pt'
536
+ torch.save(checkpoint, checkpoint_path)
537
+
538
+ logger.info(f"Checkpoint saved to {checkpoint_path}")
539
+
540
+ def load_checkpoint(self, model: nn.Module, checkpoint_path: str):
541
+ """
542
+ Load model checkpoint
543
+
544
+ Args:
545
+ model: The neural network model
546
+ checkpoint_path: Path to checkpoint file
547
+ """
548
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
549
+
550
+ model.load_state_dict(checkpoint['model_state_dict'])
551
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
552
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
553
+ self.best_loss = checkpoint['best_loss']
554
+
555
+ logger.info(f"Checkpoint loaded from {checkpoint_path}")
556
+
557
+ def get_training_summary(self) -> Dict[str, Any]:
558
+ """
559
+ Get comprehensive training summary
560
+
561
+ Returns:
562
+ Dictionary of training summary
563
+ """
564
+ return {
565
+ 'gradient_stats': self.gradient_monitor.get_statistics(),
566
+ 'training_stats': self.training_monitor.get_statistics(),
567
+ 'performance_stats': self.performance_monitor.get_statistics(),
568
+ 'anomalies': self.gradient_monitor.detect_anomalies(),
569
+ 'convergence': self.training_monitor.detect_convergence()
570
+ }
src/multi_agent_training/lora_trainer.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA Knowledge Distillation Trainer for MangoMAS Local
3
+
4
+ This module implements the main training loop for knowledge distillation
5
+ with LoRA fine-tuning optimized for Mac Mini hardware constraints.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import logging
11
+ import os
12
+ import sys
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import Dict, List
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import yaml
20
+ from peft import LoraConfig, TaskType, get_peft_model
21
+ from torch.utils.data import DataLoader
22
+ from torch.utils.tensorboard import SummaryWriter
23
+ from tqdm import tqdm
24
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
25
+ get_linear_schedule_with_warmup)
26
+
27
+ # Try to import context7 for enhanced training
28
+ try:
29
+ from context7 import Context7
30
+
31
+ CONTEXT7_AVAILABLE = True
32
+ except ImportError:
33
+ CONTEXT7_AVAILABLE = False
34
+ Context7 = None
35
+
36
+ # Try to import MLflow for experiment tracking
37
+ try:
38
+ import mlflow
39
+
40
+ MLFLOW_AVAILABLE = True
41
+ except ImportError:
42
+ MLFLOW_AVAILABLE = False
43
+ mlflow = None
44
+
45
+ # Fix import path issues for distillation loss
46
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
47
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
48
+
49
+ try:
50
+ from distillation_loss import AdaptiveDistillationLoss, DistillationLoss
51
+ except ImportError:
52
+ try:
53
+ from training.distillation_loss import (AdaptiveDistillationLoss,
54
+ DistillationLoss)
55
+ except ImportError:
56
+ # Fallback: create minimal distillation loss if not available
57
+ class DistillationLoss:
58
+ def __init__(self, alpha=0.5, temperature=2.0):
59
+ self.alpha = alpha
60
+ self.temperature = temperature
61
+ self.task_loss = nn.CrossEntropyLoss()
62
+
63
+ def compute_loss(
64
+ self, student_logits, teacher_logits, labels, attention_mask=None
65
+ ):
66
+ # Task loss (standard cross-entropy)
67
+ shift_logits = student_logits[..., :-1, :].contiguous()
68
+ shift_labels = labels[..., 1:].contiguous()
69
+ task_loss = self.task_loss(
70
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
71
+ )
72
+
73
+ # Distillation loss (KL divergence)
74
+ if teacher_logits is not None:
75
+ student_probs = nn.functional.log_softmax(
76
+ student_logits / self.temperature, dim=-1
77
+ )
78
+ teacher_probs = nn.functional.softmax(
79
+ teacher_logits / self.temperature, dim=-1
80
+ )
81
+ distill_loss = nn.functional.kl_div(
82
+ student_probs, teacher_probs, reduction="batchmean"
83
+ )
84
+ distill_loss *= self.temperature**2
85
+ else:
86
+ distill_loss = torch.tensor(0.0)
87
+
88
+ # Combined loss
89
+ total_loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss
90
+
91
+ return total_loss, {
92
+ "total_loss": total_loss.item(),
93
+ "task_loss": task_loss.item(),
94
+ "distillation_loss": (
95
+ distill_loss.item()
96
+ if isinstance(distill_loss, torch.Tensor)
97
+ else 0.0
98
+ ),
99
+ }
100
+
101
+ AdaptiveDistillationLoss = DistillationLoss # Fallback
102
+
103
+ logger = logging.getLogger(__name__)
104
+
105
+
106
+ class ConversationDataset:
107
+ """Dataset class for conversation-based training data."""
108
+
109
+ def __init__(self, data_path: str, tokenizer, max_length: int = 512):
110
+ self.tokenizer = tokenizer
111
+ self.max_length = max_length
112
+ self.data = self._load_data(data_path)
113
+
114
+ def _load_data(self, data_path: str) -> List[Dict]:
115
+ """Load conversation data from JSONL file."""
116
+ data = []
117
+ with open(data_path, "r", encoding="utf-8") as f:
118
+ for line in f:
119
+ data.append(json.loads(line.strip()))
120
+ return data
121
+
122
+ def __len__(self) -> int:
123
+ return len(self.data)
124
+
125
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
126
+ """Get tokenized conversation item."""
127
+ item = self.data[idx]
128
+
129
+ # Handle different data formats
130
+ if "messages" in item:
131
+ # Chat format with messages
132
+ conversation_text = ""
133
+ for message in item["messages"]:
134
+ role = message["role"]
135
+ content = message["content"]
136
+ conversation_text += f"<{role}>\n{content}\n</{role}>\n\n"
137
+ elif "instruction" in item and "response" in item:
138
+ # Instruction-response format
139
+ instruction = item["instruction"]
140
+ response = item["response"]
141
+ conversation_text = f"<user>\n{instruction}\n</user>\n\n<assistant>\n{response}\n</assistant>\n\n"
142
+ elif "prompt" in item and "completion" in item:
143
+ # Prompt-completion format
144
+ prompt = item["prompt"]
145
+ completion = item["completion"]
146
+ conversation_text = f"<user>\n{prompt}\n</user>\n\n<assistant>\n{completion}\n</assistant>\n\n"
147
+ else:
148
+ # Fallback - try to extract any text
149
+ conversation_text = str(item)
150
+
151
+ # Tokenize
152
+ encoding = self.tokenizer(
153
+ conversation_text,
154
+ truncation=True,
155
+ padding="max_length",
156
+ max_length=self.max_length,
157
+ return_tensors="pt",
158
+ )
159
+
160
+ return {
161
+ "input_ids": encoding["input_ids"].squeeze(),
162
+ "attention_mask": encoding["attention_mask"].squeeze(),
163
+ "labels": encoding["input_ids"].squeeze().clone(),
164
+ "agent_type": item.get("agent_type", "unknown"),
165
+ }
166
+
167
+
168
+ class LoRADistillationTrainer:
169
+ """Main trainer class for LoRA knowledge distillation."""
170
+
171
+ def __init__(self, config_path: str):
172
+ """Initialize trainer with configuration."""
173
+ with open(config_path, "r") as f:
174
+ self.config = yaml.safe_load(f)
175
+
176
+ self.setup_logging()
177
+ self.setup_device()
178
+ self.setup_monitoring()
179
+
180
+ logger.info("Initialized LoRA Distillation Trainer")
181
+ logger.info(f"Device: {self.device}")
182
+ logger.info(f"Config: {config_path}")
183
+
184
+ def setup_logging(self) -> None:
185
+ """Set up logging configuration."""
186
+ log_dir = Path("logs")
187
+ log_dir.mkdir(exist_ok=True)
188
+
189
+ logging.basicConfig(
190
+ level=logging.INFO,
191
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
192
+ handlers=[
193
+ logging.FileHandler(log_dir / "training.log"),
194
+ logging.StreamHandler(),
195
+ ],
196
+ )
197
+
198
+ def setup_device(self) -> None:
199
+ """Set up compute device (MPS for Mac Mini)."""
200
+ device_config = self.config["hardware"]["device"]
201
+
202
+ if device_config == "mps" and torch.backends.mps.is_available():
203
+ self.device = torch.device("mps")
204
+ logger.info("Using Apple Metal Performance Shaders (MPS)")
205
+ elif device_config == "cuda" and torch.cuda.is_available():
206
+ self.device = torch.device("cuda")
207
+ logger.info(f"Using CUDA: {torch.cuda.get_device_name()}")
208
+ else:
209
+ self.device = torch.device("cpu")
210
+ logger.warning("Using CPU - training will be slow")
211
+
212
+ def setup_monitoring(self) -> None:
213
+ """Set up experiment tracking and monitoring."""
214
+ self.use_tensorboard = self.config["monitoring"]["use_tensorboard"]
215
+ self.use_mlflow = self.config["monitoring"]["use_mlflow"]
216
+
217
+ if self.use_tensorboard:
218
+ log_dir = self.config["monitoring"]["log_dir"]
219
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
220
+ self.tb_writer = SummaryWriter(log_dir)
221
+ logger.info(f"TensorBoard logging to: {log_dir}")
222
+
223
+ if self.use_mlflow:
224
+ try:
225
+ import mlflow
226
+
227
+ experiment_name = self.config["monitoring"]["experiment_name"]
228
+ mlflow.set_experiment(experiment_name)
229
+ logger.info(f"MLflow experiment: {experiment_name}")
230
+ except (ImportError, AttributeError) as e:
231
+ logger.warning(
232
+ f"MLflow not available or not properly initialized, disabling: {e}"
233
+ )
234
+ self.use_mlflow = False
235
+
236
+ def load_models(self) -> None:
237
+ """Load teacher and student models."""
238
+ # Load tokenizer
239
+ model_name = self.config["models"]["student"]["base_model"]
240
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
241
+
242
+ # Add pad token if it doesn't exist
243
+ if self.tokenizer.pad_token is None:
244
+ self.tokenizer.pad_token = self.tokenizer.eos_token
245
+
246
+ # Load student model - fix deprecated torch_dtype
247
+ dtype = (
248
+ torch.float16
249
+ if self.config["optimization"]["use_fp16"] and self.device.type == "cuda"
250
+ else torch.float32
251
+ )
252
+
253
+ self.student_model = AutoModelForCausalLM.from_pretrained(
254
+ model_name,
255
+ dtype=dtype, # Use dtype instead of torch_dtype
256
+ device_map="auto" if self.device.type == "cuda" else None,
257
+ trust_remote_code=True,
258
+ )
259
+
260
+ # Apply LoRA to student model - fix target modules for DialoGPT
261
+ target_modules = self.config["lora"]["target_modules"]
262
+ # If using default transformer modules but this is DialoGPT, adjust
263
+ if target_modules == ["q_proj", "v_proj", "k_proj", "o_proj"]:
264
+ target_modules = ["c_attn", "c_proj", "c_fc"] # DialoGPT modules
265
+ logger.info("Adjusted LoRA target modules for DialoGPT architecture")
266
+
267
+ lora_config = LoraConfig(
268
+ r=self.config["lora"]["r"],
269
+ lora_alpha=self.config["lora"]["lora_alpha"],
270
+ target_modules=target_modules,
271
+ lora_dropout=self.config["lora"]["lora_dropout"],
272
+ bias=self.config["lora"]["bias"],
273
+ task_type=TaskType.CAUSAL_LM,
274
+ )
275
+
276
+ self.student_model = get_peft_model(self.student_model, lora_config)
277
+ self.student_model.to(self.device)
278
+
279
+ # Setup teacher model
280
+ self.teacher_manager = TeacherModelManager(
281
+ self.config["models"]["teacher"], self.tokenizer
282
+ )
283
+
284
+ logger.info("Loaded student model with LoRA")
285
+ logger.info(
286
+ f"Trainable parameters: {self.student_model.num_parameters(only_trainable=True):,}"
287
+ )
288
+ logger.info("Loaded teacher model")
289
+
290
+ def load_datasets(self, agent_type: str) -> tuple:
291
+ """Load training and validation datasets for specific agent."""
292
+ data_dir = Path("data/processed")
293
+
294
+ train_path = data_dir / f"{agent_type}_train.jsonl"
295
+ val_path = data_dir / f"{agent_type}_validation.jsonl"
296
+
297
+ if not train_path.exists():
298
+ raise FileNotFoundError(f"Training data not found: {train_path}")
299
+ if not val_path.exists():
300
+ raise FileNotFoundError(f"Validation data not found: {val_path}")
301
+
302
+ max_length = self.config["data"]["max_sequence_length"]
303
+
304
+ train_dataset = ConversationDataset(train_path, self.tokenizer, max_length)
305
+ val_dataset = ConversationDataset(val_path, self.tokenizer, max_length)
306
+
307
+ logger.info(
308
+ f"Loaded datasets: {len(train_dataset)} train, {len(val_dataset)} val"
309
+ )
310
+
311
+ return train_dataset, val_dataset
312
+
313
+ def create_data_loaders(self, train_dataset, val_dataset) -> tuple:
314
+ """Create data loaders for training and validation."""
315
+ batch_size = self.config["training"]["batch_size"]
316
+ num_workers = self.config["optimization"]["dataloader_num_workers"]
317
+ pin_memory = self.config["optimization"]["pin_memory"]
318
+
319
+ train_loader = DataLoader(
320
+ train_dataset,
321
+ batch_size=batch_size,
322
+ shuffle=True,
323
+ num_workers=num_workers,
324
+ pin_memory=pin_memory,
325
+ drop_last=True,
326
+ )
327
+
328
+ val_loader = DataLoader(
329
+ val_dataset,
330
+ batch_size=batch_size,
331
+ shuffle=False,
332
+ num_workers=num_workers,
333
+ pin_memory=pin_memory,
334
+ drop_last=False,
335
+ )
336
+
337
+ return train_loader, val_loader
338
+
339
+ def setup_training(self, train_dataset_size: int) -> None:
340
+ """Set up optimizer, scheduler, and loss function."""
341
+ # Calculate training steps
342
+ batch_size = self.config["training"]["batch_size"]
343
+ gradient_accumulation_steps = self.config["training"][
344
+ "gradient_accumulation_steps"
345
+ ]
346
+ num_epochs = self.config["training"]["num_epochs"]
347
+
348
+ steps_per_epoch = train_dataset_size // (
349
+ batch_size * gradient_accumulation_steps
350
+ )
351
+ self.total_steps = steps_per_epoch * num_epochs
352
+
353
+ # Setup optimizer
354
+ self.optimizer = torch.optim.AdamW(
355
+ self.student_model.parameters(),
356
+ lr=self.config["training"]["learning_rate"],
357
+ weight_decay=0.01,
358
+ )
359
+
360
+ # Setup scheduler
361
+ self.scheduler = get_linear_schedule_with_warmup(
362
+ self.optimizer,
363
+ num_warmup_steps=self.config["training"]["warmup_steps"],
364
+ num_training_steps=self.total_steps,
365
+ )
366
+
367
+ # Setup loss function
368
+ self.distill_loss = DistillationLoss(
369
+ alpha=self.config["distillation"]["alpha"],
370
+ temperature=self.config["distillation"]["temperature"],
371
+ )
372
+
373
+ logger.info(f"Setup training: {self.total_steps} total steps")
374
+
375
+ def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]:
376
+ """Train for one epoch."""
377
+ self.student_model.train()
378
+
379
+ total_loss = 0.0
380
+ total_task_loss = 0.0
381
+ total_distill_loss = 0.0
382
+ num_batches = 0
383
+
384
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", disable=False)
385
+
386
+ for batch_idx, batch in enumerate(progress_bar):
387
+ # Move batch to device
388
+ input_ids = batch["input_ids"].to(self.device)
389
+ attention_mask = batch["attention_mask"].to(self.device)
390
+ labels = batch["labels"].to(self.device)
391
+
392
+ # Get student outputs
393
+ student_outputs = self.student_model(
394
+ input_ids=input_ids, attention_mask=attention_mask
395
+ )
396
+ student_logits = student_outputs.logits
397
+
398
+ # Get teacher outputs
399
+ with torch.no_grad():
400
+ teacher_logits = self.teacher_manager.get_logits(
401
+ input_ids, attention_mask
402
+ )
403
+
404
+ # Compute distillation loss
405
+ loss, loss_dict = self.distill_loss.compute_loss(
406
+ student_logits, teacher_logits, labels, attention_mask
407
+ )
408
+
409
+ # Backward pass with gradient accumulation
410
+ loss = loss / self.config["training"]["gradient_accumulation_steps"]
411
+ loss.backward()
412
+
413
+ # Update model
414
+ if (batch_idx + 1) % self.config["training"][
415
+ "gradient_accumulation_steps"
416
+ ] == 0:
417
+ torch.nn.utils.clip_grad_norm_(
418
+ self.student_model.parameters(),
419
+ self.config["training"]["max_grad_norm"],
420
+ )
421
+ self.optimizer.step()
422
+ self.scheduler.step()
423
+ self.optimizer.zero_grad()
424
+
425
+ # Track metrics
426
+ total_loss += loss_dict["total_loss"]
427
+ total_task_loss += loss_dict["task_loss"]
428
+ total_distill_loss += loss_dict["distillation_loss"]
429
+ num_batches += 1
430
+
431
+ # Update progress bar
432
+ progress_bar.set_postfix(
433
+ {
434
+ "loss": f"{loss_dict['total_loss']:.4f}",
435
+ "task": f"{loss_dict['task_loss']:.4f}",
436
+ "distill": f"{loss_dict['distillation_loss']:.4f}",
437
+ }
438
+ )
439
+
440
+ # Log to tensorboard
441
+ if (
442
+ self.use_tensorboard
443
+ and batch_idx % self.config["training"]["logging_steps"] == 0
444
+ ):
445
+ step = epoch * len(train_loader) + batch_idx
446
+ self.tb_writer.add_scalar(
447
+ "train/total_loss", loss_dict["total_loss"], step
448
+ )
449
+ self.tb_writer.add_scalar(
450
+ "train/task_loss", loss_dict["task_loss"], step
451
+ )
452
+ self.tb_writer.add_scalar(
453
+ "train/distillation_loss", loss_dict["distillation_loss"], step
454
+ )
455
+
456
+ # Calculate epoch averages
457
+ epoch_metrics = {
458
+ "avg_loss": total_loss / num_batches,
459
+ "avg_task_loss": total_task_loss / num_batches,
460
+ "avg_distill_loss": total_distill_loss / num_batches,
461
+ }
462
+
463
+ return epoch_metrics
464
+
465
+ def evaluate(self, val_loader: DataLoader) -> Dict[str, float]:
466
+ """Evaluate model on validation set."""
467
+ self.student_model.eval()
468
+
469
+ total_loss = 0.0
470
+ total_task_loss = 0.0
471
+ total_distill_loss = 0.0
472
+ num_batches = 0
473
+
474
+ with torch.no_grad():
475
+ for batch in tqdm(val_loader, desc="Evaluating"):
476
+ # Move batch to device
477
+ input_ids = batch["input_ids"].to(self.device)
478
+ attention_mask = batch["attention_mask"].to(self.device)
479
+ labels = batch["labels"].to(self.device)
480
+
481
+ # Get model outputs
482
+ student_outputs = self.student_model(
483
+ input_ids=input_ids, attention_mask=attention_mask
484
+ )
485
+ student_logits = student_outputs.logits
486
+
487
+ # Get teacher outputs
488
+ teacher_logits = self.teacher_manager.get_logits(
489
+ input_ids, attention_mask
490
+ )
491
+
492
+ # Compute loss
493
+ loss, loss_dict = self.distill_loss.compute_loss(
494
+ student_logits, teacher_logits, labels, attention_mask
495
+ )
496
+
497
+ total_loss += loss_dict["total_loss"]
498
+ total_task_loss += loss_dict["task_loss"]
499
+ total_distill_loss += loss_dict["distillation_loss"]
500
+ num_batches += 1
501
+
502
+ val_metrics = {
503
+ "val_loss": total_loss / num_batches,
504
+ "val_task_loss": total_task_loss / num_batches,
505
+ "val_distill_loss": total_distill_loss / num_batches,
506
+ }
507
+
508
+ return val_metrics
509
+
510
+ def save_model(self, output_dir: str, agent_type: str, epoch: int) -> None:
511
+ """Save model checkpoint."""
512
+ output_path = Path(output_dir) / agent_type / f"epoch_{epoch}"
513
+ output_path.mkdir(parents=True, exist_ok=True)
514
+
515
+ # Save LoRA adapter
516
+ self.student_model.save_pretrained(output_path)
517
+
518
+ # Save tokenizer
519
+ self.tokenizer.save_pretrained(output_path)
520
+
521
+ # Save training config
522
+ config_path = output_path / "training_config.yaml"
523
+ with open(config_path, "w") as f:
524
+ yaml.dump(self.config, f)
525
+
526
+ logger.info(f"Saved model to: {output_path}")
527
+
528
+ def train_agent(self, agent_type: str) -> None:
529
+ """Train a specific agent with knowledge distillation."""
530
+ logger.info(f"Starting training for {agent_type} agent")
531
+
532
+ # Load models if not already loaded
533
+ if not hasattr(self, "student_model"):
534
+ self.load_models()
535
+
536
+ # Load datasets
537
+ train_dataset, val_dataset = self.load_datasets(agent_type)
538
+ train_loader, val_loader = self.create_data_loaders(train_dataset, val_dataset)
539
+
540
+ # Setup training components
541
+ self.setup_training(len(train_dataset))
542
+
543
+ # Start MLflow run
544
+ if self.use_mlflow:
545
+ mlflow.start_run(
546
+ run_name=f"{agent_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
547
+ )
548
+ mlflow.log_params(
549
+ {
550
+ "agent_type": agent_type,
551
+ "model_name": self.config["models"]["student"]["base_model"],
552
+ "lora_r": self.config["lora"]["r"],
553
+ "lora_alpha": self.config["lora"]["lora_alpha"],
554
+ "batch_size": self.config["training"]["batch_size"],
555
+ "learning_rate": self.config["training"]["learning_rate"],
556
+ "distillation_alpha": self.config["distillation"]["alpha"],
557
+ "temperature": self.config["distillation"]["temperature"],
558
+ }
559
+ )
560
+
561
+ try:
562
+ # Training loop
563
+ best_val_loss = float("inf")
564
+ num_epochs = self.config["training"]["num_epochs"]
565
+
566
+ for epoch in range(num_epochs):
567
+ logger.info(f"Epoch {epoch+1}/{num_epochs}")
568
+
569
+ # Train
570
+ train_metrics = self.train_epoch(train_loader, epoch)
571
+ logger.info(
572
+ f"Train - Loss: {train_metrics['avg_loss']:.4f}, "
573
+ f"Task: {train_metrics['avg_task_loss']:.4f}, "
574
+ f"Distill: {train_metrics['avg_distill_loss']:.4f}"
575
+ )
576
+
577
+ # Evaluate
578
+ val_metrics = self.evaluate(val_loader)
579
+ logger.info(
580
+ f"Val - Loss: {val_metrics['val_loss']:.4f}, "
581
+ f"Task: {val_metrics['val_task_loss']:.4f}, "
582
+ f"Distill: {val_metrics['val_distill_loss']:.4f}"
583
+ )
584
+
585
+ # Log to MLflow
586
+ if self.use_mlflow:
587
+ mlflow.log_metrics({**train_metrics, **val_metrics}, step=epoch)
588
+
589
+ # Log to TensorBoard
590
+ if self.use_tensorboard:
591
+ for key, value in train_metrics.items():
592
+ self.tb_writer.add_scalar(f"epoch/{key}", value, epoch)
593
+ for key, value in val_metrics.items():
594
+ self.tb_writer.add_scalar(f"epoch/{key}", value, epoch)
595
+
596
+ # Save checkpoint if best model
597
+ if val_metrics["val_loss"] < best_val_loss:
598
+ best_val_loss = val_metrics["val_loss"]
599
+ self.save_model(
600
+ self.config["output"]["base_dir"], agent_type, epoch
601
+ )
602
+ logger.info(f"New best model saved (val_loss: {best_val_loss:.4f})")
603
+
604
+ finally:
605
+ if self.use_mlflow:
606
+ mlflow.end_run()
607
+
608
+ logger.info(f"Training completed for {agent_type} agent")
609
+
610
+
611
+ class TeacherModelManager:
612
+ """Manages teacher model interactions (API or local)."""
613
+
614
+ def __init__(self, teacher_config: Dict, tokenizer):
615
+ self.config = teacher_config
616
+ self.tokenizer = tokenizer
617
+
618
+ if teacher_config["type"] == "api":
619
+ self.setup_api_teacher()
620
+ else:
621
+ self.setup_local_teacher()
622
+
623
+ def setup_api_teacher(self) -> None:
624
+ """Set up API-based teacher model."""
625
+ self.model_name = self.config["model_name"]
626
+ logger.info(f"Using API teacher model: {self.model_name}")
627
+
628
+ # This would integrate with OpenAI/Anthropic APIs
629
+ # For now, we'll use a placeholder that returns random logits
630
+ # In production, you'd implement actual API calls here
631
+
632
+ def setup_local_teacher(self) -> None:
633
+ """Set up local teacher model."""
634
+ model_path = self.config.get("local_model_path", "microsoft/DialoGPT-large")
635
+
636
+ self.teacher_model = AutoModelForCausalLM.from_pretrained(
637
+ model_path, torch_dtype=torch.float16, device_map="auto"
638
+ )
639
+ logger.info(f"Loaded local teacher model: {model_path}")
640
+
641
+ def get_logits(
642
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor
643
+ ) -> torch.Tensor:
644
+ """Get teacher model logits."""
645
+ if self.config["type"] == "api":
646
+ # Placeholder for API-based teacher
647
+ # In practice, you'd call the API and convert responses to logits
648
+ batch_size, seq_len = input_ids.shape
649
+ vocab_size = self.tokenizer.vocab_size
650
+ return torch.randn(batch_size, seq_len, vocab_size).to(input_ids.device)
651
+ else:
652
+ # Local teacher model
653
+ with torch.no_grad():
654
+ outputs = self.teacher_model(
655
+ input_ids=input_ids, attention_mask=attention_mask
656
+ )
657
+ return outputs.logits
658
+
659
+
660
+ def main():
661
+ parser = argparse.ArgumentParser(
662
+ description="Train MangoMAS agents with LoRA and knowledge distillation"
663
+ )
664
+ parser.add_argument(
665
+ "--config",
666
+ type=str,
667
+ default="config/training/distillation.yaml",
668
+ help="Path to training configuration file",
669
+ )
670
+ parser.add_argument(
671
+ "--agent",
672
+ type=str,
673
+ choices=["infrastructure", "devsecops", "risk_assessment", "all"],
674
+ default="all",
675
+ help="Which agent to train",
676
+ )
677
+ parser.add_argument("--data", type=str, help="Path to training data file")
678
+
679
+ args = parser.parse_args()
680
+
681
+ # Initialize trainer
682
+ trainer = LoRADistillationTrainer(args.config)
683
+
684
+ # If data path is provided, update the trainer to use it
685
+ if args.data:
686
+ trainer.custom_data_path = args.data
687
+
688
+ # Train specified agent(s)
689
+ if args.agent == "all":
690
+ agents = ["infrastructure", "devsecops", "risk_assessment"]
691
+ else:
692
+ agents = [args.agent]
693
+
694
+ for agent_type in agents:
695
+ trainer.train_agent(agent_type)
696
+
697
+
698
+ if __name__ == "__main__":
699
+ main()
src/multi_agent_training/modular_distillation_loss.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refactored Knowledge Distillation Loss using modular architecture.
3
+
4
+ This module implements a clean, testable loss function that follows the interface contracts
5
+ and provides better separation of concerns.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Dict, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from ..core.base_components import BaseLossFunction
16
+ from ..core.exceptions import TrainingError
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ModularDistillationLoss(BaseLossFunction):
22
+ """
23
+ Modular distillation loss function implementing clean interface contracts.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ alpha: float = 0.5,
29
+ temperature: float = 3.0,
30
+ task_loss_fn: Optional[nn.Module] = None,
31
+ ):
32
+ """
33
+ Initialize the modular distillation loss.
34
+
35
+ Args:
36
+ alpha: Balance between task loss and distillation loss (0.0-1.0)
37
+ temperature: Temperature for softmax in knowledge distillation
38
+ task_loss_fn: Custom task loss function (defaults to CrossEntropyLoss)
39
+ """
40
+ super().__init__({"alpha": alpha, "temperature": temperature})
41
+
42
+ if not 0.0 <= alpha <= 1.0:
43
+ raise ValueError(f"Alpha must be between 0 and 1, got {alpha}")
44
+ if temperature <= 0.0:
45
+ raise ValueError(f"Temperature must be positive, got {temperature}")
46
+
47
+ self.alpha = alpha
48
+ self.temperature = temperature
49
+ self.task_loss_fn = task_loss_fn or nn.CrossEntropyLoss(ignore_index=-100)
50
+ self.kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
51
+
52
+ logger.info(
53
+ f"Initialized ModularDistillationLoss with alpha={alpha}, temperature={temperature}"
54
+ )
55
+
56
+ def compute(
57
+ self,
58
+ student_logits: torch.Tensor,
59
+ teacher_logits: torch.Tensor,
60
+ labels: Optional[torch.Tensor] = None,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ **kwargs,
63
+ ) -> torch.Tensor:
64
+ """
65
+ Compute the combined distillation loss.
66
+
67
+ Args:
68
+ student_logits: Logits from student model
69
+ teacher_logits: Logits from teacher model
70
+ labels: Target labels for task loss
71
+ attention_mask: Attention mask for valid tokens
72
+
73
+ Returns:
74
+ Combined loss tensor
75
+ """
76
+ try:
77
+ # Compute task loss (if labels provided)
78
+ task_loss = torch.tensor(0.0, device=student_logits.device)
79
+ if labels is not None:
80
+ task_loss = self._compute_task_loss(
81
+ student_logits, labels, attention_mask
82
+ )
83
+
84
+ # Compute distillation loss
85
+ distill_loss = self._compute_distillation_loss(
86
+ student_logits, teacher_logits, attention_mask
87
+ )
88
+
89
+ # Combined loss
90
+ total_loss = self.alpha * task_loss + (1 - self.alpha) * distill_loss
91
+
92
+ # Track metrics
93
+ metrics = {
94
+ "task_loss": task_loss.item(),
95
+ "distillation_loss": distill_loss.item(),
96
+ "total_loss": total_loss.item(),
97
+ "alpha": self.alpha,
98
+ "temperature": self.temperature,
99
+ }
100
+ self._track_metrics(metrics)
101
+ self._track_loss(total_loss.item())
102
+
103
+ return total_loss
104
+
105
+ except Exception as e:
106
+ raise TrainingError(
107
+ f"Loss computation failed: {str(e)}", "LOSS_COMPUTATION_ERROR"
108
+ )
109
+
110
+ def _compute_task_loss(
111
+ self,
112
+ logits: torch.Tensor,
113
+ labels: torch.Tensor,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ ) -> torch.Tensor:
116
+ """Compute standard cross-entropy task loss."""
117
+ # Reshape for cross entropy: [batch_size * seq_len, vocab_size]
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+
121
+ # Flatten the tokens
122
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
123
+ flat_labels = shift_labels.view(-1)
124
+
125
+ # Apply attention mask if provided
126
+ if attention_mask is not None:
127
+ shift_attention_mask = attention_mask[..., 1:].contiguous()
128
+ flat_attention_mask = shift_attention_mask.view(-1)
129
+
130
+ # Only compute loss on non-masked tokens
131
+ active_logits = flat_logits[flat_attention_mask == 1]
132
+ active_labels = flat_labels[flat_attention_mask == 1]
133
+
134
+ if active_logits.size(0) > 0:
135
+ return self.task_loss_fn(active_logits, active_labels)
136
+ else:
137
+ return torch.tensor(0.0, device=logits.device)
138
+
139
+ return self.task_loss_fn(flat_logits, flat_labels)
140
+
141
+ def _compute_distillation_loss(
142
+ self,
143
+ student_logits: torch.Tensor,
144
+ teacher_logits: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ ) -> torch.Tensor:
147
+ """Compute knowledge distillation loss using KL divergence."""
148
+ # Apply temperature scaling
149
+ student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
150
+ teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
151
+
152
+ # Compute KL divergence
153
+ if attention_mask is not None:
154
+ # Mask out padding tokens
155
+ mask = attention_mask.unsqueeze(-1).expand_as(student_soft)
156
+ student_soft_masked = student_soft * mask
157
+ teacher_soft_masked = teacher_soft * mask
158
+
159
+ # Compute loss only on valid tokens
160
+ kl_loss = self.kl_loss_fn(student_soft_masked, teacher_soft_masked)
161
+ else:
162
+ kl_loss = self.kl_loss_fn(student_soft, teacher_soft)
163
+
164
+ # Scale by temperature squared (as per distillation literature)
165
+ return kl_loss * (self.temperature**2)
166
+
167
+ def update_alpha(self, new_alpha: float) -> None:
168
+ """Update the alpha parameter for dynamic loss weighting."""
169
+ if not 0.0 <= new_alpha <= 1.0:
170
+ raise ValueError(f"Alpha must be between 0 and 1, got {new_alpha}")
171
+
172
+ old_alpha = self.alpha
173
+ self.alpha = new_alpha
174
+ self.config["alpha"] = new_alpha
175
+
176
+ logger.info(f"Updated alpha from {old_alpha} to {new_alpha}")
177
+
178
+ def update_temperature(self, new_temperature: float) -> None:
179
+ """Update the temperature parameter for dynamic distillation."""
180
+ if new_temperature <= 0.0:
181
+ raise ValueError(f"Temperature must be positive, got {new_temperature}")
182
+
183
+ old_temperature = self.temperature
184
+ self.temperature = new_temperature
185
+ self.config["temperature"] = new_temperature
186
+
187
+ logger.info(f"Updated temperature from {old_temperature} to {new_temperature}")
188
+
189
+
190
+ class AdaptiveDistillationLoss(ModularDistillationLoss):
191
+ """
192
+ Adaptive distillation loss that adjusts alpha based on training progress.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ alpha: float = 0.5,
198
+ temperature: float = 3.0,
199
+ adaptation_strategy: str = "linear_decay",
200
+ adaptation_config: Optional[Dict[str, Any]] = None,
201
+ ):
202
+ """
203
+ Initialize adaptive distillation loss.
204
+
205
+ Args:
206
+ alpha: Initial alpha value
207
+ temperature: Temperature for distillation
208
+ adaptation_strategy: Strategy for adapting alpha ("linear_decay", "cosine_decay", "step_decay")
209
+ adaptation_config: Configuration for adaptation strategy
210
+ """
211
+ super().__init__(alpha, temperature)
212
+
213
+ self.initial_alpha = alpha
214
+ self.adaptation_strategy = adaptation_strategy
215
+ self.adaptation_config = adaptation_config or {}
216
+ self.step_count = 0
217
+
218
+ logger.info(
219
+ f"Initialized AdaptiveDistillationLoss with strategy: {adaptation_strategy}"
220
+ )
221
+
222
+ def compute(self, *args, **kwargs) -> torch.Tensor:
223
+ """Compute loss with adaptive alpha adjustment."""
224
+ # Update alpha based on training progress
225
+ self._update_alpha_adaptive()
226
+
227
+ # Increment step count
228
+ self.step_count += 1
229
+
230
+ return super().compute(*args, **kwargs)
231
+
232
+ def _update_alpha_adaptive(self) -> None:
233
+ """Update alpha based on the selected adaptation strategy."""
234
+ total_steps = self.adaptation_config.get("total_steps", 1000)
235
+
236
+ if self.adaptation_strategy == "linear_decay":
237
+ # Linearly decay alpha from initial value to 0
238
+ progress = min(self.step_count / total_steps, 1.0)
239
+ new_alpha = self.initial_alpha * (1.0 - progress)
240
+
241
+ elif self.adaptation_strategy == "cosine_decay":
242
+ # Cosine decay
243
+ import math
244
+
245
+ progress = min(self.step_count / total_steps, 1.0)
246
+ new_alpha = self.initial_alpha * (1 + math.cos(math.pi * progress)) / 2
247
+
248
+ elif self.adaptation_strategy == "step_decay":
249
+ # Step decay at specified intervals
250
+ decay_steps = self.adaptation_config.get("decay_steps", [500, 750])
251
+ decay_factor = self.adaptation_config.get("decay_factor", 0.5)
252
+
253
+ new_alpha = self.initial_alpha
254
+ for decay_step in decay_steps:
255
+ if self.step_count >= decay_step:
256
+ new_alpha *= decay_factor
257
+ else:
258
+ # No adaptation
259
+ new_alpha = self.alpha
260
+
261
+ # Update if changed significantly
262
+ if abs(new_alpha - self.alpha) > 1e-6:
263
+ self.alpha = new_alpha
src/multi_agent_training/modular_trainer.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refactored LoRA Knowledge Distillation Trainer using modular architecture.
3
+
4
+ This module implements a clean, testable trainer that follows the interface contracts
5
+ and provides better separation of concerns.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import DataLoader
15
+
16
+ from ..core.base_components import BaseTrainer
17
+ from ..core.exceptions import TrainingError
18
+ from ..core.interfaces import TrainingConfig
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ModularLoRATrainer(BaseTrainer):
24
+ """Modular LoRA trainer with clean separation of concerns."""
25
+
26
+ def __init__(
27
+ self,
28
+ model: nn.Module,
29
+ optimizer: torch.optim.Optimizer,
30
+ loss_function,
31
+ device: str = "cpu",
32
+ teacher_model: Optional[nn.Module] = None,
33
+ ):
34
+ """
35
+ Initialize the modular LoRA trainer.
36
+
37
+ Args:
38
+ model: Student model to train
39
+ optimizer: Optimizer for training
40
+ loss_function: Loss function implementing ILossFunction
41
+ device: Training device
42
+ teacher_model: Optional teacher model for distillation
43
+ """
44
+ super().__init__(model, optimizer, device)
45
+ self.loss_function = loss_function
46
+ self.teacher_model = teacher_model
47
+ if self.teacher_model:
48
+ self.teacher_model.to(self.device)
49
+ self.teacher_model.eval()
50
+
51
+ self.custom_loss_fn = None
52
+
53
+ def set_custom_loss_fn(self, loss_fn):
54
+ """Set custom loss function for specialized training."""
55
+ self.custom_loss_fn = loss_fn
56
+
57
+ def compute_distillation_loss(self, student_outputs, teacher_outputs, batch):
58
+ """Compute standard distillation loss."""
59
+ return self.loss_function.compute(
60
+ student_outputs.logits,
61
+ (
62
+ teacher_outputs.logits
63
+ if hasattr(teacher_outputs, "logits")
64
+ else teacher_outputs
65
+ ),
66
+ labels=batch.get("labels"),
67
+ )
68
+
69
+ def train(self, dataloader: DataLoader, config: TrainingConfig) -> Dict[str, Any]:
70
+ """
71
+ Train the model with the given configuration.
72
+
73
+ Args:
74
+ dataloader: Training data loader
75
+ config: Training configuration
76
+
77
+ Returns:
78
+ Training results and metrics
79
+ """
80
+ try:
81
+ self.model.train()
82
+ total_loss = 0.0
83
+ num_batches = 0
84
+ training_metrics = {}
85
+
86
+ for epoch in range(config.num_epochs):
87
+ epoch_loss = 0.0
88
+ epoch_batches = 0
89
+
90
+ for batch_idx, batch in enumerate(dataloader):
91
+ # Move batch to device
92
+ batch = self._move_batch_to_device(batch)
93
+
94
+ # Forward pass
95
+ self.optimizer.zero_grad()
96
+
97
+ # Student model forward pass
98
+ student_outputs = self.model(**batch)
99
+
100
+ # Teacher model forward pass (if available)
101
+ teacher_outputs = None
102
+ if self.teacher_model:
103
+ with torch.no_grad():
104
+ teacher_outputs = self.teacher_model(**batch)
105
+
106
+ # Compute loss
107
+ if self.custom_loss_fn:
108
+ loss = self.custom_loss_fn(
109
+ student_outputs, teacher_outputs, batch
110
+ )
111
+ else:
112
+ loss = self.loss_function.compute(
113
+ (
114
+ student_outputs.logits
115
+ if hasattr(student_outputs, "logits")
116
+ else student_outputs
117
+ ),
118
+ batch.get("labels", batch.get("input_ids")),
119
+ )
120
+
121
+ # Backward pass
122
+ loss.backward()
123
+ self.optimizer.step()
124
+
125
+ # Track metrics
126
+ epoch_loss += loss.item()
127
+ epoch_batches += 1
128
+
129
+ # Log training step
130
+ if batch_idx % config.save_steps == 0:
131
+ step_metrics = self.loss_function.get_metrics()
132
+ self._log_training_step(
133
+ epoch, batch_idx, loss.item(), step_metrics
134
+ )
135
+
136
+ logger.info(
137
+ f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}"
138
+ )
139
+
140
+ # End of epoch
141
+ avg_epoch_loss = (
142
+ epoch_loss / epoch_batches if epoch_batches > 0 else 0.0
143
+ )
144
+ total_loss += epoch_loss
145
+ num_batches += epoch_batches
146
+
147
+ training_metrics[f"epoch_{epoch}_loss"] = avg_epoch_loss
148
+
149
+ # Save checkpoint
150
+ if epoch % config.save_steps == 0:
151
+ checkpoint_path = (
152
+ Path(config.output_dir) / f"checkpoint_epoch_{epoch}.pt"
153
+ )
154
+ self.save_checkpoint(checkpoint_path, epoch)
155
+
156
+ # Final results
157
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
158
+
159
+ results = {
160
+ "average_loss": avg_loss,
161
+ "total_epochs": config.num_epochs,
162
+ "total_batches": num_batches,
163
+ "training_metrics": training_metrics,
164
+ "loss_function_metrics": self.loss_function.get_metrics(),
165
+ }
166
+
167
+ logger.info(f"Training completed. Average loss: {avg_loss:.4f}")
168
+ return results
169
+
170
+ except Exception as e:
171
+ raise TrainingError(
172
+ f"Training failed: {str(e)}",
173
+ "TRAINING_FAILED",
174
+ {"epoch": getattr(self, "current_epoch", 0)},
175
+ )
176
+
177
+ def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
178
+ """
179
+ Evaluate the model on the given dataset.
180
+
181
+ Args:
182
+ dataloader: Evaluation data loader
183
+
184
+ Returns:
185
+ Evaluation metrics
186
+ """
187
+ try:
188
+ self.model.eval()
189
+ total_loss = 0.0
190
+ num_batches = 0
191
+
192
+ with torch.no_grad():
193
+ for batch in dataloader:
194
+ batch = self._move_batch_to_device(batch)
195
+
196
+ # Forward pass
197
+ outputs = self.model(**batch)
198
+
199
+ # Compute loss
200
+ loss = self.loss_function.compute(
201
+ outputs.logits if hasattr(outputs, "logits") else outputs,
202
+ batch.get("labels", batch.get("input_ids")),
203
+ )
204
+
205
+ total_loss += loss.item()
206
+ num_batches += 1
207
+
208
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
209
+
210
+ results = {"eval_loss": avg_loss, "eval_batches": num_batches}
211
+ results.update(self.loss_function.get_metrics())
212
+
213
+ logger.info(f"Evaluation completed. Average loss: {avg_loss:.4f}")
214
+ return results
215
+
216
+ except Exception as e:
217
+ raise TrainingError(f"Evaluation failed: {str(e)}", "EVALUATION_FAILED")
218
+
219
+ def _move_batch_to_device(
220
+ self, batch: Dict[str, torch.Tensor]
221
+ ) -> Dict[str, torch.Tensor]:
222
+ """Move batch tensors to the training device."""
223
+ device_batch = {}
224
+ for key, value in batch.items():
225
+ if isinstance(value, torch.Tensor):
226
+ device_batch[key] = value.to(self.device)
227
+ else:
228
+ device_batch[key] = value
229
+ return device_batch
src/multi_agent_training/multi_agent_trainer.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CPU-Optimized Multi-Agent Trainer
4
+
5
+ This module provides comprehensive multi-agent training capabilities optimized for CPU execution,
6
+ including LoRA fine-tuning, agent-specific conditioning, and integration with existing training infrastructure.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import math
12
+ import random
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple, Any, Union
16
+ from dataclasses import dataclass, field
17
+
18
+ import torch
19
+ import yaml
20
+ from datasets import DatasetDict, Dataset
21
+ from transformers import (
22
+ AutoModelForCausalLM,
23
+ AutoTokenizer,
24
+ TrainingArguments,
25
+ Trainer,
26
+ DataCollatorForLanguageModeling
27
+ )
28
+ from trl import SFTTrainer
29
+ from peft import LoraConfig, get_peft_model, TaskType
30
+ from huggingface_hub import HfApi, create_repo
31
+
32
+ from ..multi_agent_datasets.multi_agent_loader import MultiAgentDatasetLoader, MultiAgentDatasetConfig
33
+ from ..multi_agent_tokenization.agent_tokenizer import AgentTokenManager, AgentTokenConfig, AgentTokenizer
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ @dataclass
38
+ class MultiAgentTrainingConfig:
39
+ """Configuration for multi-agent training"""
40
+ # Model configuration
41
+ base_model: str = "microsoft/Phi-3.5-MoE-instruct"
42
+ model_cache_dir: Optional[str] = None
43
+ trust_remote_code: bool = True
44
+
45
+ # Training configuration
46
+ output_dir: str = "./outputs"
47
+ max_steps: int = 50
48
+ num_train_epochs: int = 1
49
+ per_device_train_batch_size: int = 1
50
+ per_device_eval_batch_size: int = 1
51
+ gradient_accumulation_steps: int = 8
52
+ learning_rate: float = 2e-5
53
+ lr_scheduler_type: str = "cosine"
54
+ warmup_steps: int = 0
55
+
56
+ # LoRA configuration
57
+ lora_r: int = 8
58
+ lora_alpha: int = 16
59
+ lora_dropout: float = 0.05
60
+ lora_target_modules: str = "all-linear"
61
+ lora_bias: str = "none"
62
+
63
+ # CPU optimization
64
+ use_cpu: bool = True
65
+ bf16: bool = False
66
+ fp16: bool = False
67
+ gradient_checkpointing: bool = True
68
+ dataloader_num_workers: int = 0
69
+ remove_unused_columns: bool = False
70
+
71
+ # Multi-agent specific
72
+ agent_prefix: str = "<|agent:"
73
+ agent_suffix: str = "|>"
74
+ balance_agents: bool = False
75
+ balance_cap: Optional[int] = None
76
+
77
+ # Logging and monitoring
78
+ logging_steps: int = 5
79
+ save_steps: int = 50
80
+ eval_steps: int = 25
81
+ save_total_limit: int = 1
82
+ logging_dir: str = "./logs"
83
+ report_to: str = "none"
84
+
85
+ # Hugging Face Hub
86
+ hub_repo_id: Optional[str] = None
87
+ push_to_hub: bool = False
88
+ hub_token: Optional[str] = None
89
+
90
+ # Dataset configuration
91
+ dataset_config: Optional[MultiAgentDatasetConfig] = None
92
+
93
+ class CPUOptimizedMultiAgentTrainer:
94
+ """
95
+ CPU-optimized multi-agent trainer with LoRA fine-tuning
96
+ """
97
+
98
+ def __init__(self, config: MultiAgentTrainingConfig):
99
+ self.config = config
100
+ self.tokenizer: Optional[AutoTokenizer] = None
101
+ self.model: Optional[torch.nn.Module] = None
102
+ self.agent_manager: Optional[AgentTokenManager] = None
103
+ self.dataset_loader: Optional[MultiAgentDatasetLoader] = None
104
+ self.trainer: Optional[SFTTrainer] = None
105
+ self.agents: List[str] = []
106
+ self.dataset_stats: Dict[str, Any] = {}
107
+
108
+ # Setup logging
109
+ self._setup_logging()
110
+
111
+ def _setup_logging(self):
112
+ """Setup logging configuration"""
113
+ log_level = logging.INFO
114
+ logging.basicConfig(
115
+ level=log_level,
116
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
117
+ handlers=[
118
+ logging.StreamHandler(),
119
+ logging.FileHandler(os.path.join(self.config.logging_dir, 'training.log'))
120
+ ]
121
+ )
122
+
123
+ # Create logs directory
124
+ os.makedirs(self.config.logging_dir, exist_ok=True)
125
+
126
+ def load_model_and_tokenizer(self) -> Tuple[AutoTokenizer, torch.nn.Module]:
127
+ """Load model and tokenizer optimized for CPU"""
128
+ logger.info(f"Loading model and tokenizer: {self.config.base_model}")
129
+
130
+ # Load tokenizer
131
+ tokenizer_kwargs = {
132
+ "trust_remote_code": self.config.trust_remote_code,
133
+ "cache_dir": self.config.model_cache_dir
134
+ }
135
+
136
+ self.tokenizer = AutoTokenizer.from_pretrained(
137
+ self.config.base_model,
138
+ **tokenizer_kwargs
139
+ )
140
+
141
+ # Configure tokenizer for CPU training
142
+ self.tokenizer.model_max_length = 2048
143
+ if self.tokenizer.pad_token is None:
144
+ self.tokenizer.pad_token = self.tokenizer.unk_token or self.tokenizer.eos_token
145
+ self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
146
+ self.tokenizer.padding_side = "right"
147
+
148
+ # Load model with CPU optimizations
149
+ model_kwargs = {
150
+ "trust_remote_code": self.config.trust_remote_code,
151
+ "torch_dtype": torch.float32, # Use float32 for CPU
152
+ "device_map": "cpu",
153
+ "attn_implementation": "eager", # Force CPU-compatible attention
154
+ "use_cache": False, # Disable cache for training
155
+ "cache_dir": self.config.model_cache_dir
156
+ }
157
+
158
+ self.model = AutoModelForCausalLM.from_pretrained(
159
+ self.config.base_model,
160
+ **model_kwargs
161
+ )
162
+
163
+ logger.info(f"Model loaded with {self.model.num_parameters():,} parameters")
164
+ return self.tokenizer, self.model
165
+
166
+ def setup_agent_tokens(self, agents: List[str]) -> AgentTokenManager:
167
+ """Setup agent token management"""
168
+ logger.info(f"Setting up agent tokens for {len(agents)} agents")
169
+
170
+ agent_config = AgentTokenConfig(
171
+ agent_prefix=self.config.agent_prefix,
172
+ agent_suffix=self.config.agent_suffix,
173
+ resize_embeddings=True
174
+ )
175
+
176
+ self.agent_manager = AgentTokenManager(agent_config)
177
+
178
+ # Add agent tokens to tokenizer
179
+ self.tokenizer, agent_tokens = self.agent_manager.add_agent_tokens_to_tokenizer(
180
+ self.tokenizer, agents
181
+ )
182
+
183
+ # Resize model embeddings
184
+ self.model = self.agent_manager.resize_model_embeddings(self.model, self.tokenizer)
185
+
186
+ logger.info(f"Agent tokens setup complete. Tokens: {agent_tokens}")
187
+ return self.agent_manager
188
+
189
+ def load_dataset(self, dataset_path: str) -> Tuple[DatasetDict, List[str], Dict[str, Any]]:
190
+ """Load and process multi-agent dataset"""
191
+ logger.info(f"Loading dataset from: {dataset_path}")
192
+
193
+ # Create dataset configuration
194
+ if self.config.dataset_config is None:
195
+ dataset_config = MultiAgentDatasetConfig(
196
+ dataset_path=dataset_path,
197
+ agent_prefix=self.config.agent_prefix,
198
+ agent_suffix=self.config.agent_suffix,
199
+ balance_agents=self.config.balance_agents,
200
+ balance_cap=self.config.balance_cap
201
+ )
202
+ else:
203
+ dataset_config = self.config.dataset_config
204
+ dataset_config.dataset_path = dataset_path
205
+
206
+ # Create dataset loader
207
+ self.dataset_loader = MultiAgentDatasetLoader(dataset_config)
208
+
209
+ # Load and process dataset
210
+ dataset, agents, stats = self.dataset_loader.load_and_process(self.tokenizer)
211
+
212
+ self.agents = agents
213
+ self.dataset_stats = stats
214
+
215
+ logger.info(f"Dataset loaded: {len(agents)} agents, {stats['total_samples']} samples")
216
+ return dataset, agents, stats
217
+
218
+ def create_lora_config(self) -> LoraConfig:
219
+ """Create LoRA configuration optimized for CPU"""
220
+ logger.info("Creating LoRA configuration")
221
+
222
+ lora_config = LoraConfig(
223
+ r=self.config.lora_r,
224
+ lora_alpha=self.config.lora_alpha,
225
+ lora_dropout=self.config.lora_dropout,
226
+ bias=self.config.lora_bias,
227
+ task_type=TaskType.CAUSAL_LM,
228
+ target_modules=self.config.lora_target_modules
229
+ )
230
+
231
+ logger.info(f"LoRA config: r={lora_config.r}, alpha={lora_config.lora_alpha}, dropout={lora_config.lora_dropout}")
232
+ return lora_config
233
+
234
+ def create_training_arguments(self) -> TrainingArguments:
235
+ """Create training arguments optimized for CPU"""
236
+ logger.info("Creating training arguments")
237
+
238
+ training_args = TrainingArguments(
239
+ output_dir=self.config.output_dir,
240
+ overwrite_output_dir=True,
241
+ num_train_epochs=self.config.num_train_epochs,
242
+ max_steps=self.config.max_steps,
243
+ per_device_train_batch_size=self.config.per_device_train_batch_size,
244
+ per_device_eval_batch_size=self.config.per_device_eval_batch_size,
245
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
246
+ learning_rate=self.config.learning_rate,
247
+ lr_scheduler_type=self.config.lr_scheduler_type,
248
+ warmup_steps=self.config.warmup_steps,
249
+
250
+ # CPU optimizations
251
+ bf16=self.config.bf16,
252
+ fp16=self.config.fp16,
253
+ gradient_checkpointing=self.config.gradient_checkpointing,
254
+ dataloader_num_workers=self.config.dataloader_num_workers,
255
+ remove_unused_columns=self.config.remove_unused_columns,
256
+
257
+ # Logging and saving
258
+ logging_steps=self.config.logging_steps,
259
+ save_steps=self.config.save_steps,
260
+ eval_steps=self.config.eval_steps,
261
+ save_total_limit=self.config.save_total_limit,
262
+ logging_dir=self.config.logging_dir,
263
+ report_to=self.config.report_to,
264
+
265
+ # Evaluation
266
+ evaluation_strategy="steps" if self.config.eval_steps > 0 else "no",
267
+
268
+ # Optimization
269
+ optim="adamw_torch",
270
+ weight_decay=0.01,
271
+ max_grad_norm=1.0,
272
+
273
+ # Hub integration
274
+ push_to_hub=self.config.push_to_hub,
275
+ hub_model_id=self.config.hub_repo_id,
276
+ hub_token=self.config.hub_token,
277
+ )
278
+
279
+ logger.info(f"Training arguments created: {training_args.output_dir}")
280
+ return training_args
281
+
282
+ def create_trainer(self, dataset: DatasetDict, lora_config: LoraConfig, training_args: TrainingArguments) -> SFTTrainer:
283
+ """Create SFT trainer for multi-agent training"""
284
+ logger.info("Creating SFT trainer")
285
+
286
+ # Get training and evaluation datasets
287
+ train_dataset = dataset["train"]
288
+ eval_dataset = dataset.get("test", None)
289
+
290
+ # Create trainer
291
+ self.trainer = SFTTrainer(
292
+ model=self.model,
293
+ args=training_args,
294
+ peft_config=lora_config,
295
+ train_dataset=train_dataset,
296
+ eval_dataset=eval_dataset,
297
+ tokenizer=self.tokenizer,
298
+ max_seq_length=2048,
299
+ dataset_text_field="text",
300
+ packing=True, # Enable packing for efficiency
301
+ data_collator=None, # Use default
302
+ )
303
+
304
+ logger.info("SFT trainer created successfully")
305
+ return self.trainer
306
+
307
+ def train(self) -> Dict[str, Any]:
308
+ """Execute training process"""
309
+ logger.info("Starting training process")
310
+
311
+ if self.trainer is None:
312
+ raise ValueError("Trainer not initialized. Call create_trainer() first.")
313
+
314
+ # Start training
315
+ training_result = self.trainer.train()
316
+
317
+ # Save model and tokenizer
318
+ self.save_model()
319
+
320
+ # Save agent tokens
321
+ if self.agent_manager:
322
+ self.agent_manager.save_agent_tokens(self.config.output_dir)
323
+
324
+ # Generate training report
325
+ report = self.generate_training_report(training_result)
326
+
327
+ logger.info("Training completed successfully")
328
+ return report
329
+
330
+ def save_model(self):
331
+ """Save trained model and tokenizer"""
332
+ logger.info(f"Saving model to {self.config.output_dir}")
333
+
334
+ os.makedirs(self.config.output_dir, exist_ok=True)
335
+
336
+ # Save model
337
+ self.trainer.model.save_pretrained(self.config.output_dir)
338
+
339
+ # Save tokenizer
340
+ self.tokenizer.save_pretrained(self.config.output_dir)
341
+
342
+ # Save training configuration
343
+ config_file = os.path.join(self.config.output_dir, "training_config.json")
344
+ with open(config_file, 'w') as f:
345
+ json.dump(self.config.__dict__, f, indent=2, default=str)
346
+
347
+ logger.info("Model saved successfully")
348
+
349
+ def generate_training_report(self, training_result: Any) -> Dict[str, Any]:
350
+ """Generate comprehensive training report"""
351
+ report = {
352
+ "training_config": self.config.__dict__,
353
+ "dataset_stats": self.dataset_stats,
354
+ "agents": self.agents,
355
+ "agent_tokens": self.agent_manager.get_agent_statistics() if self.agent_manager else {},
356
+ "training_metrics": {
357
+ "train_loss": getattr(training_result, 'train_loss', None),
358
+ "train_runtime": getattr(training_result, 'train_runtime', None),
359
+ "train_samples_per_second": getattr(training_result, 'train_samples_per_second', None),
360
+ "train_steps_per_second": getattr(training_result, 'train_steps_per_second', None),
361
+ },
362
+ "model_info": {
363
+ "base_model": self.config.base_model,
364
+ "num_parameters": self.model.num_parameters() if self.model else None,
365
+ "vocab_size": len(self.tokenizer) if self.tokenizer else None,
366
+ }
367
+ }
368
+
369
+ # Save report
370
+ report_file = os.path.join(self.config.output_dir, "training_report.json")
371
+ with open(report_file, 'w') as f:
372
+ json.dump(report, f, indent=2, default=str)
373
+
374
+ logger.info(f"Training report saved to {report_file}")
375
+ return report
376
+
377
+ def push_to_hub(self, repo_id: Optional[str] = None, commit_message: str = "Multi-agent LoRA adapter"):
378
+ """Push trained model to Hugging Face Hub"""
379
+ if not self.config.push_to_hub:
380
+ logger.info("Push to hub disabled")
381
+ return
382
+
383
+ repo_id = repo_id or self.config.hub_repo_id
384
+ if not repo_id:
385
+ raise ValueError("Repository ID not specified")
386
+
387
+ if not self.config.hub_token:
388
+ raise ValueError("Hub token not provided")
389
+
390
+ logger.info(f"Pushing model to Hub: {repo_id}")
391
+
392
+ # Create repository
393
+ create_repo(repo_id, repo_type="model", exist_ok=True, token=self.config.hub_token)
394
+
395
+ # Upload model
396
+ api = HfApi(token=self.config.hub_token)
397
+ api.upload_folder(
398
+ folder_path=self.config.output_dir,
399
+ repo_id=repo_id,
400
+ repo_type="model",
401
+ commit_message=commit_message,
402
+ allow_patterns=["*.json", "*.md", "*.bin", "*.yaml", "*.txt"]
403
+ )
404
+
405
+ logger.info(f"Model pushed to https://huggingface.co/{repo_id}")
406
+
407
+ def create_readme(self) -> str:
408
+ """Create README for the trained model"""
409
+ readme_content = f"""# Multi-Agent LoRA Adapter for {self.config.base_model}
410
+
411
+ ## Overview
412
+ This is a LoRA (Low-Rank Adaptation) adapter trained for multi-agent scenarios using {self.config.base_model}.
413
+
414
+ ## Agent Conditioning Tokens
415
+ This adapter expects agent-specific tokens to condition the model behavior:
416
+
417
+ """
418
+
419
+ if self.agents:
420
+ for agent in self.agents:
421
+ token = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}"
422
+ readme_content += f"- `{token}` - {agent} agent\n"
423
+
424
+ readme_content += f"""
425
+ ## Usage Example
426
+
427
+ ```python
428
+ from transformers import AutoTokenizer, AutoModelForCausalLM
429
+ from peft import PeftModel
430
+
431
+ # Load base model and tokenizer
432
+ tokenizer = AutoTokenizer.from_pretrained("{self.config.base_model}")
433
+ model = AutoModelForCausalLM.from_pretrained("{self.config.base_model}")
434
+
435
+ # Load LoRA adapter
436
+ model = PeftModel.from_pretrained(model, "{self.config.hub_repo_id}")
437
+
438
+ # Example usage
439
+ prompt = "How do I implement a binary search algorithm?"
440
+ agent_token = "{self.config.agent_prefix}SWE{self.config.agent_suffix}\\n"
441
+ full_prompt = agent_token + prompt
442
+
443
+ inputs = tokenizer(full_prompt, return_tensors="pt")
444
+ outputs = model.generate(**inputs, max_new_tokens=100)
445
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
446
+ print(response)
447
+ ```
448
+
449
+ ## Training Configuration
450
+ - **Base Model**: {self.config.base_model}
451
+ - **LoRA Rank**: {self.config.lora_r}
452
+ - **LoRA Alpha**: {self.config.lora_alpha}
453
+ - **Learning Rate**: {self.config.learning_rate}
454
+ - **Max Steps**: {self.config.max_steps}
455
+ - **Batch Size**: {self.config.per_device_train_batch_size}
456
+
457
+ ## Dataset Statistics
458
+ - **Total Samples**: {self.dataset_stats.get('total_samples', 'N/A')}
459
+ - **Agents**: {', '.join(self.agents) if self.agents else 'N/A'}
460
+
461
+ ## License
462
+ This model is released under the same license as the base model.
463
+ """
464
+ else:
465
+ readme_content += "No specific agents were configured for this adapter.\n"
466
+
467
+ # Save README
468
+ readme_file = os.path.join(self.config.output_dir, "README.md")
469
+ with open(readme_file, 'w') as f:
470
+ f.write(readme_content)
471
+
472
+ logger.info(f"README created: {readme_file}")
473
+ return readme_file
474
+
475
+ class MultiAgentTrainingPipeline:
476
+ """
477
+ Complete pipeline for multi-agent training
478
+ """
479
+
480
+ def __init__(self, config: MultiAgentTrainingConfig):
481
+ self.config = config
482
+ self.trainer = CPUOptimizedMultiAgentTrainer(config)
483
+
484
+ def run_training(self, dataset_path: str) -> Dict[str, Any]:
485
+ """Run complete training pipeline"""
486
+ logger.info("Starting multi-agent training pipeline")
487
+
488
+ try:
489
+ # Load model and tokenizer
490
+ self.trainer.load_model_and_tokenizer()
491
+
492
+ # Load dataset
493
+ dataset, agents, stats = self.trainer.load_dataset(dataset_path)
494
+
495
+ # Setup agent tokens
496
+ self.trainer.setup_agent_tokens(agents)
497
+
498
+ # Create LoRA config
499
+ lora_config = self.trainer.create_lora_config()
500
+
501
+ # Create training arguments
502
+ training_args = self.trainer.create_training_arguments()
503
+
504
+ # Create trainer
505
+ self.trainer.create_trainer(dataset, lora_config, training_args)
506
+
507
+ # Train model
508
+ training_result = self.trainer.train()
509
+
510
+ # Create README
511
+ self.trainer.create_readme()
512
+
513
+ # Push to hub if configured
514
+ if self.config.push_to_hub:
515
+ self.trainer.push_to_hub()
516
+
517
+ logger.info("Training pipeline completed successfully")
518
+ return training_result
519
+
520
+ except Exception as e:
521
+ logger.error(f"Training pipeline failed: {e}")
522
+ raise
523
+
524
+ # Example usage and testing
525
+ if __name__ == "__main__":
526
+ # Configure logging
527
+ logging.basicConfig(level=logging.INFO)
528
+
529
+ # Example configuration
530
+ config = MultiAgentTrainingConfig(
531
+ base_model="microsoft/Phi-3.5-MoE-instruct",
532
+ output_dir="./outputs/multi_agent_test",
533
+ max_steps=10, # Small number for testing
534
+ hub_repo_id="test/multi-agent-adapter",
535
+ push_to_hub=False # Set to True for actual deployment
536
+ )
537
+
538
+ # Create training pipeline
539
+ pipeline = MultiAgentTrainingPipeline(config)
540
+
541
+ try:
542
+ # Run training (would need actual dataset path)
543
+ # result = pipeline.run_training("/path/to/dataset")
544
+ print("Multi-agent training pipeline ready")
545
+
546
+ except Exception as e:
547
+ print(f"Error: {e}")
src/multi_agent_training/specialized/curiosity_module.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Curiosity Training Module for MangoMAS Local
3
+
4
+ This module implements specialized training for curiosity and exploration capabilities,
5
+ adapted from the AWS backup system for local training.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ from typing import Any, Dict, List
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset
18
+
19
+ from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class CuriosityDataset(Dataset):
25
+ """Dataset for training curiosity and exploration capabilities."""
26
+
27
+ def __init__(self, data_path: str, tokenizer, max_length: int = 768):
28
+ """
29
+ Initialize the curiosity dataset.
30
+
31
+ Args:
32
+ data_path: Path to the curiosity data file
33
+ tokenizer: Tokenizer for text processing
34
+ max_length: Maximum sequence length
35
+ """
36
+ self.tokenizer = tokenizer
37
+ self.max_length = max_length
38
+ self.data = self._load_data(data_path)
39
+
40
+ logger.info(f"Loaded curiosity dataset with {len(self.data)} examples")
41
+
42
+ def _load_data(self, data_path: str) -> List[Dict]:
43
+ """Load curiosity training data."""
44
+ data = []
45
+ with open(data_path, "r", encoding="utf-8") as f:
46
+ for line in f:
47
+ try:
48
+ item = json.loads(line.strip())
49
+ # Validate required fields for curiosity data
50
+ if (
51
+ "scenario" in item
52
+ and "curiosity_questions" in item
53
+ and "exploration_directions" in item
54
+ ):
55
+ data.append(item)
56
+ except (json.JSONDecodeError, KeyError) as e:
57
+ logger.warning(f"Skipping invalid curiosity data: {e}")
58
+ return data
59
+
60
+ def __len__(self) -> int:
61
+ """Return the number of examples in the dataset."""
62
+ return len(self.data)
63
+
64
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
65
+ """Get a training example."""
66
+ item = self.data[idx]
67
+
68
+ # Format the curiosity example
69
+ scenario = item["scenario"]
70
+ curiosity_questions = item["curiosity_questions"]
71
+ exploration_directions = item["exploration_directions"]
72
+
73
+ # Format as text
74
+ text = f"Scenario: {scenario}\n\n"
75
+
76
+ text += "Curiosity Questions:\n"
77
+ for i, question in enumerate(curiosity_questions):
78
+ text += f"{i+1}. {question}\n"
79
+ text += "\n"
80
+
81
+ text += "Exploration Directions:\n"
82
+ for i, direction in enumerate(exploration_directions):
83
+ text += f"{i+1}. {direction}\n"
84
+
85
+ # Tokenize
86
+ encoding = self.tokenizer(
87
+ text,
88
+ truncation=True,
89
+ padding="max_length",
90
+ max_length=self.max_length,
91
+ return_tensors="pt",
92
+ )
93
+
94
+ return {
95
+ "input_ids": encoding["input_ids"].squeeze(),
96
+ "attention_mask": encoding["attention_mask"].squeeze(),
97
+ "labels": encoding["input_ids"].squeeze().clone(),
98
+ "scenario": scenario,
99
+ "curiosity_questions": curiosity_questions,
100
+ "exploration_directions": exploration_directions,
101
+ }
102
+
103
+
104
+ class CuriosityEvaluator:
105
+ """Evaluator for curiosity and exploration capabilities."""
106
+
107
+ def __init__(self, tokenizer):
108
+ """
109
+ Initialize the curiosity evaluator.
110
+
111
+ Args:
112
+ tokenizer: Tokenizer for text processing
113
+ """
114
+ self.tokenizer = tokenizer
115
+ self.metrics = {
116
+ "question_diversity": 0.0,
117
+ "exploration_breadth": 0.0,
118
+ "uncertainty_identification": 0.0,
119
+ "assumption_challenging": 0.0,
120
+ }
121
+
122
+ # Question starters for evaluating diversity
123
+ self.question_starters = [
124
+ "what",
125
+ "how",
126
+ "why",
127
+ "when",
128
+ "where",
129
+ "who",
130
+ "which",
131
+ "could",
132
+ "would",
133
+ "is",
134
+ "are",
135
+ "do",
136
+ "does",
137
+ "have",
138
+ "has",
139
+ ]
140
+
141
+ # Exploration markers
142
+ self.exploration_markers = [
143
+ "alternative",
144
+ "perspective",
145
+ "consider",
146
+ "explore",
147
+ "investigate",
148
+ "possibility",
149
+ "approach",
150
+ "angle",
151
+ "viewpoint",
152
+ "scenario",
153
+ ]
154
+
155
+ # Uncertainty phrases
156
+ self.uncertainty_phrases = [
157
+ "unclear",
158
+ "unknown",
159
+ "uncertain",
160
+ "not sure",
161
+ "ambiguous",
162
+ "might be",
163
+ "could be",
164
+ "possibly",
165
+ "perhaps",
166
+ "may",
167
+ ]
168
+
169
+ # Assumption challenging phrases
170
+ self.assumption_phrases = [
171
+ "assuming",
172
+ "assumption",
173
+ "presuppose",
174
+ "presupposition",
175
+ "take for granted",
176
+ "implicit",
177
+ "unstated",
178
+ "underlying",
179
+ ]
180
+
181
+ def evaluate(self, model, eval_dataset: CuriosityDataset) -> Dict[str, float]:
182
+ """
183
+ Evaluate curiosity capabilities on the provided dataset.
184
+
185
+ Args:
186
+ model: The model to evaluate
187
+ eval_dataset: Dataset of curiosity examples
188
+
189
+ Returns:
190
+ Dictionary of evaluation metrics
191
+ """
192
+ model.eval()
193
+ device = next(model.parameters()).device
194
+
195
+ # Reset metrics
196
+ for key in self.metrics:
197
+ self.metrics[key] = 0.0
198
+
199
+ total_examples = min(
200
+ len(eval_dataset), 50
201
+ ) # Limit to 50 examples for efficiency
202
+
203
+ with torch.no_grad():
204
+ for idx in range(total_examples):
205
+ example = eval_dataset[idx]
206
+ scenario = example["scenario"]
207
+
208
+ # Generate questions for the scenario
209
+ prompt = f"Scenario: {scenario}\n\nGenerate curious questions to explore this further:"
210
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
211
+ device
212
+ )
213
+
214
+ generated_ids = model.generate(
215
+ input_ids,
216
+ max_length=256,
217
+ temperature=0.8, # Slightly higher temperature for creativity
218
+ num_return_sequences=1,
219
+ )
220
+
221
+ generated_text = self.tokenizer.decode(
222
+ generated_ids[0], skip_special_tokens=True
223
+ )
224
+
225
+ # Evaluate curiosity quality
226
+ self._evaluate_curiosity(
227
+ scenario=scenario,
228
+ expected_questions=example["curiosity_questions"],
229
+ expected_directions=example["exploration_directions"],
230
+ generated_text=generated_text,
231
+ )
232
+
233
+ # Calculate averages
234
+ for key in self.metrics:
235
+ self.metrics[key] /= total_examples
236
+
237
+ return self.metrics
238
+
239
+ def _evaluate_curiosity(
240
+ self,
241
+ scenario: str,
242
+ expected_questions: List[str],
243
+ expected_directions: List[str],
244
+ generated_text: str,
245
+ ) -> None:
246
+ """
247
+ Evaluate curiosity quality for a specific example.
248
+
249
+ Args:
250
+ scenario: The scenario to explore
251
+ expected_questions: Expected curiosity questions
252
+ expected_directions: Expected exploration directions
253
+ generated_text: The text generated by the model
254
+ """
255
+ # Extract questions from generated text (simple approach)
256
+ generated_questions = [
257
+ line.strip()
258
+ for line in generated_text.split("\n")
259
+ if line.strip().endswith("?")
260
+ ]
261
+
262
+ if not generated_questions:
263
+ # Try to extract numbered questions
264
+ for line in generated_text.split("\n"):
265
+ if any(f"{i}." in line for i in range(1, 10)) and "?" in line:
266
+ generated_questions.append(line.strip())
267
+
268
+ # 1. Question diversity - variety of question types
269
+ starter_counts = {starter: 0 for starter in self.question_starters}
270
+ for question in generated_questions:
271
+ for starter in self.question_starters:
272
+ if (
273
+ question.lower().startswith(starter)
274
+ or f" {starter} " in question.lower()
275
+ ):
276
+ starter_counts[starter] += 1
277
+
278
+ unique_starters = sum(1 for count in starter_counts.values() if count > 0)
279
+ self.metrics["question_diversity"] += min(
280
+ 1.0, unique_starters / 5
281
+ ) # Normalize to 5 unique types
282
+
283
+ # 2. Exploration breadth - check for exploration markers
284
+ exploration_marker_count = sum(
285
+ 1 for marker in self.exploration_markers if marker in generated_text.lower()
286
+ )
287
+ self.metrics["exploration_breadth"] += min(1.0, exploration_marker_count / 3)
288
+
289
+ # 3. Uncertainty identification - check for uncertainty phrases
290
+ uncertainty_phrase_count = sum(
291
+ 1 for phrase in self.uncertainty_phrases if phrase in generated_text.lower()
292
+ )
293
+ self.metrics["uncertainty_identification"] += min(
294
+ 1.0, uncertainty_phrase_count / 2
295
+ )
296
+
297
+ # 4. Assumption challenging - check for phrases that challenge assumptions
298
+ assumption_phrase_count = sum(
299
+ 1 for phrase in self.assumption_phrases if phrase in generated_text.lower()
300
+ )
301
+ self.metrics["assumption_challenging"] += min(1.0, assumption_phrase_count / 1)
302
+
303
+
304
+ class CuriosityTrainingModule(SpecializedTrainingModule):
305
+ """Specialized training module for curiosity and exploration capabilities."""
306
+
307
+ def __init__(self, config: TrainingModuleConfig, tokenizer):
308
+ """
309
+ Initialize the curiosity training module.
310
+
311
+ Args:
312
+ config: Module configuration
313
+ tokenizer: Tokenizer for text processing
314
+ """
315
+ super().__init__(config, tokenizer)
316
+
317
+ # Initialize curiosity-specific components
318
+ self.data_path = config.data_path or "data/processed/curiosity_train.jsonl"
319
+ self.evaluator = CuriosityEvaluator(tokenizer)
320
+
321
+ # Curiosity-specific loss with higher temperature
322
+ self.curiosity_temp = config.module_config.get("temperature", 1.5)
323
+ self.curiosity_loss = nn.CrossEntropyLoss(ignore_index=-100)
324
+
325
+ # Training metrics
326
+ self.metrics = {
327
+ "curiosity_loss": 0.0,
328
+ "question_generation_score": 0.0,
329
+ "exploration_score": 0.0,
330
+ }
331
+
332
+ logger.info(
333
+ f"Initialized curiosity training module with temperature: {self.curiosity_temp}"
334
+ )
335
+
336
+ def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
337
+ """
338
+ Prepare a batch of data for curiosity training.
339
+
340
+ Args:
341
+ batch: The input batch from the dataloader
342
+
343
+ Returns:
344
+ Processed batch ready for curiosity training
345
+ """
346
+ # Extract curiosity-specific elements if they exist
347
+ if all(
348
+ key in batch
349
+ for key in ["scenario", "curiosity_questions", "exploration_directions"]
350
+ ):
351
+ # This is already a curiosity-specific batch
352
+ return batch
353
+
354
+ # For general conversation batches, we could extract potential exploration scenarios
355
+ # This is a simplified placeholder implementation
356
+ return batch
357
+
358
+ def compute_loss(
359
+ self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
360
+ ) -> torch.Tensor:
361
+ """
362
+ Compute the curiosity-specific loss.
363
+
364
+ Args:
365
+ student_outputs: Outputs from the student model
366
+ teacher_outputs: Outputs from the teacher model
367
+ batch: The processed input batch
368
+
369
+ Returns:
370
+ Curiosity-specific loss tensor
371
+ """
372
+ # Get logits from outputs
373
+ student_logits = (
374
+ student_outputs.logits
375
+ if hasattr(student_outputs, "logits")
376
+ else student_outputs
377
+ )
378
+ teacher_logits = (
379
+ teacher_outputs.logits
380
+ if hasattr(teacher_outputs, "logits")
381
+ else teacher_outputs
382
+ )
383
+
384
+ # Standard distillation loss calculation
385
+ student_logits = student_logits[:, :-1, :].contiguous()
386
+ teacher_logits = teacher_logits[:, :-1, :].contiguous()
387
+ target_ids = batch["labels"][:, 1:].contiguous()
388
+
389
+ # For curiosity, we use a higher temperature to encourage more diverse outputs
390
+ temperature = self.curiosity_temp
391
+ kl_loss = F.kl_div(
392
+ F.log_softmax(student_logits / temperature, dim=-1),
393
+ F.softmax(teacher_logits / temperature, dim=-1),
394
+ reduction="batchmean",
395
+ ) * (temperature**2)
396
+
397
+ # Cross-entropy loss against labels
398
+ ce_loss = self.curiosity_loss(
399
+ student_logits.view(-1, student_logits.size(-1)), target_ids.view(-1)
400
+ )
401
+
402
+ # Combined loss with curiosity focus
403
+ # We weight KL divergence higher to encourage exploration
404
+ loss = 0.4 * ce_loss + 0.6 * kl_loss
405
+
406
+ # Update metrics
407
+ self.metrics["curiosity_loss"] = loss.item()
408
+
409
+ return loss
410
+
411
+ def get_metrics(self) -> Dict[str, float]:
412
+ """
413
+ Get metrics specific to curiosity training.
414
+
415
+ Returns:
416
+ Dictionary of metric names and values
417
+ """
418
+ return self.metrics
419
+
420
+ def generate_synthetic_curiosity_data(
421
+ self, output_path: str, num_samples: int = 1000
422
+ ) -> None:
423
+ """
424
+ Generate synthetic curiosity training data.
425
+
426
+ Args:
427
+ output_path: Path to save the generated data
428
+ num_samples: Number of samples to generate
429
+ """
430
+ # This is a simplified implementation based on the AWS backup
431
+ # In a full implementation, this would be much more sophisticated
432
+
433
+ curiosity_templates = [
434
+ {
435
+ "scenario": "A company is developing a new voice assistant technology.",
436
+ "curiosity_questions": [
437
+ "How might this technology affect people's privacy in their homes?",
438
+ "What unexpected ways might users interact with this technology?",
439
+ "How could this technology evolve over the next five years?",
440
+ "What ethical considerations might arise from widespread adoption?",
441
+ "How might this technology affect different demographic groups differently?",
442
+ ],
443
+ "exploration_directions": [
444
+ "Consider alternative interaction models beyond voice commands",
445
+ "Explore potential integration with other smart home systems",
446
+ "Investigate privacy-preserving design approaches",
447
+ "Consider accessibility implications for diverse user groups",
448
+ "Examine potential unintended consequences of ambient listening",
449
+ ],
450
+ },
451
+ {
452
+ "scenario": "Scientists have discovered a new species of deep-sea organism that can survive extreme pressure and temperature.",
453
+ "curiosity_questions": [
454
+ "What adaptations allow this organism to survive such extreme conditions?",
455
+ "Could these adaptations be applied to human technology or medicine?",
456
+ "What might this discovery tell us about the possibility of life on other planets?",
457
+ "How might climate change affect deep-sea ecosystems and this organism?",
458
+ "What other undiscovered species might exist in similar environments?",
459
+ ],
460
+ "exploration_directions": [
461
+ "Examine evolutionary pathways for extreme environment adaptation",
462
+ "Consider biomimicry applications in engineering and materials science",
463
+ "Explore implications for astrobiology and extraterrestrial life",
464
+ "Investigate ecological relationships in extreme environments",
465
+ "Consider ethical dimensions of deep-sea exploration and bioprospecting",
466
+ ],
467
+ },
468
+ {
469
+ "scenario": "A small town is experiencing rapid population growth due to remote workers relocating from urban areas.",
470
+ "curiosity_questions": [
471
+ "How might this demographic shift affect the town's culture and community?",
472
+ "What infrastructure challenges might arise from rapid population growth?",
473
+ "How could this trend impact local housing prices and affordability?",
474
+ "What economic opportunities and challenges might emerge?",
475
+ "How might long-term residents and newcomers develop different perspectives?",
476
+ ],
477
+ "exploration_directions": [
478
+ "Investigate similar historical population shifts and their outcomes",
479
+ "Consider varying perspectives from different stakeholder groups",
480
+ "Explore potential policy approaches to manage growth sustainably",
481
+ "Examine social integration mechanisms between established and new residents",
482
+ "Consider environmental impacts of changing land use patterns",
483
+ ],
484
+ },
485
+ ]
486
+
487
+ # Generate variations
488
+ output_data = []
489
+ for _ in range(num_samples):
490
+ template = random.choice(curiosity_templates)
491
+
492
+ # Create a variation to avoid exact duplicates
493
+ variation = template.copy()
494
+
495
+ # Add metadata
496
+ variation["metadata"] = {
497
+ "generated": True,
498
+ "timestamp": (
499
+ torch.cuda.get_device_name(0)
500
+ if torch.cuda.is_available()
501
+ else "CPU"
502
+ ),
503
+ "requires_exploration": True,
504
+ }
505
+
506
+ output_data.append(variation)
507
+
508
+ # Save to file
509
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
510
+ with open(output_path, "w", encoding="utf-8") as f:
511
+ for item in output_data:
512
+ f.write(json.dumps(item) + "\n")
513
+
514
+ logger.info(
515
+ f"Generated {len(output_data)} synthetic curiosity examples at {output_path}"
516
+ )
src/multi_agent_training/specialized/empathy_module.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Empathy Training Module for MangoMAS Local
3
+
4
+ This module implements specialized training for empathy and emotional intelligence,
5
+ adapted from the AWS backup system for local training.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ from typing import Any, Dict, List
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset
18
+
19
+ from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EmpathyDataset(Dataset):
25
+ """Dataset for training empathy and emotional intelligence capabilities."""
26
+
27
+ def __init__(self, data_path: str, tokenizer, max_length: int = 768):
28
+ """
29
+ Initialize the empathy dataset.
30
+
31
+ Args:
32
+ data_path: Path to the empathy data file
33
+ tokenizer: Tokenizer for text processing
34
+ max_length: Maximum sequence length
35
+ """
36
+ self.tokenizer = tokenizer
37
+ self.max_length = max_length
38
+ self.data = self._load_data(data_path)
39
+
40
+ logger.info(f"Loaded empathy dataset with {len(self.data)} examples")
41
+
42
+ def _load_data(self, data_path: str) -> List[Dict]:
43
+ """Load empathy training data."""
44
+ data = []
45
+ with open(data_path, "r", encoding="utf-8") as f:
46
+ for line in f:
47
+ try:
48
+ item = json.loads(line.strip())
49
+ # Validate required fields for empathy data
50
+ if (
51
+ "user_message" in item
52
+ and "emotional_state" in item
53
+ and "empathetic_response" in item
54
+ ):
55
+ data.append(item)
56
+ except (json.JSONDecodeError, KeyError) as e:
57
+ logger.warning(f"Skipping invalid empathy data: {e}")
58
+ return data
59
+
60
+ def __len__(self) -> int:
61
+ """Return the number of examples in the dataset."""
62
+ return len(self.data)
63
+
64
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
65
+ """Get a training example."""
66
+ item = self.data[idx]
67
+
68
+ # Format the empathy example
69
+ user_message = item["user_message"]
70
+ emotional_state = item["emotional_state"]
71
+ empathetic_response = item["empathetic_response"]
72
+
73
+ # Additional fields if available
74
+ emotional_cues = item.get("emotional_cues", [])
75
+ context = item.get("context", "")
76
+
77
+ # Construct the text with empathy markers
78
+ text = f"User: {user_message}\n\n"
79
+
80
+ # Include emotional analysis section for training
81
+ text += f"Emotional State: {emotional_state}\n"
82
+
83
+ if emotional_cues:
84
+ text += "Emotional Cues:\n"
85
+ for i, cue in enumerate(emotional_cues):
86
+ text += f"{i+1}. {cue}\n"
87
+ text += "\n"
88
+
89
+ if context:
90
+ text += f"Context: {context}\n\n"
91
+
92
+ text += f"Empathetic Response: {empathetic_response}"
93
+
94
+ # Tokenize
95
+ encoding = self.tokenizer(
96
+ text,
97
+ truncation=True,
98
+ padding="max_length",
99
+ max_length=self.max_length,
100
+ return_tensors="pt",
101
+ )
102
+
103
+ return {
104
+ "input_ids": encoding["input_ids"].squeeze(),
105
+ "attention_mask": encoding["attention_mask"].squeeze(),
106
+ "labels": encoding["input_ids"].squeeze().clone(),
107
+ "user_message": user_message,
108
+ "emotional_state": emotional_state,
109
+ "empathetic_response": empathetic_response,
110
+ }
111
+
112
+
113
+ class EmpathyEvaluator:
114
+ """Evaluator for empathy and emotional intelligence capabilities."""
115
+
116
+ def __init__(self, tokenizer):
117
+ """
118
+ Initialize the empathy evaluator.
119
+
120
+ Args:
121
+ tokenizer: Tokenizer for text processing
122
+ """
123
+ self.tokenizer = tokenizer
124
+ self.metrics = {
125
+ "emotional_recognition": 0.0,
126
+ "empathetic_language": 0.0,
127
+ "supportive_tone": 0.0,
128
+ "personalization": 0.0,
129
+ }
130
+
131
+ # Empathetic language markers
132
+ self.empathetic_phrases = [
133
+ "understand",
134
+ "feel",
135
+ "appreciate",
136
+ "recognize",
137
+ "acknowledge",
138
+ "must be",
139
+ "sounds like",
140
+ "seems like",
141
+ "I hear you",
142
+ "that's difficult",
143
+ "that's challenging",
144
+ "I'm sorry",
145
+ "thank you for sharing",
146
+ "I can imagine",
147
+ ]
148
+
149
+ # Emotional state categories
150
+ self.emotional_states = {
151
+ "positive": [
152
+ "happy",
153
+ "excited",
154
+ "grateful",
155
+ "proud",
156
+ "hopeful",
157
+ "inspired",
158
+ ],
159
+ "negative": [
160
+ "sad",
161
+ "angry",
162
+ "frustrated",
163
+ "anxious",
164
+ "disappointed",
165
+ "overwhelmed",
166
+ ],
167
+ "neutral": [
168
+ "confused",
169
+ "uncertain",
170
+ "curious",
171
+ "surprised",
172
+ "contemplative",
173
+ ],
174
+ }
175
+
176
+ def evaluate(self, model, eval_dataset: EmpathyDataset) -> Dict[str, float]:
177
+ """
178
+ Evaluate empathy capabilities on the provided dataset.
179
+
180
+ Args:
181
+ model: The model to evaluate
182
+ eval_dataset: Dataset of empathy examples
183
+
184
+ Returns:
185
+ Dictionary of evaluation metrics
186
+ """
187
+ model.eval()
188
+ device = next(model.parameters()).device
189
+
190
+ # Reset metrics
191
+ for key in self.metrics:
192
+ self.metrics[key] = 0.0
193
+
194
+ total_examples = min(
195
+ len(eval_dataset), 50
196
+ ) # Limit to 50 examples for efficiency
197
+
198
+ with torch.no_grad():
199
+ for idx in range(total_examples):
200
+ example = eval_dataset[idx]
201
+ user_message = example["user_message"]
202
+ expected_emotional_state = example["emotional_state"]
203
+
204
+ # Generate response without providing emotional state
205
+ prompt = f"User: {user_message}\n\nProvide an empathetic response:"
206
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
207
+ device
208
+ )
209
+
210
+ generated_ids = model.generate(
211
+ input_ids, max_length=256, temperature=0.7, num_return_sequences=1
212
+ )
213
+
214
+ generated_text = self.tokenizer.decode(
215
+ generated_ids[0], skip_special_tokens=True
216
+ )
217
+
218
+ # Evaluate empathy quality
219
+ self._evaluate_empathy(
220
+ user_message=user_message,
221
+ expected_emotional_state=expected_emotional_state,
222
+ expected_response=example["empathetic_response"],
223
+ generated_response=generated_text,
224
+ )
225
+
226
+ # Calculate averages
227
+ for key in self.metrics:
228
+ self.metrics[key] /= total_examples
229
+
230
+ return self.metrics
231
+
232
+ def _evaluate_empathy(
233
+ self,
234
+ user_message: str,
235
+ expected_emotional_state: str,
236
+ expected_response: str,
237
+ generated_response: str,
238
+ ) -> None:
239
+ """
240
+ Evaluate empathy quality for a specific example.
241
+
242
+ Args:
243
+ user_message: The user's message
244
+ expected_emotional_state: Expected identified emotional state
245
+ expected_response: Expected empathetic response
246
+ generated_response: The response generated by the model
247
+ """
248
+ # 1. Emotional recognition - check if response acknowledges correct emotion
249
+ emotional_category = None
250
+ for category, emotions in self.emotional_states.items():
251
+ if any(emotion in expected_emotional_state.lower() for emotion in emotions):
252
+ emotional_category = category
253
+ break
254
+
255
+ if emotional_category:
256
+ # Check if response contains words matching the emotional category
257
+ emotion_words = self.emotional_states[emotional_category]
258
+ emotion_recognition = any(
259
+ word in generated_response.lower() for word in emotion_words
260
+ )
261
+ self.metrics["emotional_recognition"] += 1.0 if emotion_recognition else 0.0
262
+ else:
263
+ # Default partial score if we couldn't categorize
264
+ self.metrics["emotional_recognition"] += 0.5
265
+
266
+ # 2. Empathetic language - check for empathetic phrases
267
+ empathy_phrase_count = sum(
268
+ 1
269
+ for phrase in self.empathetic_phrases
270
+ if phrase in generated_response.lower()
271
+ )
272
+ self.metrics["empathetic_language"] += min(1.0, empathy_phrase_count / 2)
273
+
274
+ # 3. Supportive tone - simplified check for supportive language
275
+ supportive_score = 0.0
276
+ if (
277
+ "here for you" in generated_response.lower()
278
+ or "support" in generated_response.lower()
279
+ ):
280
+ supportive_score += 0.5
281
+ if (
282
+ "help" in generated_response.lower()
283
+ or "advice" in generated_response.lower()
284
+ ):
285
+ supportive_score += 0.3
286
+ if any(
287
+ phrase in generated_response.lower()
288
+ for phrase in ["let me know", "is there anything", "can i"]
289
+ ):
290
+ supportive_score += 0.2
291
+ self.metrics["supportive_tone"] += min(1.0, supportive_score)
292
+
293
+ # 4. Personalization - check if response refers to specific details from user message
294
+ user_specific_terms = set(user_message.lower().split()) - {
295
+ "i",
296
+ "me",
297
+ "my",
298
+ "mine",
299
+ "am",
300
+ "was",
301
+ "the",
302
+ "a",
303
+ "an",
304
+ }
305
+ generated_terms = set(generated_response.lower().split())
306
+ specific_term_overlap = len(user_specific_terms.intersection(generated_terms))
307
+ self.metrics["personalization"] += min(1.0, specific_term_overlap / 3)
308
+
309
+
310
+ class EmpathyTrainingModule(SpecializedTrainingModule):
311
+ """Specialized training module for empathy and emotional intelligence capabilities."""
312
+
313
+ def __init__(self, config: TrainingModuleConfig, tokenizer):
314
+ """
315
+ Initialize the empathy training module.
316
+
317
+ Args:
318
+ config: Module configuration
319
+ tokenizer: Tokenizer for text processing
320
+ """
321
+ super().__init__(config, tokenizer)
322
+
323
+ # Initialize empathy-specific components
324
+ self.data_path = config.data_path or "data/processed/empathy_train.jsonl"
325
+ self.evaluator = EmpathyEvaluator(tokenizer)
326
+
327
+ # Empathy-specific loss
328
+ self.empathy_loss = nn.CrossEntropyLoss(ignore_index=-100)
329
+
330
+ # Training metrics
331
+ self.metrics = {
332
+ "empathy_loss": 0.0,
333
+ "emotion_recognition_rate": 0.0,
334
+ "empathetic_language_score": 0.0,
335
+ }
336
+
337
+ logger.info("Initialized empathy training module")
338
+
339
+ def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
340
+ """
341
+ Prepare a batch of data for empathy training.
342
+
343
+ Args:
344
+ batch: The input batch from the dataloader
345
+
346
+ Returns:
347
+ Processed batch ready for empathy training
348
+ """
349
+ # Extract empathy-specific elements if they exist
350
+ if all(
351
+ key in batch
352
+ for key in ["user_message", "emotional_state", "empathetic_response"]
353
+ ):
354
+ # This is already an empathy-specific batch
355
+ return batch
356
+
357
+ # For general conversation batches, we need to identify emotional content
358
+ # This is a simplified placeholder implementation
359
+ return batch
360
+
361
+ def compute_loss(
362
+ self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
363
+ ) -> torch.Tensor:
364
+ """
365
+ Compute the empathy-specific loss.
366
+
367
+ Args:
368
+ student_outputs: Outputs from the student model
369
+ teacher_outputs: Outputs from the teacher model
370
+ batch: The processed input batch
371
+
372
+ Returns:
373
+ Empathy-specific loss tensor
374
+ """
375
+ # Get logits from outputs
376
+ student_logits = (
377
+ student_outputs.logits
378
+ if hasattr(student_outputs, "logits")
379
+ else student_outputs
380
+ )
381
+ teacher_logits = (
382
+ teacher_outputs.logits
383
+ if hasattr(teacher_outputs, "logits")
384
+ else teacher_outputs
385
+ )
386
+
387
+ # Standard distillation loss calculation
388
+ student_logits = student_logits[:, :-1, :].contiguous()
389
+ teacher_logits = teacher_logits[:, :-1, :].contiguous()
390
+ target_ids = batch["labels"][:, 1:].contiguous()
391
+
392
+ # KL divergence loss for distillation
393
+ temperature = 2.0
394
+ kl_loss = F.kl_div(
395
+ F.log_softmax(student_logits / temperature, dim=-1),
396
+ F.softmax(teacher_logits / temperature, dim=-1),
397
+ reduction="batchmean",
398
+ ) * (temperature**2)
399
+
400
+ # Cross-entropy loss against labels
401
+ ce_loss = self.empathy_loss(
402
+ student_logits.view(-1, student_logits.size(-1)), target_ids.view(-1)
403
+ )
404
+
405
+ # Combined loss with balanced weighting
406
+ loss = 0.5 * ce_loss + 0.5 * kl_loss
407
+
408
+ # Update metrics
409
+ self.metrics["empathy_loss"] = loss.item()
410
+
411
+ return loss
412
+
413
+ def get_metrics(self) -> Dict[str, float]:
414
+ """
415
+ Get metrics specific to empathy training.
416
+
417
+ Returns:
418
+ Dictionary of metric names and values
419
+ """
420
+ return self.metrics
421
+
422
+ def generate_synthetic_empathy_data(
423
+ self, output_path: str, num_samples: int = 1000
424
+ ) -> None:
425
+ """
426
+ Generate synthetic empathy training data.
427
+
428
+ Args:
429
+ output_path: Path to save the generated data
430
+ num_samples: Number of samples to generate
431
+ """
432
+ # This is a simplified implementation based on the AWS backup
433
+ # In a full implementation, this would be much more sophisticated
434
+
435
+ empathy_templates = [
436
+ {
437
+ "user_message": "I just lost my job after 5 years at the company. I'm feeling really lost and worried about how I'll pay my bills.",
438
+ "emotional_state": "anxious and worried",
439
+ "emotional_cues": [
440
+ "Lost job unexpectedly",
441
+ "Financial insecurity",
442
+ "Identity/purpose uncertainty",
443
+ "Use of 'lost' to describe feelings",
444
+ ],
445
+ "empathetic_response": "I'm really sorry to hear about losing your job, especially after dedicating 5 years to the company. That's incredibly difficult. It's completely understandable to feel lost and worried about your financial situation right now. Sudden changes like this can be overwhelming. Would it help to talk through some immediate steps you might take, or would you prefer to just share how you're processing this right now?",
446
+ "context": "Professional setback, financial stress",
447
+ },
448
+ {
449
+ "user_message": "My dog passed away yesterday. He was with me for 12 years and I don't know how to handle this emptiness in my home.",
450
+ "emotional_state": "grief and sadness",
451
+ "emotional_cues": [
452
+ "Loss of longtime companion",
453
+ "Reference to 'emptiness'",
454
+ "Uncertainty about coping",
455
+ "Recent timing (yesterday)",
456
+ ],
457
+ "empathetic_response": "I'm so sorry about the loss of your dog. After 12 years together, they're truly family, and that leaves such a profound emptiness. The love you shared for so many years speaks to what a caring person you are and what a special connection you had. There's no timeline for grief, especially for a companion who was part of your daily life for so long. Would it help to talk about some of your favorite memories of them?",
458
+ "context": "Pet loss, grief",
459
+ },
460
+ {
461
+ "user_message": "I just got accepted to my dream university! I've been working toward this for years and can't believe it finally happened.",
462
+ "emotional_state": "excited and proud",
463
+ "emotional_cues": [
464
+ "Exclamation mark showing excitement",
465
+ "Achievement of long-term goal",
466
+ "Expression of disbelief at success",
467
+ "Years of effort mentioned",
468
+ ],
469
+ "empathetic_response": "Congratulations! That's absolutely wonderful news about your acceptance to your dream university! Years of hard work and dedication have paid off, and you have every right to feel excited and proud of this achievement. It's such a special feeling when a goal you've worked toward for so long finally becomes reality. I'd love to hear more about the university and what you're most looking forward to as you start this new chapter!",
470
+ "context": "Academic achievement, celebration",
471
+ },
472
+ ]
473
+
474
+ # Generate variations
475
+ output_data = []
476
+ for _ in range(num_samples):
477
+ template = random.choice(empathy_templates)
478
+
479
+ # Create a variation to avoid exact duplicates
480
+ variation = template.copy()
481
+
482
+ # Add metadata
483
+ variation["metadata"] = {
484
+ "generated": True,
485
+ "timestamp": (
486
+ torch.cuda.get_device_name(0)
487
+ if torch.cuda.is_available()
488
+ else "CPU"
489
+ ),
490
+ }
491
+
492
+ output_data.append(variation)
493
+
494
+ # Save to file
495
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
496
+ with open(output_path, "w", encoding="utf-8") as f:
497
+ for item in output_data:
498
+ f.write(json.dumps(item) + "\n")
499
+
500
+ logger.info(
501
+ f"Generated {len(output_data)} synthetic empathy examples at {output_path}"
502
+ )
src/multi_agent_training/specialized/ethics_module.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ethics Training Module for MangoMAS Local
3
+
4
+ This module implements specialized training for ethical reasoning capabilities,
5
+ adapted from the AWS backup system for local training.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ from typing import Any, Dict, List
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset
18
+
19
+ from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EthicsDataset(Dataset):
25
+ """Dataset for training ethical reasoning capabilities."""
26
+
27
+ def __init__(self, data_path: str, tokenizer, max_length: int = 768):
28
+ """
29
+ Initialize the ethics dataset.
30
+
31
+ Args:
32
+ data_path: Path to the ethics data file
33
+ tokenizer: Tokenizer for text processing
34
+ max_length: Maximum sequence length
35
+ """
36
+ self.tokenizer = tokenizer
37
+ self.max_length = max_length
38
+ self.data = self._load_data(data_path)
39
+
40
+ logger.info(f"Loaded ethics dataset with {len(self.data)} examples")
41
+
42
+ def _load_data(self, data_path: str) -> List[Dict]:
43
+ """Load ethics training data."""
44
+ data = []
45
+ with open(data_path, "r", encoding="utf-8") as f:
46
+ for line in f:
47
+ try:
48
+ item = json.loads(line.strip())
49
+ # Validate required fields for ethics data
50
+ if (
51
+ "scenario" in item
52
+ and "ethical_analysis" in item
53
+ and "recommendation" in item
54
+ ):
55
+ data.append(item)
56
+ except json.JSONDecodeError:
57
+ continue
58
+ return data
59
+
60
+ def __len__(self):
61
+ return len(self.data)
62
+
63
+ def __getitem__(self, idx):
64
+ item = self.data[idx]
65
+
66
+ # Format the ethics prompt
67
+ prompt = f"Scenario: {item['scenario']}\nEthical Analysis: {item['ethical_analysis']}\nRecommendation: {item['recommendation']}"
68
+
69
+ # Tokenize
70
+ encoding = self.tokenizer(
71
+ prompt,
72
+ max_length=self.max_length,
73
+ padding="max_length",
74
+ truncation=True,
75
+ return_tensors="pt",
76
+ )
77
+
78
+ return {
79
+ "input_ids": encoding["input_ids"].squeeze(),
80
+ "attention_mask": encoding["attention_mask"].squeeze(),
81
+ "labels": encoding["input_ids"].squeeze(),
82
+ }
83
+
84
+
85
+ class EthicsTrainingModule(SpecializedTrainingModule):
86
+ """Specialized training module for ethical reasoning capabilities."""
87
+
88
+ def __init__(self, config: TrainingModuleConfig, tokenizer):
89
+ """
90
+ Initialize the ethics training module.
91
+
92
+ Args:
93
+ config: Module configuration
94
+ tokenizer: Tokenizer for text processing
95
+ """
96
+ super().__init__(config, tokenizer)
97
+
98
+ # Initialize ethics-specific components
99
+ self.ethics_loss = nn.CrossEntropyLoss(ignore_index=-100)
100
+ self.metrics = {
101
+ "ethics_loss": 0.0,
102
+ "ethical_consistency": 0.0,
103
+ "principle_alignment": 0.0,
104
+ }
105
+
106
+ logger.info("Initialized EthicsTrainingModule")
107
+
108
+ def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
109
+ """
110
+ Prepare a batch of data for ethics training.
111
+
112
+ Args:
113
+ batch: The input batch from the dataloader
114
+
115
+ Returns:
116
+ Processed batch ready for ethics training
117
+ """
118
+ # Move batch to device
119
+ prepared_batch = {}
120
+ for key, value in batch.items():
121
+ if isinstance(value, torch.Tensor):
122
+ prepared_batch[key] = value.to(self.device)
123
+ else:
124
+ prepared_batch[key] = value
125
+
126
+ return prepared_batch
127
+
128
+ def compute_loss(
129
+ self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
130
+ ) -> torch.Tensor:
131
+ """
132
+ Compute the ethics-specific loss.
133
+
134
+ Args:
135
+ student_outputs: Outputs from the student model
136
+ teacher_outputs: Outputs from the teacher model
137
+ batch: The processed input batch
138
+
139
+ Returns:
140
+ Loss tensor for ethics training
141
+ """
142
+ try:
143
+ # Extract logits from model outputs
144
+ if hasattr(student_outputs, "logits"):
145
+ student_logits = student_outputs.logits
146
+ else:
147
+ student_logits = student_outputs
148
+
149
+ if hasattr(teacher_outputs, "logits"):
150
+ teacher_logits = teacher_outputs.logits
151
+ else:
152
+ teacher_logits = teacher_outputs
153
+
154
+ # Get labels from batch
155
+ labels = batch.get("labels", batch.get("input_ids"))
156
+
157
+ # Compute cross entropy loss for ethics
158
+ shift_logits = student_logits[..., :-1, :].contiguous()
159
+ shift_labels = labels[..., 1:].contiguous()
160
+
161
+ ethics_loss = self.ethics_loss(
162
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
163
+ )
164
+
165
+ # Add KL divergence loss between student and teacher
166
+ if teacher_logits is not None:
167
+ kl_loss = F.kl_div(
168
+ F.log_softmax(student_logits, dim=-1),
169
+ F.softmax(teacher_logits, dim=-1),
170
+ reduction="batchmean",
171
+ )
172
+ total_loss = ethics_loss + 0.1 * kl_loss
173
+ else:
174
+ total_loss = ethics_loss
175
+
176
+ # Update metrics
177
+ self.metrics["ethics_loss"] = ethics_loss.item()
178
+
179
+ return total_loss * self.loss_weight
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error computing ethics loss: {e}")
183
+ # Return a small loss to avoid training failure
184
+ return torch.tensor(0.01, requires_grad=True)
185
+
186
+ def get_metrics(self) -> Dict[str, float]:
187
+ """
188
+ Get metrics specific to ethics training.
189
+
190
+ Returns:
191
+ Dictionary of ethics metrics
192
+ """
193
+ return self.metrics.copy()
194
+
195
+ def generate_synthetic_ethics_data(
196
+ self, output_path: str, num_samples: int = 1000
197
+ ) -> None:
198
+ """
199
+ Generate synthetic ethics training data.
200
+
201
+ Args:
202
+ output_path: Path to save the generated data
203
+ num_samples: Number of samples to generate
204
+ """
205
+ # This is a simplified implementation based on the AWS backup
206
+ # In a full implementation, this would be much more sophisticated
207
+
208
+ ethics_templates = [
209
+ {
210
+ "scenario": "A company wants to collect user location data to improve their mapping service.",
211
+ "principles": [
212
+ "privacy",
213
+ "consent",
214
+ "transparency",
215
+ "data minimization",
216
+ ],
217
+ "considerations": [
218
+ "Users should be clearly informed about what data is collected",
219
+ "Data collection should be opt-in rather than opt-out",
220
+ "Only necessary data should be collected and stored",
221
+ "Data should be properly secured and anonymized where possible",
222
+ ],
223
+ "ethical_analysis": "The collection of location data raises privacy concerns, but can be ethical if done with transparency, informed consent, and data minimization practices. Users must be clearly informed about what data is collected, how it's used, and given genuine choice in the matter.",
224
+ "recommendation": "Proceed with location data collection only with explicit opt-in consent, clear privacy notices, data minimization practices, and strong security measures.",
225
+ "stance": "neutral",
226
+ },
227
+ {
228
+ "scenario": "An AI company is developing a facial recognition system to be sold to law enforcement agencies without oversight mechanisms.",
229
+ "principles": [
230
+ "privacy",
231
+ "justice",
232
+ "accountability",
233
+ "potential for discrimination",
234
+ ],
235
+ "considerations": [
236
+ "Facial recognition has known bias issues across different demographics",
237
+ "Law enforcement use creates significant civil liberties concerns",
238
+ "Lack of oversight could lead to misuse and privacy violations",
239
+ "Potential chilling effect on free speech and assembly",
240
+ ],
241
+ "ethical_analysis": "Deploying facial recognition in law enforcement without oversight mechanisms raises serious ethical concerns. These systems have demonstrated bias across demographic groups, potentially leading to discriminatory outcomes. Without accountability measures, there's significant risk of misuse, privacy violations, and erosion of civil liberties.",
242
+ "recommendation": "Do not deploy facial recognition systems to law enforcement without robust oversight, accuracy testing across demographics, clear usage limitations, and strong accountability mechanisms.",
243
+ "stance": "harmful",
244
+ },
245
+ {
246
+ "scenario": "A medical AI is being developed to help doctors identify potential early signs of cancer in medical images.",
247
+ "principles": [
248
+ "beneficence",
249
+ "non-maleficence",
250
+ "human oversight",
251
+ "transparency",
252
+ ],
253
+ "considerations": [
254
+ "Early detection could save many lives",
255
+ "False positives could cause unnecessary stress and procedures",
256
+ "False negatives could delay critical treatment",
257
+ "System should augment rather than replace medical expertise",
258
+ ],
259
+ "ethical_analysis": "A medical AI for cancer detection has significant potential benefits in improving early diagnosis and saving lives. However, it's critical that the system maintains high accuracy to minimize both false positives (causing unnecessary procedures) and false negatives (missing actual cases). The system should be designed to augment rather than replace medical professionals, with humans making final decisions.",
260
+ "recommendation": "Proceed with development with rigorous clinical validation, transparent reporting of accuracy metrics across diverse populations, clear communication about limitations, and implementation as a decision support tool rather than autonomous system.",
261
+ "stance": "beneficial",
262
+ },
263
+ ]
264
+
265
+ # Generate variations
266
+ output_data = []
267
+ for _ in range(num_samples):
268
+ template = random.choice(ethics_templates)
269
+
270
+ # Create a minor variation to avoid exact duplicates
271
+ variation = template.copy()
272
+
273
+ # Add metadata
274
+ variation["metadata"] = {
275
+ "generated": True,
276
+ "timestamp": (
277
+ torch.cuda.get_device_name(0)
278
+ if torch.cuda.is_available()
279
+ else "CPU"
280
+ ),
281
+ }
282
+
283
+ output_data.append(variation)
284
+
285
+ # Save to file
286
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
287
+ with open(output_path, "w", encoding="utf-8") as f:
288
+ for item in output_data:
289
+ f.write(json.dumps(item) + "\n")
290
+
291
+ logger.info(
292
+ f"Generated {len(output_data)} synthetic ethics examples at {output_path}"
293
+ )
src/multi_agent_training/specialized/generators/ethics_generator.py ADDED
@@ -0,0 +1,1244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ethics Module Synthetic Data Generator for MangoMAS Local
3
+
4
+ This module generates synthetic training data for the ethics capability.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import random
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List
12
+
13
+ from ..synthetic_data_generator import (SyntheticDataGenerator,
14
+ SyntheticDataGeneratorRegistry)
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Templates for ethics scenarios
19
+ ETHICAL_DOMAINS = [
20
+ "privacy and data protection",
21
+ "algorithmic fairness and bias",
22
+ "transparency and explainability",
23
+ "safety and risk management",
24
+ "human autonomy and consent",
25
+ "social impact and responsibility",
26
+ "environmental sustainability",
27
+ "intellectual property",
28
+ "misinformation and content moderation",
29
+ "inclusion and accessibility",
30
+ ]
31
+
32
+ ETHICAL_PRINCIPLES = [
33
+ "autonomy - respecting individual choice and self-determination",
34
+ "beneficence - promoting well-being and preventing harm",
35
+ "non-maleficence - avoiding causing harm",
36
+ "justice - ensuring fair distribution of benefits and burdens",
37
+ "transparency - being open about actions and decisions",
38
+ "accountability - taking responsibility for outcomes",
39
+ "privacy - protecting personal information",
40
+ "consent - obtaining informed agreement",
41
+ "dignity - respecting the inherent worth of individuals",
42
+ "sustainability - meeting present needs without compromising future generations",
43
+ "inclusivity - ensuring access and representation for all groups",
44
+ "proportionality - ensuring actions are appropriate to circumstances",
45
+ ]
46
+
47
+ SCENARIO_TEMPLATES = [
48
+ "A company is developing {system_type} that collects {data_type} from users to {purpose}. However, there are concerns that {concern}.",
49
+ "An organization wants to use {system_type} to {purpose} in {context}. Some stakeholders worry that {concern}.",
50
+ "A research team has created {system_type} that can {capability}. While this could {benefit}, there's a risk that {risk}.",
51
+ "A government agency plans to deploy {system_type} for {purpose} across {context}. Critics argue that {criticism}.",
52
+ "A new {system_type} enables {capability}, promising to {benefit}. However, ethical questions arise when {ethical_question}.",
53
+ "A multinational corporation is implementing {system_type} that will {capability} affecting {stakeholders}. This raises questions about {ethical_question}.",
54
+ "An education institution is considering using {system_type} to {purpose} for students. Parents and educators are divided because {concern}.",
55
+ "Healthcare providers are exploring {system_type} to {purpose} for patients. Medical ethicists caution that {caution}.",
56
+ "A startup has developed {system_type} that can {capability}, potentially disrupting {industry}. This creates tension between {competing_values}.",
57
+ "A social media platform is rolling out {system_type} to {purpose}. Civil liberties organizations are concerned that {concern}.",
58
+ ]
59
+
60
+ ANALYSIS_TEMPLATES = [
61
+ "From an ethical standpoint, this scenario involves several key principles. First, {principle_1} is at stake because {principle_1_reason}. Second, {principle_2} must be considered as {principle_2_reason}. There's also a tension between {competing_value_1} and {competing_value_2}, as {tension_explanation}. Looking at potential consequences, {consequence_1} could result if {condition_1}, while {consequence_2} might occur if {condition_2}. Stakeholder perspectives vary: {stakeholder_1} would prioritize {priority_1}, whereas {stakeholder_2} would emphasize {priority_2}.",
62
+ "This ethical dilemma centers on balancing multiple considerations. The principle of {principle_1} suggests that {principle_1_implication}, while {principle_2} indicates {principle_2_implication}. Historical precedents in similar cases show that {historical_precedent}. From a consequentialist perspective, we should consider that {consequence_analysis}. Rights-based considerations include {rights_consideration}. A virtue ethics approach would emphasize {virtue_consideration}. The broader societal implications include {societal_implication}.",
63
+ "Analyzing this situation through multiple ethical frameworks: A deontological approach would focus on {deontological_focus}, concluding that {deontological_conclusion}. Utilitarianism would weigh {utilitarian_considerations}, suggesting {utilitarian_conclusion}. Justice considerations highlight {justice_considerations}. Professional ethics in this domain dictate {professional_ethics}. Cultural perspectives may vary, as {cultural_perspective}. Long-term versus short-term impacts create additional complexity because {temporal_complexity}.",
64
+ "This case presents several ethical dimensions worth examining. Regarding consent and autonomy, {consent_analysis}. Privacy considerations include {privacy_analysis}. Fairness and equity issues arise because {fairness_issues}. Transparency requirements would dictate {transparency_requirements}. Accountability mechanisms should include {accountability_mechanisms}. Risk assessment suggests {risk_assessment}. Balancing innovation with caution requires {innovation_balance}.",
65
+ ]
66
+
67
+ RECOMMENDATION_TEMPLATES = [
68
+ "Based on this ethical analysis, I recommend {primary_recommendation}. This approach respects {respected_principle} while also acknowledging {acknowledged_concern}. To mitigate potential harms, {mitigation_strategy} should be implemented. Additionally, {secondary_recommendation} would help address remaining concerns. Ongoing {monitoring_approach} will be essential to ensure ethical compliance as the situation evolves.",
69
+ "The most ethically sound course of action would be to {primary_recommendation}. This balances the competing values of {balanced_value_1} and {balanced_value_2}. Key safeguards should include {safeguard_1} and {safeguard_2}. Stakeholder engagement through {engagement_method} will be crucial for maintaining trust. Future evaluations should focus on {evaluation_focus} to ensure continued ethical alignment.",
70
+ "I recommend a multi-faceted approach: First, {recommendation_1} to address immediate concerns. Second, {recommendation_2} as a medium-term solution. Third, develop {recommendation_3} for long-term sustainability. Throughout implementation, prioritize {prioritized_value}. This approach acknowledges trade-offs between {trade_off_values} but optimizes for {optimized_outcome} while minimizing {minimized_harm}.",
71
+ "Given the competing ethical considerations, the most appropriate response is to {primary_recommendation}. This should be accompanied by {accompanying_measure} to ensure {ensured_outcome}. Transparency can be maintained through {transparency_measure}. Affected parties should be {stakeholder_treatment}. Regular {review_process} will help adjust the approach as new information or circumstances emerge.",
72
+ ]
73
+
74
+ SYSTEM_TYPES = [
75
+ "an AI-powered recommendation system",
76
+ "a facial recognition platform",
77
+ "an automated decision-making algorithm",
78
+ "a predictive analytics tool",
79
+ "a natural language processing service",
80
+ "a biometric authentication system",
81
+ "an autonomous vehicle navigation system",
82
+ "a behavioral targeting platform",
83
+ "a content moderation algorithm",
84
+ "a data mining infrastructure",
85
+ "an automated surveillance system",
86
+ "a personalized learning algorithm",
87
+ "a sentiment analysis tool",
88
+ "an automated hiring system",
89
+ "a social scoring mechanism",
90
+ ]
91
+
92
+ DATA_TYPES = [
93
+ "personal health information",
94
+ "financial transaction history",
95
+ "biometric identifiers",
96
+ "location data",
97
+ "social media activity",
98
+ "behavioral patterns",
99
+ "political opinions",
100
+ "educational records",
101
+ "emotional responses",
102
+ "interpersonal communications",
103
+ "professional qualifications",
104
+ "demographic information",
105
+ "browsing history",
106
+ "genetic information",
107
+ "consumption preferences",
108
+ ]
109
+
110
+ PURPOSES = [
111
+ "improve user experience",
112
+ "personalize content delivery",
113
+ "optimize resource allocation",
114
+ "identify potential security threats",
115
+ "streamline decision-making processes",
116
+ "reduce operational costs",
117
+ "increase engagement metrics",
118
+ "predict future behavior",
119
+ "automate compliance verification",
120
+ "enhance public safety",
121
+ "accelerate research outcomes",
122
+ "minimize human bias",
123
+ "maximize conversion rates",
124
+ "improve accessibility",
125
+ "reduce carbon footprint",
126
+ ]
127
+
128
+ CONTEXTS = [
129
+ "healthcare settings",
130
+ "public transportation systems",
131
+ "educational institutions",
132
+ "financial services",
133
+ "law enforcement operations",
134
+ "social welfare programs",
135
+ "urban planning initiatives",
136
+ "employment processes",
137
+ "content creation industries",
138
+ "critical infrastructure",
139
+ "disaster response scenarios",
140
+ "retail environments",
141
+ "public spaces",
142
+ "political campaigns",
143
+ "scientific research",
144
+ ]
145
+
146
+ CONCERNS = [
147
+ "it may perpetuate existing societal biases",
148
+ "privacy protections may be inadequate",
149
+ "users might not provide meaningful informed consent",
150
+ "the technology could be repurposed for surveillance",
151
+ "it may disproportionately impact vulnerable populations",
152
+ "the decision-making process lacks transparency",
153
+ "there's no clear accountability for harmful outcomes",
154
+ "it could undermine human autonomy",
155
+ "long-term societal impacts haven't been adequately studied",
156
+ "it creates security vulnerabilities that could be exploited",
157
+ "it might amplify misinformation or manipulation",
158
+ "the environmental impact is potentially significant",
159
+ "it could lead to job displacement without transition plans",
160
+ "it may erode important social skills or capacities",
161
+ "it prioritizes efficiency over human wellbeing",
162
+ ]
163
+
164
+
165
+ class EthicsDataGenerator(SyntheticDataGenerator):
166
+ """Generator for synthetic ethics data."""
167
+
168
+ def _load_templates(self) -> Dict[str, Any]:
169
+ """Load ethics templates."""
170
+ templates = {
171
+ "domains": ETHICAL_DOMAINS,
172
+ "principles": ETHICAL_PRINCIPLES,
173
+ "scenarios": SCENARIO_TEMPLATES,
174
+ "analyses": ANALYSIS_TEMPLATES,
175
+ "recommendations": RECOMMENDATION_TEMPLATES,
176
+ "system_types": SYSTEM_TYPES,
177
+ "data_types": DATA_TYPES,
178
+ "purposes": PURPOSES,
179
+ "contexts": CONTEXTS,
180
+ "concerns": CONCERNS,
181
+ }
182
+
183
+ # Additional domain-specific templates could be loaded from files
184
+ template_path = self.config.get("template_path")
185
+ if template_path and Path(template_path).exists():
186
+ try:
187
+ with open(template_path, "r", encoding="utf-8") as f:
188
+ custom_templates = json.load(f)
189
+ for key, values in custom_templates.items():
190
+ if key in templates and isinstance(values, list):
191
+ templates[key].extend(values)
192
+ except Exception as e:
193
+ logger.warning(f"Failed to load custom templates: {e}")
194
+
195
+ return templates
196
+
197
+ def generate_example(self) -> Dict[str, Any]:
198
+ """Generate a single ethics example."""
199
+ # Select templates
200
+ domain = random.choice(self.templates["domains"])
201
+ scenario_template = random.choice(self.templates["scenarios"])
202
+ analysis_template = random.choice(self.templates["analyses"])
203
+ recommendation_template = random.choice(self.templates["recommendations"])
204
+
205
+ # Select random principles (2-4)
206
+ num_principles = random.randint(2, 4)
207
+ principles = random.sample(self.templates["principles"], num_principles)
208
+
209
+ # Generate context for the scenario
210
+ system_type = random.choice(self.templates["system_types"])
211
+ data_type = random.choice(self.templates["data_types"])
212
+ purpose = random.choice(self.templates["purposes"])
213
+ context = random.choice(self.templates["contexts"])
214
+ concern = random.choice(self.templates["concerns"])
215
+
216
+ # Generate capability and benefits/risks
217
+ capability = self._generate_capability(system_type)
218
+ benefit = self._generate_benefit(capability)
219
+ risk = self._generate_risk(capability)
220
+ criticism = self._generate_criticism(system_type, purpose)
221
+ ethical_question = self._generate_ethical_question(system_type)
222
+ stakeholders = self._generate_stakeholders(context)
223
+ caution = self._generate_caution(system_type, purpose)
224
+ industry = self._generate_industry(capability)
225
+ competing_values = self._generate_competing_values()
226
+
227
+ # Fill in scenario template
228
+ scenario_context = {
229
+ "system_type": system_type,
230
+ "data_type": data_type,
231
+ "purpose": purpose,
232
+ "context": context,
233
+ "concern": concern,
234
+ "capability": capability,
235
+ "benefit": benefit,
236
+ "risk": risk,
237
+ "criticism": criticism,
238
+ "ethical_question": ethical_question,
239
+ "stakeholders": stakeholders,
240
+ "caution": caution,
241
+ "industry": industry,
242
+ "competing_values": competing_values,
243
+ }
244
+
245
+ scenario = scenario_template.format(**scenario_context)
246
+
247
+ # Generate analysis components
248
+ analysis_context = self._generate_analysis_context(
249
+ domain, principles, scenario_context
250
+ )
251
+ analysis = analysis_template.format(**analysis_context)
252
+
253
+ # Generate recommendation components
254
+ recommendation_context = self._generate_recommendation_context(analysis_context)
255
+ recommendation = recommendation_template.format(**recommendation_context)
256
+
257
+ # Return the structured example
258
+ return {
259
+ "scenario": scenario,
260
+ "ethical_analysis": analysis,
261
+ "recommendation": recommendation,
262
+ "domain": domain,
263
+ "principles": [
264
+ p.split(" - ")[0] for p in principles
265
+ ], # Just the principle names
266
+ "considerations": self._extract_considerations(analysis),
267
+ "stance": random.choice(["cautious", "balanced", "progressive"]),
268
+ "metadata": {
269
+ "system_type": system_type,
270
+ "context": context,
271
+ "ethical_domain": domain,
272
+ "complexity": random.choice(["low", "medium", "high"]),
273
+ },
274
+ }
275
+
276
+ # Helper methods for generating specific content
277
+ def _generate_capability(self, system_type: str) -> str:
278
+ """Generate a capability based on system type."""
279
+ capabilities = {
280
+ "an AI-powered recommendation system": [
281
+ "suggest highly personalized content based on behavioral patterns",
282
+ "predict user preferences with minimal explicit input",
283
+ "adapt recommendations in real-time based on changing behavior",
284
+ "influence user decisions through strategic content presentation",
285
+ "identify vulnerable users for targeted interventions",
286
+ ],
287
+ "a facial recognition platform": [
288
+ "identify individuals in crowded public spaces",
289
+ "track movement patterns across multiple locations",
290
+ "detect emotional states from facial expressions",
291
+ "verify identity without explicit user interaction",
292
+ "categorize individuals based on physical characteristics",
293
+ ],
294
+ "an automated decision-making algorithm": [
295
+ "evaluate complex applications without human review",
296
+ "allocate resources based on predicted outcomes",
297
+ "prioritize cases according to algorithmic criteria",
298
+ "make high-stakes decisions at scale",
299
+ "adjust criteria based on historical outcomes",
300
+ ],
301
+ }
302
+
303
+ # Return specific capability if available, otherwise generic
304
+ if system_type in capabilities:
305
+ return random.choice(capabilities[system_type])
306
+ else:
307
+ generic_capabilities = [
308
+ "process large amounts of data to make predictions",
309
+ "automate traditionally human decision processes",
310
+ "identify patterns invisible to human analysts",
311
+ "scale operations beyond human capabilities",
312
+ "integrate multiple data sources for comprehensive analysis",
313
+ ]
314
+ return random.choice(generic_capabilities)
315
+
316
+ def _generate_benefit(self, capability: str) -> str:
317
+ """Generate a potential benefit based on capability."""
318
+ benefits = [
319
+ "improve efficiency and reduce costs",
320
+ "enhance accuracy and reliability",
321
+ "provide services to previously underserved populations",
322
+ "reduce human bias in important decisions",
323
+ "accelerate innovation in critical fields",
324
+ "increase safety through constant monitoring",
325
+ "personalize experiences to individual needs",
326
+ "scale solutions to meet growing demand",
327
+ "enable more timely interventions in critical situations",
328
+ "democratize access to specialized expertise",
329
+ ]
330
+ return random.choice(benefits)
331
+
332
+ def _generate_risk(self, capability: str) -> str:
333
+ """Generate a potential risk based on capability."""
334
+ risks = [
335
+ "it could be used for unauthorized surveillance",
336
+ "sensitive data might be vulnerable to breaches",
337
+ "algorithmic biases could lead to discriminatory outcomes",
338
+ "it might create a false sense of objectivity",
339
+ "human oversight could be inappropriately reduced",
340
+ "it could amplify existing social inequalities",
341
+ "unintended consequences might emerge at scale",
342
+ "it might be difficult to correct errors once deployed",
343
+ "malicious actors could exploit vulnerabilities",
344
+ "it could erode important social norms or skills",
345
+ ]
346
+ return random.choice(risks)
347
+
348
+ def _generate_criticism(self, system_type: str, purpose: str) -> str:
349
+ """Generate criticism of the system and purpose."""
350
+ criticisms = [
351
+ "the system lacks adequate oversight mechanisms",
352
+ "informed consent processes are insufficient",
353
+ "privacy protections don't meet international standards",
354
+ "the technology disproportionately burdens marginalized communities",
355
+ "transparency requirements have been minimized for commercial reasons",
356
+ "the system creates unacceptable power imbalances",
357
+ "long-term societal implications haven't been adequately considered",
358
+ "the technology perpetuates harmful social categorizations",
359
+ "stated benefits are exaggerated while risks are downplayed",
360
+ "alternatives with fewer ethical concerns haven't been properly explored",
361
+ ]
362
+ return random.choice(criticisms)
363
+
364
+ def _generate_ethical_question(self, system_type: str) -> str:
365
+ """Generate an ethical question about the system."""
366
+ questions = [
367
+ "considering who controls access to the resulting data",
368
+ "examining who bears the costs versus who receives the benefits",
369
+ "determining appropriate limitations on its use",
370
+ "considering implications for human autonomy and dignity",
371
+ "evaluating its impact on social relationships and institutions",
372
+ "questioning whether adequate safeguards exist",
373
+ "considering how it might shape future technological development",
374
+ "examining how it might alter power dynamics",
375
+ "determining whether it creates unacceptable vulnerabilities",
376
+ "considering how responsibility for outcomes is assigned",
377
+ ]
378
+ return random.choice(questions)
379
+
380
+ def _generate_stakeholders(self, context: str) -> str:
381
+ """Generate stakeholders based on context."""
382
+ stakeholders = {
383
+ "healthcare settings": "patients, healthcare providers, and insurers",
384
+ "educational institutions": "students, educators, and parents",
385
+ "financial services": "customers, financial institutions, and regulators",
386
+ "law enforcement operations": "citizens, officers, and legal advocates",
387
+ }
388
+
389
+ # Return specific stakeholders if available, otherwise generic
390
+ if context in stakeholders:
391
+ return stakeholders[context]
392
+ else:
393
+ generic_stakeholders = [
394
+ "users, service providers, and regulators",
395
+ "consumers, businesses, and government agencies",
396
+ "citizens, private entities, and public institutions",
397
+ "individuals, communities, and organizational leaders",
398
+ "current users, potential users, and third parties",
399
+ ]
400
+ return random.choice(generic_stakeholders)
401
+
402
+ def _generate_caution(self, system_type: str, purpose: str) -> str:
403
+ """Generate a caution about the system and purpose."""
404
+ cautions = [
405
+ "the technology could compromise patient confidentiality in unexpected ways",
406
+ "algorithmic decisions might not account for important contextual factors",
407
+ "the system could create harmful dependencies on proprietary technology",
408
+ "the focus on quantifiable metrics might neglect qualitative human values",
409
+ "there's insufficient evidence for the system's long-term effectiveness",
410
+ "technical limitations could lead to harmful false positives or negatives",
411
+ "the technology might erode important human skills or capacities",
412
+ "it could create unrealistic expectations about technological solutions",
413
+ "vulnerable populations might be inadequately represented in development",
414
+ "implementation might prioritize institutional convenience over user needs",
415
+ ]
416
+ return random.choice(cautions)
417
+
418
+ def _generate_industry(self, capability: str) -> str:
419
+ """Generate an industry that might be disrupted."""
420
+ industries = [
421
+ "healthcare diagnostics",
422
+ "financial advisory services",
423
+ "educational assessment",
424
+ "creative content production",
425
+ "transportation and logistics",
426
+ "customer service operations",
427
+ "legal research and documentation",
428
+ "security and surveillance",
429
+ "human resources and recruitment",
430
+ "market research and consumer insights",
431
+ ]
432
+ return random.choice(industries)
433
+
434
+ def _generate_competing_values(self) -> str:
435
+ """Generate competing ethical values."""
436
+ value_pairs = [
437
+ "innovation and caution",
438
+ "efficiency and inclusivity",
439
+ "personalization and privacy",
440
+ "automation and human oversight",
441
+ "progress and tradition",
442
+ "individual choice and collective welfare",
443
+ "transparency and intellectual property",
444
+ "convenience and security",
445
+ "access and quality control",
446
+ "standardization and flexibility",
447
+ ]
448
+ return random.choice(value_pairs)
449
+
450
+ def _generate_analysis_context(
451
+ self, domain: str, principles: List[str], scenario_context: Dict[str, str]
452
+ ) -> Dict[str, str]:
453
+ """Generate context for ethical analysis."""
454
+ # Split principles into name and description
455
+ principle_names = [p.split(" - ")[0] for p in principles]
456
+
457
+ context = {
458
+ # Principles analysis
459
+ "principle_1": principle_names[0],
460
+ "principle_1_reason": self._generate_principle_reason(
461
+ principle_names[0], scenario_context
462
+ ),
463
+ "principle_2": principle_names[1],
464
+ "principle_2_reason": self._generate_principle_reason(
465
+ principle_names[1], scenario_context
466
+ ),
467
+ "principle_1_implication": self._generate_principle_implication(
468
+ principle_names[0]
469
+ ),
470
+ "principle_2_implication": self._generate_principle_implication(
471
+ principle_names[1]
472
+ ),
473
+ # Competing values
474
+ "competing_value_1": random.choice(
475
+ ["innovation", "efficiency", "progress", "accessibility", "convenience"]
476
+ ),
477
+ "competing_value_2": random.choice(
478
+ ["caution", "equity", "privacy", "security", "quality"]
479
+ ),
480
+ "tension_explanation": self._generate_tension_explanation(),
481
+ # Consequences
482
+ "consequence_1": self._generate_consequence(
483
+ scenario_context, positive=True
484
+ ),
485
+ "condition_1": self._generate_condition(positive=True),
486
+ "consequence_2": self._generate_consequence(
487
+ scenario_context, positive=False
488
+ ),
489
+ "condition_2": self._generate_condition(positive=False),
490
+ "consequence_analysis": self._generate_consequence_analysis(),
491
+ # Stakeholders
492
+ "stakeholder_1": self._generate_stakeholder_type(),
493
+ "priority_1": self._generate_priority(),
494
+ "stakeholder_2": self._generate_stakeholder_type(),
495
+ "priority_2": self._generate_priority(),
496
+ # Historical and precedent
497
+ "historical_precedent": self._generate_historical_precedent(),
498
+ # Rights and virtues
499
+ "rights_consideration": self._generate_rights_consideration(),
500
+ "virtue_consideration": self._generate_virtue_consideration(),
501
+ "societal_implication": self._generate_societal_implication(),
502
+ # Ethical frameworks
503
+ "deontological_focus": self._generate_deontological_focus(),
504
+ "deontological_conclusion": self._generate_deontological_conclusion(),
505
+ "utilitarian_considerations": self._generate_utilitarian_considerations(),
506
+ "utilitarian_conclusion": self._generate_utilitarian_conclusion(),
507
+ "justice_considerations": self._generate_justice_considerations(),
508
+ "professional_ethics": self._generate_professional_ethics(),
509
+ "cultural_perspective": self._generate_cultural_perspective(),
510
+ "temporal_complexity": self._generate_temporal_complexity(),
511
+ # Specific ethical dimensions
512
+ "consent_analysis": self._generate_consent_analysis(),
513
+ "privacy_analysis": self._generate_privacy_analysis(),
514
+ "fairness_issues": self._generate_fairness_issues(),
515
+ "transparency_requirements": self._generate_transparency_requirements(),
516
+ "accountability_mechanisms": self._generate_accountability_mechanisms(),
517
+ "risk_assessment": self._generate_risk_assessment(),
518
+ "innovation_balance": self._generate_innovation_balance(),
519
+ }
520
+
521
+ return context
522
+
523
+ def _generate_principle_reason(
524
+ self, principle: str, scenario_context: Dict[str, str]
525
+ ) -> str:
526
+ """Generate a reason why a principle is relevant."""
527
+ system_type = scenario_context["system_type"]
528
+
529
+ reasons = {
530
+ "autonomy": [
531
+ f"the {system_type} could limit individual choices without their awareness",
532
+ f"users might not fully understand how the {system_type} influences their decisions",
533
+ "the technology might create subtle forms of manipulation or coercion",
534
+ ],
535
+ "beneficence": [
536
+ f"the {system_type} must demonstrably improve outcomes for those it affects",
537
+ "any technology should prioritize creating positive impacts for users and society",
538
+ "the claimed benefits need careful evaluation against potential harms",
539
+ ],
540
+ "non-maleficence": [
541
+ "even unintended consequences must be anticipated and mitigated",
542
+ f"the {system_type} could create new vulnerabilities or risks",
543
+ "harm prevention should take precedence over feature development",
544
+ ],
545
+ "justice": [
546
+ f"benefits and burdens of the {system_type} may be unevenly distributed",
547
+ "existing inequalities could be amplified by the technology",
548
+ "fair access to the benefits must be ensured across different groups",
549
+ ],
550
+ "transparency": [
551
+ f"users should understand how the {system_type} works and affects them",
552
+ "the inner workings of the system must not be a 'black box'",
553
+ "people have a right to know how decisions about them are being made",
554
+ ],
555
+ "accountability": [
556
+ "clear responsibility must be established for any harmful outcomes",
557
+ f"someone must be answerable when the {system_type} fails or causes harm",
558
+ "mechanisms for redress need to be accessible to those affected",
559
+ ],
560
+ "privacy": [
561
+ f"the {system_type} collects sensitive information that requires protection",
562
+ "data collection should be minimized to what's necessary",
563
+ "users should maintain control over their personal information",
564
+ ],
565
+ "consent": [
566
+ "meaningful permission requires genuine understanding and choice",
567
+ "users must be able to withdraw their participation without penalty",
568
+ "implied consent is insufficient for this level of intervention",
569
+ ],
570
+ }
571
+
572
+ # Return specific reason if available, otherwise generic
573
+ if principle.lower() in reasons:
574
+ return random.choice(reasons[principle.lower()])
575
+ else:
576
+ generic_reasons = [
577
+ f"ethical deployment of {system_type} requires careful consideration of this principle",
578
+ "this fundamental value must be protected despite technological advancement",
579
+ "stakeholders have expressed specific concerns related to this principle",
580
+ ]
581
+ return random.choice(generic_reasons)
582
+
583
+ def _generate_principle_implication(self, principle: str) -> str:
584
+ """Generate implications of a principle."""
585
+ implications = {
586
+ "autonomy": [
587
+ "user control should be maximized through opt-in features and clear choices",
588
+ "systems should be designed to enhance rather than replace human judgment",
589
+ "meaningful alternatives must be available without significant penalties",
590
+ ],
591
+ "beneficence": [
592
+ "we should prioritize features that provide demonstrable benefits to users",
593
+ "proactive measures to promote well-being should be incorporated",
594
+ "positive impacts should be measurable and verified independently",
595
+ ],
596
+ "non-maleficence": [
597
+ "rigorous testing for potential harms must precede deployment",
598
+ "ongoing monitoring for unexpected negative effects is essential",
599
+ "safety mechanisms should be redundant and fail-secure",
600
+ ],
601
+ "justice": [
602
+ "access and benefits should be equitably distributed across demographics",
603
+ "additional support may be needed for disadvantaged groups",
604
+ "pricing and availability should not create or reinforce divisions",
605
+ ],
606
+ "transparency": [
607
+ "explanations should be accessible to users with varying technical literacy",
608
+ "documentation should include both technical details and plain language summaries",
609
+ "the limits and confidence levels of the system should be clearly disclosed",
610
+ ],
611
+ }
612
+
613
+ # Return specific implication if available, otherwise generic
614
+ if principle.lower() in implications:
615
+ return random.choice(implications[principle.lower()])
616
+ else:
617
+ generic_implications = [
618
+ "we must design systems with this principle as a foundational requirement",
619
+ "technical capabilities should never override this ethical consideration",
620
+ "policies and practices should reflect this value throughout implementation",
621
+ ]
622
+ return random.choice(generic_implications)
623
+
624
+ def _generate_tension_explanation(self) -> str:
625
+ """Generate explanation of tension between competing values."""
626
+ explanations = [
627
+ "maximizing one often requires compromising the other in practical implementation",
628
+ "stakeholders prioritize these values differently based on their perspectives",
629
+ "short-term demands often favor one while long-term sustainability requires the other",
630
+ "technical constraints make it difficult to fully satisfy both simultaneously",
631
+ "existing regulatory frameworks tend to emphasize one over the other",
632
+ "market incentives typically reward one more directly than the other",
633
+ "historical precedents have created an imbalance between these values",
634
+ "public discourse often presents these as binary choices rather than complementary goals",
635
+ ]
636
+ return random.choice(explanations)
637
+
638
+ def _generate_consequence(
639
+ self, scenario_context: Dict[str, str], positive: bool
640
+ ) -> str:
641
+ """Generate a consequence based on the scenario."""
642
+ if positive:
643
+ consequences = [
644
+ "user outcomes could significantly improve",
645
+ "accessibility would increase for underserved populations",
646
+ "operational efficiency would enable broader service delivery",
647
+ "costs would decrease while maintaining quality",
648
+ "innovation in related fields would accelerate",
649
+ "human error would be reduced in critical processes",
650
+ "timely interventions would prevent negative outcomes",
651
+ "resource allocation would become more equitable",
652
+ ]
653
+ else:
654
+ consequences = [
655
+ "some users might experience new forms of exclusion",
656
+ "privacy erosion could occur incrementally and irreversibly",
657
+ "dependency on proprietary systems might limit future choices",
658
+ "important human skills could atrophy from disuse",
659
+ "accountability might become diffused and ineffective",
660
+ "unexpected failure modes could emerge at scale",
661
+ "societal norms might shift in potentially harmful ways",
662
+ "power imbalances could be exacerbated between groups",
663
+ ]
664
+ return random.choice(consequences)
665
+
666
+ def _generate_condition(self, positive: bool) -> str:
667
+ """Generate a condition for consequences."""
668
+ if positive:
669
+ conditions = [
670
+ "proper safeguards are implemented from the beginning",
671
+ "diverse perspectives are included throughout development",
672
+ "transparent operations are maintained consistently",
673
+ "regular independent audits verify compliance",
674
+ "user feedback mechanisms are robust and responsive",
675
+ "adaptive governance evolves with the technology",
676
+ "benefits are explicitly shared across stakeholder groups",
677
+ "continuous education keeps users informed and empowered",
678
+ ]
679
+ else:
680
+ conditions = [
681
+ "commercial pressures override ethical considerations",
682
+ "deployment occurs without adequate testing",
683
+ "regulatory oversight fails to keep pace with innovation",
684
+ "user understanding lags behind implementation",
685
+ "technical complexity obscures important trade-offs",
686
+ "powerful interests dominate decision-making processes",
687
+ "short-term metrics drive optimization at the expense of long-term values",
688
+ "security vulnerabilities remain unaddressed",
689
+ ]
690
+ return random.choice(conditions)
691
+
692
+ def _generate_stakeholder_type(self) -> str:
693
+ """Generate a stakeholder type."""
694
+ stakeholders = [
695
+ "direct users",
696
+ "service providers",
697
+ "regulatory bodies",
698
+ "vulnerable populations",
699
+ "industry competitors",
700
+ "civil society organizations",
701
+ "technical developers",
702
+ "academic researchers",
703
+ "future generations",
704
+ "institutional purchasers",
705
+ ]
706
+ return random.choice(stakeholders)
707
+
708
+ def _generate_priority(self) -> str:
709
+ """Generate a stakeholder priority."""
710
+ priorities = [
711
+ "maximum functionality and convenience",
712
+ "data security and privacy protections",
713
+ "cost efficiency and scalability",
714
+ "accessibility and inclusivity",
715
+ "reliability and consistency",
716
+ "innovation and competitive advantage",
717
+ "compliance and risk mitigation",
718
+ "user control and customization",
719
+ "simplicity and ease of understanding",
720
+ "integration with existing systems",
721
+ ]
722
+ return random.choice(priorities)
723
+
724
+ def _generate_consequence_analysis(self) -> str:
725
+ """Generate analysis of consequences."""
726
+ analyses = [
727
+ "immediate benefits must be weighed against potential long-term harms",
728
+ "both intended and unintended consequences require thorough examination",
729
+ "the distribution of benefits and harms across different groups is asymmetric",
730
+ "second-order effects may be more significant than primary outcomes",
731
+ "probabilistic harms must be taken seriously even if uncertain",
732
+ "comparative analysis with alternatives shows important trade-offs",
733
+ "systemic impacts may differ significantly from individual-level effects",
734
+ "threshold effects could create tipping points with cascading consequences",
735
+ ]
736
+ return random.choice(analyses)
737
+
738
+ def _generate_historical_precedent(self) -> str:
739
+ """Generate historical precedent."""
740
+ precedents = [
741
+ "early adoption without adequate safeguards has repeatedly led to harmful outcomes",
742
+ "similar technologies have evolved to include stronger protections over time",
743
+ "initial concerns often materialize in unexpected ways during scaling",
744
+ "retroactive regulation has proven difficult once systems are widely deployed",
745
+ "stakeholder inclusion has consistently improved outcomes in similar cases",
746
+ "transparency requirements have increased following documented abuses",
747
+ "ethical frameworks have matured through iterative real-world applications",
748
+ "public trust has proven difficult to rebuild once compromised",
749
+ ]
750
+ return random.choice(precedents)
751
+
752
+ def _generate_rights_consideration(self) -> str:
753
+ """Generate rights-based considerations."""
754
+ considerations = [
755
+ "the right to privacy requires meaningful control over personal information",
756
+ "informational self-determination is fundamental to human dignity",
757
+ "equal treatment before automated systems is an emerging right",
758
+ "the right to explanation is increasingly recognized in algorithmic contexts",
759
+ "freedom from manipulation requires protection from subtle influences",
760
+ "cognitive liberty includes freedom from excessive technological nudging",
761
+ "meaningful human review should be available for significant decisions",
762
+ "rights to access and rectification are essential for data subjects",
763
+ ]
764
+ return random.choice(considerations)
765
+
766
+ def _generate_virtue_consideration(self) -> str:
767
+ """Generate virtue ethics considerations."""
768
+ considerations = [
769
+ "technologies should be designed to encourage human flourishing",
770
+ "systems should promote rather than undermine ethical character development",
771
+ "virtues like honesty and fairness should be embedded in system design",
772
+ "responsible innovation requires practical wisdom and moderation",
773
+ "compassion should guide decisions about vulnerable users",
774
+ "courage is needed to prioritize ethics over expediency",
775
+ "intellectual humility recognizes the limits of algorithmic solutions",
776
+ "prudence is essential when deploying technologies with uncertain impacts",
777
+ ]
778
+ return random.choice(considerations)
779
+
780
+ def _generate_societal_implication(self) -> str:
781
+ """Generate societal implications."""
782
+ implications = [
783
+ "social cohesion could be affected as technology mediates more interactions",
784
+ "civic discourse might be transformed in both positive and negative ways",
785
+ "power dynamics between institutions and individuals may shift significantly",
786
+ "collective decision-making processes could be enhanced or undermined",
787
+ "social norms regarding privacy and disclosure may evolve rapidly",
788
+ "intergenerational equity requires considering future impacts",
789
+ "cultural diversity in ethical standards must be respected",
790
+ "democratic oversight of technological development may face new challenges",
791
+ ]
792
+ return random.choice(implications)
793
+
794
+ def _generate_deontological_focus(self) -> str:
795
+ """Generate deontological focus."""
796
+ focuses = [
797
+ "whether the technology respects the inherent dignity of all persons",
798
+ "if users are treated as ends in themselves rather than merely as means",
799
+ "whether actions would be acceptable if universally applied",
800
+ "the intentions behind the system's design and deployment",
801
+ "whether the technology honors fundamental duties to users",
802
+ "if rights are being respected regardless of consequences",
803
+ "whether all stakeholders are treated with equal moral consideration",
804
+ "if the design process respects the moral autonomy of users",
805
+ ]
806
+ return random.choice(focuses)
807
+
808
+ def _generate_deontological_conclusion(self) -> str:
809
+ """Generate deontological conclusion."""
810
+ conclusions = [
811
+ "the current approach requires modification to fully respect human dignity",
812
+ "certain elements violate the categorical imperative and must be reconsidered",
813
+ "the design treats some users merely as means rather than as ends",
814
+ "fundamental duties to users are being upheld in principle but not in practice",
815
+ "rights-based considerations suggest significant changes are needed",
816
+ "the approach is permissible if informed consent is genuinely secured",
817
+ "intentions appear sound but implementation fails to honor core duties",
818
+ "moral consistency would require extending similar protections to all affected parties",
819
+ ]
820
+ return random.choice(conclusions)
821
+
822
+ def _generate_utilitarian_considerations(self) -> str:
823
+ """Generate utilitarian considerations."""
824
+ considerations = [
825
+ "the aggregate benefits and harms across all affected parties",
826
+ "whether alternative approaches might produce greater net welfare",
827
+ "how to quantify intangible impacts like privacy or autonomy",
828
+ "the distribution of utility across different stakeholder groups",
829
+ "long-term versus short-term utility calculations",
830
+ "the reliability of predictions about future consequences",
831
+ "preference satisfaction versus objective welfare measures",
832
+ "rule-utilitarian versus act-utilitarian perspectives",
833
+ ]
834
+ return random.choice(considerations)
835
+
836
+ def _generate_utilitarian_conclusion(self) -> str:
837
+ """Generate utilitarian conclusion."""
838
+ conclusions = [
839
+ "the likely benefits appear to outweigh potential harms if safeguards are implemented",
840
+ "alternative approaches could produce greater overall utility with less risk",
841
+ "a more targeted implementation would optimize the benefit-to-harm ratio",
842
+ "the uncertainty of long-term impacts suggests a more cautious approach",
843
+ "a phased rollout would allow for utility assessment before full deployment",
844
+ "potential harms to vulnerable minorities outweigh aggregate benefits",
845
+ "the utility calculation changes significantly when considering second-order effects",
846
+ "maximizing long-term utility requires stronger oversight mechanisms",
847
+ ]
848
+ return random.choice(conclusions)
849
+
850
+ def _generate_justice_considerations(self) -> str:
851
+ """Generate justice considerations."""
852
+ considerations = [
853
+ "disparate impacts across demographic groups require mitigation",
854
+ "access barriers could create new forms of inequality",
855
+ "procedural justice demands transparent decision criteria",
856
+ "distribution of benefits should prioritize the least advantaged",
857
+ "historical injustices might be perpetuated or amplified",
858
+ "fairness requires similar treatment in similar circumstances",
859
+ "capabilities and opportunities may be affected asymmetrically",
860
+ "intergenerational justice considerations for long-lived systems",
861
+ ]
862
+ return random.choice(considerations)
863
+
864
+ def _generate_professional_ethics(self) -> str:
865
+ """Generate professional ethics considerations."""
866
+ considerations = [
867
+ "practitioners must prioritize public welfare above other interests",
868
+ "professional codes of conduct require transparency about limitations",
869
+ "conflicts of interest must be disclosed and managed appropriately",
870
+ "the duty of care extends to foreseeable impacts of technical work",
871
+ "professional standards require ongoing competence as technology evolves",
872
+ "collaboration across disciplines is needed for ethical assessment",
873
+ "whistleblower protections should be available for ethical concerns",
874
+ "professional integrity requires refusing work that violates ethical principles",
875
+ ]
876
+ return random.choice(considerations)
877
+
878
+ def _generate_cultural_perspective(self) -> str:
879
+ """Generate cultural perspective considerations."""
880
+ considerations = [
881
+ "privacy expectations vary significantly across cultural contexts",
882
+ "concepts of fairness and justice have cultural dimensions",
883
+ "some communities may value collective benefits differently than individual rights",
884
+ "technological solutionism itself reflects particular cultural assumptions",
885
+ "diverse cultural perspectives should inform ethical frameworks",
886
+ "global deployment requires cultural sensitivity in implementation",
887
+ "different historical experiences shape risk perception across cultures",
888
+ "cultural values influence acceptable trade-offs and priorities",
889
+ ]
890
+ return random.choice(considerations)
891
+
892
+ def _generate_temporal_complexity(self) -> str:
893
+ """Generate temporal complexity considerations."""
894
+ considerations = [
895
+ "short-term efficiency gains may come at the cost of long-term resilience",
896
+ "path dependencies created today will shape options for future generations",
897
+ "adaptive capacity must be preserved for responding to emerging issues",
898
+ "technological lock-in can limit future ethical choices",
899
+ "values and norms evolve while deployed systems may remain static",
900
+ "reversibility becomes increasingly difficult as adoption spreads",
901
+ "learning curves for new technologies create temporary vulnerabilities",
902
+ "long-term maintenance of ethical standards requires ongoing vigilance",
903
+ ]
904
+ return random.choice(considerations)
905
+
906
+ def _generate_consent_analysis(self) -> str:
907
+ """Generate consent analysis."""
908
+ analyses = [
909
+ "truly informed consent requires genuine understanding of complex systems",
910
+ "default opt-in approaches often fail to secure meaningful consent",
911
+ "power imbalances can render formal consent procedures inadequate",
912
+ "consent should be dynamic and revocable as circumstances change",
913
+ "alternatives must be available without punitive consequences",
914
+ "consent interfaces often prioritize convenience over comprehension",
915
+ "special considerations apply for vulnerable or dependent users",
916
+ "presumed consent is inappropriate for novel applications with uncertain risks",
917
+ ]
918
+ return random.choice(analyses)
919
+
920
+ def _generate_privacy_analysis(self) -> str:
921
+ """Generate privacy analysis."""
922
+ analyses = [
923
+ "data minimization should be a foundational design principle",
924
+ "purpose limitation prevents function creep and unexpected uses",
925
+ "privacy by design requires proactive rather than reactive measures",
926
+ "differential privacy techniques could provide mathematical guarantees",
927
+ "data retention policies should align with necessity rather than utility",
928
+ "re-identification risks increase as data sets are combined",
929
+ "privacy impact assessments should precede new data uses",
930
+ "contextual integrity considers appropriate information flows",
931
+ ]
932
+ return random.choice(analyses)
933
+
934
+ def _generate_fairness_issues(self) -> str:
935
+ """Generate fairness issues."""
936
+ issues = [
937
+ "historical data often contains embedded biases that models reproduce",
938
+ "different fairness metrics may be mutually incompatible",
939
+ "seemingly neutral criteria can create disparate impacts",
940
+ "underrepresented groups often face higher error rates",
941
+ "feedback loops can amplify initial inequalities over time",
942
+ "proxy variables may reinstate bias even when protected attributes are excluded",
943
+ "fairness requires ongoing monitoring rather than one-time certification",
944
+ "different stakeholders may have legitimate but conflicting fairness definitions",
945
+ ]
946
+ return random.choice(issues)
947
+
948
+ def _generate_transparency_requirements(self) -> str:
949
+ """Generate transparency requirements."""
950
+ requirements = [
951
+ "explanations should be accessible to users with varying technical literacy",
952
+ "decision criteria must be documentable and justifiable",
953
+ "both technical mechanisms and governance processes require transparency",
954
+ "meaningful transparency includes understanding system limitations",
955
+ "audit trails should track key decisions and modifications",
956
+ "transparency may need to be tailored differently for various stakeholders",
957
+ "trade secrets should not prevent appropriate oversight",
958
+ "algorithmic impact assessments should be publicly available",
959
+ ]
960
+ return random.choice(requirements)
961
+
962
+ def _generate_accountability_mechanisms(self) -> str:
963
+ """Generate accountability mechanisms."""
964
+ mechanisms = [
965
+ "clear lines of responsibility for outcomes regardless of complexity",
966
+ "accessible processes for contesting decisions and seeking redress",
967
+ "regular independent audits of performance and impacts",
968
+ "penalties proportionate to potential harms for non-compliance",
969
+ "whistleblower protections for reporting ethical concerns",
970
+ "ongoing monitoring with stakeholder involvement",
971
+ "documentation requirements that enable external review",
972
+ "grievance mechanisms that include support for affected parties",
973
+ ]
974
+ return random.choice(mechanisms)
975
+
976
+ def _generate_risk_assessment(self) -> str:
977
+ """Generate risk assessment."""
978
+ assessments = [
979
+ "potential harms should be evaluated systematically before deployment",
980
+ "low-probability but high-impact risks warrant serious consideration",
981
+ "risk distribution across different populations requires explicit analysis",
982
+ "compounding risks from system interactions may be overlooked",
983
+ "unknown risks require adaptive governance and ongoing vigilance",
984
+ "reversibility and containment options limit potential damage",
985
+ "risk thresholds should be determined through inclusive deliberation",
986
+ "precautionary principles apply when risks are uncertain but potentially severe",
987
+ ]
988
+ return random.choice(assessments)
989
+
990
+ def _generate_innovation_balance(self) -> str:
991
+ """Generate innovation balance."""
992
+ balances = [
993
+ "ethical guardrails should channel rather than prevent innovation",
994
+ "staged deployment can balance progress with caution",
995
+ "innovation sandboxes allow testing under controlled conditions",
996
+ "principles-based approaches provide flexibility with accountability",
997
+ "participatory design incorporates diverse perspectives from the outset",
998
+ "value-sensitive design embeds ethical considerations in development",
999
+ "anticipatory governance prepares for emerging challenges",
1000
+ "responsible innovation frameworks align technical and social progress",
1001
+ ]
1002
+ return random.choice(balances)
1003
+
1004
+ def _generate_recommendation_context(
1005
+ self, analysis_context: Dict[str, str]
1006
+ ) -> Dict[str, str]:
1007
+ """Generate context for recommendation."""
1008
+
1009
+ # Primary recommendations based on analysis
1010
+ primary_recommendations = [
1011
+ "implement the system with enhanced transparency measures and opt-in consent",
1012
+ "proceed with a limited pilot program with robust monitoring mechanisms",
1013
+ "redesign key components to address identified ethical concerns",
1014
+ "establish an independent oversight committee with diverse representation",
1015
+ "develop clear usage guidelines with explicit ethical boundaries",
1016
+ "adopt a phased implementation approach with checkpoints for review",
1017
+ "incorporate additional safeguards for vulnerable populations",
1018
+ "create a comprehensive data governance framework before proceeding",
1019
+ ]
1020
+
1021
+ # Mitigation strategies
1022
+ mitigation_strategies = [
1023
+ "regular algorithmic audits by independent third parties",
1024
+ "ongoing monitoring with clear thresholds for intervention",
1025
+ "robust anonymization techniques for all personal data",
1026
+ "user-friendly controls for privacy preferences",
1027
+ "accessible explanations of system operation and limitations",
1028
+ "formal impact assessments before major updates",
1029
+ "channels for stakeholder feedback and reported concerns",
1030
+ "sunset provisions requiring periodic reauthorization",
1031
+ ]
1032
+
1033
+ # Secondary recommendations
1034
+ secondary_recommendations = [
1035
+ "develop educational resources to promote informed user decisions",
1036
+ "establish an ethics advisory board with rotating membership",
1037
+ "create a transparent incident response protocol",
1038
+ "implement differential access levels based on necessity",
1039
+ "require regular public reporting on system performance and impacts",
1040
+ "develop industry standards for similar technologies",
1041
+ "establish compensation mechanisms for demonstrated harms",
1042
+ "invest in research addressing identified ethical challenges",
1043
+ ]
1044
+
1045
+ # Generate the recommendation context
1046
+ context = {
1047
+ "primary_recommendation": random.choice(primary_recommendations),
1048
+ "respected_principle": random.choice(
1049
+ ["autonomy", "beneficence", "justice", "transparency", "privacy"]
1050
+ ),
1051
+ "acknowledged_concern": random.choice(
1052
+ [
1053
+ "potential biases",
1054
+ "privacy risks",
1055
+ "accessibility barriers",
1056
+ "security vulnerabilities",
1057
+ "oversight limitations",
1058
+ ]
1059
+ ),
1060
+ "mitigation_strategy": random.choice(mitigation_strategies),
1061
+ "secondary_recommendation": random.choice(secondary_recommendations),
1062
+ "monitoring_approach": random.choice(
1063
+ [
1064
+ "impact assessment",
1065
+ "stakeholder feedback",
1066
+ "outcome metrics",
1067
+ "independent auditing",
1068
+ "automated monitoring",
1069
+ ]
1070
+ ),
1071
+ # For the second template
1072
+ "balanced_value_1": analysis_context["competing_value_1"],
1073
+ "balanced_value_2": analysis_context["competing_value_2"],
1074
+ "safeguard_1": random.choice(
1075
+ [
1076
+ "data minimization protocols",
1077
+ "explicit user consent mechanisms",
1078
+ "algorithmic fairness checks",
1079
+ "human oversight for critical decisions",
1080
+ "robust security measures",
1081
+ ]
1082
+ ),
1083
+ "safeguard_2": random.choice(
1084
+ [
1085
+ "regular third-party audits",
1086
+ "transparent documentation",
1087
+ "accessible appeals process",
1088
+ "ongoing stakeholder consultation",
1089
+ "sunset provisions",
1090
+ ]
1091
+ ),
1092
+ "engagement_method": random.choice(
1093
+ [
1094
+ "public consultations",
1095
+ "representative advisory panels",
1096
+ "user feedback mechanisms",
1097
+ "community review boards",
1098
+ "multi-stakeholder governance",
1099
+ ]
1100
+ ),
1101
+ "evaluation_focus": random.choice(
1102
+ [
1103
+ "impact on vulnerable groups",
1104
+ "unintended consequences",
1105
+ "efficacy of safeguards",
1106
+ "changing risk landscape",
1107
+ "evolving social norms",
1108
+ ]
1109
+ ),
1110
+ # For the third template
1111
+ "recommendation_1": random.choice(
1112
+ [
1113
+ "implement enhanced transparency measures",
1114
+ "strengthen consent procedures",
1115
+ "establish clear boundaries for use",
1116
+ "create oversight mechanisms",
1117
+ "develop explicit ethical guidelines",
1118
+ ]
1119
+ ),
1120
+ "recommendation_2": random.choice(
1121
+ [
1122
+ "regular independent audits",
1123
+ "ongoing stakeholder engagement",
1124
+ "robust monitoring protocols",
1125
+ "expanded user controls",
1126
+ "comprehensive documentation",
1127
+ ]
1128
+ ),
1129
+ "recommendation_3": random.choice(
1130
+ [
1131
+ "industry-wide ethical standards",
1132
+ "policy frameworks for governance",
1133
+ "research programs on ethical dimensions",
1134
+ "educational initiatives for users",
1135
+ "collaborative oversight mechanisms",
1136
+ ]
1137
+ ),
1138
+ "prioritized_value": random.choice(
1139
+ [
1140
+ "human dignity",
1141
+ "informed autonomy",
1142
+ "equitable access",
1143
+ "harm prevention",
1144
+ "accountability",
1145
+ ]
1146
+ ),
1147
+ "trade_off_values": random.choice(
1148
+ [
1149
+ "innovation and caution",
1150
+ "efficiency and inclusivity",
1151
+ "personalization and privacy",
1152
+ "automation and oversight",
1153
+ "scale and sensitivity",
1154
+ ]
1155
+ ),
1156
+ "optimized_outcome": random.choice(
1157
+ [
1158
+ "long-term trust",
1159
+ "sustainable innovation",
1160
+ "equitable benefits",
1161
+ "resilient systems",
1162
+ "ongoing improvement",
1163
+ ]
1164
+ ),
1165
+ "minimized_harm": random.choice(
1166
+ [
1167
+ "exclusion of vulnerable groups",
1168
+ "privacy violations",
1169
+ "erosion of autonomy",
1170
+ "unfair treatment",
1171
+ "unaccountable outcomes",
1172
+ ]
1173
+ ),
1174
+ # For the fourth template
1175
+ "accompanying_measure": random.choice(
1176
+ [
1177
+ "robust oversight mechanisms",
1178
+ "comprehensive documentation",
1179
+ "ongoing impact assessment",
1180
+ "accessible control options",
1181
+ "proactive risk mitigation",
1182
+ ]
1183
+ ),
1184
+ "ensured_outcome": random.choice(
1185
+ [
1186
+ "equitable access",
1187
+ "meaningful consent",
1188
+ "appropriate use",
1189
+ "accountability for outcomes",
1190
+ "ongoing improvement",
1191
+ ]
1192
+ ),
1193
+ "transparency_measure": random.choice(
1194
+ [
1195
+ "public documentation",
1196
+ "algorithmic explainability",
1197
+ "regular disclosure reports",
1198
+ "accessible descriptions of processes",
1199
+ "notification of significant changes",
1200
+ ]
1201
+ ),
1202
+ "stakeholder_treatment": random.choice(
1203
+ [
1204
+ "actively consulted throughout implementation",
1205
+ "provided with meaningful opt-out options",
1206
+ "given accessible channels for feedback",
1207
+ "included in oversight processes",
1208
+ "informed of their rights and options",
1209
+ ]
1210
+ ),
1211
+ "review_process": random.choice(
1212
+ [
1213
+ "independent audits",
1214
+ "stakeholder consultations",
1215
+ "impact assessments",
1216
+ "performance evaluations",
1217
+ "ethical compliance reviews",
1218
+ ]
1219
+ ),
1220
+ }
1221
+
1222
+ return context
1223
+
1224
+ def _extract_considerations(self, analysis: str) -> List[str]:
1225
+ """Extract key considerations from the analysis."""
1226
+ # This is a simplified version - in a real implementation,
1227
+ # this could use NLP techniques to extract actual considerations
1228
+ considerations = [
1229
+ "Balancing innovation with ethical safeguards",
1230
+ "Ensuring meaningful informed consent",
1231
+ "Protecting privacy and data security",
1232
+ "Preventing algorithmic bias and discrimination",
1233
+ "Maintaining appropriate human oversight",
1234
+ "Establishing clear lines of accountability",
1235
+ "Providing transparency in operation and decisions",
1236
+ "Considering impacts on vulnerable populations",
1237
+ ]
1238
+
1239
+ # Randomly select 3-5 considerations
1240
+ return random.sample(considerations, random.randint(3, 5))
1241
+
1242
+
1243
+ # Register the generator
1244
+ SyntheticDataGeneratorRegistry.register("ethics", EthicsDataGenerator)
src/multi_agent_training/specialized/generators/reasoning_generator.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reasoning Module Synthetic Data Generator for MangoMAS Local
3
+
4
+ This module generates synthetic training data for the reasoning capability.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import random
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List
12
+
13
+ from ..synthetic_data_generator import (SyntheticDataGenerator,
14
+ SyntheticDataGeneratorRegistry)
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Templates for reasoning scenarios
19
+ REASONING_DOMAINS = [
20
+ "logical deduction",
21
+ "causal analysis",
22
+ "statistical reasoning",
23
+ "comparative analysis",
24
+ "analogical reasoning",
25
+ "counterfactual reasoning",
26
+ "inductive reasoning",
27
+ "abductive reasoning",
28
+ "conditional reasoning",
29
+ "syllogistic reasoning",
30
+ ]
31
+
32
+ PREMISE_TEMPLATES = [
33
+ "In a study of {topic}, researchers observed that {observation}.",
34
+ "The data shows that {observation} when {condition}.",
35
+ "Historical records indicate that {observation} throughout {period}.",
36
+ "Given that {fact_1} and {fact_2}, we need to determine {question}.",
37
+ "In the context of {domain}, we observe that {observation}.",
38
+ "When analyzing {topic}, we can see that {observation} correlates with {factor}.",
39
+ "System logs reveal that {observation} occurs whenever {condition}.",
40
+ "The problem states that {observation} under conditions where {condition}.",
41
+ "In the experiment, {observation} was measured when {condition}.",
42
+ "The scenario presents a situation where {observation} after {event}.",
43
+ ]
44
+
45
+ REASONING_TEMPLATES = [
46
+ "First, we need to identify the key variables: {variables}. Looking at the relationship between {var1} and {var2}, we can see that {relationship}. This suggests that {inference}. Furthermore, considering {var3}, we can deduce that {deduction}.",
47
+ "Let's analyze this step by step. If {fact_1}, then {consequence_1}. Given that {fact_2}, we can infer {consequence_2}. Combining these insights, {combined_inference}. Additionally, {extra_insight} further supports this reasoning.",
48
+ "Breaking this down systematically: 1) {step_1}, 2) {step_2}, 3) {step_3}, 4) {step_4}. The logical connection between steps 2 and 3 shows that {connection}, which leads us to {inference}.",
49
+ "The key insight here is that {key_insight}. This means that when {condition}, the result is {result}. We can verify this by examining {evidence}, which confirms {confirmation}.",
50
+ "To solve this, I'll use {method}. Starting with {starting_point}, I can determine that {determination}. This leads to {intermediate_conclusion}, and finally to {final_step}.",
51
+ ]
52
+
53
+ CONCLUSION_TEMPLATES = [
54
+ "Therefore, we can conclude that {conclusion}.",
55
+ "Based on this analysis, the answer is {conclusion}.",
56
+ "The evidence strongly suggests that {conclusion}.",
57
+ "This reasoning leads to the conclusion that {conclusion}.",
58
+ "We can definitively state that {conclusion} based on the above analysis.",
59
+ ]
60
+
61
+ # Topics and facts to fill in templates
62
+ TOPICS = [
63
+ "climate patterns",
64
+ "market trends",
65
+ "neural network performance",
66
+ "population demographics",
67
+ "traffic flow optimization",
68
+ "disease spread",
69
+ "economic indicators",
70
+ "algorithm efficiency",
71
+ "material properties",
72
+ "social network dynamics",
73
+ "cognitive biases",
74
+ "language acquisition",
75
+ "genetic inheritance",
76
+ "planetary motion",
77
+ "quantum phenomena",
78
+ ]
79
+
80
+ FACTS = [
81
+ "energy consumption increases with temperature",
82
+ "user engagement drops after 3 minutes",
83
+ "performance plateaus with more than 8 layers",
84
+ "error rates increase exponentially with load",
85
+ "response time correlates with system memory",
86
+ "conversion rates are highest on Tuesdays",
87
+ "signal strength decreases with distance squared",
88
+ "learning rate affects convergence time",
89
+ "failure probability doubles every 5 years",
90
+ "growth rate is proportional to nutrient concentration",
91
+ ]
92
+
93
+
94
+ class ReasoningDataGenerator(SyntheticDataGenerator):
95
+ """Generator for synthetic reasoning data."""
96
+
97
+ def _load_templates(self) -> List[Dict[str, Any]]:
98
+ """Load reasoning templates."""
99
+ templates = {
100
+ "domains": REASONING_DOMAINS,
101
+ "premises": PREMISE_TEMPLATES,
102
+ "reasoning": REASONING_TEMPLATES,
103
+ "conclusions": CONCLUSION_TEMPLATES,
104
+ "topics": TOPICS,
105
+ "facts": FACTS,
106
+ }
107
+
108
+ # Additional domain-specific templates could be loaded from files
109
+ template_path = self.config.get("template_path")
110
+ if template_path and Path(template_path).exists():
111
+ try:
112
+ with open(template_path, "r", encoding="utf-8") as f:
113
+ custom_templates = json.load(f)
114
+ for key, values in custom_templates.items():
115
+ if key in templates and isinstance(values, list):
116
+ templates[key].extend(values)
117
+ except Exception as e:
118
+ logger.warning(f"Failed to load custom templates: {e}")
119
+
120
+ return templates
121
+
122
+ def generate_example(self) -> Dict[str, Any]:
123
+ """Generate a single reasoning example."""
124
+ # Select templates
125
+ domain = random.choice(self.templates["domains"])
126
+ premise_template = random.choice(self.templates["premises"])
127
+ reasoning_template = random.choice(self.templates["reasoning"])
128
+ conclusion_template = random.choice(self.templates["conclusions"])
129
+
130
+ # Generate content for templates
131
+ topic = random.choice(self.templates["topics"])
132
+ fact_1 = random.choice(self.templates["facts"])
133
+ fact_2 = random.choice(self.templates["facts"])
134
+ while fact_2 == fact_1: # Ensure different facts
135
+ fact_2 = random.choice(self.templates["facts"])
136
+
137
+ # Generate dynamic content
138
+ observation = self._generate_observation(topic)
139
+ condition = self._generate_condition(topic)
140
+ variables = self._generate_variables(domain)
141
+ var1, var2, var3 = variables[:3]
142
+ relationship = self._generate_relationship(var1, var2)
143
+ inference = self._generate_inference(relationship)
144
+ deduction = self._generate_deduction(var3)
145
+
146
+ # Fill in templates
147
+ context = {
148
+ "domain": domain,
149
+ "topic": topic,
150
+ "observation": observation,
151
+ "condition": condition,
152
+ "fact_1": fact_1,
153
+ "fact_2": fact_2,
154
+ "variables": ", ".join(variables),
155
+ "var1": var1,
156
+ "var2": var2,
157
+ "var3": var3,
158
+ "relationship": relationship,
159
+ "inference": inference,
160
+ "deduction": deduction,
161
+ "question": self._generate_question(topic),
162
+ "period": self._generate_period(),
163
+ "factor": self._generate_factor(topic),
164
+ "event": self._generate_event(),
165
+ "consequence_1": self._generate_consequence(fact_1),
166
+ "consequence_2": self._generate_consequence(fact_2),
167
+ "combined_inference": self._generate_combined_inference(fact_1, fact_2),
168
+ "extra_insight": self._generate_extra_insight(topic),
169
+ "step_1": self._generate_step(1, domain),
170
+ "step_2": self._generate_step(2, domain),
171
+ "step_3": self._generate_step(3, domain),
172
+ "step_4": self._generate_step(4, domain),
173
+ "connection": self._generate_connection(),
174
+ "key_insight": self._generate_key_insight(domain),
175
+ "result": self._generate_result(domain),
176
+ "evidence": self._generate_evidence(domain),
177
+ "confirmation": self._generate_confirmation(domain),
178
+ "method": self._generate_method(domain),
179
+ "starting_point": self._generate_starting_point(domain),
180
+ "determination": self._generate_determination(domain),
181
+ "intermediate_conclusion": self._generate_intermediate_conclusion(domain),
182
+ "final_step": self._generate_final_step(domain),
183
+ }
184
+
185
+ # Format strings with context
186
+ premise = premise_template.format(**context)
187
+ reasoning = reasoning_template.format(**context)
188
+
189
+ # Generate conclusion based on the premise and reasoning
190
+ conclusion = self._generate_conclusion(premise, reasoning)
191
+ context["conclusion"] = conclusion
192
+ conclusion_text = conclusion_template.format(**context)
193
+
194
+ # Return the structured example
195
+ return {
196
+ "premise": premise,
197
+ "reasoning": reasoning,
198
+ "conclusion": conclusion_text,
199
+ "domain": domain,
200
+ "complexity": random.choice(["low", "medium", "high"]),
201
+ "metadata": {
202
+ "variables": variables,
203
+ "topic": topic,
204
+ "reasoning_type": domain,
205
+ },
206
+ }
207
+
208
+ # Helper methods to generate dynamic content
209
+ def _generate_observation(self, topic: str) -> str:
210
+ observations = [
211
+ "the rate of change increases over time",
212
+ "there is a strong correlation between input and output variables",
213
+ "performance degrades under specific conditions",
214
+ "the system exhibits unexpected behavior when stressed",
215
+ "outliers significantly impact the overall trend",
216
+ "recurring patterns emerge after sufficient iterations",
217
+ "the distribution follows a power law rather than normal distribution",
218
+ "feedback loops amplify small initial differences",
219
+ "thresholds exist beyond which behavior changes dramatically",
220
+ "cyclical patterns emerge with a period of varying length",
221
+ ]
222
+ return random.choice(observations)
223
+
224
+ def _generate_condition(self, topic: str) -> str:
225
+ conditions = [
226
+ "the system is under heavy load",
227
+ "external factors remain constant",
228
+ "all variables are optimized simultaneously",
229
+ "specific constraints are applied",
230
+ "the environment changes unexpectedly",
231
+ "resource limitations come into play",
232
+ "feedback mechanisms are activated",
233
+ "multiple agents interact simultaneously",
234
+ "time delays exceed a critical threshold",
235
+ "boundary conditions are enforced",
236
+ ]
237
+ return random.choice(conditions)
238
+
239
+ def _generate_variables(self, domain: str) -> List[str]:
240
+ variable_sets = {
241
+ "logical deduction": [
242
+ "premise validity",
243
+ "logical consistency",
244
+ "conclusion strength",
245
+ "assumption bias",
246
+ ],
247
+ "causal analysis": [
248
+ "cause magnitude",
249
+ "effect delay",
250
+ "confounding factors",
251
+ "intervention efficacy",
252
+ ],
253
+ "statistical reasoning": [
254
+ "sample size",
255
+ "confidence interval",
256
+ "p-value",
257
+ "effect size",
258
+ "statistical power",
259
+ ],
260
+ "comparative analysis": [
261
+ "baseline performance",
262
+ "improvement margin",
263
+ "relative efficiency",
264
+ "comparison fairness",
265
+ ],
266
+ "analogical reasoning": [
267
+ "source similarity",
268
+ "target applicability",
269
+ "mapping strength",
270
+ "inference validity",
271
+ ],
272
+ }
273
+
274
+ # Get domain-specific variables or use generic ones
275
+ variables = variable_sets.get(
276
+ domain,
277
+ [
278
+ "factor A",
279
+ "factor B",
280
+ "factor C",
281
+ "response variable",
282
+ "control variable",
283
+ ],
284
+ )
285
+
286
+ # Shuffle and return
287
+ random.shuffle(variables)
288
+ return variables
289
+
290
+ def _generate_relationship(self, var1: str, var2: str) -> str:
291
+ relationships = [
292
+ f"an increase in {var1} leads to a proportional increase in {var2}",
293
+ f"{var1} and {var2} have an inverse relationship",
294
+ f"changes in {var1} precede changes in {var2} by a consistent time interval",
295
+ f"{var1} influences {var2} only after exceeding a critical threshold",
296
+ f"the relationship between {var1} and {var2} is non-linear and follows a power law",
297
+ f"{var1} and {var2} are conditionally independent given certain conditions",
298
+ f"extreme values of {var1} have a disproportionate effect on {var2}",
299
+ f"the correlation between {var1} and {var2} changes direction over time",
300
+ ]
301
+ return random.choice(relationships)
302
+
303
+ def _generate_inference(self, relationship: str) -> str:
304
+ inferences = [
305
+ "we should focus our optimization efforts on the most sensitive parameters",
306
+ "the system will likely reach equilibrium after sufficient time",
307
+ "interventions should target root causes rather than symptoms",
308
+ "small changes can potentially lead to significant improvements",
309
+ "we need to account for interaction effects between variables",
310
+ "the observed behavior is likely part of a larger pattern",
311
+ "we should implement fail-safes for extreme conditions",
312
+ "multiple pathways may lead to the same outcome",
313
+ ]
314
+ return random.choice(inferences)
315
+
316
+ def _generate_deduction(self, var: str) -> str:
317
+ deductions = [
318
+ f"optimizing {var} alone will not solve the underlying problem",
319
+ f"changes in {var} represent a leading indicator for system performance",
320
+ f"the role of {var} has been previously underestimated",
321
+ f"{var} acts as a moderating variable in this context",
322
+ f"the impact of {var} follows a diminishing returns pattern",
323
+ f"{var} exhibits threshold effects that must be accounted for",
324
+ f"historical data on {var} supports this conclusion",
325
+ f"contrary to conventional wisdom, {var} is not the limiting factor",
326
+ ]
327
+ return random.choice(deductions)
328
+
329
+ def _generate_question(self, topic: str) -> str:
330
+ questions = [
331
+ "how to optimize performance under these conditions",
332
+ "whether the observed pattern will continue in the future",
333
+ "which factors contribute most significantly to the outcome",
334
+ "how to mitigate negative effects while preserving benefits",
335
+ "what intervention would produce the most efficient solution",
336
+ "how robust the system is to unexpected disturbances",
337
+ "whether the findings can be generalized to other contexts",
338
+ "how to distinguish correlation from causation in this case",
339
+ ]
340
+ return random.choice(questions)
341
+
342
+ def _generate_period(self) -> str:
343
+ periods = [
344
+ "the past decade",
345
+ "periods of economic volatility",
346
+ "the system's entire operational history",
347
+ "multiple successive iterations",
348
+ "both growth and decline phases",
349
+ "controlled experimental conditions",
350
+ "repeated stress-test cycles",
351
+ "varying environmental conditions",
352
+ ]
353
+ return random.choice(periods)
354
+
355
+ def _generate_factor(self, topic: str) -> str:
356
+ factors = [
357
+ "resource utilization",
358
+ "system complexity",
359
+ "external pressure",
360
+ "user engagement",
361
+ "adaptation rate",
362
+ "failure frequency",
363
+ "communication efficiency",
364
+ "innovation adoption",
365
+ ]
366
+ return random.choice(factors)
367
+
368
+ def _generate_event(self) -> str:
369
+ events = [
370
+ "system initialization",
371
+ "critical resource depletion",
372
+ "unexpected environmental change",
373
+ "crossing a performance threshold",
374
+ "implementing a major upgrade",
375
+ "integrating new components",
376
+ "encountering novel inputs",
377
+ "recovering from failure",
378
+ ]
379
+ return random.choice(events)
380
+
381
+ def _generate_consequence(self, fact: str) -> str:
382
+ return "the system will adapt by adjusting its parameters accordingly"
383
+
384
+ def _generate_combined_inference(self, fact1: str, fact2: str) -> str:
385
+ return "we can establish a clear causal relationship between the observed phenomena"
386
+
387
+ def _generate_extra_insight(self, topic: str) -> str:
388
+ insights = [
389
+ "temporal patterns reveal cyclical behavior",
390
+ "boundary conditions significantly affect outcomes",
391
+ "network effects amplify individual contributions",
392
+ "emergent properties cannot be predicted from components alone",
393
+ "system resilience depends on redundant pathways",
394
+ "optimization often involves trade-offs between competing goals",
395
+ "adaptation requires continuous feedback and adjustment",
396
+ "complex systems often exhibit counterintuitive behavior",
397
+ ]
398
+ return random.choice(insights)
399
+
400
+ def _generate_step(self, step_num: int, domain: str) -> str:
401
+ if step_num == 1:
402
+ steps = [
403
+ "Identify the key variables and their relationships",
404
+ "Establish the initial conditions and constraints",
405
+ "Define the problem space and boundaries",
406
+ "Gather relevant data and observations",
407
+ "Frame the question in precise terms",
408
+ ]
409
+ elif step_num == 2:
410
+ steps = [
411
+ "Analyze the patterns and correlations in the data",
412
+ "Apply appropriate analytical methods",
413
+ "Consider alternative explanations",
414
+ "Map the causal relationships between factors",
415
+ "Identify potential confounding variables",
416
+ ]
417
+ elif step_num == 3:
418
+ steps = [
419
+ "Evaluate the strength of evidence for each possibility",
420
+ "Synthesize insights from multiple perspectives",
421
+ "Test hypotheses against available data",
422
+ "Assess the logical consistency of arguments",
423
+ "Consider edge cases and exceptions",
424
+ ]
425
+ else:
426
+ steps = [
427
+ "Draw conclusions based on the strongest evidence",
428
+ "Formulate actionable recommendations",
429
+ "Identify remaining uncertainties",
430
+ "Propose methods to validate findings",
431
+ "Connect conclusions to the original question",
432
+ ]
433
+ return random.choice(steps)
434
+
435
+ def _generate_connection(self) -> str:
436
+ connections = [
437
+ "there's a causal relationship rather than mere correlation",
438
+ "feedback mechanisms create self-reinforcing patterns",
439
+ "threshold effects trigger qualitative changes in behavior",
440
+ "multiple factors interact in non-linear ways",
441
+ "temporal sequences reveal important dependencies",
442
+ "structural constraints limit possible outcomes",
443
+ "probabilistic influences accumulate deterministically",
444
+ "conditional dependencies reveal deeper patterns",
445
+ ]
446
+ return random.choice(connections)
447
+
448
+ def _generate_key_insight(self, domain: str) -> str:
449
+ insights = [
450
+ "optimizing for average cases often fails at the extremes",
451
+ "emergent properties cannot be reduced to component behaviors",
452
+ "apparent contradictions point to incomplete models",
453
+ "historical patterns constrain future possibilities",
454
+ "local optimizations can lead to global suboptimality",
455
+ "precision must be balanced with generalizability",
456
+ "second-order effects often dominate in the long run",
457
+ "robust systems prioritize adaptation over optimization",
458
+ ]
459
+ return random.choice(insights)
460
+
461
+ def _generate_result(self, domain: str) -> str:
462
+ results = [
463
+ "performance improves non-linearly",
464
+ "stability increases at the cost of responsiveness",
465
+ "resource utilization becomes more efficient",
466
+ "adaptability improves in novel situations",
467
+ "resilience to disturbances increases",
468
+ "error rates decrease systematically",
469
+ "learning accelerates with experience",
470
+ "coordination emerges without central control",
471
+ ]
472
+ return random.choice(results)
473
+
474
+ def _generate_evidence(self, domain: str) -> str:
475
+ evidence = [
476
+ "historical performance data",
477
+ "controlled experimental results",
478
+ "comparative case studies",
479
+ "simulation outcomes under varied conditions",
480
+ "natural experiments from system perturbations",
481
+ "user feedback and behavioral patterns",
482
+ "statistical analysis of large datasets",
483
+ "theoretical models with empirical validation",
484
+ ]
485
+ return random.choice(evidence)
486
+
487
+ def _generate_confirmation(self, domain: str) -> str:
488
+ confirmations = [
489
+ "the hypothesized mechanism actually operates as expected",
490
+ "predicted outcomes match observed results",
491
+ "alternative explanations can be ruled out",
492
+ "the pattern holds across different contexts",
493
+ "edge cases follow the same principles",
494
+ "the model successfully predicts future behavior",
495
+ "interventions produce expected effects",
496
+ "independent measures converge on the same conclusion",
497
+ ]
498
+ return random.choice(confirmations)
499
+
500
+ def _generate_method(self, domain: str) -> str:
501
+ methods = [
502
+ "systematic decomposition into components",
503
+ "counterfactual analysis",
504
+ "process tracing through causal chains",
505
+ "comparative analysis of similar cases",
506
+ "statistical inference from patterns",
507
+ "first principles reasoning",
508
+ "model-based simulation",
509
+ "abductive inference to the best explanation",
510
+ ]
511
+ return random.choice(methods)
512
+
513
+ def _generate_starting_point(self, domain: str) -> str:
514
+ starting_points = [
515
+ "the fundamental constraints of the system",
516
+ "established principles in this domain",
517
+ "patterns observed in similar situations",
518
+ "the minimal necessary conditions",
519
+ "key defining relationships",
520
+ "initial boundary conditions",
521
+ "critical assumptions that must hold",
522
+ "invariant properties across contexts",
523
+ ]
524
+ return random.choice(starting_points)
525
+
526
+ def _generate_determination(self, domain: str) -> str:
527
+ determinations = [
528
+ "certain factors exert disproportionate influence",
529
+ "system behavior follows predictable patterns under specific conditions",
530
+ "apparent anomalies actually confirm deeper principles",
531
+ "constraints channel possible outcomes in specific directions",
532
+ "dynamic equilibria emerge from competing forces",
533
+ "feedback loops stabilize or amplify depending on parameters",
534
+ "path dependencies limit future possibilities",
535
+ "critical thresholds separate qualitatively different regimes",
536
+ ]
537
+ return random.choice(determinations)
538
+
539
+ def _generate_intermediate_conclusion(self, domain: str) -> str:
540
+ conclusions = [
541
+ "we need to reconsider fundamental assumptions",
542
+ "the system exhibits emergent properties not predictable from components",
543
+ "apparent contradictions resolve at a higher level of analysis",
544
+ "complex interactions require a more nuanced approach",
545
+ "optimal solutions balance multiple competing objectives",
546
+ "robustness comes at the cost of peak performance",
547
+ "adaptability requires maintaining strategic flexibility",
548
+ "precision must be traded off against generalizability",
549
+ ]
550
+ return random.choice(conclusions)
551
+
552
+ def _generate_final_step(self, domain: str) -> str:
553
+ final_steps = [
554
+ "we can formulate a general principle that applies broadly",
555
+ "we can predict system behavior under novel conditions",
556
+ "we can design interventions that leverage key mechanisms",
557
+ "we can identify early warning signals for critical transitions",
558
+ "we can optimize for robust performance across scenarios",
559
+ "we can balance competing objectives through targeted trade-offs",
560
+ "we can establish boundaries of applicability for our conclusions",
561
+ "we can translate insights into actionable recommendations",
562
+ ]
563
+ return random.choice(final_steps)
564
+
565
+ def _generate_conclusion(self, premise: str, reasoning: str) -> str:
566
+ """Generate a coherent conclusion based on premise and reasoning."""
567
+ conclusions = [
568
+ "the observed patterns indicate a fundamental relationship between key variables",
569
+ "we should prioritize interventions that address root causes rather than symptoms",
570
+ "the system's behavior can be predicted with reasonable accuracy under specified conditions",
571
+ "optimizing for extreme cases provides more robust performance than optimizing for average cases",
572
+ "adaptation mechanisms are essential for maintaining performance in changing environments",
573
+ "complex interactions between components create emergent properties at the system level",
574
+ "resource allocation should follow a dynamic rather than static strategy",
575
+ "feedback loops must be carefully managed to prevent unintended consequences",
576
+ ]
577
+ return random.choice(conclusions)
578
+
579
+
580
+ # Register the generator
581
+ SyntheticDataGeneratorRegistry.register("reasoning", ReasoningDataGenerator)
src/multi_agent_training/specialized/memory_module.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory Training Module for MangoMAS Local
3
+
4
+ This module implements specialized training for memory and context retention capabilities,
5
+ adapted from the AWS backup system for local training.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ from typing import Any, Dict, List
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset
18
+
19
+ from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class MemoryDataset(Dataset):
25
+ """Dataset for training memory and context retention capabilities."""
26
+
27
+ def __init__(self, data_path: str, tokenizer, max_length: int = 1024):
28
+ """
29
+ Initialize the memory dataset.
30
+
31
+ Args:
32
+ data_path: Path to the memory training data file
33
+ tokenizer: Tokenizer for text processing
34
+ max_length: Maximum sequence length
35
+ """
36
+ self.tokenizer = tokenizer
37
+ self.max_length = max_length
38
+ self.data = self._load_data(data_path)
39
+
40
+ logger.info(f"Loaded memory dataset with {len(self.data)} examples")
41
+
42
+ def _load_data(self, data_path: str) -> List[Dict]:
43
+ """Load memory training data."""
44
+ data = []
45
+ with open(data_path, "r", encoding="utf-8") as f:
46
+ for line in f:
47
+ try:
48
+ item = json.loads(line.strip())
49
+ # Validate required fields
50
+ if "conversation" in item and isinstance(
51
+ item["conversation"], list
52
+ ):
53
+ data.append(item)
54
+ except json.JSONDecodeError:
55
+ continue
56
+ return data
57
+
58
+ def __len__(self):
59
+ return len(self.data)
60
+
61
+ def __getitem__(self, idx):
62
+ item = self.data[idx]
63
+
64
+ # Format the conversation for memory training
65
+ conversation = item["conversation"]
66
+ context = "\n".join(
67
+ [f"{turn['role']}: {turn['content']}" for turn in conversation[:-1]]
68
+ )
69
+ target = conversation[-1]["content"]
70
+
71
+ prompt = f"Context:\n{context}\nResponse: {target}"
72
+
73
+ # Tokenize
74
+ encoding = self.tokenizer(
75
+ prompt,
76
+ max_length=self.max_length,
77
+ padding="max_length",
78
+ truncation=True,
79
+ return_tensors="pt",
80
+ )
81
+
82
+ return {
83
+ "input_ids": encoding["input_ids"].squeeze(),
84
+ "attention_mask": encoding["attention_mask"].squeeze(),
85
+ "labels": encoding["input_ids"].squeeze(),
86
+ }
87
+
88
+
89
+ class MemoryTrainingModule(SpecializedTrainingModule):
90
+ """Specialized training module for memory and context retention capabilities."""
91
+
92
+ def __init__(self, config: TrainingModuleConfig, tokenizer):
93
+ """
94
+ Initialize the memory training module.
95
+
96
+ Args:
97
+ config: Module configuration
98
+ tokenizer: Tokenizer for text processing
99
+ """
100
+ super().__init__(config, tokenizer)
101
+
102
+ # Initialize memory-specific components
103
+ self.memory_loss = nn.CrossEntropyLoss(ignore_index=-100)
104
+ self.metrics = {
105
+ "memory_loss": 0.0,
106
+ "context_retention": 0.0,
107
+ "coherence_score": 0.0,
108
+ }
109
+
110
+ logger.info("Initialized MemoryTrainingModule")
111
+
112
+ def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
113
+ """
114
+ Prepare a batch of data for memory training.
115
+
116
+ Args:
117
+ batch: The input batch from the dataloader
118
+
119
+ Returns:
120
+ Processed batch ready for memory training
121
+ """
122
+ # Move batch to device
123
+ prepared_batch = {}
124
+ for key, value in batch.items():
125
+ if isinstance(value, torch.Tensor):
126
+ prepared_batch[key] = value.to(self.device)
127
+ else:
128
+ prepared_batch[key] = value
129
+
130
+ return prepared_batch
131
+
132
+ def compute_loss(
133
+ self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
134
+ ) -> torch.Tensor:
135
+ """
136
+ Compute the memory-specific loss.
137
+
138
+ Args:
139
+ student_outputs: Outputs from the student model
140
+ teacher_outputs: Outputs from the teacher model
141
+ batch: The processed input batch
142
+
143
+ Returns:
144
+ Loss tensor for memory training
145
+ """
146
+ try:
147
+ # Extract logits from model outputs
148
+ if hasattr(student_outputs, "logits"):
149
+ student_logits = student_outputs.logits
150
+ else:
151
+ student_logits = student_outputs
152
+
153
+ if hasattr(teacher_outputs, "logits"):
154
+ teacher_logits = teacher_outputs.logits
155
+ else:
156
+ teacher_logits = teacher_outputs
157
+
158
+ # Get labels from batch
159
+ labels = batch.get("labels", batch.get("input_ids"))
160
+
161
+ # Compute cross entropy loss for memory
162
+ shift_logits = student_logits[..., :-1, :].contiguous()
163
+ shift_labels = labels[..., 1:].contiguous()
164
+
165
+ memory_loss = self.memory_loss(
166
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
167
+ )
168
+
169
+ # Add KL divergence loss between student and teacher
170
+ if teacher_logits is not None:
171
+ kl_loss = F.kl_div(
172
+ F.log_softmax(student_logits, dim=-1),
173
+ F.softmax(teacher_logits, dim=-1),
174
+ reduction="batchmean",
175
+ )
176
+ total_loss = memory_loss + 0.1 * kl_loss
177
+ else:
178
+ total_loss = memory_loss
179
+
180
+ # Update metrics
181
+ self.metrics["memory_loss"] = memory_loss.item()
182
+
183
+ return total_loss * self.loss_weight
184
+
185
+ except Exception as e:
186
+ logger.error(f"Error computing memory loss: {e}")
187
+ # Return a small loss to avoid training failure
188
+ return torch.tensor(0.01, requires_grad=True)
189
+
190
+ def get_metrics(self) -> Dict[str, float]:
191
+ """
192
+ Get metrics specific to memory training.
193
+
194
+ Returns:
195
+ Dictionary of memory metrics
196
+ """
197
+ return self.metrics.copy()
198
+
199
+ def generate_synthetic_memory_data(
200
+ self, output_path: str, num_samples: int = 1000
201
+ ) -> None:
202
+ """
203
+ Generate synthetic memory training data.
204
+
205
+ Args:
206
+ output_path: Path to save the generated data
207
+ num_samples: Number of samples to generate
208
+ """
209
+ # This is a simplified implementation based on the AWS backup
210
+ # In a full implementation, this would be much more sophisticated
211
+
212
+ conversation_templates = [
213
+ [
214
+ {
215
+ "role": "user",
216
+ "content": "Hi, my name is Alex and I'm interested in machine learning.",
217
+ },
218
+ {
219
+ "role": "assistant",
220
+ "content": "Hello Alex! I'd be happy to discuss machine learning with you. What aspects are you most interested in?",
221
+ },
222
+ {
223
+ "role": "user",
224
+ "content": "I'm particularly interested in natural language processing.",
225
+ },
226
+ {
227
+ "role": "assistant",
228
+ "content": "NLP is a fascinating field! It's used for tasks like translation, summarization, and question answering.",
229
+ },
230
+ {
231
+ "role": "user",
232
+ "content": "What do you think would be a good first project?",
233
+ },
234
+ {
235
+ "role": "assistant",
236
+ "content": "For a beginner in NLP, I'd recommend starting with a text classification project, like sentiment analysis.",
237
+ },
238
+ ],
239
+ [
240
+ {
241
+ "role": "user",
242
+ "content": "I'm planning a trip to Japan next spring.",
243
+ },
244
+ {
245
+ "role": "assistant",
246
+ "content": "That sounds exciting! Japan is beautiful in spring with cherry blossoms. What cities are you planning to visit?",
247
+ },
248
+ {
249
+ "role": "user",
250
+ "content": "I'm thinking Tokyo, Kyoto, and maybe Osaka.",
251
+ },
252
+ {
253
+ "role": "assistant",
254
+ "content": "Great choices! Tokyo has modern attractions, Kyoto has historical temples, and Osaka is known for amazing food.",
255
+ },
256
+ {
257
+ "role": "user",
258
+ "content": "What's the best way to travel between these cities?",
259
+ },
260
+ {
261
+ "role": "assistant",
262
+ "content": "The Shinkansen (bullet train) is the most efficient way to travel between these cities. It's fast, comfortable, and reliable.",
263
+ },
264
+ ],
265
+ ]
266
+
267
+ recall_templates = [
268
+ {
269
+ "recall_context": "what was my name again?",
270
+ "recall_target": "Your name is Alex, as you mentioned at the beginning of our conversation.",
271
+ },
272
+ {
273
+ "recall_context": "which cities did I say I wanted to visit?",
274
+ "recall_target": "You mentioned you're planning to visit Tokyo, Kyoto, and possibly Osaka during your trip to Japan.",
275
+ },
276
+ ]
277
+
278
+ # Generate variations
279
+ output_data = []
280
+ for _ in range(num_samples):
281
+ template_idx = random.randint(0, len(conversation_templates) - 1)
282
+ conversation = conversation_templates[template_idx].copy()
283
+
284
+ # Add a recall question if this is the right template
285
+ if template_idx < len(recall_templates):
286
+ recall_template = recall_templates[template_idx]
287
+
288
+ # Add a user question asking for recall
289
+ conversation.append(
290
+ {"role": "user", "content": recall_template["recall_context"]}
291
+ )
292
+
293
+ # Create the full example with recall targets
294
+ example = {
295
+ "conversation": conversation,
296
+ "recall_context": recall_template["recall_context"],
297
+ "recall_target": recall_template["recall_target"],
298
+ "metadata": {"generated": True, "requires_memory": True},
299
+ }
300
+ else:
301
+ # Regular conversation without specific recall target
302
+ example = {
303
+ "conversation": conversation,
304
+ "metadata": {"generated": True, "requires_memory": False},
305
+ }
306
+
307
+ output_data.append(example)
308
+
309
+ # Save to file
310
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
311
+ with open(output_path, "w", encoding="utf-8") as f:
312
+ for item in output_data:
313
+ f.write(json.dumps(item) + "\n")
314
+
315
+ logger.info(
316
+ f"Generated {len(output_data)} synthetic memory examples at {output_path}"
317
+ )
src/multi_agent_training/specialized/reasoning_module.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reasoning Training Module for MangoMAS Local
3
+
4
+ This module implements specialized training for reasoning capabilities,
5
+ adapted from the AWS backup system for local training.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ import re
13
+ from typing import Any, Dict, List
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import Dataset
19
+
20
+ from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ReasoningDataset(Dataset):
26
+ """Dataset for training reasoning capabilities."""
27
+
28
+ def __init__(self, data_path: str, tokenizer, max_length: int = 512):
29
+ """
30
+ Initialize the reasoning dataset.
31
+
32
+ Args:
33
+ data_path: Path to the reasoning data file
34
+ tokenizer: Tokenizer for text processing
35
+ max_length: Maximum sequence length
36
+ """
37
+ self.tokenizer = tokenizer
38
+ self.max_length = max_length
39
+ self.data = self._load_data(data_path)
40
+
41
+ logger.info(f"Loaded reasoning dataset with {len(self.data)} examples")
42
+
43
+ def _load_data(self, data_path: str) -> List[Dict]:
44
+ """Load reasoning training data."""
45
+ data = []
46
+ with open(data_path, "r", encoding="utf-8") as f:
47
+ for line in f:
48
+ try:
49
+ item = json.loads(line.strip())
50
+ # Validate required fields
51
+ if "question" in item and "reasoning" in item and "answer" in item:
52
+ data.append(item)
53
+ except json.JSONDecodeError:
54
+ continue
55
+ return data
56
+
57
+ def __len__(self):
58
+ return len(self.data)
59
+
60
+ def __getitem__(self, idx):
61
+ item = self.data[idx]
62
+
63
+ # Format the reasoning prompt
64
+ prompt = f"Question: {item['question']}\nReasoning: {item['reasoning']}\nAnswer: {item['answer']}"
65
+
66
+ # Tokenize
67
+ encoding = self.tokenizer(
68
+ prompt,
69
+ max_length=self.max_length,
70
+ padding="max_length",
71
+ truncation=True,
72
+ return_tensors="pt",
73
+ )
74
+
75
+ return {
76
+ "input_ids": encoding["input_ids"].squeeze(),
77
+ "attention_mask": encoding["attention_mask"].squeeze(),
78
+ "labels": encoding["input_ids"].squeeze(),
79
+ }
80
+
81
+
82
+ class ReasoningEvaluator:
83
+ """Evaluator for reasoning capabilities."""
84
+
85
+ def __init__(self, tokenizer):
86
+ """
87
+ Initialize the reasoning evaluator.
88
+
89
+ Args:
90
+ tokenizer: Tokenizer for text processing
91
+ """
92
+ self.tokenizer = tokenizer
93
+ self.metrics = {
94
+ "logical_consistency": 0.0,
95
+ "premise_relevance": 0.0,
96
+ "conclusion_validity": 0.0,
97
+ "steps_coherence": 0.0,
98
+ }
99
+
100
+ def evaluate(self, model, eval_dataset: ReasoningDataset) -> Dict[str, float]:
101
+ """
102
+ Evaluate reasoning capabilities on the provided dataset.
103
+
104
+ Args:
105
+ model: The model to evaluate
106
+ eval_dataset: Dataset of reasoning examples
107
+
108
+ Returns:
109
+ Dictionary of evaluation metrics
110
+ """
111
+ model.eval()
112
+ device = next(model.parameters()).device
113
+
114
+ # Reset metrics
115
+ for key in self.metrics:
116
+ self.metrics[key] = 0.0
117
+
118
+ total_examples = min(
119
+ len(eval_dataset), 100
120
+ ) # Limit to 100 examples for efficiency
121
+
122
+ with torch.no_grad():
123
+ for idx in range(total_examples):
124
+ example = eval_dataset[idx]
125
+ premise = example["premise"]
126
+
127
+ # Generate reasoning and conclusion from premise
128
+ prompt = f"Premise: {premise}\nReasoning:"
129
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
130
+ device
131
+ )
132
+
133
+ generated_ids = model.generate(
134
+ input_ids, max_length=512, temperature=0.7, num_return_sequences=1
135
+ )
136
+
137
+ generated_text = self.tokenizer.decode(
138
+ generated_ids[0], skip_special_tokens=True
139
+ )
140
+
141
+ # Extract reasoning and conclusion from generated text
142
+ try:
143
+ generated_reasoning = re.search(
144
+ r"Reasoning:(.*?)(?:Conclusion:|$)", generated_text, re.DOTALL
145
+ )
146
+ generated_conclusion = re.search(
147
+ r"Conclusion:(.*?)$", generated_text, re.DOTALL
148
+ )
149
+
150
+ if generated_reasoning:
151
+ gen_reasoning = generated_reasoning.group(1).strip()
152
+ else:
153
+ gen_reasoning = ""
154
+
155
+ if generated_conclusion:
156
+ gen_conclusion = generated_conclusion.group(1).strip()
157
+ else:
158
+ gen_conclusion = ""
159
+
160
+ # Evaluate reasoning quality
161
+ self._update_metrics(
162
+ premise=premise,
163
+ expected_reasoning=example["reasoning"],
164
+ expected_conclusion=example["conclusion"],
165
+ generated_reasoning=gen_reasoning,
166
+ generated_conclusion=gen_conclusion,
167
+ )
168
+ except Exception as e:
169
+ logger.error(f"Error evaluating reasoning: {e}")
170
+
171
+ # Calculate averages
172
+ for key in self.metrics:
173
+ self.metrics[key] /= total_examples
174
+
175
+ return self.metrics
176
+
177
+ def _update_metrics(
178
+ self,
179
+ premise: str,
180
+ expected_reasoning: str,
181
+ expected_conclusion: str,
182
+ generated_reasoning: str,
183
+ generated_conclusion: str,
184
+ ) -> None:
185
+ """
186
+ Update reasoning metrics based on a single example.
187
+
188
+ Args:
189
+ premise: Input premise
190
+ expected_reasoning: Expected reasoning steps
191
+ expected_conclusion: Expected conclusion
192
+ generated_reasoning: Generated reasoning steps
193
+ generated_conclusion: Generated conclusion
194
+ """
195
+ # Very simplified evaluation - in a real system, this would use more sophisticated
196
+ # semantic similarity and logical consistency checking
197
+
198
+ # Logical consistency - check if reasoning follows from premise
199
+ self.metrics["logical_consistency"] += 0.5 # Simplified placeholder
200
+
201
+ # Premise relevance - check if reasoning references key terms from premise
202
+ premise_terms = set(premise.lower().split())
203
+ reasoning_terms = set(generated_reasoning.lower().split())
204
+ term_overlap = len(premise_terms.intersection(reasoning_terms)) / max(
205
+ len(premise_terms), 1
206
+ )
207
+ self.metrics["premise_relevance"] += term_overlap
208
+
209
+ # Conclusion validity - check if conclusion follows from reasoning
210
+ if generated_conclusion and "therefore" in generated_conclusion.lower():
211
+ self.metrics["conclusion_validity"] += 0.7 # Simplified placeholder
212
+ else:
213
+ self.metrics["conclusion_validity"] += 0.3
214
+
215
+ # Steps coherence - check for logical flow markers
216
+ flow_markers = [
217
+ "first",
218
+ "second",
219
+ "third",
220
+ "then",
221
+ "next",
222
+ "finally",
223
+ "because",
224
+ "thus",
225
+ "hence",
226
+ ]
227
+ marker_count = sum(
228
+ 1 for marker in flow_markers if marker in generated_reasoning.lower()
229
+ )
230
+ self.metrics["steps_coherence"] += min(1.0, marker_count / 3)
231
+
232
+
233
+ class ReasoningTrainingModule(SpecializedTrainingModule):
234
+ """Specialized training module for reasoning capabilities."""
235
+
236
+ def __init__(self, config: TrainingModuleConfig, tokenizer):
237
+ """
238
+ Initialize the reasoning training module.
239
+
240
+ Args:
241
+ config: Module configuration
242
+ tokenizer: Tokenizer for text processing
243
+ """
244
+ super().__init__(config, tokenizer)
245
+
246
+ # Initialize reasoning-specific components
247
+ self.reasoning_loss = nn.CrossEntropyLoss(ignore_index=-100)
248
+ self.metrics = {
249
+ "reasoning_loss": 0.0,
250
+ "reasoning_accuracy": 0.0,
251
+ "reasoning_perplexity": 0.0,
252
+ }
253
+
254
+ logger.info("Initialized ReasoningTrainingModule")
255
+
256
+ def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
257
+ """
258
+ Prepare a batch of data for reasoning training.
259
+
260
+ Args:
261
+ batch: The input batch from the dataloader
262
+
263
+ Returns:
264
+ Processed batch ready for reasoning training
265
+ """
266
+ # Move batch to device
267
+ prepared_batch = {}
268
+ for key, value in batch.items():
269
+ if isinstance(value, torch.Tensor):
270
+ prepared_batch[key] = value.to(self.device)
271
+ else:
272
+ prepared_batch[key] = value
273
+
274
+ return prepared_batch
275
+
276
+ def compute_loss(
277
+ self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
278
+ ) -> torch.Tensor:
279
+ """
280
+ Compute the reasoning-specific loss.
281
+
282
+ Args:
283
+ student_outputs: Outputs from the student model
284
+ teacher_outputs: Outputs from the teacher model
285
+ batch: The processed input batch
286
+
287
+ Returns:
288
+ Loss tensor for reasoning training
289
+ """
290
+ try:
291
+ # Extract logits from model outputs
292
+ if hasattr(student_outputs, "logits"):
293
+ student_logits = student_outputs.logits
294
+ else:
295
+ student_logits = student_outputs
296
+
297
+ if hasattr(teacher_outputs, "logits"):
298
+ teacher_logits = teacher_outputs.logits
299
+ else:
300
+ teacher_logits = teacher_outputs
301
+
302
+ # Get labels from batch
303
+ labels = batch.get("labels", batch.get("input_ids"))
304
+
305
+ # Compute cross entropy loss for reasoning
306
+ shift_logits = student_logits[..., :-1, :].contiguous()
307
+ shift_labels = labels[..., 1:].contiguous()
308
+
309
+ reasoning_loss = self.reasoning_loss(
310
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
311
+ )
312
+
313
+ # Add KL divergence loss between student and teacher
314
+ if teacher_logits is not None:
315
+ kl_loss = F.kl_div(
316
+ F.log_softmax(student_logits, dim=-1),
317
+ F.softmax(teacher_logits, dim=-1),
318
+ reduction="batchmean",
319
+ )
320
+ total_loss = reasoning_loss + 0.1 * kl_loss
321
+ else:
322
+ total_loss = reasoning_loss
323
+
324
+ # Update metrics
325
+ self.metrics["reasoning_loss"] = reasoning_loss.item()
326
+
327
+ return total_loss * self.loss_weight
328
+
329
+ except Exception as e:
330
+ logger.error(f"Error computing reasoning loss: {e}")
331
+ # Return a small loss to avoid training failure
332
+ return torch.tensor(0.01, requires_grad=True)
333
+
334
+ def get_metrics(self) -> Dict[str, float]:
335
+ """
336
+ Get metrics specific to reasoning training.
337
+
338
+ Returns:
339
+ Dictionary of reasoning metrics
340
+ """
341
+ return self.metrics.copy()
342
+
343
+ def generate_synthetic_reasoning_data(
344
+ self, output_path: str, num_samples: int = 1000
345
+ ) -> None:
346
+ """
347
+ Generate synthetic reasoning data for training.
348
+
349
+ Args:
350
+ output_path: Path to save the generated data
351
+ num_samples: Number of samples to generate
352
+ """
353
+ # This is a simplified implementation based on the AWS backup's synthetic_generator
354
+ # In a full implementation, this would be much more sophisticated
355
+
356
+ templates = [
357
+ {
358
+ "premise": "If it rains, the ground gets wet. It is raining now.",
359
+ "reasoning": "Since it is raining, and rain makes the ground wet, we can conclude that the ground is getting wet.",
360
+ "conclusion": "Therefore, the ground is wet.",
361
+ },
362
+ {
363
+ "premise": "All mammals are warm-blooded. Whales are mammals.",
364
+ "reasoning": "Whales are classified as mammals. All mammals are warm-blooded animals. Therefore, as a mammal, a whale must be warm-blooded.",
365
+ "conclusion": "Therefore, whales are warm-blooded.",
366
+ },
367
+ {
368
+ "premise": "If you study hard, you will pass the exam. You studied hard.",
369
+ "reasoning": "The premise states a conditional relationship between studying hard and passing the exam. Since you studied hard, the condition is met.",
370
+ "conclusion": "Therefore, you will pass the exam.",
371
+ },
372
+ ]
373
+
374
+ # Generate variations of the templates
375
+ output_data = []
376
+ for _ in range(num_samples):
377
+ template = random.choice(templates)
378
+
379
+ # Create a variation (very simplified)
380
+ variation = {
381
+ "premise": template["premise"],
382
+ "reasoning": template["reasoning"],
383
+ "conclusion": template["conclusion"],
384
+ "metadata": {
385
+ "generated": True,
386
+ "timestamp": str(
387
+ torch.cuda.get_device_name(0)
388
+ if torch.cuda.is_available()
389
+ else "CPU"
390
+ ),
391
+ },
392
+ }
393
+
394
+ output_data.append(variation)
395
+
396
+ # Save to file
397
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
398
+ with open(output_path, "w", encoding="utf-8") as f:
399
+ for item in output_data:
400
+ f.write(json.dumps(item) + "\n")
401
+
402
+ logger.info(
403
+ f"Generated {len(output_data)} synthetic reasoning examples at {output_path}"
404
+ )
src/multi_agent_training/synthetic_data_generator.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synthetic Data Generator for MangoMAS Local
3
+
4
+ This module provides a framework for generating synthetic training data
5
+ for different specialized capabilities, adaptable across all modules.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ from abc import ABC, abstractmethod
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple
15
+
16
+ import yaml
17
+ from tqdm import tqdm
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SyntheticDataGenerator(ABC):
23
+ """
24
+ Abstract base class for synthetic data generators.
25
+ Each specialized module can implement this interface for data generation.
26
+ """
27
+
28
+ def __init__(self, config: Dict[str, Any], output_dir: str = "data/processed"):
29
+ """
30
+ Initialize the synthetic data generator.
31
+
32
+ Args:
33
+ config: Configuration for the data generator
34
+ output_dir: Directory to save generated data
35
+ """
36
+ self.config = config
37
+ self.output_dir = Path(output_dir)
38
+ self.output_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ # Set number of examples to generate
41
+ self.num_examples = config.get("synthetic_examples", 1000)
42
+
43
+ # Template pool for generation
44
+ self.templates = self._load_templates()
45
+
46
+ logger.info(f"Initialized {self.__class__.__name__} with {self.num_examples} examples")
47
+
48
+ @abstractmethod
49
+ def _load_templates(self) -> List[Dict[str, Any]]:
50
+ """
51
+ Load templates for data generation.
52
+ Each implementation should define its own templates.
53
+
54
+ Returns:
55
+ List of template dictionaries
56
+ """
57
+ pass
58
+
59
+ @abstractmethod
60
+ def generate_example(self) -> Dict[str, Any]:
61
+ """
62
+ Generate a single synthetic training example.
63
+
64
+ Returns:
65
+ Dictionary with the generated example
66
+ """
67
+ pass
68
+
69
+ def generate_dataset(self, filename: str, num_examples: Optional[int] = None) -> str:
70
+ """
71
+ Generate a synthetic dataset and save to a JSONL file.
72
+
73
+ Args:
74
+ filename: Name of the output file
75
+ num_examples: Number of examples to generate (overrides config)
76
+
77
+ Returns:
78
+ Path to the generated dataset file
79
+ """
80
+ n = num_examples if num_examples is not None else self.num_examples
81
+ output_file = self.output_dir / filename
82
+
83
+ logger.info(f"Generating {n} synthetic examples for {self.__class__.__name__}")
84
+
85
+ with open(output_file, 'w', encoding='utf-8') as f:
86
+ for _ in tqdm(range(n), desc=f"Generating {filename}"):
87
+ example = self.generate_example()
88
+ f.write(json.dumps(example) + '\n')
89
+
90
+ logger.info(f"Generated dataset saved to {output_file}")
91
+ return str(output_file)
92
+
93
+ def augment_existing_dataset(self, input_file: str, output_file: Optional[str] = None,
94
+ ratio: float = 0.5) -> str:
95
+ """
96
+ Augment an existing dataset with synthetic examples.
97
+
98
+ Args:
99
+ input_file: Path to the existing dataset
100
+ output_file: Path to save the augmented dataset (or None to overwrite)
101
+ ratio: Ratio of synthetic to original examples
102
+
103
+ Returns:
104
+ Path to the augmented dataset
105
+ """
106
+ if output_file is None:
107
+ output_file = input_file
108
+
109
+ # Load existing data
110
+ existing_data = []
111
+ try:
112
+ with open(input_file, 'r', encoding='utf-8') as f:
113
+ for line in f:
114
+ existing_data.append(json.loads(line.strip()))
115
+ except (FileNotFoundError, json.JSONDecodeError) as e:
116
+ logger.warning(f"Error loading existing data: {e}")
117
+ existing_data = []
118
+
119
+ # Calculate number of synthetic examples to generate
120
+ n_existing = len(existing_data)
121
+ n_synthetic = int(n_existing * ratio)
122
+
123
+ # Generate synthetic examples
124
+ synthetic_data = [self.generate_example() for _ in tqdm(range(n_synthetic),
125
+ desc=f"Generating augmentation data")]
126
+
127
+ # Combine datasets
128
+ combined_data = existing_data + synthetic_data
129
+ random.shuffle(combined_data)
130
+
131
+ # Save augmented dataset
132
+ with open(output_file, 'w', encoding='utf-8') as f:
133
+ for item in combined_data:
134
+ f.write(json.dumps(item) + '\n')
135
+
136
+ logger.info(f"Augmented dataset with {n_synthetic} synthetic examples, saved to {output_file}")
137
+ return output_file
138
+
139
+
140
+ class SyntheticDataGeneratorRegistry:
141
+ """Registry for all synthetic data generators in the system."""
142
+
143
+ _generators = {}
144
+
145
+ @classmethod
146
+ def register(cls, module_type: str, generator_class):
147
+ """Register a generator class for a module type."""
148
+ cls._generators[module_type] = generator_class
149
+
150
+ @classmethod
151
+ def get_generator(cls, module_type: str, config: Dict[str, Any], output_dir: str) -> SyntheticDataGenerator:
152
+ """Get a generator instance for a module type."""
153
+ if module_type not in cls._generators:
154
+ raise ValueError(f"No generator registered for module type: {module_type}")
155
+
156
+ return cls._generators[module_type](config, output_dir)
157
+
158
+ @classmethod
159
+ def list_generators(cls) -> List[str]:
160
+ """List all registered generator types."""
161
+ return list(cls._generators.keys())
162
+ """