B2NL-IntelligentTokenizer-v6.2.1 / parallel_inference.py
ggunio's picture
Initial upload: B2NL-IntelligentTokenizer v6.2.1 (Autoregressive Mode)
ffbd655 verified
"""
Parallel (Non-autoregressive) Inference for B2NL-IntelligentTokenizer
Faster inference by generating all tokens at once
"""
import torch
import torch.nn.functional as F
import sys
import time
import io
from pathlib import Path
# Fix Windows Unicode
if sys.platform == 'win32':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
# Add paths
sys.path.insert(0, 'core')
from unified_model import IntelligentTokenizerV62
from tokenizer import ByteTokenizerV62
class ParallelTokenizer:
"""Fast parallel generation (non-autoregressive)"""
def __init__(self, checkpoint_path: str = None):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if checkpoint_path is None:
checkpoint_path = "D:/intelligent-tokenizer/intelligent-tokenizer_v6.2.1/checkpoints/v62/16.0/epoch_100.pt"
# Load model
self.model = IntelligentTokenizerV62()
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
print(f"Model loaded on {self.device}")
def parallel_generate(self, text: str) -> str:
"""
Parallel generation - generate all 48 tokens at once
This uses teacher forcing with dummy inputs
"""
tokenizer = self.model.tokenizer
# Encode input
encoded = tokenizer.encode(text)
if isinstance(encoded, dict):
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
else:
input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded
attention_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
# Encode
with torch.no_grad():
encoder_outputs = self.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# Prepare all hidden states
if 'all_hidden_states' in encoder_outputs:
encoder_all_hidden = encoder_outputs['all_hidden_states']
else:
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
encoder_all_hidden = [compressed] * 4
# Create dummy decoder input (all MASK tokens)
batch_size = input_ids.size(0)
dummy_input = torch.full((batch_size, 48), tokenizer.MASK, device=self.device)
dummy_input[:, 0] = tokenizer.BOS # Start with BOS
# Single forward pass - generate all tokens at once
decoder_outputs = self.model.decoder(
encoder_all_hidden=encoder_all_hidden,
decoder_input_ids=dummy_input,
attention_mask=torch.ones_like(dummy_input),
use_cache=False
)
# Get predictions for all positions
logits = decoder_outputs['logits'] # [batch, 48, vocab]
# Take argmax to get predicted tokens
predicted = torch.argmax(logits, dim=-1) # [batch, 48]
# Decode to text
if predicted.dim() > 1:
text = tokenizer.decode(predicted[0])
else:
text = tokenizer.decode(predicted)
return text
def iterative_refinement(self, text: str, iterations: int = 2) -> str:
"""
Iterative refinement - generate multiple times and refine
Similar to BERT-style masked prediction
"""
tokenizer = self.model.tokenizer
# Encode input
encoded = tokenizer.encode(text)
if isinstance(encoded, dict):
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
else:
input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded
attention_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
# Encode once
with torch.no_grad():
encoder_outputs = self.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
if 'all_hidden_states' in encoder_outputs:
encoder_all_hidden = encoder_outputs['all_hidden_states']
else:
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
encoder_all_hidden = [compressed] * 4
batch_size = input_ids.size(0)
# Start with all MASK tokens
current = torch.full((batch_size, 48), tokenizer.MASK, device=self.device)
current[:, 0] = tokenizer.BOS
# Iteratively refine
for iteration in range(iterations):
# Forward pass with current tokens
decoder_outputs = self.model.decoder(
encoder_all_hidden=encoder_all_hidden,
decoder_input_ids=current,
attention_mask=torch.ones_like(current),
use_cache=False
)
logits = decoder_outputs['logits']
# Gradually unmask tokens (confidence-based)
probs = F.softmax(logits, dim=-1)
confidence = torch.max(probs, dim=-1)[0] # [batch, 48]
# Update tokens with high confidence
threshold = 0.7 - (iteration * 0.1) # Lower threshold over iterations
high_conf_mask = confidence > threshold
new_tokens = torch.argmax(logits, dim=-1)
current = torch.where(high_conf_mask, new_tokens, current)
# Add some randomness to break loops
if iteration < iterations - 1:
# Randomly mask 10% of tokens for next iteration
rand_mask = torch.rand_like(confidence) < 0.1
current = torch.where(rand_mask, tokenizer.MASK, current)
# Final decode
if current.dim() > 1:
text = tokenizer.decode(current[0])
else:
text = tokenizer.decode(current)
return text
def test_parallel_vs_autoregressive():
"""Compare parallel vs autoregressive generation"""
print("="*60)
print("Parallel vs Autoregressive Comparison")
print("="*60)
# Load both models
parallel = ParallelTokenizer()
# Also test with the original autoregressive
from inference import B2NLTokenizer
autoregressive = B2NLTokenizer()
test_texts = [
"Hello, world!",
"The quick brown fox",
"์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค.",
"Testing 123",
]
print("\n1. PARALLEL GENERATION (Single Pass)")
print("-"*40)
for text in test_texts:
start = time.time()
result = parallel.parallel_generate(text)
elapsed = (time.time() - start) * 1000
accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100
print(f"Input: {text}")
print(f"Output: {result}")
print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n")
print("\n2. ITERATIVE REFINEMENT (2 iterations)")
print("-"*40)
for text in test_texts:
start = time.time()
result = parallel.iterative_refinement(text, iterations=2)
elapsed = (time.time() - start) * 1000
accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100
print(f"Input: {text}")
print(f"Output: {result}")
print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n")
print("\n3. AUTOREGRESSIVE (Original - 48 steps)")
print("-"*40)
for text in test_texts:
start = time.time()
result = autoregressive.reconstruct(text, temperature=0.1)
elapsed = (time.time() - start) * 1000
accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100
print(f"Input: {text}")
print(f"Output: {result}")
print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n")
if __name__ == "__main__":
test_parallel_vs_autoregressive()