|
|
|
|
|
""" |
|
|
SOFIA Federated Learning Framework |
|
|
Implements distributed training while preserving data privacy |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import pickle |
|
|
import logging |
|
|
import asyncio |
|
|
import threading |
|
|
from typing import Dict, List, Tuple, Optional, Any, Callable |
|
|
from datetime import datetime |
|
|
from collections import defaultdict |
|
|
import hashlib |
|
|
import secrets |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
TORCH_AVAILABLE = True |
|
|
except ImportError: |
|
|
TORCH_AVAILABLE = False |
|
|
print("Warning: PyTorch not available. Federated learning will use mock implementation.") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class FederatedDataset(Dataset): |
|
|
""" |
|
|
Dataset wrapper for federated learning |
|
|
""" |
|
|
|
|
|
def __init__(self, data: List[Tuple[str, str]], tokenizer=None): |
|
|
self.data = data |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
text1, text2 = self.data[idx] |
|
|
if self.tokenizer: |
|
|
|
|
|
return { |
|
|
'text1': text1, |
|
|
'text2': text2, |
|
|
'input_ids': torch.tensor([1, 2, 3]), |
|
|
'attention_mask': torch.tensor([1, 1, 1]) |
|
|
} |
|
|
return { |
|
|
'text1': text1, |
|
|
'text2': text2, |
|
|
'input_ids': torch.tensor([1, 2, 3]) if TORCH_AVAILABLE else [1, 2, 3], |
|
|
'attention_mask': torch.tensor([1, 1, 1]) if TORCH_AVAILABLE else [1, 1, 1] |
|
|
} |
|
|
|
|
|
class LocalModel: |
|
|
""" |
|
|
Represents a local model on a client device |
|
|
""" |
|
|
|
|
|
def __init__(self, client_id: str, model_config: Dict[str, Any]): |
|
|
self.client_id = client_id |
|
|
self.model_config = model_config |
|
|
self.model = None |
|
|
self.optimizer = None |
|
|
self.local_epochs = 1 |
|
|
self.batch_size = 32 |
|
|
self.learning_rate = 2e-5 |
|
|
|
|
|
|
|
|
self.noise_multiplier = 0.1 |
|
|
self.max_grad_norm = 1.0 |
|
|
|
|
|
|
|
|
self.current_round = 0 |
|
|
self.local_loss_history = [] |
|
|
self.samples_processed = 0 |
|
|
|
|
|
def initialize_model(self): |
|
|
"""Initialize the local model""" |
|
|
if not TORCH_AVAILABLE: |
|
|
logger.warning("PyTorch not available, using mock model") |
|
|
self.model = MockModel() |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
self.model = nn.Sequential( |
|
|
nn.Linear(768, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, 256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, 128) |
|
|
) |
|
|
|
|
|
self.optimizer = torch.optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=self.learning_rate, |
|
|
weight_decay=0.01 |
|
|
) |
|
|
|
|
|
def train_local(self, train_loader: DataLoader, epochs: int = 1) -> Dict[str, Any]: |
|
|
""" |
|
|
Train the local model on client's data |
|
|
|
|
|
Args: |
|
|
train_loader: DataLoader with client's training data |
|
|
epochs: Number of local training epochs |
|
|
|
|
|
Returns: |
|
|
Training statistics and model updates |
|
|
""" |
|
|
if not self.model: |
|
|
self.initialize_model() |
|
|
|
|
|
self.model.train() |
|
|
total_loss = 0 |
|
|
num_batches = 0 |
|
|
|
|
|
for epoch in range(epochs): |
|
|
epoch_loss = 0 |
|
|
for batch in train_loader: |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
if TORCH_AVAILABLE: |
|
|
|
|
|
batch_size = batch['input_ids'].shape[0] if hasattr(batch['input_ids'], 'shape') else len(batch['input_ids']) |
|
|
outputs = self.model(torch.randn(batch_size, 768)) |
|
|
loss = torch.nn.functional.mse_loss(outputs, torch.randn_like(outputs)) |
|
|
else: |
|
|
|
|
|
loss = 0.5 |
|
|
|
|
|
if TORCH_AVAILABLE and hasattr(loss, 'backward'): |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
self._apply_differential_privacy() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
epoch_loss += loss.item() if TORCH_AVAILABLE and hasattr(loss, 'item') else loss |
|
|
num_batches += 1 |
|
|
batch_size = batch['input_ids'].shape[0] if hasattr(batch['input_ids'], 'shape') else len(batch['input_ids']) |
|
|
self.samples_processed += batch_size |
|
|
|
|
|
avg_epoch_loss = epoch_loss / num_batches |
|
|
self.local_loss_history.append(avg_epoch_loss) |
|
|
total_loss += avg_epoch_loss |
|
|
|
|
|
|
|
|
model_update = self._generate_model_update() |
|
|
|
|
|
training_stats = { |
|
|
'client_id': self.client_id, |
|
|
'round': self.current_round, |
|
|
'epochs': epochs, |
|
|
'avg_loss': total_loss / epochs, |
|
|
'samples_processed': self.samples_processed, |
|
|
'model_update_size': len(pickle.dumps(model_update)) |
|
|
} |
|
|
|
|
|
self.current_round += 1 |
|
|
return training_stats, model_update |
|
|
|
|
|
def _apply_differential_privacy(self): |
|
|
"""Apply differential privacy noise to gradients""" |
|
|
if not TORCH_AVAILABLE: |
|
|
return |
|
|
|
|
|
for param in self.model.parameters(): |
|
|
if param.grad is not None: |
|
|
|
|
|
noise = torch.normal(0, self.noise_multiplier, param.grad.shape) |
|
|
param.grad.data += noise |
|
|
|
|
|
def _generate_model_update(self) -> Dict[str, Any]: |
|
|
"""Generate model update for aggregation""" |
|
|
if not TORCH_AVAILABLE: |
|
|
|
|
|
return { |
|
|
'client_id': self.client_id, |
|
|
'round': self.current_round, |
|
|
'weights': {'layer1': np.random.randn(512, 768)}, |
|
|
'gradients': {'layer1': np.random.randn(512, 768)} |
|
|
} |
|
|
|
|
|
|
|
|
update = { |
|
|
'client_id': self.client_id, |
|
|
'round': self.current_round, |
|
|
'weights': {}, |
|
|
'gradients': {} |
|
|
} |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
update['weights'][name] = param.data.clone() |
|
|
|
|
|
update['gradients'][name] = param.grad.clone() if param.grad is not None else torch.zeros_like(param) |
|
|
|
|
|
return update |
|
|
|
|
|
def update_model(self, global_update: Dict[str, Any]): |
|
|
"""Update local model with global aggregated update""" |
|
|
if not TORCH_AVAILABLE: |
|
|
logger.info(f"Client {self.client_id}: Mock model update applied") |
|
|
return |
|
|
|
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
if name in global_update.get('weights', {}): |
|
|
param.data = global_update['weights'][name] |
|
|
|
|
|
logger.info(f"Client {self.client_id}: Model updated with global parameters") |
|
|
|
|
|
class MockModel: |
|
|
"""Mock model for demonstration when PyTorch is not available""" |
|
|
|
|
|
def __init__(self): |
|
|
self.parameters = lambda: [] |
|
|
self.training = True |
|
|
|
|
|
def train(self): |
|
|
self.training = True |
|
|
|
|
|
def eval(self): |
|
|
self.training = False |
|
|
|
|
|
def __call__(self, x): |
|
|
return torch.tensor(0.5) if TORCH_AVAILABLE else 0.5 |
|
|
|
|
|
class FederatedAggregator: |
|
|
""" |
|
|
Aggregates model updates from multiple clients |
|
|
""" |
|
|
|
|
|
def __init__(self, aggregation_method: str = 'fedavg'): |
|
|
self.aggregation_method = aggregation_method |
|
|
self.global_model_state = {} |
|
|
self.client_updates = [] |
|
|
self.round_number = 0 |
|
|
|
|
|
|
|
|
self.aggregation_history = [] |
|
|
|
|
|
def aggregate_updates(self, client_updates: List[Tuple[Dict[str, Any], Any]]) -> Dict[str, Any]: |
|
|
""" |
|
|
Aggregate model updates from clients |
|
|
|
|
|
Args: |
|
|
client_updates: List of (training_stats, model_update) tuples |
|
|
|
|
|
Returns: |
|
|
Aggregated global model update |
|
|
""" |
|
|
if not client_updates: |
|
|
return {} |
|
|
|
|
|
self.round_number += 1 |
|
|
self.client_updates = client_updates |
|
|
|
|
|
if self.aggregation_method == 'fedavg': |
|
|
return self._fedavg_aggregation(client_updates) |
|
|
elif self.aggregation_method == 'fedprox': |
|
|
return self._fedprox_aggregation(client_updates) |
|
|
else: |
|
|
return self._fedavg_aggregation(client_updates) |
|
|
|
|
|
def _fedavg_aggregation(self, client_updates: List[Tuple[Dict[str, Any], Any]]) -> Dict[str, Any]: |
|
|
"""Federated Averaging aggregation""" |
|
|
if not client_updates: |
|
|
return {} |
|
|
|
|
|
|
|
|
model_updates = [update for _, update in client_updates] |
|
|
|
|
|
|
|
|
total_samples = sum(stats['samples_processed'] for stats, _ in client_updates) |
|
|
|
|
|
|
|
|
aggregated_weights = {} |
|
|
aggregated_gradients = {} |
|
|
|
|
|
|
|
|
if model_updates: |
|
|
param_names = set() |
|
|
for update in model_updates: |
|
|
if 'weights' in update: |
|
|
param_names.update(update['weights'].keys()) |
|
|
|
|
|
|
|
|
for param_name in param_names: |
|
|
weights = [] |
|
|
weight_contributions = [] |
|
|
|
|
|
for (stats, update) in client_updates: |
|
|
if param_name in update.get('weights', {}): |
|
|
weight = update['weights'][param_name] |
|
|
sample_weight = stats['samples_processed'] / total_samples |
|
|
|
|
|
if TORCH_AVAILABLE and isinstance(weight, torch.Tensor): |
|
|
weights.append(weight * sample_weight) |
|
|
else: |
|
|
weights.append(np.array(weight) * sample_weight) |
|
|
weight_contributions.append(sample_weight) |
|
|
|
|
|
|
|
|
if weights: |
|
|
if TORCH_AVAILABLE and isinstance(weights[0], torch.Tensor): |
|
|
aggregated_weights[param_name] = torch.stack(weights).sum(dim=0) |
|
|
else: |
|
|
aggregated_weights[param_name] = np.stack(weights).sum(axis=0) |
|
|
|
|
|
|
|
|
aggregation_stats = { |
|
|
'round': self.round_number, |
|
|
'num_clients': len(client_updates), |
|
|
'total_samples': total_samples, |
|
|
'aggregation_method': self.aggregation_method, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
self.aggregation_history.append(aggregation_stats) |
|
|
|
|
|
return { |
|
|
'weights': aggregated_weights, |
|
|
'gradients': aggregated_gradients, |
|
|
'aggregation_stats': aggregation_stats |
|
|
} |
|
|
|
|
|
def _fedprox_aggregation(self, client_updates: List[Tuple[Dict[str, Any], Any]]) -> Dict[str, Any]: |
|
|
"""FedProx aggregation (simplified)""" |
|
|
|
|
|
|
|
|
return self._fedavg_aggregation(client_updates) |
|
|
|
|
|
def get_aggregation_stats(self) -> Dict[str, Any]: |
|
|
"""Get aggregation statistics""" |
|
|
if not self.aggregation_history: |
|
|
return {'total_rounds': 0} |
|
|
|
|
|
recent_stats = self.aggregation_history[-1] |
|
|
return { |
|
|
'total_rounds': len(self.aggregation_history), |
|
|
'last_round_clients': recent_stats['num_clients'], |
|
|
'last_round_samples': recent_stats['total_samples'], |
|
|
'aggregation_method': self.aggregation_method |
|
|
} |
|
|
|
|
|
class PrivacyController: |
|
|
""" |
|
|
Manages privacy-preserving techniques in federated learning |
|
|
""" |
|
|
|
|
|
def __init__(self, privacy_budget: float = 1.0, delta: float = 1e-5): |
|
|
self.privacy_budget = privacy_budget |
|
|
self.delta = delta |
|
|
self.spent_budget = 0.0 |
|
|
|
|
|
|
|
|
self.differential_privacy_enabled = True |
|
|
self.secure_aggregation_enabled = False |
|
|
|
|
|
def check_privacy_budget(self, epsilon: float) -> bool: |
|
|
"""Check if privacy budget allows the operation""" |
|
|
if self.spent_budget + epsilon > self.privacy_budget: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def spend_privacy_budget(self, epsilon: float): |
|
|
"""Spend privacy budget""" |
|
|
self.spent_budget += epsilon |
|
|
|
|
|
def apply_secure_aggregation(self, client_updates: List[Any]) -> Any: |
|
|
"""Apply secure aggregation (simplified)""" |
|
|
if not self.secure_aggregation_enabled: |
|
|
return client_updates |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Secure aggregation applied (simplified)") |
|
|
return client_updates |
|
|
|
|
|
def generate_privacy_report(self) -> Dict[str, Any]: |
|
|
"""Generate privacy report""" |
|
|
return { |
|
|
'total_budget': self.privacy_budget, |
|
|
'spent_budget': self.spent_budget, |
|
|
'remaining_budget': self.privacy_budget - self.spent_budget, |
|
|
'differential_privacy_enabled': self.differential_privacy_enabled, |
|
|
'secure_aggregation_enabled': self.secure_aggregation_enabled, |
|
|
'privacy_level': 'high' if self.differential_privacy_enabled else 'medium' |
|
|
} |
|
|
|
|
|
class FederatedLearningCoordinator: |
|
|
""" |
|
|
Coordinates federated learning across multiple clients |
|
|
""" |
|
|
|
|
|
def __init__(self, num_clients: int = 3, rounds: int = 5): |
|
|
self.num_clients = num_clients |
|
|
self.rounds = rounds |
|
|
self.clients = {} |
|
|
self.aggregator = FederatedAggregator() |
|
|
self.privacy_controller = PrivacyController() |
|
|
|
|
|
|
|
|
self.client_updates_queue = asyncio.Queue() |
|
|
self.global_model_available = asyncio.Event() |
|
|
|
|
|
|
|
|
self.training_stats = [] |
|
|
self.round_times = [] |
|
|
|
|
|
async def initialize_clients(self, client_data: Dict[str, List[Tuple[str, str]]]): |
|
|
"""Initialize federated clients with their data""" |
|
|
for client_id, data in client_data.items(): |
|
|
model_config = { |
|
|
'model_type': 'sofia_embedding', |
|
|
'input_dim': 768, |
|
|
'hidden_dim': 512, |
|
|
'output_dim': 256 |
|
|
} |
|
|
|
|
|
client = LocalModel(client_id, model_config) |
|
|
client.initialize_model() |
|
|
|
|
|
|
|
|
dataset = FederatedDataset(data) |
|
|
dataloader = DataLoader(dataset, batch_size=client.batch_size, shuffle=True) |
|
|
|
|
|
self.clients[client_id] = { |
|
|
'model': client, |
|
|
'dataloader': dataloader, |
|
|
'data_size': len(data) |
|
|
} |
|
|
|
|
|
logger.info(f"Initialized {len(self.clients)} federated clients") |
|
|
|
|
|
async def run_federated_training(self) -> Dict[str, Any]: |
|
|
"""Run federated training for specified rounds""" |
|
|
logger.info(f"Starting federated training for {self.rounds} rounds") |
|
|
|
|
|
for round_num in range(1, self.rounds + 1): |
|
|
round_start = datetime.now() |
|
|
|
|
|
logger.info(f"Round {round_num}/{self.rounds} starting") |
|
|
|
|
|
|
|
|
client_updates = await self._train_clients_in_round(round_num) |
|
|
|
|
|
|
|
|
global_update = self.aggregator.aggregate_updates(client_updates) |
|
|
|
|
|
|
|
|
await self._update_clients_with_global_model(global_update) |
|
|
|
|
|
|
|
|
round_time = (datetime.now() - round_start).total_seconds() |
|
|
self.round_times.append(round_time) |
|
|
|
|
|
round_stats = { |
|
|
'round': round_num, |
|
|
'num_clients': len(client_updates), |
|
|
'round_time': round_time, |
|
|
'aggregation_stats': global_update.get('aggregation_stats', {}) |
|
|
} |
|
|
self.training_stats.append(round_stats) |
|
|
|
|
|
logger.info(f"Round {round_num} completed in {round_time:.2f}s") |
|
|
|
|
|
|
|
|
final_report = self._generate_final_report() |
|
|
return final_report |
|
|
|
|
|
async def _train_clients_in_round(self, round_num: int) -> List[Tuple[Dict[str, Any], Any]]: |
|
|
"""Train all clients in parallel for one round""" |
|
|
tasks = [] |
|
|
|
|
|
for client_id, client_info in self.clients.items(): |
|
|
task = asyncio.create_task( |
|
|
self._train_single_client(client_id, client_info, round_num) |
|
|
) |
|
|
tasks.append(task) |
|
|
|
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
|
return [r for r in results if r is not None] |
|
|
|
|
|
async def _train_single_client(self, client_id: str, client_info: Dict[str, Any], round_num: int): |
|
|
"""Train a single client""" |
|
|
try: |
|
|
client = client_info['model'] |
|
|
dataloader = client_info['dataloader'] |
|
|
|
|
|
|
|
|
training_stats, model_update = client.train_local(dataloader, epochs=client.local_epochs) |
|
|
|
|
|
logger.info(f"Client {client_id}: Training completed - Loss: {training_stats['avg_loss']:.4f}") |
|
|
|
|
|
return training_stats, model_update |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Client {client_id} training failed: {e}") |
|
|
return None |
|
|
|
|
|
async def _update_clients_with_global_model(self, global_update: Dict[str, Any]): |
|
|
"""Update all clients with the new global model""" |
|
|
tasks = [] |
|
|
|
|
|
for client_id, client_info in self.clients.items(): |
|
|
task = asyncio.create_task( |
|
|
self._update_single_client(client_id, client_info['model'], global_update) |
|
|
) |
|
|
tasks.append(task) |
|
|
|
|
|
await asyncio.gather(*tasks) |
|
|
|
|
|
async def _update_single_client(self, client_id: str, client_model: LocalModel, global_update: Dict[str, Any]): |
|
|
"""Update a single client with global model""" |
|
|
try: |
|
|
client_model.update_model(global_update) |
|
|
logger.info(f"Client {client_id}: Model updated with global parameters") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to update client {client_id}: {e}") |
|
|
|
|
|
def _generate_final_report(self) -> Dict[str, Any]: |
|
|
"""Generate final training report""" |
|
|
total_time = sum(self.round_times) |
|
|
avg_round_time = total_time / len(self.round_times) if self.round_times else 0 |
|
|
|
|
|
return { |
|
|
'federated_training_completed': True, |
|
|
'total_rounds': len(self.training_stats), |
|
|
'total_training_time': total_time, |
|
|
'average_round_time': avg_round_time, |
|
|
'clients_participated': len(self.clients), |
|
|
'privacy_report': self.privacy_controller.generate_privacy_report(), |
|
|
'aggregation_stats': self.aggregator.get_aggregation_stats(), |
|
|
'round_stats': self.training_stats, |
|
|
'final_model_available': True |
|
|
} |
|
|
|
|
|
|
|
|
async def demo_federated_learning(): |
|
|
"""Demonstrate federated learning with mock data""" |
|
|
print("SOFIA Federated Learning Demo") |
|
|
print("=" * 40) |
|
|
|
|
|
|
|
|
client_data = { |
|
|
'client_1': [ |
|
|
("Hello world", "Hi there"), |
|
|
("How are you", "I'm fine"), |
|
|
("Machine learning", "AI models") |
|
|
] * 10, |
|
|
'client_2': [ |
|
|
("Python programming", "Code development"), |
|
|
("Data science", "Analytics"), |
|
|
("Neural networks", "Deep learning") |
|
|
] * 10, |
|
|
'client_3': [ |
|
|
("Natural language", "Text processing"), |
|
|
("Computer vision", "Image recognition"), |
|
|
("Reinforcement learning", "RL algorithms") |
|
|
] * 10 |
|
|
} |
|
|
|
|
|
|
|
|
coordinator = FederatedLearningCoordinator(num_clients=3, rounds=3) |
|
|
|
|
|
|
|
|
await coordinator.initialize_clients(client_data) |
|
|
|
|
|
|
|
|
print("Starting federated training...") |
|
|
final_report = await coordinator.run_federated_training() |
|
|
|
|
|
|
|
|
print("\nFederated Training Results:") |
|
|
print(f"Total rounds: {final_report['total_rounds']}") |
|
|
print(".2f") |
|
|
print(".2f") |
|
|
print(f"Clients participated: {final_report['clients_participated']}") |
|
|
print(f"Privacy level: {final_report['privacy_report']['privacy_level']}") |
|
|
|
|
|
return final_report |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
asyncio.run(demo_federated_learning()) |
|
|
|