32x_Quantum_NLP / src /cst /quantum /run_quantum_cst.py
melhelbawi's picture
Fix HF imports runtime error
fe9a6a4
# CST / QCST Dual License
# Non-commercial research use only.
# Commercial use requires explicit permission.
# Copyright (c) 2025 Mohamed Mohamed Elhelbawi
# All rights reserved.
# See LICENSE file in the project root for full license information.
"""
Quantum-Enhanced CST - Unified Runner Script
Supports modes: demo, train, benchmark
"""
import argparse
import logging
import os
import sys
import torch
import time
import json
from pathlib import Path
from tabulate import tabulate
import numpy as np
# Setup paths
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
quantum_dir = os.path.join(project_root, 'src', 'cst', 'quantum')
classical_dir = os.path.join(project_root, 'src', 'cst', 'classical')
sys.path.insert(0, quantum_dir)
sys.path.insert(0, classical_dir)
try:
from .quantum_cst_config import CSTConfig, QuantumConfig, ConfigPresets, TrainingConfig
from .quantum_cst_transformer import QuantumCSTransformer
except ImportError:
from quantum_cst_config import CSTConfig, QuantumConfig, ConfigPresets, TrainingConfig
from quantum_cst_transformer import QuantumCSTransformer
# Try importing training utilities from classical module
try:
from training_pipeline import CSTDataset, collate_fn
except ImportError:
# Fallback if training_pipeline not available
logger.warning("training_pipeline not found, using minimal dataset")
CSTDataset = None
collate_fn = None
from torch.utils.data import DataLoader
# Setup logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
def run_demo():
"""Run demonstration of ambiguity resolution"""
print("\n" + "="*60)
print("🚀 Quantum CST Demo: Ambiguity Resolution")
print("="*60 + "\n")
# 1. Setup Model
print("Initializing Quantum-Enhanced Model (Research Preset)...")
config = ConfigPresets.get_research_config()
# Ensure vocab size is large enough for demo
config.vocab_size = 5000
model = QuantumCSTransformer(config, task_type='mlm')
model.eval()
# 2. Demo Sentences
sentences = [
("The bank denied the loan application.", "bank", "Financial Institution"),
("I sat by the river bank and watched the water.", "bank", "River Edge"),
("The apple was red and delicious.", "apple", "Fruit"),
("Apple released a new iPhone today.", "apple", "Tech Company")
]
# Mock tokenizer (simple mapping for demo purposes)
def simple_tokenize(text):
# Deterministic hashing to get IDs in range
return torch.tensor([hash(w) % 4000 + 100 for w in text.lower().split()])
print(f"Analyzing {len(sentences)} sentences for Contextual Embeddings...\n")
results = []
embeddings_store = {}
for text, target_word, context_desc in sentences:
tokens = simple_tokenize(text)
input_ids = tokens.unsqueeze(0) # Batch size 1
# Determine target index
words = text.lower().split()
try:
target_idx = words.index(target_word.lower())
except ValueError:
continue
# Create dummy context
context_data = {
'document_embedding': torch.randn(1, config.raw_doc_dim),
'metadata': {
'author': torch.zeros(1, dtype=torch.long),
'domain': torch.zeros(1, dtype=torch.long)
}
}
# Run Inference
with torch.no_grad():
outputs = model(input_ids, context_data, return_quantum_metrics=True)
# Extract embedding of target word
# shape: [batch, seq, dim] -> [dim]
target_embedding = outputs['hidden_states'][0, target_idx, :]
embeddings_store[(target_word, context_desc)] = target_embedding
# Get metrics
q_metrics = outputs.get('quantum_metrics', {})
results.append([
text,
target_word,
context_desc,
f"{q_metrics.get('quantum_ratio', 0):.1%}"
])
# 3. Display Results
print(tabulate(results, headers=["Sentence", "Target", "Context", "Quantum Usage"], tablefmt="grid"))
# 4. Calculate Similarity
print("\n\n📊 Semantic Analysis (Cosine Similarity):")
# Bank Comparison
e1 = embeddings_store[("bank", "Financial Institution")]
e2 = embeddings_store[("bank", "River Edge")]
sim_bank = torch.nn.functional.cosine_similarity(e1.unsqueeze(0), e2.unsqueeze(0)).item()
# Apple Comparison
e3 = embeddings_store[("apple", "Fruit")]
e4 = embeddings_store[("apple", "Tech Company")]
sim_apple = torch.nn.functional.cosine_similarity(e3.unsqueeze(0), e4.unsqueeze(0)).item()
# Control (Same context)
e1_clone = e1.clone()
sim_control = torch.nn.functional.cosine_similarity(e1.unsqueeze(0), e1_clone.unsqueeze(0)).item()
sim_table = [
["Bank (Financial) vs Bank (River)", f"{sim_bank:.4f}", "Low (Distinct Meanings)"],
["Apple (Fruit) vs Apple (Tech)", f"{sim_apple:.4f}", "Low (Distinct Meanings)"],
["Bank (Financial) vs Bank (Financial)", f"{sim_control:.4f}", "High (Same Meaning)"]
]
print(tabulate(sim_table, headers=["Comparison", "Similarity", "Expected"], tablefmt="simple"))
print("\n✅ Demo Completed Successfully!")
def run_train(epochs=2, batch_size=4):
"""Run training loop verification"""
print("\n" + "="*60)
print(f"🚂 Quantum CST Training Verification (Epochs: {epochs})")
print("="*60 + "\n")
# Config
config = ConfigPresets.get_fast_training_config()
train_config = TrainingConfig()
train_config.max_epochs = epochs
train_config.checkpoint_dir = "checkpoints_test"
os.makedirs(train_config.checkpoint_dir, exist_ok=True)
# Model
print("Initializing Model...")
model = QuantumCSTransformer(config, task_type='mlm')
if torch.cuda.is_available():
model.cuda()
# Data (Synthetic)
print("Generating Synthetic Dataset...")
train_dataset = CSTDataset("dummy_path", config, split='train')
# Hack to limit synthetic data size for quick test
train_dataset.data = train_dataset.data[:20]
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
# Trainer
print("Starting Training Loop...")
trainer = CSTTrainer(model, config, train_config)
trainer.train(train_loader)
print(f"\n✅ Training Completed! Checkpoints saved to {train_config.checkpoint_dir}")
def run_benchmark():
"""Compare Quantum vs Classical performance"""
print("\n" + "="*60)
print("⚡ Quantum vs Classical Benchmark")
print("="*60 + "\n")
config = ConfigPresets.get_production_config()
input_ids = torch.randint(0, config.vocab_size, (16, 64)) # Batch 16, Seq 64
context_data = {
'document_embedding': torch.randn(16, config.raw_doc_dim),
'metadata': {
'author': torch.zeros(16, dtype=torch.long),
'domain': torch.zeros(16, dtype=torch.long)
}
}
if torch.cuda.is_available():
input_ids = input_ids.cuda()
context_data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in context_data.items()}
context_data['metadata'] = {k: v.cuda() for k, v in context_data['metadata'].items()}
results = []
# 1. Classical Mode
print("Benchmarking Classical Mode...")
config.quantum_config.enable_quantum = False
model_c = QuantumCSTransformer(config)
if torch.cuda.is_available(): model_c.cuda()
model_c.eval()
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = model_c(input_ids, context_data)
end = time.time()
avg_c = (end - start) / 50
results.append(["Classical", f"{avg_c*1000:.2f} ms", "1.0x", "Baseline"])
# 2. Quantum Mode
print("Benchmarking Quantum Mode...")
config.quantum_config.enable_quantum = True
model_q = QuantumCSTransformer(config)
if torch.cuda.is_available(): model_q.cuda()
model_q.eval()
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = model_q(input_ids, context_data)
end = time.time()
avg_q = (end - start) / 50
# Calculate overhead
overhead = avg_q / avg_c
results.append(["Quantum (Simulated)", f"{avg_q*1000:.2f} ms", f"{overhead:.1f}x", "Simulation Overhead"])
print("\n" + tabulate(results, headers=["Mode", "Avg Latency (Batch 16)", "Rel. Time", "Note"], tablefmt="pretty"))
print("\n*Note: Quantum latency includes simulation overhead. On real QPU, latency depends on circuit execution time not CPU simulation.")
def main():
parser = argparse.ArgumentParser(description="Quantum CST Runner")
parser.add_argument("--mode", choices=['demo', 'train', 'benchmark'], required=True, help="Execution mode")
parser.add_argument("--epochs", type=int, default=2, help="Number of training epochs")
parser.add_argument("--batch", type=int, default=4, help="Batch size")
args = parser.parse_args()
if args.mode == 'demo':
run_demo()
elif args.mode == 'train':
run_train(args.epochs, args.batch)
elif args.mode == 'benchmark':
run_benchmark()
if __name__ == "__main__":
main()