| """Neural decoder for quantum syndrome extraction.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ResidualBlock3D(nn.Module): | |
| """3D Residual block for syndrome processing.""" | |
| def __init__(self, channels, kernel_size=3): | |
| super().__init__() | |
| padding = kernel_size // 2 | |
| self.conv1 = nn.Conv3d(channels, channels, kernel_size, padding=padding) | |
| self.norm1 = nn.GroupNorm(8, channels) | |
| self.conv2 = nn.Conv3d(channels, channels, kernel_size, padding=padding) | |
| self.norm2 = nn.GroupNorm(8, channels) | |
| self.activation = nn.GELU() | |
| def forward(self, x): | |
| residual = x | |
| out = self.activation(self.norm1(self.conv1(x))) | |
| out = self.norm2(self.conv2(out)) | |
| return self.activation(out + residual) | |
| class QuantumSyndromeDecoder(nn.Module): | |
| """ | |
| 3D CNN decoder for surface code syndrome extraction. | |
| Input: (batch, 2, D, D, T) syndrome tensor | |
| - Channel 0: X syndrome events | |
| - Channel 1: Z syndrome events | |
| Output: (batch, 4, D, D, T) error predictions | |
| - Channel 0-1: X error spatial/temporal predictions | |
| - Channel 2-3: Z error spatial/temporal predictions | |
| """ | |
| def __init__(self, distance=5, time_steps=3, channels=128, num_layers=4, | |
| use_residual=True): | |
| super().__init__() | |
| self.distance = distance | |
| self.time_steps = time_steps | |
| # Initial projection | |
| self.input_conv = nn.Conv3d(2, channels, kernel_size=3, padding=1) | |
| self.input_norm = nn.GroupNorm(8, channels) | |
| # Main processing blocks | |
| self.blocks = nn.ModuleList() | |
| for _ in range(num_layers): | |
| if use_residual: | |
| self.blocks.append(ResidualBlock3D(channels)) | |
| else: | |
| self.blocks.append(nn.Sequential( | |
| nn.Conv3d(channels, channels, 3, padding=1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| )) | |
| # Output layers | |
| self.output_conv = nn.Conv3d(channels, 4, kernel_size=1) | |
| self.activation = nn.GELU() | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (batch, 2, D, D, T) syndrome tensor | |
| Returns: | |
| (batch, 4, D, D, T) error probability logits | |
| """ | |
| x = self.activation(self.input_norm(self.input_conv(x))) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.output_conv(x) | |
| return torch.sigmoid(x) | |
| def predict_errors(self, syndrome): | |
| """Predict error locations from syndrome.""" | |
| with torch.no_grad(): | |
| probs = self.forward(syndrome) | |
| # Threshold at 0.5 | |
| return (probs > 0.5).float() | |
| class RLPolicyNetwork(nn.Module): | |
| """Policy network for RL-based syndrome decoding.""" | |
| def __init__(self, distance=5, time_steps=3, channels=128): | |
| super().__init__() | |
| self.distance = distance | |
| self.time_steps = time_steps | |
| # Shared encoder | |
| self.encoder = QuantumSyndromeDecoder( | |
| distance, time_steps, channels, num_layers=4 | |
| ) | |
| # Policy head: action probabilities for each qubit | |
| self.policy_head = nn.Sequential( | |
| nn.Conv3d(4, channels, 3, padding=1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| nn.Conv3d(channels, 3, 1), # 3 actions: no-op, X, Z | |
| ) | |
| # Value head: state value estimation | |
| self.value_head = nn.Sequential( | |
| nn.Conv3d(4, channels, 3, padding=1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| nn.AdaptiveAvgPool3d(1), | |
| nn.Flatten(), | |
| nn.Linear(channels, 1), | |
| ) | |
| def forward(self, syndrome): | |
| features = self.encoder(syndrome) | |
| policy_logits = self.policy_head(features) | |
| value = self.value_head(features) | |
| return policy_logits, value | |
| def get_action(self, syndrome, deterministic=False): | |
| """Sample action from policy.""" | |
| policy_logits, value = self.forward(syndrome) | |
| if deterministic: | |
| action = policy_logits.argmax(dim=1) | |
| else: | |
| probs = F.softmax(policy_logits, dim=1) | |
| action = torch.multinomial( | |
| probs.view(-1, 3), 1 | |
| ).view(probs.shape[0], self.distance, self.distance, self.time_steps) | |
| return action, value | |
| class SynergyExtractor(nn.Module): | |
| """ | |
| Noise-aware syndrome extractor that learns to correct | |
| measurement errors in syndrome extraction. | |
| """ | |
| def __init__(self, distance=5, time_steps=3, channels=64): | |
| super().__init__() | |
| self.distance = distance | |
| # Temporal convolution for syndrome history | |
| self.temporal_conv = nn.Sequential( | |
| nn.Conv3d(2, channels, (1, 1, 3), padding=(0, 0, 1)), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| ) | |
| # Spatial processing | |
| self.spatial_conv = nn.Sequential( | |
| nn.Conv3d(channels, channels, 3, padding=1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| nn.Conv3d(channels, channels, 3, padding=1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| ) | |
| # Denoising output | |
| self.denoise = nn.Conv3d(channels, 2, 1) | |
| # Confidence estimation | |
| self.confidence = nn.Sequential( | |
| nn.Conv3d(channels, 32, 1), | |
| nn.GELU(), | |
| nn.Conv3d(32, 1, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, syndrome_history): | |
| """ | |
| Args: | |
| syndrome_history: (batch, 2, D, D, T) | |
| Returns: | |
| denoised_syndrome: (batch, 2, D, D, T) | |
| confidence: (batch, 1, D, D, T) | |
| """ | |
| x = self.temporal_conv(syndrome_history) | |
| x = self.spatial_conv(x) | |
| denoised = torch.sigmoid(self.denoise(x)) | |
| conf = self.confidence(x) | |
| return denoised, conf | |
Xet Storage Details
- Size:
- 6.14 kB
- Xet hash:
- da80bdbbeb5d42e828510b88404c3a6e802614f35078519e0114a3ce5469e3ff
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.