MaliosDark commited on
Commit
c5de26c
·
verified ·
1 Parent(s): 3472b9f

Add AGI module: sofia_federated.py

Browse files
Files changed (1) hide show
  1. sofia_federated.py +601 -0
sofia_federated.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SOFIA Federated Learning Framework
4
+ Implements distributed training while preserving data privacy
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import pickle
10
+ import logging
11
+ import asyncio
12
+ import threading
13
+ from typing import Dict, List, Tuple, Optional, Any, Callable
14
+ from datetime import datetime
15
+ from collections import defaultdict
16
+ import hashlib
17
+ import secrets
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ import numpy as np
20
+
21
+ # For demonstration - in real implementation, use proper federated learning libraries
22
+ try:
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.utils.data import DataLoader, Dataset
26
+ TORCH_AVAILABLE = True
27
+ except ImportError:
28
+ TORCH_AVAILABLE = False
29
+ print("Warning: PyTorch not available. Federated learning will use mock implementation.")
30
+
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ class FederatedDataset(Dataset):
35
+ """
36
+ Dataset wrapper for federated learning
37
+ """
38
+
39
+ def __init__(self, data: List[Tuple[str, str]], tokenizer=None):
40
+ self.data = data
41
+ self.tokenizer = tokenizer
42
+
43
+ def __len__(self):
44
+ return len(self.data)
45
+
46
+ def __getitem__(self, idx):
47
+ text1, text2 = self.data[idx]
48
+ if self.tokenizer:
49
+ # In real implementation, tokenize here
50
+ return {
51
+ 'text1': text1,
52
+ 'text2': text2,
53
+ 'input_ids': torch.tensor([1, 2, 3]), # Mock
54
+ 'attention_mask': torch.tensor([1, 1, 1]) # Mock
55
+ }
56
+ return {
57
+ 'text1': text1,
58
+ 'text2': text2,
59
+ 'input_ids': torch.tensor([1, 2, 3]) if TORCH_AVAILABLE else [1, 2, 3], # Always provide input_ids
60
+ 'attention_mask': torch.tensor([1, 1, 1]) if TORCH_AVAILABLE else [1, 1, 1]
61
+ }
62
+
63
+ class LocalModel:
64
+ """
65
+ Represents a local model on a client device
66
+ """
67
+
68
+ def __init__(self, client_id: str, model_config: Dict[str, Any]):
69
+ self.client_id = client_id
70
+ self.model_config = model_config
71
+ self.model = None
72
+ self.optimizer = None
73
+ self.local_epochs = 1
74
+ self.batch_size = 32
75
+ self.learning_rate = 2e-5
76
+
77
+ # Privacy parameters
78
+ self.noise_multiplier = 0.1
79
+ self.max_grad_norm = 1.0
80
+
81
+ # Training state
82
+ self.current_round = 0
83
+ self.local_loss_history = []
84
+ self.samples_processed = 0
85
+
86
+ def initialize_model(self):
87
+ """Initialize the local model"""
88
+ if not TORCH_AVAILABLE:
89
+ logger.warning("PyTorch not available, using mock model")
90
+ self.model = MockModel()
91
+ return
92
+
93
+ # In real implementation, load SOFIA model architecture
94
+ # For now, use a simple transformer-like model
95
+ self.model = nn.Sequential(
96
+ nn.Linear(768, 512),
97
+ nn.ReLU(),
98
+ nn.Linear(512, 256),
99
+ nn.ReLU(),
100
+ nn.Linear(256, 128)
101
+ )
102
+
103
+ self.optimizer = torch.optim.AdamW(
104
+ self.model.parameters(),
105
+ lr=self.learning_rate,
106
+ weight_decay=0.01
107
+ )
108
+
109
+ def train_local(self, train_loader: DataLoader, epochs: int = 1) -> Dict[str, Any]:
110
+ """
111
+ Train the local model on client's data
112
+
113
+ Args:
114
+ train_loader: DataLoader with client's training data
115
+ epochs: Number of local training epochs
116
+
117
+ Returns:
118
+ Training statistics and model updates
119
+ """
120
+ if not self.model:
121
+ self.initialize_model()
122
+
123
+ self.model.train()
124
+ total_loss = 0
125
+ num_batches = 0
126
+
127
+ for epoch in range(epochs):
128
+ epoch_loss = 0
129
+ for batch in train_loader:
130
+ self.optimizer.zero_grad()
131
+
132
+ if TORCH_AVAILABLE:
133
+ # Mock forward pass - use input_ids shape
134
+ batch_size = batch['input_ids'].shape[0] if hasattr(batch['input_ids'], 'shape') else len(batch['input_ids'])
135
+ outputs = self.model(torch.randn(batch_size, 768))
136
+ loss = torch.nn.functional.mse_loss(outputs, torch.randn_like(outputs))
137
+ else:
138
+ # Mock loss
139
+ loss = 0.5
140
+
141
+ if TORCH_AVAILABLE and hasattr(loss, 'backward'):
142
+ loss.backward()
143
+
144
+ # Apply differential privacy (simplified)
145
+ self._apply_differential_privacy()
146
+
147
+ # Gradient clipping
148
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
149
+
150
+ self.optimizer.step()
151
+
152
+ epoch_loss += loss.item() if TORCH_AVAILABLE and hasattr(loss, 'item') else loss
153
+ num_batches += 1
154
+ batch_size = batch['input_ids'].shape[0] if hasattr(batch['input_ids'], 'shape') else len(batch['input_ids'])
155
+ self.samples_processed += batch_size
156
+
157
+ avg_epoch_loss = epoch_loss / num_batches
158
+ self.local_loss_history.append(avg_epoch_loss)
159
+ total_loss += avg_epoch_loss
160
+
161
+ # Generate model update (gradient or weight differences)
162
+ model_update = self._generate_model_update()
163
+
164
+ training_stats = {
165
+ 'client_id': self.client_id,
166
+ 'round': self.current_round,
167
+ 'epochs': epochs,
168
+ 'avg_loss': total_loss / epochs,
169
+ 'samples_processed': self.samples_processed,
170
+ 'model_update_size': len(pickle.dumps(model_update))
171
+ }
172
+
173
+ self.current_round += 1
174
+ return training_stats, model_update
175
+
176
+ def _apply_differential_privacy(self):
177
+ """Apply differential privacy noise to gradients"""
178
+ if not TORCH_AVAILABLE:
179
+ return
180
+
181
+ for param in self.model.parameters():
182
+ if param.grad is not None:
183
+ # Add Gaussian noise for differential privacy
184
+ noise = torch.normal(0, self.noise_multiplier, param.grad.shape)
185
+ param.grad.data += noise
186
+
187
+ def _generate_model_update(self) -> Dict[str, Any]:
188
+ """Generate model update for aggregation"""
189
+ if not TORCH_AVAILABLE:
190
+ # Mock update
191
+ return {
192
+ 'client_id': self.client_id,
193
+ 'round': self.current_round,
194
+ 'weights': {'layer1': np.random.randn(512, 768)},
195
+ 'gradients': {'layer1': np.random.randn(512, 768)}
196
+ }
197
+
198
+ # In real implementation, compute weight differences or gradients
199
+ update = {
200
+ 'client_id': self.client_id,
201
+ 'round': self.current_round,
202
+ 'weights': {},
203
+ 'gradients': {}
204
+ }
205
+
206
+ for name, param in self.model.named_parameters():
207
+ update['weights'][name] = param.data.clone()
208
+ # In practice, you'd send gradients or weight differences
209
+ update['gradients'][name] = param.grad.clone() if param.grad is not None else torch.zeros_like(param)
210
+
211
+ return update
212
+
213
+ def update_model(self, global_update: Dict[str, Any]):
214
+ """Update local model with global aggregated update"""
215
+ if not TORCH_AVAILABLE:
216
+ logger.info(f"Client {self.client_id}: Mock model update applied")
217
+ return
218
+
219
+ # In real implementation, apply the global update
220
+ for name, param in self.model.named_parameters():
221
+ if name in global_update.get('weights', {}):
222
+ param.data = global_update['weights'][name]
223
+
224
+ logger.info(f"Client {self.client_id}: Model updated with global parameters")
225
+
226
+ class MockModel:
227
+ """Mock model for demonstration when PyTorch is not available"""
228
+
229
+ def __init__(self):
230
+ self.parameters = lambda: []
231
+ self.training = True
232
+
233
+ def train(self):
234
+ self.training = True
235
+
236
+ def eval(self):
237
+ self.training = False
238
+
239
+ def __call__(self, x):
240
+ return torch.tensor(0.5) if TORCH_AVAILABLE else 0.5
241
+
242
+ class FederatedAggregator:
243
+ """
244
+ Aggregates model updates from multiple clients
245
+ """
246
+
247
+ def __init__(self, aggregation_method: str = 'fedavg'):
248
+ self.aggregation_method = aggregation_method
249
+ self.global_model_state = {}
250
+ self.client_updates = []
251
+ self.round_number = 0
252
+
253
+ # Aggregation statistics
254
+ self.aggregation_history = []
255
+
256
+ def aggregate_updates(self, client_updates: List[Tuple[Dict[str, Any], Any]]) -> Dict[str, Any]:
257
+ """
258
+ Aggregate model updates from clients
259
+
260
+ Args:
261
+ client_updates: List of (training_stats, model_update) tuples
262
+
263
+ Returns:
264
+ Aggregated global model update
265
+ """
266
+ if not client_updates:
267
+ return {}
268
+
269
+ self.round_number += 1
270
+ self.client_updates = client_updates
271
+
272
+ if self.aggregation_method == 'fedavg':
273
+ return self._fedavg_aggregation(client_updates)
274
+ elif self.aggregation_method == 'fedprox':
275
+ return self._fedprox_aggregation(client_updates)
276
+ else:
277
+ return self._fedavg_aggregation(client_updates)
278
+
279
+ def _fedavg_aggregation(self, client_updates: List[Tuple[Dict[str, Any], Any]]) -> Dict[str, Any]:
280
+ """Federated Averaging aggregation"""
281
+ if not client_updates:
282
+ return {}
283
+
284
+ # Collect all model updates
285
+ model_updates = [update for _, update in client_updates]
286
+
287
+ # Calculate total samples for weighted averaging
288
+ total_samples = sum(stats['samples_processed'] for stats, _ in client_updates)
289
+
290
+ # Aggregate weights
291
+ aggregated_weights = {}
292
+ aggregated_gradients = {}
293
+
294
+ # Get all parameter names from first update
295
+ if model_updates:
296
+ param_names = set()
297
+ for update in model_updates:
298
+ if 'weights' in update:
299
+ param_names.update(update['weights'].keys())
300
+
301
+ # Aggregate each parameter
302
+ for param_name in param_names:
303
+ weights = []
304
+ weight_contributions = []
305
+
306
+ for (stats, update) in client_updates:
307
+ if param_name in update.get('weights', {}):
308
+ weight = update['weights'][param_name]
309
+ sample_weight = stats['samples_processed'] / total_samples
310
+
311
+ if TORCH_AVAILABLE and isinstance(weight, torch.Tensor):
312
+ weights.append(weight * sample_weight)
313
+ else:
314
+ weights.append(np.array(weight) * sample_weight)
315
+ weight_contributions.append(sample_weight)
316
+
317
+ # Average the weights
318
+ if weights:
319
+ if TORCH_AVAILABLE and isinstance(weights[0], torch.Tensor):
320
+ aggregated_weights[param_name] = torch.stack(weights).sum(dim=0)
321
+ else:
322
+ aggregated_weights[param_name] = np.stack(weights).sum(axis=0)
323
+
324
+ # Store aggregation statistics
325
+ aggregation_stats = {
326
+ 'round': self.round_number,
327
+ 'num_clients': len(client_updates),
328
+ 'total_samples': total_samples,
329
+ 'aggregation_method': self.aggregation_method,
330
+ 'timestamp': datetime.now().isoformat()
331
+ }
332
+
333
+ self.aggregation_history.append(aggregation_stats)
334
+
335
+ return {
336
+ 'weights': aggregated_weights,
337
+ 'gradients': aggregated_gradients,
338
+ 'aggregation_stats': aggregation_stats
339
+ }
340
+
341
+ def _fedprox_aggregation(self, client_updates: List[Tuple[Dict[str, Any], Any]]) -> Dict[str, Any]:
342
+ """FedProx aggregation (simplified)"""
343
+ # FedProx adds a proximal term to prevent client drift
344
+ # For simplicity, using same logic as FedAvg but could add regularization
345
+ return self._fedavg_aggregation(client_updates)
346
+
347
+ def get_aggregation_stats(self) -> Dict[str, Any]:
348
+ """Get aggregation statistics"""
349
+ if not self.aggregation_history:
350
+ return {'total_rounds': 0}
351
+
352
+ recent_stats = self.aggregation_history[-1]
353
+ return {
354
+ 'total_rounds': len(self.aggregation_history),
355
+ 'last_round_clients': recent_stats['num_clients'],
356
+ 'last_round_samples': recent_stats['total_samples'],
357
+ 'aggregation_method': self.aggregation_method
358
+ }
359
+
360
+ class PrivacyController:
361
+ """
362
+ Manages privacy-preserving techniques in federated learning
363
+ """
364
+
365
+ def __init__(self, privacy_budget: float = 1.0, delta: float = 1e-5):
366
+ self.privacy_budget = privacy_budget
367
+ self.delta = delta
368
+ self.spent_budget = 0.0
369
+
370
+ # Privacy mechanisms
371
+ self.differential_privacy_enabled = True
372
+ self.secure_aggregation_enabled = False
373
+
374
+ def check_privacy_budget(self, epsilon: float) -> bool:
375
+ """Check if privacy budget allows the operation"""
376
+ if self.spent_budget + epsilon > self.privacy_budget:
377
+ return False
378
+ return True
379
+
380
+ def spend_privacy_budget(self, epsilon: float):
381
+ """Spend privacy budget"""
382
+ self.spent_budget += epsilon
383
+
384
+ def apply_secure_aggregation(self, client_updates: List[Any]) -> Any:
385
+ """Apply secure aggregation (simplified)"""
386
+ if not self.secure_aggregation_enabled:
387
+ return client_updates
388
+
389
+ # In real implementation, use cryptographic techniques
390
+ # For now, just return updates (no-op)
391
+ logger.info("Secure aggregation applied (simplified)")
392
+ return client_updates
393
+
394
+ def generate_privacy_report(self) -> Dict[str, Any]:
395
+ """Generate privacy report"""
396
+ return {
397
+ 'total_budget': self.privacy_budget,
398
+ 'spent_budget': self.spent_budget,
399
+ 'remaining_budget': self.privacy_budget - self.spent_budget,
400
+ 'differential_privacy_enabled': self.differential_privacy_enabled,
401
+ 'secure_aggregation_enabled': self.secure_aggregation_enabled,
402
+ 'privacy_level': 'high' if self.differential_privacy_enabled else 'medium'
403
+ }
404
+
405
+ class FederatedLearningCoordinator:
406
+ """
407
+ Coordinates federated learning across multiple clients
408
+ """
409
+
410
+ def __init__(self, num_clients: int = 3, rounds: int = 5):
411
+ self.num_clients = num_clients
412
+ self.rounds = rounds
413
+ self.clients = {}
414
+ self.aggregator = FederatedAggregator()
415
+ self.privacy_controller = PrivacyController()
416
+
417
+ # Communication
418
+ self.client_updates_queue = asyncio.Queue()
419
+ self.global_model_available = asyncio.Event()
420
+
421
+ # Statistics
422
+ self.training_stats = []
423
+ self.round_times = []
424
+
425
+ async def initialize_clients(self, client_data: Dict[str, List[Tuple[str, str]]]):
426
+ """Initialize federated clients with their data"""
427
+ for client_id, data in client_data.items():
428
+ model_config = {
429
+ 'model_type': 'sofia_embedding',
430
+ 'input_dim': 768,
431
+ 'hidden_dim': 512,
432
+ 'output_dim': 256
433
+ }
434
+
435
+ client = LocalModel(client_id, model_config)
436
+ client.initialize_model()
437
+
438
+ # Create dataset and dataloader
439
+ dataset = FederatedDataset(data)
440
+ dataloader = DataLoader(dataset, batch_size=client.batch_size, shuffle=True)
441
+
442
+ self.clients[client_id] = {
443
+ 'model': client,
444
+ 'dataloader': dataloader,
445
+ 'data_size': len(data)
446
+ }
447
+
448
+ logger.info(f"Initialized {len(self.clients)} federated clients")
449
+
450
+ async def run_federated_training(self) -> Dict[str, Any]:
451
+ """Run federated training for specified rounds"""
452
+ logger.info(f"Starting federated training for {self.rounds} rounds")
453
+
454
+ for round_num in range(1, self.rounds + 1):
455
+ round_start = datetime.now()
456
+
457
+ logger.info(f"Round {round_num}/{self.rounds} starting")
458
+
459
+ # Client training phase
460
+ client_updates = await self._train_clients_in_round(round_num)
461
+
462
+ # Aggregation phase
463
+ global_update = self.aggregator.aggregate_updates(client_updates)
464
+
465
+ # Update all clients with global model
466
+ await self._update_clients_with_global_model(global_update)
467
+
468
+ # Record round statistics
469
+ round_time = (datetime.now() - round_start).total_seconds()
470
+ self.round_times.append(round_time)
471
+
472
+ round_stats = {
473
+ 'round': round_num,
474
+ 'num_clients': len(client_updates),
475
+ 'round_time': round_time,
476
+ 'aggregation_stats': global_update.get('aggregation_stats', {})
477
+ }
478
+ self.training_stats.append(round_stats)
479
+
480
+ logger.info(f"Round {round_num} completed in {round_time:.2f}s")
481
+
482
+ # Generate final report
483
+ final_report = self._generate_final_report()
484
+ return final_report
485
+
486
+ async def _train_clients_in_round(self, round_num: int) -> List[Tuple[Dict[str, Any], Any]]:
487
+ """Train all clients in parallel for one round"""
488
+ tasks = []
489
+
490
+ for client_id, client_info in self.clients.items():
491
+ task = asyncio.create_task(
492
+ self._train_single_client(client_id, client_info, round_num)
493
+ )
494
+ tasks.append(task)
495
+
496
+ # Wait for all clients to complete training
497
+ results = await asyncio.gather(*tasks)
498
+ return [r for r in results if r is not None] # Filter out failed trainings
499
+
500
+ async def _train_single_client(self, client_id: str, client_info: Dict[str, Any], round_num: int):
501
+ """Train a single client"""
502
+ try:
503
+ client = client_info['model']
504
+ dataloader = client_info['dataloader']
505
+
506
+ # Train locally
507
+ training_stats, model_update = client.train_local(dataloader, epochs=client.local_epochs)
508
+
509
+ logger.info(f"Client {client_id}: Training completed - Loss: {training_stats['avg_loss']:.4f}")
510
+
511
+ return training_stats, model_update
512
+
513
+ except Exception as e:
514
+ logger.error(f"Client {client_id} training failed: {e}")
515
+ return None
516
+
517
+ async def _update_clients_with_global_model(self, global_update: Dict[str, Any]):
518
+ """Update all clients with the new global model"""
519
+ tasks = []
520
+
521
+ for client_id, client_info in self.clients.items():
522
+ task = asyncio.create_task(
523
+ self._update_single_client(client_id, client_info['model'], global_update)
524
+ )
525
+ tasks.append(task)
526
+
527
+ await asyncio.gather(*tasks)
528
+
529
+ async def _update_single_client(self, client_id: str, client_model: LocalModel, global_update: Dict[str, Any]):
530
+ """Update a single client with global model"""
531
+ try:
532
+ client_model.update_model(global_update)
533
+ logger.info(f"Client {client_id}: Model updated with global parameters")
534
+ except Exception as e:
535
+ logger.error(f"Failed to update client {client_id}: {e}")
536
+
537
+ def _generate_final_report(self) -> Dict[str, Any]:
538
+ """Generate final training report"""
539
+ total_time = sum(self.round_times)
540
+ avg_round_time = total_time / len(self.round_times) if self.round_times else 0
541
+
542
+ return {
543
+ 'federated_training_completed': True,
544
+ 'total_rounds': len(self.training_stats),
545
+ 'total_training_time': total_time,
546
+ 'average_round_time': avg_round_time,
547
+ 'clients_participated': len(self.clients),
548
+ 'privacy_report': self.privacy_controller.generate_privacy_report(),
549
+ 'aggregation_stats': self.aggregator.get_aggregation_stats(),
550
+ 'round_stats': self.training_stats,
551
+ 'final_model_available': True
552
+ }
553
+
554
+ # Example usage and testing
555
+ async def demo_federated_learning():
556
+ """Demonstrate federated learning with mock data"""
557
+ print("SOFIA Federated Learning Demo")
558
+ print("=" * 40)
559
+
560
+ # Create mock client data
561
+ client_data = {
562
+ 'client_1': [
563
+ ("Hello world", "Hi there"),
564
+ ("How are you", "I'm fine"),
565
+ ("Machine learning", "AI models")
566
+ ] * 10, # Repeat for more data
567
+ 'client_2': [
568
+ ("Python programming", "Code development"),
569
+ ("Data science", "Analytics"),
570
+ ("Neural networks", "Deep learning")
571
+ ] * 10,
572
+ 'client_3': [
573
+ ("Natural language", "Text processing"),
574
+ ("Computer vision", "Image recognition"),
575
+ ("Reinforcement learning", "RL algorithms")
576
+ ] * 10
577
+ }
578
+
579
+ # Initialize coordinator
580
+ coordinator = FederatedLearningCoordinator(num_clients=3, rounds=3)
581
+
582
+ # Initialize clients
583
+ await coordinator.initialize_clients(client_data)
584
+
585
+ # Run federated training
586
+ print("Starting federated training...")
587
+ final_report = await coordinator.run_federated_training()
588
+
589
+ # Print results
590
+ print("\nFederated Training Results:")
591
+ print(f"Total rounds: {final_report['total_rounds']}")
592
+ print(".2f")
593
+ print(".2f")
594
+ print(f"Clients participated: {final_report['clients_participated']}")
595
+ print(f"Privacy level: {final_report['privacy_report']['privacy_level']}")
596
+
597
+ return final_report
598
+
599
+ if __name__ == "__main__":
600
+ # Run demo
601
+ asyncio.run(demo_federated_learning())