Samabe1109's picture
download
raw
9.6 kB
"""Training script for quantum syndrome decoder."""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
from pathlib import Path
import argparse
from decoder_model import QuantumSyndromeDecoder, SynergyExtractor, RLPolicyNetwork
class SyndromeDataset(Dataset):
"""Dataset for syndrome-to-error mapping."""
def __init__(self, syndromes, errors):
self.syndromes = torch.from_numpy(syndromes).float()
self.errors = torch.from_numpy(errors).float()
def __len__(self):
return len(self.syndromes)
def __getitem__(self, idx):
# Add time dimension (T=1 for single round)
syndrome = self.syndromes[idx].permute(2, 0, 1).unsqueeze(-1)
error = self.errors[idx].permute(2, 0, 1).unsqueeze(-1)
return syndrome, error
def train_supervised(model, train_loader, val_loader, epochs=50, lr=5e-4,
device='cuda', save_path='decoder.pt'):
"""Train decoder with supervised learning."""
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=3e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
criterion = nn.BCELoss()
best_val_loss = float('inf')
for epoch in range(epochs):
model.train()
train_loss = 0
train_correct = 0
train_total = 0
for syndromes, errors in train_loader:
syndromes = syndromes.to(device)
errors = errors.to(device)
optimizer.zero_grad()
predictions = model(syndromes)
# Compute loss for all 4 output channels
loss = criterion(predictions, errors.repeat(1, 2, 1, 1, 1))
loss.backward()
optimizer.step()
train_loss += loss.item()
# Accuracy for binary classification
pred_binary = (predictions > 0.5).float()
train_correct += (pred_binary == errors.repeat(1, 2, 1, 1, 1)).sum().item()
train_total += errors.numel() * 2
# Validation
model.eval()
val_loss = 0
val_correct = 0
val_total = 0
with torch.no_grad():
for syndromes, errors in val_loader:
syndromes = syndromes.to(device)
errors = errors.to(device)
predictions = model(syndromes)
loss = criterion(predictions, errors.repeat(1, 2, 1, 1, 1))
val_loss += loss.item()
pred_binary = (predictions > 0.5).float()
val_correct += (pred_binary == errors.repeat(1, 2, 1, 1, 1)).sum().item()
val_total += errors.numel() * 2
train_loss /= len(train_loader)
val_loss /= len(val_loader)
train_acc = train_correct / train_total
val_acc = val_correct / val_total
print(f"Epoch {epoch+1}/{epochs}: "
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
}, save_path)
print(f" -> Saved best model (val_loss={val_loss:.4f})")
scheduler.step()
return model
def train_rl_policy(policy_net, env, episodes=1000, lr=3e-4, gamma=0.99,
device='cuda', save_path='policy.pt'):
"""Train policy network with REINFORCE."""
policy_net = policy_net.to(device)
optimizer = optim.Adam(policy_net.parameters(), lr=lr)
for episode in range(episodes):
# Generate episode
state = env.reset()
log_probs = []
values = []
rewards = []
done = False
while not done:
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
policy_logits, value = policy_net(state_tensor)
# Sample action
probs = torch.softmax(policy_logits.view(-1, 3), dim=-1)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
log_prob = dist.log_prob(action)
# Execute action
next_state, reward, done, _ = env.step(action.cpu().numpy())
log_probs.append(log_prob)
values.append(value)
rewards.append(reward)
state = next_state
# Compute returns
returns = []
G = 0
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
returns = torch.tensor(returns, device=device)
values = torch.cat(values)
log_probs = torch.cat(log_probs)
# Policy gradient loss
advantage = returns - values.squeeze()
policy_loss = -(log_probs * advantage.detach()).mean()
value_loss = advantage.pow(2).mean()
loss = policy_loss + 0.5 * value_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if episode % 100 == 0:
print(f"Episode {episode}: reward={sum(rewards):.2f}, "
f"policy_loss={policy_loss.item():.4f}")
torch.save(policy_net.state_dict(), save_path)
return policy_net
def evaluate_decoder(model, test_loader, device='cuda'):
"""Evaluate decoder performance."""
model.eval()
model = model.to(device)
total_errors = 0
corrected_errors = 0
false_positives = 0
with torch.no_grad():
for syndromes, errors in test_loader:
syndromes = syndromes.to(device)
errors = errors.to(device)
predictions = model(syndromes)
pred_errors = (predictions > 0.5).float()
true_errors = errors.repeat(1, 2, 1, 1, 1)
total_errors += true_errors.sum().item()
corrected_errors += (pred_errors * true_errors).sum().item()
false_positives += (pred_errors * (1 - true_errors)).sum().item()
precision = corrected_errors / (corrected_errors + false_positives) if (corrected_errors + false_positives) > 0 else 0
recall = corrected_errors / total_errors if total_errors > 0 else 0
print(f"Evaluation Results:")
print(f" Total errors: {total_errors}")
print(f" Corrected: {corrected_errors}")
print(f" False positives: {false_positives}")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
return {
'precision': precision,
'recall': recall,
'total_errors': total_errors,
'corrected': corrected_errors,
'false_positives': false_positives,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['supervised', 'rl'], default='supervised')
parser.add_argument('--distance', type=int, default=5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=5e-4)
parser.add_argument('--channels', type=int, default=128)
parser.add_argument('--layers', type=int, default=4)
parser.add_argument('--data_dir', default='data')
parser.add_argument('--output_dir', default='outputs')
args = parser.parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load data
syndromes = np.load(f"{args.data_dir}/syndromes.npy")
errors = np.load(f"{args.data_dir}/errors.npy")
# Split
n = len(syndromes)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
train_dataset = SyndromeDataset(syndromes[:n_train], errors[:n_train])
val_dataset = SyndromeDataset(syndromes[n_train:n_train+n_val], errors[n_train:n_train+n_val])
test_dataset = SyndromeDataset(syndromes[n_train+n_val:], errors[n_train+n_val:])
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
if args.mode == 'supervised':
model = QuantumSyndromeDecoder(
distance=args.distance,
channels=args.channels,
num_layers=args.layers,
)
print(f"Training supervised decoder: {sum(p.numel() for p in model.parameters())} parameters")
model = train_supervised(
model, train_loader, val_loader,
epochs=args.epochs, lr=args.lr, device=device,
save_path=f"{args.output_dir}/decoder_best.pt"
)
# Final evaluation
model.load_state_dict(torch.load(f"{args.output_dir}/decoder_best.pt")['model_state_dict'])
results = evaluate_decoder(model, test_loader, device)
with open(f"{args.output_dir}/results.json", "w") as f:
json.dump(results, f, indent=2)
else:
print("RL mode requires environment implementation")

Xet Storage Details

Size:
9.6 kB
·
Xet hash:
48c9fd7f7606f8d3f35c009e19f245221b8c4015ed22f0f7d0c4c517eab1d95a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.