""" Parallel (Non-autoregressive) Inference for B2NL-IntelligentTokenizer Faster inference by generating all tokens at once """ import torch import torch.nn.functional as F import sys import time import io from pathlib import Path # Fix Windows Unicode if sys.platform == 'win32': sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') # Add paths sys.path.insert(0, 'core') from unified_model import IntelligentTokenizerV62 from tokenizer import ByteTokenizerV62 class ParallelTokenizer: """Fast parallel generation (non-autoregressive)""" def __init__(self, checkpoint_path: str = None): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if checkpoint_path is None: checkpoint_path = "D:/intelligent-tokenizer/intelligent-tokenizer_v6.2.1/checkpoints/v62/16.0/epoch_100.pt" # Load model self.model = IntelligentTokenizerV62() checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) self.model = self.model.to(self.device) self.model.eval() print(f"Model loaded on {self.device}") def parallel_generate(self, text: str) -> str: """ Parallel generation - generate all 48 tokens at once This uses teacher forcing with dummy inputs """ tokenizer = self.model.tokenizer # Encode input encoded = tokenizer.encode(text) if isinstance(encoded, dict): input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids'] attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask'] else: input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded attention_mask = torch.ones_like(input_ids) input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) # Encode with torch.no_grad(): encoder_outputs = self.model.encoder( input_ids=input_ids, attention_mask=attention_mask ) # Prepare all hidden states if 'all_hidden_states' in encoder_outputs: encoder_all_hidden = encoder_outputs['all_hidden_states'] else: compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states')) encoder_all_hidden = [compressed] * 4 # Create dummy decoder input (all MASK tokens) batch_size = input_ids.size(0) dummy_input = torch.full((batch_size, 48), tokenizer.MASK, device=self.device) dummy_input[:, 0] = tokenizer.BOS # Start with BOS # Single forward pass - generate all tokens at once decoder_outputs = self.model.decoder( encoder_all_hidden=encoder_all_hidden, decoder_input_ids=dummy_input, attention_mask=torch.ones_like(dummy_input), use_cache=False ) # Get predictions for all positions logits = decoder_outputs['logits'] # [batch, 48, vocab] # Take argmax to get predicted tokens predicted = torch.argmax(logits, dim=-1) # [batch, 48] # Decode to text if predicted.dim() > 1: text = tokenizer.decode(predicted[0]) else: text = tokenizer.decode(predicted) return text def iterative_refinement(self, text: str, iterations: int = 2) -> str: """ Iterative refinement - generate multiple times and refine Similar to BERT-style masked prediction """ tokenizer = self.model.tokenizer # Encode input encoded = tokenizer.encode(text) if isinstance(encoded, dict): input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids'] attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask'] else: input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded attention_mask = torch.ones_like(input_ids) input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) # Encode once with torch.no_grad(): encoder_outputs = self.model.encoder( input_ids=input_ids, attention_mask=attention_mask ) if 'all_hidden_states' in encoder_outputs: encoder_all_hidden = encoder_outputs['all_hidden_states'] else: compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states')) encoder_all_hidden = [compressed] * 4 batch_size = input_ids.size(0) # Start with all MASK tokens current = torch.full((batch_size, 48), tokenizer.MASK, device=self.device) current[:, 0] = tokenizer.BOS # Iteratively refine for iteration in range(iterations): # Forward pass with current tokens decoder_outputs = self.model.decoder( encoder_all_hidden=encoder_all_hidden, decoder_input_ids=current, attention_mask=torch.ones_like(current), use_cache=False ) logits = decoder_outputs['logits'] # Gradually unmask tokens (confidence-based) probs = F.softmax(logits, dim=-1) confidence = torch.max(probs, dim=-1)[0] # [batch, 48] # Update tokens with high confidence threshold = 0.7 - (iteration * 0.1) # Lower threshold over iterations high_conf_mask = confidence > threshold new_tokens = torch.argmax(logits, dim=-1) current = torch.where(high_conf_mask, new_tokens, current) # Add some randomness to break loops if iteration < iterations - 1: # Randomly mask 10% of tokens for next iteration rand_mask = torch.rand_like(confidence) < 0.1 current = torch.where(rand_mask, tokenizer.MASK, current) # Final decode if current.dim() > 1: text = tokenizer.decode(current[0]) else: text = tokenizer.decode(current) return text def test_parallel_vs_autoregressive(): """Compare parallel vs autoregressive generation""" print("="*60) print("Parallel vs Autoregressive Comparison") print("="*60) # Load both models parallel = ParallelTokenizer() # Also test with the original autoregressive from inference import B2NLTokenizer autoregressive = B2NLTokenizer() test_texts = [ "Hello, world!", "The quick brown fox", "안녕하세요, 반갑습니다.", "Testing 123", ] print("\n1. PARALLEL GENERATION (Single Pass)") print("-"*40) for text in test_texts: start = time.time() result = parallel.parallel_generate(text) elapsed = (time.time() - start) * 1000 accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100 print(f"Input: {text}") print(f"Output: {result}") print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n") print("\n2. ITERATIVE REFINEMENT (2 iterations)") print("-"*40) for text in test_texts: start = time.time() result = parallel.iterative_refinement(text, iterations=2) elapsed = (time.time() - start) * 1000 accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100 print(f"Input: {text}") print(f"Output: {result}") print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n") print("\n3. AUTOREGRESSIVE (Original - 48 steps)") print("-"*40) for text in test_texts: start = time.time() result = autoregressive.reconstruct(text, temperature=0.1) elapsed = (time.time() - start) * 1000 accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100 print(f"Input: {text}") print(f"Output: {result}") print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n") if __name__ == "__main__": test_parallel_vs_autoregressive()