|
|
"""
|
|
|
Parallel (Non-autoregressive) Inference for B2NL-IntelligentTokenizer
|
|
|
Faster inference by generating all tokens at once
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import sys
|
|
|
import time
|
|
|
import io
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
if sys.platform == 'win32':
|
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
|
|
|
|
|
|
|
|
sys.path.insert(0, 'core')
|
|
|
|
|
|
from unified_model import IntelligentTokenizerV62
|
|
|
from tokenizer import ByteTokenizerV62
|
|
|
|
|
|
|
|
|
class ParallelTokenizer:
|
|
|
"""Fast parallel generation (non-autoregressive)"""
|
|
|
|
|
|
def __init__(self, checkpoint_path: str = None):
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
if checkpoint_path is None:
|
|
|
checkpoint_path = "D:/intelligent-tokenizer/intelligent-tokenizer_v6.2.1/checkpoints/v62/16.0/epoch_100.pt"
|
|
|
|
|
|
|
|
|
self.model = IntelligentTokenizerV62()
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
self.model = self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
|
|
|
print(f"Model loaded on {self.device}")
|
|
|
|
|
|
def parallel_generate(self, text: str) -> str:
|
|
|
"""
|
|
|
Parallel generation - generate all 48 tokens at once
|
|
|
This uses teacher forcing with dummy inputs
|
|
|
"""
|
|
|
tokenizer = self.model.tokenizer
|
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text)
|
|
|
if isinstance(encoded, dict):
|
|
|
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
|
|
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
else:
|
|
|
input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
|
|
input_ids = input_ids.to(self.device)
|
|
|
attention_mask = attention_mask.to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
encoder_outputs = self.model.encoder(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
if 'all_hidden_states' in encoder_outputs:
|
|
|
encoder_all_hidden = encoder_outputs['all_hidden_states']
|
|
|
else:
|
|
|
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
|
|
|
encoder_all_hidden = [compressed] * 4
|
|
|
|
|
|
|
|
|
batch_size = input_ids.size(0)
|
|
|
dummy_input = torch.full((batch_size, 48), tokenizer.MASK, device=self.device)
|
|
|
dummy_input[:, 0] = tokenizer.BOS
|
|
|
|
|
|
|
|
|
decoder_outputs = self.model.decoder(
|
|
|
encoder_all_hidden=encoder_all_hidden,
|
|
|
decoder_input_ids=dummy_input,
|
|
|
attention_mask=torch.ones_like(dummy_input),
|
|
|
use_cache=False
|
|
|
)
|
|
|
|
|
|
|
|
|
logits = decoder_outputs['logits']
|
|
|
|
|
|
|
|
|
predicted = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
if predicted.dim() > 1:
|
|
|
text = tokenizer.decode(predicted[0])
|
|
|
else:
|
|
|
text = tokenizer.decode(predicted)
|
|
|
|
|
|
return text
|
|
|
|
|
|
def iterative_refinement(self, text: str, iterations: int = 2) -> str:
|
|
|
"""
|
|
|
Iterative refinement - generate multiple times and refine
|
|
|
Similar to BERT-style masked prediction
|
|
|
"""
|
|
|
tokenizer = self.model.tokenizer
|
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text)
|
|
|
if isinstance(encoded, dict):
|
|
|
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
|
|
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
else:
|
|
|
input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
|
|
input_ids = input_ids.to(self.device)
|
|
|
attention_mask = attention_mask.to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
encoder_outputs = self.model.encoder(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
if 'all_hidden_states' in encoder_outputs:
|
|
|
encoder_all_hidden = encoder_outputs['all_hidden_states']
|
|
|
else:
|
|
|
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
|
|
|
encoder_all_hidden = [compressed] * 4
|
|
|
|
|
|
batch_size = input_ids.size(0)
|
|
|
|
|
|
|
|
|
current = torch.full((batch_size, 48), tokenizer.MASK, device=self.device)
|
|
|
current[:, 0] = tokenizer.BOS
|
|
|
|
|
|
|
|
|
for iteration in range(iterations):
|
|
|
|
|
|
decoder_outputs = self.model.decoder(
|
|
|
encoder_all_hidden=encoder_all_hidden,
|
|
|
decoder_input_ids=current,
|
|
|
attention_mask=torch.ones_like(current),
|
|
|
use_cache=False
|
|
|
)
|
|
|
|
|
|
logits = decoder_outputs['logits']
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
confidence = torch.max(probs, dim=-1)[0]
|
|
|
|
|
|
|
|
|
threshold = 0.7 - (iteration * 0.1)
|
|
|
high_conf_mask = confidence > threshold
|
|
|
|
|
|
new_tokens = torch.argmax(logits, dim=-1)
|
|
|
current = torch.where(high_conf_mask, new_tokens, current)
|
|
|
|
|
|
|
|
|
if iteration < iterations - 1:
|
|
|
|
|
|
rand_mask = torch.rand_like(confidence) < 0.1
|
|
|
current = torch.where(rand_mask, tokenizer.MASK, current)
|
|
|
|
|
|
|
|
|
if current.dim() > 1:
|
|
|
text = tokenizer.decode(current[0])
|
|
|
else:
|
|
|
text = tokenizer.decode(current)
|
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
def test_parallel_vs_autoregressive():
|
|
|
"""Compare parallel vs autoregressive generation"""
|
|
|
print("="*60)
|
|
|
print("Parallel vs Autoregressive Comparison")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
parallel = ParallelTokenizer()
|
|
|
|
|
|
|
|
|
from inference import B2NLTokenizer
|
|
|
autoregressive = B2NLTokenizer()
|
|
|
|
|
|
test_texts = [
|
|
|
"Hello, world!",
|
|
|
"The quick brown fox",
|
|
|
"์๋
ํ์ธ์, ๋ฐ๊ฐ์ต๋๋ค.",
|
|
|
"Testing 123",
|
|
|
]
|
|
|
|
|
|
print("\n1. PARALLEL GENERATION (Single Pass)")
|
|
|
print("-"*40)
|
|
|
for text in test_texts:
|
|
|
start = time.time()
|
|
|
result = parallel.parallel_generate(text)
|
|
|
elapsed = (time.time() - start) * 1000
|
|
|
|
|
|
accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100
|
|
|
print(f"Input: {text}")
|
|
|
print(f"Output: {result}")
|
|
|
print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n")
|
|
|
|
|
|
print("\n2. ITERATIVE REFINEMENT (2 iterations)")
|
|
|
print("-"*40)
|
|
|
for text in test_texts:
|
|
|
start = time.time()
|
|
|
result = parallel.iterative_refinement(text, iterations=2)
|
|
|
elapsed = (time.time() - start) * 1000
|
|
|
|
|
|
accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100
|
|
|
print(f"Input: {text}")
|
|
|
print(f"Output: {result}")
|
|
|
print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n")
|
|
|
|
|
|
print("\n3. AUTOREGRESSIVE (Original - 48 steps)")
|
|
|
print("-"*40)
|
|
|
for text in test_texts:
|
|
|
start = time.time()
|
|
|
result = autoregressive.reconstruct(text, temperature=0.1)
|
|
|
elapsed = (time.time() - start) * 1000
|
|
|
|
|
|
accuracy = sum(1 for i in range(min(len(text), len(result))) if text[i] == result[i]) / len(text) * 100
|
|
|
print(f"Input: {text}")
|
|
|
print(f"Output: {result}")
|
|
|
print(f"Accuracy: {accuracy:.1f}%, Time: {elapsed:.1f}ms\n")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_parallel_vs_autoregressive() |