| """Training script for quantum syndrome decoder.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| import json | |
| from pathlib import Path | |
| import argparse | |
| from decoder_model import QuantumSyndromeDecoder, SynergyExtractor, RLPolicyNetwork | |
| class SyndromeDataset(Dataset): | |
| """Dataset for syndrome-to-error mapping.""" | |
| def __init__(self, syndromes, errors): | |
| self.syndromes = torch.from_numpy(syndromes).float() | |
| self.errors = torch.from_numpy(errors).float() | |
| def __len__(self): | |
| return len(self.syndromes) | |
| def __getitem__(self, idx): | |
| # Add time dimension (T=1 for single round) | |
| syndrome = self.syndromes[idx].permute(2, 0, 1).unsqueeze(-1) | |
| error = self.errors[idx].permute(2, 0, 1).unsqueeze(-1) | |
| return syndrome, error | |
| def train_supervised(model, train_loader, val_loader, epochs=50, lr=5e-4, | |
| device='cuda', save_path='decoder.pt'): | |
| """Train decoder with supervised learning.""" | |
| model = model.to(device) | |
| optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=3e-2) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) | |
| criterion = nn.BCELoss() | |
| best_val_loss = float('inf') | |
| for epoch in range(epochs): | |
| model.train() | |
| train_loss = 0 | |
| train_correct = 0 | |
| train_total = 0 | |
| for syndromes, errors in train_loader: | |
| syndromes = syndromes.to(device) | |
| errors = errors.to(device) | |
| optimizer.zero_grad() | |
| predictions = model(syndromes) | |
| # Compute loss for all 4 output channels | |
| loss = criterion(predictions, errors.repeat(1, 2, 1, 1, 1)) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| # Accuracy for binary classification | |
| pred_binary = (predictions > 0.5).float() | |
| train_correct += (pred_binary == errors.repeat(1, 2, 1, 1, 1)).sum().item() | |
| train_total += errors.numel() * 2 | |
| # Validation | |
| model.eval() | |
| val_loss = 0 | |
| val_correct = 0 | |
| val_total = 0 | |
| with torch.no_grad(): | |
| for syndromes, errors in val_loader: | |
| syndromes = syndromes.to(device) | |
| errors = errors.to(device) | |
| predictions = model(syndromes) | |
| loss = criterion(predictions, errors.repeat(1, 2, 1, 1, 1)) | |
| val_loss += loss.item() | |
| pred_binary = (predictions > 0.5).float() | |
| val_correct += (pred_binary == errors.repeat(1, 2, 1, 1, 1)).sum().item() | |
| val_total += errors.numel() * 2 | |
| train_loss /= len(train_loader) | |
| val_loss /= len(val_loader) | |
| train_acc = train_correct / train_total | |
| val_acc = val_correct / val_total | |
| print(f"Epoch {epoch+1}/{epochs}: " | |
| f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, " | |
| f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}") | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_loss': val_loss, | |
| }, save_path) | |
| print(f" -> Saved best model (val_loss={val_loss:.4f})") | |
| scheduler.step() | |
| return model | |
| def train_rl_policy(policy_net, env, episodes=1000, lr=3e-4, gamma=0.99, | |
| device='cuda', save_path='policy.pt'): | |
| """Train policy network with REINFORCE.""" | |
| policy_net = policy_net.to(device) | |
| optimizer = optim.Adam(policy_net.parameters(), lr=lr) | |
| for episode in range(episodes): | |
| # Generate episode | |
| state = env.reset() | |
| log_probs = [] | |
| values = [] | |
| rewards = [] | |
| done = False | |
| while not done: | |
| state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device) | |
| policy_logits, value = policy_net(state_tensor) | |
| # Sample action | |
| probs = torch.softmax(policy_logits.view(-1, 3), dim=-1) | |
| dist = torch.distributions.Categorical(probs) | |
| action = dist.sample() | |
| log_prob = dist.log_prob(action) | |
| # Execute action | |
| next_state, reward, done, _ = env.step(action.cpu().numpy()) | |
| log_probs.append(log_prob) | |
| values.append(value) | |
| rewards.append(reward) | |
| state = next_state | |
| # Compute returns | |
| returns = [] | |
| G = 0 | |
| for r in reversed(rewards): | |
| G = r + gamma * G | |
| returns.insert(0, G) | |
| returns = torch.tensor(returns, device=device) | |
| values = torch.cat(values) | |
| log_probs = torch.cat(log_probs) | |
| # Policy gradient loss | |
| advantage = returns - values.squeeze() | |
| policy_loss = -(log_probs * advantage.detach()).mean() | |
| value_loss = advantage.pow(2).mean() | |
| loss = policy_loss + 0.5 * value_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if episode % 100 == 0: | |
| print(f"Episode {episode}: reward={sum(rewards):.2f}, " | |
| f"policy_loss={policy_loss.item():.4f}") | |
| torch.save(policy_net.state_dict(), save_path) | |
| return policy_net | |
| def evaluate_decoder(model, test_loader, device='cuda'): | |
| """Evaluate decoder performance.""" | |
| model.eval() | |
| model = model.to(device) | |
| total_errors = 0 | |
| corrected_errors = 0 | |
| false_positives = 0 | |
| with torch.no_grad(): | |
| for syndromes, errors in test_loader: | |
| syndromes = syndromes.to(device) | |
| errors = errors.to(device) | |
| predictions = model(syndromes) | |
| pred_errors = (predictions > 0.5).float() | |
| true_errors = errors.repeat(1, 2, 1, 1, 1) | |
| total_errors += true_errors.sum().item() | |
| corrected_errors += (pred_errors * true_errors).sum().item() | |
| false_positives += (pred_errors * (1 - true_errors)).sum().item() | |
| precision = corrected_errors / (corrected_errors + false_positives) if (corrected_errors + false_positives) > 0 else 0 | |
| recall = corrected_errors / total_errors if total_errors > 0 else 0 | |
| print(f"Evaluation Results:") | |
| print(f" Total errors: {total_errors}") | |
| print(f" Corrected: {corrected_errors}") | |
| print(f" False positives: {false_positives}") | |
| print(f" Precision: {precision:.4f}") | |
| print(f" Recall: {recall:.4f}") | |
| return { | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'total_errors': total_errors, | |
| 'corrected': corrected_errors, | |
| 'false_positives': false_positives, | |
| } | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--mode', choices=['supervised', 'rl'], default='supervised') | |
| parser.add_argument('--distance', type=int, default=5) | |
| parser.add_argument('--epochs', type=int, default=50) | |
| parser.add_argument('--batch_size', type=int, default=64) | |
| parser.add_argument('--lr', type=float, default=5e-4) | |
| parser.add_argument('--channels', type=int, default=128) | |
| parser.add_argument('--layers', type=int, default=4) | |
| parser.add_argument('--data_dir', default='data') | |
| parser.add_argument('--output_dir', default='outputs') | |
| args = parser.parse_args() | |
| Path(args.output_dir).mkdir(exist_ok=True) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| # Load data | |
| syndromes = np.load(f"{args.data_dir}/syndromes.npy") | |
| errors = np.load(f"{args.data_dir}/errors.npy") | |
| # Split | |
| n = len(syndromes) | |
| n_train = int(0.8 * n) | |
| n_val = int(0.1 * n) | |
| train_dataset = SyndromeDataset(syndromes[:n_train], errors[:n_train]) | |
| val_dataset = SyndromeDataset(syndromes[n_train:n_train+n_val], errors[n_train:n_train+n_val]) | |
| test_dataset = SyndromeDataset(syndromes[n_train+n_val:], errors[n_train+n_val:]) | |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size) | |
| test_loader = DataLoader(test_dataset, batch_size=args.batch_size) | |
| if args.mode == 'supervised': | |
| model = QuantumSyndromeDecoder( | |
| distance=args.distance, | |
| channels=args.channels, | |
| num_layers=args.layers, | |
| ) | |
| print(f"Training supervised decoder: {sum(p.numel() for p in model.parameters())} parameters") | |
| model = train_supervised( | |
| model, train_loader, val_loader, | |
| epochs=args.epochs, lr=args.lr, device=device, | |
| save_path=f"{args.output_dir}/decoder_best.pt" | |
| ) | |
| # Final evaluation | |
| model.load_state_dict(torch.load(f"{args.output_dir}/decoder_best.pt")['model_state_dict']) | |
| results = evaluate_decoder(model, test_loader, device) | |
| with open(f"{args.output_dir}/results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| else: | |
| print("RL mode requires environment implementation") | |
Xet Storage Details
- Size:
- 9.6 kB
- Xet hash:
- 48c9fd7f7606f8d3f35c009e19f245221b8c4015ed22f0f7d0c4c517eab1d95a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.