|
|
|
|
| """
|
| Intelligent Tokenizer v6.0 - Inference Module
|
| 임베딩과 복원 기능
|
| """
|
|
|
| import torch
|
| import sys
|
| import io
|
| from pathlib import Path
|
| from typing import Dict, List, Optional, Tuple
|
|
|
|
|
| if sys.stdout.encoding != 'utf-8':
|
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
|
|
| sys.path.append(str(Path(__file__).parent))
|
|
|
| from core.boundary_aware_model import BoundaryAwareTokenizerModel
|
| from src.core.byte_tokenizer_v6 import ByteTokenizerV6
|
|
|
|
|
| class IntelligentTokenizer:
|
| """Intelligent Tokenizer for embedding and restoration"""
|
|
|
| def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None):
|
| """
|
| Initialize tokenizer
|
|
|
| Args:
|
| checkpoint_path: Path to model checkpoint
|
| device: Device to use ('cuda', 'cpu', or None for auto)
|
| """
|
| if device is None:
|
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| else:
|
| self.device = torch.device(device)
|
|
|
| print(f"Initializing Intelligent Tokenizer v6.0...")
|
| print(f"Device: {self.device}")
|
|
|
|
|
| checkpoint_path = Path(checkpoint_path)
|
| if not checkpoint_path.exists():
|
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
|
| checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
|
|
|
|
| self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| self.model = self.model.to(self.device)
|
| self.model.eval()
|
|
|
|
|
| self.tokenizer = ByteTokenizerV6()
|
| self.max_chunk_size = 250
|
|
|
| print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
|
| print(f"Ready for inference!")
|
|
|
| def embed(self, text: str) -> torch.Tensor:
|
| """
|
| Convert text to embeddings
|
|
|
| Args:
|
| text: Input text
|
|
|
| Returns:
|
| Embedding tensor
|
| """
|
|
|
| if len(text.encode('utf-8')) > self.max_chunk_size:
|
| chunks = self._split_text_safely(text)
|
| embeddings = []
|
|
|
| for chunk in chunks:
|
| emb = self._embed_single(chunk)
|
| embeddings.append(emb)
|
|
|
|
|
| return torch.cat(embeddings, dim=1)
|
| else:
|
| return self._embed_single(text)
|
|
|
| def _embed_single(self, text: str) -> torch.Tensor:
|
| """Embed single chunk"""
|
|
|
| encoded = self.tokenizer.encode(text)
|
| byte_ids = encoded['input_ids']
|
| input_ids = torch.tensor([byte_ids], device=self.device)
|
| attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
|
|
|
| with torch.no_grad():
|
|
|
| encoder_outputs = self.model.encoder(input_ids, attention_mask)
|
| embeddings = encoder_outputs['last_hidden_state']
|
|
|
| return embeddings
|
|
|
| def restore(self, text: str) -> Tuple[str, float]:
|
| """
|
| Test restoration capability
|
|
|
| Args:
|
| text: Input text
|
|
|
| Returns:
|
| Tuple of (restored_text, accuracy)
|
| """
|
|
|
| if len(text.encode('utf-8')) > self.max_chunk_size:
|
| chunks = self._split_text_safely(text)
|
| restored_chunks = []
|
| accuracies = []
|
|
|
| for chunk in chunks:
|
| restored, acc = self._restore_single(chunk)
|
| restored_chunks.append(restored)
|
| accuracies.append(acc)
|
|
|
| return ''.join(restored_chunks), sum(accuracies) / len(accuracies)
|
| else:
|
| return self._restore_single(text)
|
|
|
| def _restore_single(self, text: str) -> Tuple[str, float]:
|
| """Restore single chunk"""
|
|
|
| encoded = self.tokenizer.encode(text)
|
| byte_ids = encoded['input_ids']
|
|
|
| if len(byte_ids) <= 1:
|
| return text, 1.0
|
|
|
| input_ids = torch.tensor([byte_ids], device=self.device)
|
| attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
|
|
|
| with torch.no_grad():
|
|
|
| decoder_input = input_ids[:, :-1]
|
| labels = input_ids[:, 1:]
|
|
|
| outputs = self.model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| decoder_input_ids=decoder_input,
|
| labels=labels,
|
| use_cross_attention=True
|
| )
|
|
|
|
|
| predictions = torch.argmax(outputs['logits'], dim=-1)
|
| accuracy = (predictions == labels).float().mean().item()
|
|
|
|
|
| try:
|
|
|
| pred_list = predictions[0].cpu().tolist()
|
|
|
| full_sequence = [self.tokenizer.BOS] + pred_list
|
|
|
|
|
| filtered = [b for b in full_sequence if 0 <= b < 256]
|
| if filtered:
|
| restored_bytes = bytes(filtered)
|
| restored_text = restored_bytes.decode('utf-8', errors='ignore')
|
| else:
|
| restored_text = ""
|
| except Exception as e:
|
| print(f"Restoration error: {e}")
|
| restored_text = ""
|
|
|
| return restored_text, accuracy
|
|
|
| def compress(self, text: str) -> Dict:
|
| """
|
| Get compression statistics
|
|
|
| Args:
|
| text: Input text
|
|
|
| Returns:
|
| Dict with compression info
|
| """
|
| text_bytes = text.encode('utf-8')
|
| embeddings = self.embed(text)
|
|
|
| original_size = len(text_bytes)
|
| compressed_size = embeddings.shape[1]
|
| compression_ratio = original_size / compressed_size if compressed_size > 0 else 0
|
|
|
| return {
|
| 'original_bytes': original_size,
|
| 'compressed_tokens': compressed_size,
|
| 'compression_ratio': compression_ratio,
|
| 'embedding_shape': list(embeddings.shape)
|
| }
|
|
|
| def _split_text_safely(self, text: str) -> List[str]:
|
| """Split text safely at UTF-8 boundaries"""
|
| chunks = []
|
| text_bytes = text.encode('utf-8')
|
|
|
| start = 0
|
| while start < len(text_bytes):
|
| end = min(start + self.max_chunk_size, len(text_bytes))
|
|
|
|
|
| while end > start and end < len(text_bytes):
|
| try:
|
| chunk = text_bytes[start:end].decode('utf-8')
|
| break
|
| except UnicodeDecodeError:
|
| end -= 1
|
|
|
| if end > start:
|
| chunk = text_bytes[start:end].decode('utf-8')
|
| chunks.append(chunk)
|
| start = end
|
| else:
|
| break
|
|
|
| return chunks
|
|
|
|
|
| def test_model():
|
| """Test model functionality"""
|
| print("="*70)
|
| print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST")
|
| print("="*70)
|
|
|
|
|
| tokenizer = IntelligentTokenizer()
|
|
|
|
|
| test_samples = [
|
| ("English", "Hello, world!"),
|
| ("Korean", "안녕하세요. 반갑습니다."),
|
| ("Chinese", "今天天气很好"),
|
| ("Japanese", "こんにちは"),
|
| ("Arabic", "مرحبا بك"),
|
| ("Russian", "Привет, как дела?"),
|
| ("Emoji", "Hello 👋 World 🌍!"),
|
| ]
|
|
|
| print("\n" + "="*70)
|
| print("EMBEDDING & RESTORATION TESTS")
|
| print("="*70)
|
|
|
| total_accuracy = 0
|
| successful = 0
|
|
|
| for lang, text in test_samples:
|
| print(f"\n[{lang}]")
|
| print(f"Original: {text}")
|
|
|
|
|
| embeddings = tokenizer.embed(text)
|
| print(f"Embedding: {embeddings.shape}")
|
|
|
|
|
| compression = tokenizer.compress(text)
|
| print(f"Compression: {compression['original_bytes']} bytes → {compression['compressed_tokens']} tokens")
|
| print(f"Ratio: {compression['compression_ratio']:.2f}x")
|
|
|
|
|
| restored, accuracy = tokenizer.restore(text)
|
| print(f"Restored: {restored}")
|
| print(f"Accuracy: {accuracy:.1%}")
|
|
|
| if accuracy > 0.7:
|
| successful += 1
|
| total_accuracy += accuracy
|
|
|
|
|
| print("\n" + "="*70)
|
| print("TEST SUMMARY")
|
| print("="*70)
|
| print(f"Tests passed: {successful}/{len(test_samples)}")
|
| print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}")
|
|
|
| if successful == len(test_samples):
|
| print("\n✅ ALL TESTS PASSED!")
|
| return True
|
| elif successful >= len(test_samples) * 0.7:
|
| print("\n⚠️ PARTIAL SUCCESS (70%+ tests passed)")
|
| return True
|
| else:
|
| print("\n❌ TESTS FAILED")
|
| return False
|
|
|
|
|
| if __name__ == "__main__":
|
| success = test_model()
|
| sys.exit(0 if success else 1) |