| """
|
| Intelligent Tokenizer v6.2.0 - Unified Model
|
| Integrates encoder, decoder, and tokenizer with all GPT improvements
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Dict, List, Optional, Tuple, Union
|
| import math
|
|
|
|
|
| try:
|
| from .encoder import EncoderV62
|
| from .decoder import DecoderV62
|
| from .tokenizer import ByteTokenizerV62
|
| except ImportError:
|
|
|
| from encoder import EncoderV62
|
| from decoder import DecoderV62
|
| from tokenizer import ByteTokenizerV62
|
|
|
|
|
| class IntelligentTokenizerV62(nn.Module):
|
| """
|
| Complete v6.2.0 model with progressive splitting and optimizations
|
|
|
| Key features:
|
| - 48-byte chunks (46+2 with BOS/EOS)
|
| - Progressive splitting: 48→1→N→M tokens
|
| - Multi-level cross-attention
|
| - KV cache optimization (8x reduction)
|
| - All GPT-5 improvements integrated
|
| """
|
|
|
| def __init__(self, config: Optional[Dict] = None):
|
| super().__init__()
|
|
|
|
|
| self.config = config or {}
|
|
|
|
|
| self.tokenizer = ByteTokenizerV62(config)
|
| self.encoder = EncoderV62(config)
|
| self.decoder = DecoderV62(config)
|
|
|
|
|
| self.compression_weight = 0.1
|
| self.reconstruction_weight = 0.1
|
| self.boundary_weight = 0.1
|
|
|
|
|
| self.register_buffer('training_step', torch.tensor(0))
|
| self.register_buffer('current_epoch', torch.tensor(0))
|
|
|
| def forward(self,
|
| input_ids: torch.Tensor = None,
|
| attention_mask: torch.Tensor = None,
|
| labels: torch.Tensor = None,
|
| text: str = None,
|
| return_loss: bool = True,
|
| temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
| """
|
| Unified forward pass
|
|
|
| Args:
|
| input_ids: Pre-tokenized input (optional)
|
| attention_mask: Attention mask (optional)
|
| labels: Target labels for training (optional)
|
| text: Raw text input (alternative to input_ids)
|
| return_loss: Whether to compute loss
|
| temperature: Temperature for Gumbel-Softmax in encoder
|
|
|
| Returns:
|
| Dictionary with model outputs
|
| """
|
|
|
| if text is not None:
|
| encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
| if isinstance(input_ids, str):
|
| text = input_ids
|
| encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
| device = next(self.parameters()).device
|
| if input_ids is not None and torch.is_tensor(input_ids):
|
| input_ids = input_ids.to(device)
|
| if attention_mask is not None and torch.is_tensor(attention_mask):
|
| attention_mask = attention_mask.to(device)
|
| if labels is not None and torch.is_tensor(labels):
|
| labels = labels.to(device)
|
|
|
|
|
| encoder_outputs = self.encoder(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| temperature=temperature
|
| )
|
|
|
|
|
| if labels is not None:
|
|
|
|
|
| decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
|
| decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
|
|
|
| decoder_outputs = self.decoder(
|
| encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
| decoder_input_ids=decoder_input,
|
| attention_mask=decoder_mask
|
| )
|
| else:
|
|
|
|
|
|
|
| if return_loss and input_ids is not None:
|
| labels = input_ids
|
| decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
|
| decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
|
|
|
| decoder_outputs = self.decoder(
|
| encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
| decoder_input_ids=decoder_input,
|
| attention_mask=decoder_mask
|
| )
|
| else:
|
| decoder_outputs = self.decoder(
|
| encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
| decoder_input_ids=None,
|
| attention_mask=attention_mask
|
| )
|
|
|
|
|
| outputs = {}
|
| for key, value in encoder_outputs.items():
|
| outputs[f'enc_{key}'] = value
|
| for key, value in decoder_outputs.items():
|
| outputs[f'dec_{key}'] = value
|
|
|
|
|
| if return_loss and labels is not None:
|
| loss = self.compute_loss(outputs, labels, attention_mask)
|
| outputs['loss'] = loss
|
|
|
| return outputs
|
|
|
| def compute_loss(self,
|
| outputs: Dict[str, torch.Tensor],
|
| labels: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| """
|
| Compute combined loss with multiple objectives
|
|
|
| Components:
|
| 1. Reconstruction loss (cross-entropy)
|
| 2. Compression loss (encourage higher compression)
|
| 3. Boundary loss (boundary prediction accuracy)
|
| """
|
| losses = {}
|
|
|
|
|
| if 'dec_logits' in outputs:
|
| logits = outputs['dec_logits']
|
|
|
|
|
| target_labels = labels[:, 1:] if labels.dim() > 1 else labels[1:]
|
| target_mask = attention_mask[:, 1:] if attention_mask is not None and attention_mask.dim() > 1 else None
|
|
|
|
|
| batch_size, seq_len, vocab_size = logits.shape
|
| logits_flat = logits.reshape(-1, vocab_size)
|
| labels_flat = target_labels.reshape(-1)
|
|
|
|
|
| if target_mask is not None:
|
| mask_flat = target_mask.reshape(-1).bool()
|
| reconstruction_loss = F.cross_entropy(
|
| logits_flat[mask_flat],
|
| labels_flat[mask_flat],
|
| ignore_index=self.tokenizer.PAD,
|
| label_smoothing=0.1
|
| )
|
| else:
|
| reconstruction_loss = F.cross_entropy(
|
| logits_flat,
|
| labels_flat,
|
| ignore_index=self.tokenizer.PAD,
|
| label_smoothing=0.1
|
| )
|
|
|
| losses['reconstruction'] = reconstruction_loss * self.reconstruction_weight
|
|
|
|
|
| if 'enc_compression_ratio' in outputs:
|
|
|
| target_ratio = 24.0
|
| current_ratio = outputs['enc_compression_ratio']
|
|
|
|
|
| if isinstance(current_ratio, (int, float)):
|
| current_ratio_tensor = labels.new_tensor(current_ratio, dtype=torch.float32)
|
| else:
|
| current_ratio_tensor = current_ratio.float()
|
| target_ratio_tensor = labels.new_tensor(target_ratio, dtype=torch.float32)
|
|
|
|
|
| compression_loss = F.smooth_l1_loss(
|
| current_ratio_tensor,
|
| target_ratio_tensor,
|
| beta=2.0
|
| )
|
|
|
| losses['compression'] = compression_loss * self.compression_weight
|
|
|
|
|
| if 'enc_boundaries' in outputs and outputs['enc_boundaries'] is not None:
|
| boundary_scores = outputs['enc_boundaries']
|
|
|
|
|
|
|
| boundary_probs = torch.sigmoid(boundary_scores)
|
|
|
|
|
| sparsity_loss = boundary_probs.mean() * 0.1
|
|
|
|
|
| if boundary_scores.size(1) > 1:
|
| diff = boundary_scores[:, 1:] - boundary_scores[:, :-1]
|
| smoothness_loss = (diff ** 2).mean() * 0.01
|
| else:
|
| smoothness_loss = 0.0
|
|
|
| boundary_loss = sparsity_loss + smoothness_loss
|
|
|
| losses['boundary'] = boundary_loss * self.boundary_weight
|
|
|
|
|
| total_loss = sum(losses.values())
|
|
|
|
|
| self.last_losses = losses
|
|
|
| return total_loss
|
|
|
| def generate(self,
|
| text: str = None,
|
| input_ids: torch.Tensor = None,
|
| max_length: int = 48,
|
| temperature: float = 0.1,
|
| top_k: int = 10,
|
| top_p: float = 0.95) -> str:
|
| """
|
| Generate/reconstruct text
|
|
|
| Args:
|
| text: Input text to encode and reconstruct
|
| input_ids: Pre-encoded input
|
| max_length: Maximum generation length
|
| temperature: Sampling temperature
|
| top_k: Top-k sampling
|
| top_p: Top-p (nucleus) sampling
|
|
|
| Returns:
|
| Reconstructed/generated text
|
| """
|
|
|
| chunk_positions = None
|
| if text is not None:
|
|
|
| if len(text.encode('utf-8')) > self.tokenizer.content_size:
|
| encoded = self.tokenizer.encode(text, add_special_tokens=True, return_chunks=True)
|
| chunk_positions = encoded.get('chunk_positions', None)
|
| else:
|
| encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
|
|
| input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
| else:
|
| attention_mask = (input_ids != self.tokenizer.PAD).bool()
|
|
|
|
|
| device = next(self.parameters()).device
|
| input_ids = input_ids.to(device)
|
| attention_mask = attention_mask.to(device)
|
|
|
|
|
| with torch.no_grad():
|
| encoder_outputs = self.encoder(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask
|
| )
|
|
|
|
|
| if 'all_hidden_states' in encoder_outputs:
|
| encoder_all_hidden = encoder_outputs['all_hidden_states']
|
| else:
|
| compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
|
| encoder_all_hidden = [compressed] * 4
|
|
|
|
|
| batch_size = input_ids.size(0)
|
|
|
|
|
| generated_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device)
|
|
|
| for step in range(max_length - 1):
|
| with torch.no_grad():
|
|
|
| decoder_outputs = self.decoder(
|
| encoder_all_hidden=encoder_all_hidden,
|
| decoder_input_ids=generated_ids,
|
| attention_mask=torch.ones_like(generated_ids),
|
| use_cache=False
|
| )
|
|
|
|
|
| logits = decoder_outputs['logits'][:, -1, :] / temperature
|
|
|
|
|
| if top_k > 0:
|
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| logits[indices_to_remove] = float('-inf')
|
|
|
|
|
| probs = F.softmax(logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
|
|
| if (next_token == self.tokenizer.EOS).all():
|
| break
|
|
|
|
|
| if generated_ids.dim() > 2 and chunk_positions is not None:
|
|
|
| text = self.tokenizer.reconstruct(
|
| generated_ids,
|
| positions=chunk_positions,
|
| overlap=self.tokenizer.chunk_overlap
|
| )
|
| elif generated_ids.dim() > 2:
|
|
|
| text = self.tokenizer.reconstruct(generated_ids)
|
| else:
|
|
|
| text = self.tokenizer.decode(generated_ids[0] if generated_ids.dim() > 1 else generated_ids)
|
|
|
| return text
|
|
|
| def compress(self, text: str) -> Dict[str, Union[torch.Tensor, float]]:
|
| """
|
| Compress text and return compression statistics
|
|
|
| Args:
|
| text: Input text to compress
|
|
|
| Returns:
|
| Dictionary with compressed representation and statistics
|
| """
|
|
|
| encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
| device = next(self.parameters()).device
|
| input_ids = input_ids.to(device)
|
| attention_mask = attention_mask.to(device)
|
|
|
|
|
| with torch.no_grad():
|
| encoder_outputs = self.encoder(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask
|
| )
|
|
|
| return {
|
| 'compressed': encoder_outputs['compressed'],
|
| 'num_tokens': encoder_outputs['num_tokens'],
|
| 'compression_ratio': encoder_outputs['compression_ratio'],
|
| 'original_bytes': len(text.encode('utf-8')),
|
| 'compressed_size': encoder_outputs['num_tokens'] * 2
|
| }
|
|
|
| def update_training_state(self, epoch: int, step: int = 0, reconstruction_loss: float = None):
|
| """
|
| Update training state - adaptive, not phase-based
|
|
|
| Args:
|
| epoch: Current epoch
|
| step: Current training step
|
| reconstruction_loss: Current reconstruction quality
|
| """
|
| self.current_epoch = torch.tensor(epoch)
|
| self.training_step = torch.tensor(step)
|
|
|
|
|
| self.encoder.set_warmup_step(step)
|
|
|
|
|
| if reconstruction_loss is not None:
|
|
|
| if reconstruction_loss > 1.0:
|
| self.reconstruction_weight = 1.0
|
| self.compression_weight = 0.1
|
| else:
|
|
|
| self.reconstruction_weight = 0.5
|
| self.compression_weight = 0.1
|
|
|
|
|
| self.boundary_weight = 0.1
|
|
|
|
|
| self.encoder.adaptive_compression_control(reconstruction_loss)
|
| else:
|
|
|
| self.reconstruction_weight = 0.5
|
| self.compression_weight = 0.1
|
| self.boundary_weight = 0.1
|
|
|
| def get_model_stats(self) -> Dict[str, float]:
|
| """
|
| Get model statistics for monitoring
|
|
|
| Returns:
|
| Dictionary with various model statistics
|
| """
|
| stats = {}
|
|
|
|
|
| encoder_stats = self.encoder.get_monitoring_stats()
|
| stats.update({f'encoder_{k}': v for k, v in encoder_stats.items()})
|
|
|
|
|
| decoder_memory = self.decoder.get_memory_usage()
|
| stats.update({f'decoder_{k}': v for k, v in decoder_memory.items()})
|
|
|
|
|
| if hasattr(self, 'last_losses'):
|
| for k, v in self.last_losses.items():
|
| if isinstance(v, torch.Tensor):
|
| stats[f'loss_{k}'] = v.item() if v.numel() == 1 else v.mean().item()
|
| else:
|
| stats[f'loss_{k}'] = float(v)
|
|
|
|
|
| stats['current_epoch'] = self.current_epoch.item()
|
| stats['training_step'] = self.training_step.item()
|
|
|
| return stats
|
|
|
| def save_checkpoint(self, path: str):
|
| """
|
| Save model checkpoint
|
|
|
| Args:
|
| path: Path to save checkpoint
|
| """
|
| checkpoint = {
|
| 'model_state_dict': self.state_dict(),
|
| 'config': self.config,
|
| 'epoch': self.current_epoch.item(),
|
| 'step': self.training_step.item(),
|
| 'stats': self.get_model_stats()
|
| }
|
| torch.save(checkpoint, path)
|
| print(f"Checkpoint saved to {path}")
|
|
|
| @classmethod
|
| def from_checkpoint(cls, path: str, device: str = 'cuda'):
|
| """
|
| Load model from checkpoint
|
|
|
| Args:
|
| path: Path to checkpoint
|
| device: Device to load model on
|
|
|
| Returns:
|
| Loaded model instance
|
| """
|
| checkpoint = torch.load(path, map_location=device)
|
|
|
|
|
| model = cls(checkpoint.get('config', {}))
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| model.to(device)
|
|
|
|
|
| if 'epoch' in checkpoint:
|
| model.current_epoch = torch.tensor(checkpoint['epoch'])
|
| if 'step' in checkpoint:
|
| model.training_step = torch.tensor(checkpoint['step'])
|
|
|
| print(f"Model loaded from {path} (Epoch {checkpoint.get('epoch', 0)})")
|
| return model
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| print("Testing Intelligent Tokenizer v6.2.0")
|
|
|
|
|
| model = IntelligentTokenizerV62()
|
| print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
|
|
|
|
|
| test_texts = [
|
| "Hello, world!",
|
| "안녕하세요, 만나서 반갑습니다. 오늘 날씨가 좋네요!",
|
| "今天天气很好。",
|
| ]
|
|
|
| for text in test_texts:
|
| print(f"\nInput: {text}")
|
|
|
|
|
| compression = model.compress(text)
|
| print(f" Compression ratio: {compression['compression_ratio']:.1f}:1")
|
| print(f" Tokens: {compression['num_tokens']}")
|
|
|
|
|
| reconstructed = model.generate(text, temperature=0.1)
|
| print(f" Reconstructed: {reconstructed}")
|
|
|
|
|
| stats = model.get_model_stats()
|
| print(f"\nModel Statistics:")
|
| for key, value in stats.items():
|
| if isinstance(value, float):
|
| print(f" {key}: {value:.4f}")
|
| else:
|
| print(f" {key}: {value}") |