32x_Quantum_NLP / src /cst /quantum /quantum_training_pipeline.py
melhelbawi's picture
Fix HF imports runtime error
fe9a6a4
# CST / QCST Dual License
# Non-commercial research use only.
# Commercial use requires explicit permission.
# Copyright (c) 2025 Mohamed Mohamed Elhelbawi
# All rights reserved.
# See LICENSE file in the project root for full license information.
"""
Quantum-Enhanced Training Pipeline for CST Models
Implements quantum-aware loss functions, optimizers, and training procedures
Fully standalone - no classical dependencies
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import json
import logging
import wandb
from tqdm import tqdm
from typing import Dict, List, Optional, Any, Tuple
import numpy as np
from collections import defaultdict
import time
from .quantum_cst_transformer import QuantumCSTransformer
from .quantum_cst_config import QuantumConfig, CSTConfig
from .quantum_transformer_utils import collate_quantum_batch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QuantumInfoNCELoss(nn.Module):
"""
Quantum-Enhanced InfoNCE Loss for Contrastive Learning
Incorporates quantum entanglement measures to better capture
feature correlations in the contrastive learning objective
"""
def __init__(self, temperature: float = 0.07, quantum_weight: float = 0.2):
super().__init__()
self.temperature = temperature
self.quantum_weight = quantum_weight
def compute_quantum_entanglement_penalty(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Compute penalty based on quantum entanglement principles
Encourages embeddings to have properties similar to quantum entangled states:
- High correlation between related features
- Low correlation between unrelated features
"""
# Compute correlation matrix
normalized_emb = F.normalize(embeddings, dim=-1)
correlation_matrix = torch.matmul(normalized_emb, normalized_emb.t())
# Von Neumann entropy as entanglement measure
# Higher entropy = more entanglement = better representation
eigenvalues = torch.linalg.eigvalsh(correlation_matrix)
eigenvalues = torch.clamp(eigenvalues, min=1e-10) # Numerical stability
# Normalized eigenvalues (probability distribution)
probs = eigenvalues / eigenvalues.sum()
# Entropy: -Σ p_i log(p_i)
entropy = -(probs * torch.log(probs + 1e-10)).sum()
# Maximize entropy (negative for minimization)
return -entropy
def forward(self, positive_pairs: torch.Tensor,
negative_pairs: torch.Tensor,
quantum_embeddings: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
positive_pairs: [batch_size, embedding_dim]
negative_pairs: [batch_size * num_negatives, embedding_dim]
quantum_embeddings: Optional quantum-processed embeddings
"""
batch_size = positive_pairs.size(0)
# Normalize embeddings
positive_pairs = F.normalize(positive_pairs, dim=1)
negative_pairs = F.normalize(negative_pairs, dim=1)
# Compute similarities
pos_sim = torch.sum(positive_pairs * positive_pairs, dim=1) / self.temperature
# Reshape negatives
num_negatives = negative_pairs.size(0) // batch_size
negative_pairs = negative_pairs.view(batch_size, num_negatives, -1)
neg_sims = torch.bmm(
positive_pairs.unsqueeze(1),
negative_pairs.transpose(1, 2)
).squeeze(1) / self.temperature
# Combine positive and negative similarities
all_sims = torch.cat([pos_sim.unsqueeze(1), neg_sims], dim=1)
# InfoNCE loss
labels = torch.zeros(batch_size, dtype=torch.long, device=positive_pairs.device)
contrastive_loss = F.cross_entropy(all_sims, labels)
# Add quantum entanglement penalty if quantum embeddings provided
if quantum_embeddings is not None:
entanglement_penalty = self.compute_quantum_entanglement_penalty(quantum_embeddings)
total_loss = contrastive_loss + self.quantum_weight * entanglement_penalty
return total_loss
return contrastive_loss
class QuantumFidelityLoss(nn.Module):
"""
Quantum Fidelity Loss - Measures how well quantum states are preserved
Based on quantum state fidelity: F(ρ, σ) = Tr(√(√ρ σ √ρ))
Encourages quantum circuits to maintain high fidelity
"""
def __init__(self, target_fidelity: float = 0.95):
super().__init__()
self.target_fidelity = target_fidelity
def compute_fidelity(self, state1: torch.Tensor, state2: torch.Tensor) -> torch.Tensor:
"""
Compute quantum fidelity between two states
Simplified version using inner product for pure states
"""
# Normalize states
state1_norm = F.normalize(state1, dim=-1)
state2_norm = F.normalize(state2, dim=-1)
# Fidelity for pure states: |⟨ψ|φ⟩|²
inner_product = torch.sum(state1_norm * state2_norm, dim=-1)
fidelity = inner_product ** 2
return fidelity
def forward(self, quantum_outputs: torch.Tensor,
reference_outputs: torch.Tensor) -> torch.Tensor:
"""
Loss that encourages quantum outputs to maintain high fidelity
with reference (e.g., classical) outputs
"""
fidelity = self.compute_fidelity(quantum_outputs, reference_outputs)
# Loss is how far we are from target fidelity
loss = F.mse_loss(fidelity, torch.full_like(fidelity, self.target_fidelity))
return loss
class QuantumRegularizer(nn.Module):
"""
Quantum-specific regularization terms
Includes:
1. Circuit depth penalty (encourage shallow circuits)
2. Entanglement regularization (control entanglement levels)
3. Quantum noise resilience (prepare for NISQ devices)
"""
def __init__(self, config: QuantumConfig):
super().__init__()
self.config = config
self.depth_penalty_weight = 0.01
self.entanglement_weight = config.quantum_entanglement_regularization
def compute_circuit_depth_penalty(self, quantum_params: torch.Tensor) -> torch.Tensor:
"""
Penalize large quantum parameter magnitudes
Larger params often indicate deeper effective circuits
"""
return torch.mean(quantum_params ** 2)
def compute_entanglement_regularization(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Regularize entanglement levels
Too much entanglement can make circuits hard to train
Too little loses quantum advantage
"""
# Compute pairwise correlations
batch_size, dim = embeddings.shape
if batch_size < 2:
return torch.tensor(0.0, device=embeddings.device)
# Center embeddings
centered = embeddings - embeddings.mean(dim=0, keepdim=True)
# Covariance matrix
cov = torch.matmul(centered.t(), centered) / (batch_size - 1)
# Off-diagonal elements represent entanglement
off_diagonal = cov - torch.diag(torch.diag(cov))
entanglement_measure = torch.sum(off_diagonal ** 2)
# Target moderate entanglement (not too high, not too low)
target_entanglement = 0.3 * dim # Heuristic target
loss = (entanglement_measure - target_entanglement) ** 2
return loss
def forward(self, quantum_params: Dict[str, torch.Tensor],
quantum_embeddings: torch.Tensor) -> torch.Tensor:
"""
Compute total quantum regularization loss
"""
total_loss = torch.tensor(0.0, device=quantum_embeddings.device)
# Circuit depth penalty
for param_name, params in quantum_params.items():
if 'quantum' in param_name.lower():
total_loss += self.depth_penalty_weight * self.compute_circuit_depth_penalty(params)
# Entanglement regularization
if self.entanglement_weight > 0:
total_loss += self.entanglement_weight * self.compute_entanglement_regularization(quantum_embeddings)
return total_loss
class QuantumAwareOptimizer:
"""
Quantum-Aware Optimizer wrapper
Implements quantum-specific optimization strategies:
1. Parameter-shift rule for quantum gradients
2. Different learning rates for quantum vs classical parameters
3. Quantum gradient clipping
"""
def __init__(self, model: nn.Module, config: CSTConfig, training_config: TrainingConfig):
self.model = model
self.config = config
self.training_config = training_config
# Separate parameter groups
quantum_params = []
classical_params = []
for name, param in model.named_parameters():
if 'quantum' in name.lower() or 'qnode' in name.lower():
quantum_params.append(param)
else:
classical_params.append(param)
# Different learning rates for quantum and classical
param_groups = [
{
'params': quantum_params,
'lr': config.quantum_config.quantum_lr,
'weight_decay': 0.0, # No weight decay for quantum params
'name': 'quantum'
},
{
'params': classical_params,
'lr': config.learning_rate,
'weight_decay': config.weight_decay,
'name': 'classical'
}
]
# Choose optimizer based on config
if config.quantum_config.quantum_optimizer == 'adam':
self.optimizer = AdamW(param_groups, betas=(0.9, 0.98), eps=1e-6)
elif config.quantum_config.quantum_optimizer == 'sgd':
self.optimizer = SGD(param_groups, momentum=0.9)
else:
self.optimizer = AdamW(param_groups)
self.use_quantum_gradient_clipping = training_config.use_quantum_gradient_clipping
self.quantum_clip_norm = training_config.quantum_gradient_clip_norm
def step(self):
"""Optimization step with quantum-aware gradient handling"""
# Separate gradient clipping for quantum and classical parameters
if self.use_quantum_gradient_clipping:
for group in self.optimizer.param_groups:
if group['name'] == 'quantum':
torch.nn.utils.clip_grad_norm_(
group['params'],
self.quantum_clip_norm
)
else:
torch.nn.utils.clip_grad_norm_(
group['params'],
self.config.gradient_clip_norm
)
self.optimizer.step()
def zero_grad(self):
self.optimizer.zero_grad()
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
class QuantumSpectralRegularizer:
"""Enhanced spectral regularizer with quantum considerations"""
def __init__(self, config: CSTConfig):
self.config = config
self.reference_embeddings = {}
self.quantum_reference_embeddings = {}
self.update_frequency = config.reference_update_freq
self.step_count = 0
self.momentum = config.reference_momentum
def compute_drift_loss(self, current_embeddings: torch.Tensor,
fragment_ids: torch.Tensor,
is_quantum: bool = False) -> torch.Tensor:
"""Compute drift with separate tracking for quantum embeddings"""
reference_dict = self.quantum_reference_embeddings if is_quantum else self.reference_embeddings
drift_loss = torch.tensor(0.0, device=current_embeddings.device, requires_grad=True)
for frag_id in fragment_ids.unique():
frag_id_item = frag_id.item()
if frag_id_item in reference_dict:
current_mask = fragment_ids == frag_id
current_repr = current_embeddings[current_mask].mean(0)
reference_repr = reference_dict[frag_id_item]
drift_loss = drift_loss + F.mse_loss(current_repr, reference_repr)
return drift_loss / len(fragment_ids.unique()) if len(fragment_ids.unique()) > 0 else drift_loss
def update_references(self, embeddings: torch.Tensor,
fragment_ids: torch.Tensor,
is_quantum: bool = False):
"""Update reference embeddings with EMA"""
self.step_count += 1
if self.step_count % self.update_frequency == 0:
reference_dict = self.quantum_reference_embeddings if is_quantum else self.reference_embeddings
for frag_id in fragment_ids.unique():
frag_id_item = frag_id.item()
current_mask = fragment_ids == frag_id
current_repr = embeddings[current_mask].mean(0).detach()
if frag_id_item in reference_dict:
reference_dict[frag_id_item] = (
self.momentum * reference_dict[frag_id_item] +
(1 - self.momentum) * current_repr
)
else:
reference_dict[frag_id_item] = current_repr
class QuantumCSTrainer:
"""
Main Quantum-Enhanced Trainer for CST Models
Integrates:
- Quantum-aware loss functions
- Quantum-specific optimizers
- Quantum performance monitoring
- Hybrid quantum-classical training strategies
"""
def __init__(self,
model: QuantumCSTransformer,
config: CSTConfig,
train_config: TrainingConfig):
self.model = model
self.config = config
self.train_config = train_config
# Quantum-aware losses
self.contrastive_loss = QuantumInfoNCELoss(
temperature=config.temperature,
quantum_weight=config.quantum_loss_weight
)
self.quantum_fidelity_loss = QuantumFidelityLoss(target_fidelity=0.95)
self.quantum_regularizer = QuantumRegularizer(config.quantum_config)
self.spectral_regularizer = QuantumSpectralRegularizer(config)
# Quantum-aware optimizer
self.optimizer = None # Initialized in setup_optimizer
self.scheduler = None
# Training state
self.global_step = 0
self.epoch = 0
self.best_val_loss = float('inf')
# Metrics tracking
self.train_metrics = defaultdict(list)
self.val_metrics = defaultdict(list)
self.quantum_metrics = defaultdict(list)
# Setup distributed training if needed
self.is_distributed = train_config.distributed
if self.is_distributed:
self.setup_distributed()
# Setup logging
if train_config.wandb_project and (not self.is_distributed or dist.get_rank() == 0):
wandb.init(
project=train_config.wandb_project,
config={**config.__dict__, **config.quantum_config.__dict__}
)
wandb.watch(model, log='all', log_freq=100)
logger.info("=" * 60)
logger.info("Quantum-Enhanced CST Trainer Initialized")
logger.info("=" * 60)
logger.info(f"Quantum Enabled: {config.quantum_config.enable_quantum}")
logger.info(f"Quantum Qubits: {config.quantum_config.n_qubits}")
logger.info(f"Quantum Layers: {config.quantum_config.n_layers}")
logger.info(f"State Space Dimension: {2 ** config.quantum_config.n_qubits}")
logger.info("=" * 60)
def setup_distributed(self):
"""Setup distributed training"""
dist.init_process_group(backend=self.train_config.backend)
torch.cuda.set_device(self.train_config.rank)
self.model = DDP(self.model, device_ids=[self.train_config.rank])
def setup_optimizer(self, train_loader: DataLoader):
"""Setup quantum-aware optimizer and scheduler"""
# Quantum-aware optimizer
self.optimizer = QuantumAwareOptimizer(self.model, self.config, self.train_config)
# Setup scheduler with quantum warmup
total_steps = self.train_config.max_epochs * len(train_loader)
# Quantum warmup (longer warmup for quantum parameters)
quantum_warmup = LinearLR(
self.optimizer.optimizer,
start_factor=0.01,
total_iters=self.train_config.quantum_warmup_steps
)
# Main scheduler
if self.train_config.quantum_scheduler == 'cosine':
main_scheduler = CosineAnnealingLR(
self.optimizer.optimizer,
T_max=total_steps - self.train_config.quantum_warmup_steps
)
else:
main_scheduler = LinearLR(
self.optimizer.optimizer,
start_factor=1.0,
total_iters=total_steps - self.train_config.quantum_warmup_steps
)
self.scheduler = SequentialLR(
self.optimizer.optimizer,
[quantum_warmup, main_scheduler],
milestones=[self.train_config.quantum_warmup_steps]
)
logger.info(f"Optimizer setup complete")
logger.info(f" Quantum warmup steps: {self.train_config.quantum_warmup_steps}")
logger.info(f" Total training steps: {total_steps}")
def quantum_contrastive_step(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Quantum-enhanced contrastive learning step
"""
input_ids = batch['input_ids']
context_data = batch['context_data']
# Get quantum-enhanced embeddings
positive_embeddings = self.model.get_embeddings(input_ids, context_data)
# Create negative contexts (shuffle metadata)
negative_context_data = context_data.copy()
batch_size = input_ids.size(0)
perm = torch.randperm(batch_size)
if 'metadata' in context_data:
negative_context_data['metadata'] = {
k: v[perm] if isinstance(v, torch.Tensor) else v
for k, v in context_data['metadata'].items()
}
# Get embeddings with negative context
negative_embeddings = self.model.get_embeddings(input_ids, negative_context_data)
# Pool sequence representations
positive_pooled = positive_embeddings.mean(dim=1)
negative_pooled = negative_embeddings.mean(dim=1)
# Quantum-enhanced contrastive loss
contrastive_loss = self.contrastive_loss(
positive_pooled,
negative_pooled,
quantum_embeddings=positive_pooled # Use for entanglement penalty
)
metrics = {
'contrastive_loss': contrastive_loss.item()
}
return contrastive_loss, metrics
def language_modeling_step(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Language modeling step with quantum awareness
"""
outputs = self.model(
input_ids=batch['input_ids'],
context_data=batch['context_data'],
attention_mask=batch['attention_mask'],
fragment_chars=batch['fragment_chars'],
context_chars=batch['context_chars'],
labels=batch['labels'],
return_quantum_metrics=True
)
mlm_loss = outputs['loss']
quantum_metrics = outputs.get('quantum_metrics', {})
# Spectral regularization (separate for quantum and classical)
embeddings = outputs['hidden_states']
input_ids_flat = batch['input_ids'].view(-1)
embeddings_flat = embeddings.view(-1, embeddings.size(-1))
# Determine which embeddings were quantum-processed
is_quantum_batch = quantum_metrics.get('quantum_ratio', 0) > 0.5
drift_loss = self.spectral_regularizer.compute_drift_loss(
embeddings_flat,
input_ids_flat,
is_quantum=is_quantum_batch
)
# Update reference embeddings
self.spectral_regularizer.update_references(
embeddings_flat,
input_ids_flat,
is_quantum=is_quantum_batch
)
# Quantum regularization
quantum_params = {
name: param for name, param in self.model.named_parameters()
if 'quantum' in name.lower()
}
quantum_reg_loss = self.quantum_regularizer(
quantum_params,
embeddings.mean(dim=1) # Pool for regularization
)
# Total loss
total_loss = (mlm_loss +
self.config.drift_regularization_weight * drift_loss +
self.config.quantum_loss_weight * quantum_reg_loss)
metrics = {
'mlm_loss': mlm_loss.item(),
'drift_loss': drift_loss.item(),
'quantum_reg_loss': quantum_reg_loss.item(),
**{f'quantum_{k}': v for k, v in quantum_metrics.items()
if isinstance(v, (int, float))}
}
return total_loss, metrics
def train_step(self, batch: Dict[str, Any]) -> Dict[str, float]:
"""
Complete training step with quantum-aware losses
"""
self.model.train()
# Contrastive learning (with quantum entanglement penalty)
if self.config.use_quantum_contrastive:
contrastive_loss, contrastive_metrics = self.quantum_contrastive_step(batch)
else:
contrastive_loss = torch.tensor(0.0, device=batch['input_ids'].device)
contrastive_metrics = {}
# Language modeling (with quantum regularization)
mlm_loss, mlm_metrics = self.language_modeling_step(batch)
# Combined loss
total_loss = (self.config.contrastive_weight * contrastive_loss +
self.config.mlm_weight * mlm_loss)
# Backward pass
self.optimizer.zero_grad()
total_loss.backward()
# Quantum-aware optimization step
self.optimizer.step()
if self.scheduler:
self.scheduler.step()
# Compile metrics
metrics = {
'total_loss': total_loss.item(),
'learning_rate': self.optimizer.optimizer.param_groups[0]['lr'],
'quantum_lr': self.optimizer.optimizer.param_groups[0]['lr'], # Quantum LR
**contrastive_metrics,
**mlm_metrics
}
return metrics
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
"""Validation loop with quantum metrics"""
self.model.eval()
val_losses = []
all_metrics = defaultdict(list)
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"):
# Move to device
batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Forward pass
outputs = self.model(
input_ids=batch['input_ids'],
context_data=batch['context_data'],
attention_mask=batch['attention_mask'],
fragment_chars=batch['fragment_chars'],
context_chars=batch['context_chars'],
labels=batch['labels'],
return_quantum_metrics=True
)
val_losses.append(outputs['loss'].item())
# Collect quantum metrics
if 'quantum_metrics' in outputs:
for key, value in outputs['quantum_metrics'].items():
if isinstance(value, (int, float)):
all_metrics[f'val_quantum_{key}'].append(value)
# Collect CST stats
if 'cst_stats' in outputs:
for key, value in outputs['cst_stats'].items():
if isinstance(value, (int, float)):
all_metrics[f'val_cst_{key}'].append(value)
# Average metrics
avg_metrics = {
'val_loss': np.mean(val_losses),
**{k: np.mean(v) for k, v in all_metrics.items()}
}
return avg_metrics
def save_checkpoint(self, filepath: str, is_best: bool = False):
"""Save checkpoint with quantum state"""
checkpoint = {
'epoch': self.epoch,
'global_step': self.global_step,
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'best_val_loss': self.best_val_loss,
'config': self.config.__dict__,
'quantum_config': self.config.quantum_config.__dict__,
'spectral_regularizer': {
'reference_embeddings': self.spectral_regularizer.reference_embeddings,
'quantum_reference_embeddings': self.spectral_regularizer.quantum_reference_embeddings
},
'quantum_metrics_history': self.model.quantum_metrics_history if hasattr(self.model, 'quantum_metrics_history') else []
}
torch.save(checkpoint, filepath)
if is_best:
best_path = filepath.replace('.pt', '_best.pt')
torch.save(checkpoint, best_path)
logger.info(f"Checkpoint saved: {filepath}")
def load_checkpoint(self, filepath: str):
"""Load checkpoint with quantum state"""
checkpoint = torch.load(filepath, map_location='cuda')
if self.is_distributed:
self.model.module.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if self.scheduler and checkpoint['scheduler_state_dict']:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.epoch = checkpoint['epoch']
self.global_step = checkpoint['global_step']
self.best_val_loss = checkpoint['best_val_loss']
# Restore quantum-specific state
if 'spectral_regularizer' in checkpoint:
self.spectral_regularizer.reference_embeddings = checkpoint['spectral_regularizer']['reference_embeddings']
self.spectral_regularizer.quantum_reference_embeddings = checkpoint['spectral_regularizer']['quantum_reference_embeddings']
if 'quantum_metrics_history' in checkpoint:
self.model.quantum_metrics_history = checkpoint['quantum_metrics_history']
logger.info(f"Checkpoint loaded: {filepath}")
logger.info(f" Epoch: {self.epoch}, Step: {self.global_step}")
def train(self, train_loader: DataLoader, val_loader: Optional[DataLoader] = None):
"""Main training loop with quantum enhancements"""
self.setup_optimizer(train_loader)
logger.info("=" * 60)
logger.info(f"Starting Quantum-Enhanced Training")
logger.info(f" Epochs: {self.train_config.max_epochs}")
logger.info(f" Total steps: {len(train_loader) * self.train_config.max_epochs}")
logger.info(f" Quantum warmup: {self.train_config.quantum_warmup_steps} steps")
logger.info("=" * 60)
for epoch in range(self.train_config.max_epochs):
self.epoch = epoch
# Training loop
self.model.train()
epoch_metrics = defaultdict(list)
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(progress_bar):
# Move to device
batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Training step
step_metrics = self.train_step(batch)
# Update metrics
for key, value in step_metrics.items():
epoch_metrics[key].append(value)
# Update progress bar
progress_bar.set_postfix({
'loss': f"{step_metrics['total_loss']:.4f}",
'lr': f"{step_metrics['learning_rate']:.2e}",
'q_ratio': f"{step_metrics.get('quantum_quantum_ratio', 0):.2%}"
})
self.global_step += 1
# Logging
if self.global_step % self.train_config.log_every_n_steps == 0:
avg_metrics = {k: np.mean(v[-self.train_config.log_every_n_steps:])
for k, v in epoch_metrics.items()}
if wandb.run:
wandb.log(avg_metrics, step=self.global_step)
logger.info(f"Step {self.global_step}: Loss={avg_metrics['total_loss']:.4f}, "
f"Q_Ratio={avg_metrics.get('quantum_quantum_ratio', 0):.2%}")
# Validation
if val_loader and self.global_step % self.train_config.eval_every_n_steps == 0:
val_metrics = self.validate(val_loader)
if wandb.run:
wandb.log(val_metrics, step=self.global_step)
logger.info(f"Validation at step {self.global_step}:")
logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}")
logger.info(f" Quantum Usage: {val_metrics.get('val_quantum_quantum_ratio', 0):.2%}")
# Save best model
if val_metrics['val_loss'] < self.best_val_loss:
self.best_val_loss = val_metrics['val_loss']
self.save_checkpoint(
f"{self.train_config.checkpoint_dir}/checkpoint_step_{self.global_step}.pt",
is_best=True
)
logger.info(f" ✅ New best model saved!")
# Checkpoint saving
if self.global_step % self.train_config.save_every_n_steps == 0:
self.save_checkpoint(
f"{self.train_config.checkpoint_dir}/checkpoint_step_{self.global_step}.pt"
)
# End of epoch validation
if val_loader:
val_metrics = self.validate(val_loader)
logger.info(f"End of epoch {epoch} validation:")
logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}")
if wandb.run:
wandb.log({f"epoch_{k}": v for k, v in val_metrics.items()},
step=self.global_step)
logger.info("=" * 60)
logger.info("Training completed!")
logger.info("=" * 60)
# Final quantum summary
quantum_summary = self.model.get_quantum_summary()
logger.info("Final Quantum Usage Summary:")
logger.info(f" Total quantum tokens: {quantum_summary['cst_stats'].get('quantum_processed_tokens', 0)}")
logger.info(f" Quantum ratio: {quantum_summary['cst_stats'].get('quantum_usage_ratio', 0):.2%}")
logger.info(f" Cache hit rate: {quantum_summary['cst_stats'].get('hit_rate', 0):.2%}")
# Helper functions for data loading
class QuantumAwareDataset(Dataset):
"""Simple quantum-aware dataset for training"""
def __init__(self, data: List[Dict], config):
self.data = data
self.config = config
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def create_quantum_aware_data_loader(data: List[Dict], config,
training_config,
split: str = 'train') -> DataLoader:
"""
Create data loader optimized for quantum training
Features:
- Balanced batching (mix ambiguous and non-ambiguous samples)
- Quantum-relevant data augmentation
- Efficient context data preparation
"""
dataset = QuantumAwareDataset(data, config)
# Quantum-aware sampler (optional: sample more ambiguous examples)
sampler = None
shuffle = (split == 'train')
loader = DataLoader(
dataset,
batch_size=training_config.batch_size if split == 'train' else training_config.batch_size * 2,
shuffle=shuffle,
sampler=sampler,
collate_fn=collate_quantum_batch,
num_workers=0,
pin_memory=True,
prefetch_factor=2
)
return loader
def main():
"""Main training script with quantum enhancements"""
# Load configurations
from config import CSTConfig, TrainingConfig, QuantumConfig, ConfigPresets
# Use high-accuracy quantum config
config = ConfigPresets.get_high_accuracy_config()
training_config = TrainingConfig()
logger.info("=" * 60)
logger.info("Quantum-Enhanced CST Training")
logger.info("=" * 60)
logger.info(f"Model Configuration:")
logger.info(f" d_model: {config.d_model}")
logger.info(f" num_layers: {config.num_layers}")
logger.info(f" vocab_size: {config.vocab_size}")
logger.info(f"\nQuantum Configuration:")
logger.info(f" Quantum enabled: {config.quantum_config.enable_quantum}")
logger.info(f" n_qubits: {config.quantum_config.n_qubits}")
logger.info(f" n_layers: {config.quantum_config.n_layers}")
logger.info(f" State space: {2 ** config.quantum_config.n_qubits}")
logger.info(f" Quantum weight: {config.quantum_config.quantum_weight}")
logger.info("=" * 60)
# Create model
model = QuantumCSTransformer(config, task_type='mlm')
model.cuda()
model.enable_cst_profiling(True)
# Log model info
total_params = sum(p.numel() for p in model.parameters())
quantum_params = sum(p.numel() for n, p in model.named_parameters()
if 'quantum' in n.lower())
logger.info(f"\nModel Statistics:")
logger.info(f" Total parameters: {total_params:,}")
logger.info(f" Quantum parameters: {quantum_params:,} ({quantum_params/total_params*100:.1f}%)")
logger.info(f" Classical parameters: {total_params - quantum_params:,}")
# Quantum advantage estimation
advantage = config.estimate_quantum_advantage()
logger.info(f"\nTheoretical Quantum Advantage:")
logger.info(f" Parameter efficiency: {advantage['parameter_efficiency']:.1f}x")
logger.info(f" State space dimension: {advantage['quantum_state_dimension']}")
logger.info("=" * 60)
# Setup data loaders
train_loader = create_quantum_aware_data_loader(
training_config.train_data_path,
config,
training_config,
split='train'
)
val_loader = create_quantum_aware_data_loader(
training_config.val_data_path,
config,
training_config,
split='val'
)
logger.info(f"\nData Loaders:")
logger.info(f" Train batches: {len(train_loader)}")
logger.info(f" Val batches: {len(val_loader)}")
logger.info("=" * 60)
# Setup trainer
trainer = QuantumCSTrainer(model, config, training_config)
# Start training
logger.info("\n🚀 Starting training...\n")
trainer.train(train_loader, val_loader)
# Save final model
model.save_pretrained(f"{training_config.checkpoint_dir}/final_model")
# Print final quantum summary
logger.info("\n" + "=" * 60)
logger.info("FINAL QUANTUM SUMMARY")
logger.info("=" * 60)
quantum_summary = model.get_quantum_summary()
logger.info(f"Quantum Processing Statistics:")
logger.info(f" Total tokens processed: {quantum_summary['cst_stats'].get('quantum_processed_tokens', 0) + quantum_summary['cst_stats'].get('classical_processed_tokens', 0)}")
logger.info(f" Quantum tokens: {quantum_summary['cst_stats'].get('quantum_processed_tokens', 0)}")
logger.info(f" Quantum usage ratio: {quantum_summary['cst_stats'].get('quantum_usage_ratio', 0):.2%}")
logger.info(f" Cache efficiency: {quantum_summary['cst_stats'].get('hit_rate', 0):.2%}")
logger.info(f" Quantum cache hit rate: {quantum_summary['cst_stats'].get('quantum_hit_rate', 0):.2%}")
if 'average_quantum_metrics' in quantum_summary:
logger.info(f"\nAverage Quantum Metrics:")
for key, value in quantum_summary['average_quantum_metrics'].items():
logger.info(f" {key}: {value:.4f}")
logger.info("=" * 60)
logger.info("✅ Training completed successfully!")
logger.info("=" * 60)
# Utility: Quantum training analysis
def analyze_quantum_training_run(checkpoint_dir: str):
"""
Analyze quantum usage throughout training
Loads checkpoints and analyzes:
- Quantum vs classical usage over time
- Quantum circuit parameter evolution
- Cache performance trends
"""
import glob
import matplotlib.pyplot as plt
checkpoints = sorted(glob.glob(f"{checkpoint_dir}/checkpoint_step_*.pt"))
quantum_ratios = []
cache_hit_rates = []
steps = []
for ckpt_path in checkpoints:
ckpt = torch.load(ckpt_path, map_location='cpu')
if 'quantum_metrics_history' in ckpt:
metrics = ckpt['quantum_metrics_history']
if metrics:
avg_ratio = np.mean([m.get('quantum_ratio', 0) for m in metrics])
quantum_ratios.append(avg_ratio)
step = int(ckpt_path.split('_')[-1].replace('.pt', ''))
steps.append(step)
# Plot quantum usage over training
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
axes[0].plot(steps, quantum_ratios, 'b-', linewidth=2)
axes[0].set_xlabel('Training Step')
axes[0].set_ylabel('Quantum Usage Ratio')
axes[0].set_title('Quantum Processing Over Training')
axes[0].grid(True, alpha=0.3)
if cache_hit_rates:
axes[1].plot(steps, cache_hit_rates, 'g-', linewidth=2)
axes[1].set_xlabel('Training Step')
axes[1].set_ylabel('Cache Hit Rate')
axes[1].set_title('Cache Performance Over Training')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/quantum_training_analysis.png", dpi=300)
plt.close()
logger.info(f"Training analysis saved to {checkpoint_dir}/quantum_training_analysis.png")
return {
'steps': steps,
'quantum_ratios': quantum_ratios,
'cache_hit_rates': cache_hit_rates,
'final_quantum_ratio': quantum_ratios[-1] if quantum_ratios else 0,
'avg_quantum_ratio': np.mean(quantum_ratios) if quantum_ratios else 0
}
if __name__ == "__main__":
# Check for required packages
try:
import pennylane
logger.info("✅ PennyLane installed")
except ImportError:
logger.error("❌ PennyLane not installed. Run: pip install pennylane")
exit(1)
# Run training
main()