ismail / Model_Architecture /generation.py
ikaganacar's picture
Test Model
18a94f8
import torch
import tiktoken
from model import ismail, ModelArgs
from data import TurkishTokenizerWrapper, TURKISH_TOKENIZER_AVAILABLE
#####################################
# TEXT GENERATION FUNCTIONS
#####################################
def generate_text_simple(model, idx, max_new_tokens, context_size):
"""
Generate text using simple greedy decoding (argmax).
Args:
model: The transformer model
idx: Input token indices of shape (batch_size, seq_len)
max_new_tokens: Number of new tokens to generate
context_size: Maximum context size the model can handle
Returns:
Generated token indices of shape (batch_size, seq_len + max_new_tokens)
"""
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx
def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None):
"""
Generate text using sampling with temperature and optional top-k filtering.
Args:
model: The transformer model
idx: Input token indices of shape (batch_size, seq_len)
max_new_tokens: Number of new tokens to generate
context_size: Maximum context size the model can handle
temperature: Sampling temperature (higher = more random, lower = more deterministic)
top_k: If set, only sample from the top k most likely tokens
Returns:
Generated token indices of shape (batch_size, seq_len + max_new_tokens)
"""
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
logits = logits[:, -1, :]
# Clamp temperature to avoid division by very small numbers
temperature = max(temperature, 1e-8)
logits = logits / temperature
# Optional: apply top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# Handle edge cases: check for invalid probabilities
if torch.isnan(probs).any() or torch.isinf(probs).any():
# Fallback to uniform distribution over valid tokens
probs = torch.ones_like(probs) / probs.size(-1)
# Ensure probabilities sum to 1
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
def text_to_token_ids(text, tokenizer):
"""
Convert text to token IDs.
Args:
text: Input text string
tokenizer: Tokenizer instance (tiktoken or TurkishTokenizerWrapper)
Returns:
Tensor of token IDs with shape (1, seq_len)
"""
# Turkish tokenizer doesn't support allowed_special parameter
if isinstance(tokenizer, TurkishTokenizerWrapper):
encoded = tokenizer.encode(text)
else:
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
return encoded_tensor
def token_ids_to_text(token_ids, tokenizer):
"""
Convert token IDs to text.
Args:
token_ids: Tensor of token IDs, can be 1D or 2D
tokenizer: Tokenizer instance (tiktoken or TurkishTokenizerWrapper)
Returns:
Decoded text string
"""
# Handle both 1D and 2D tensors
if token_ids.dim() == 2:
token_ids = token_ids.squeeze(0)
# Convert to list and decode
flat = token_ids.tolist()
return tokenizer.decode(flat)
def get_tokenizer(use_turkish=False, tokenizer_name="gpt2"):
"""
Get the appropriate tokenizer based on user preference.
Args:
use_turkish: Whether to use Turkish tokenizer
tokenizer_name: Name of tiktoken tokenizer to use if not using Turkish
Returns:
Tokenizer instance (TurkishTokenizerWrapper or tiktoken tokenizer)
"""
if use_turkish:
if not TURKISH_TOKENIZER_AVAILABLE:
raise ImportError(
"Turkish tokenizer requested but not available. "
"Install it with: pip install turkish-tokenizer"
)
tokenizer = TurkishTokenizerWrapper()
print(f"🇹🇷 Using Turkish Tokenizer (vocab size: {tokenizer.n_vocab:,})")
return tokenizer
else:
tokenizer = tiktoken.get_encoding(tokenizer_name)
print(f"📚 Using tiktoken tokenizer: {tokenizer_name} (vocab size: {tokenizer.n_vocab:,})")
return tokenizer
#####################################
# EXAMPLE USAGE
#####################################
def load_checkpoint(model, checkpoint_path):
"""
Load a trained checkpoint into the model.
Args:
model: The model instance
checkpoint_path: Path to the checkpoint file (.pt)
Returns:
The loaded checkpoint dictionary with metadata
"""
print(f"\n📦 Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ Loaded model state from checkpoint")
if 'step' in checkpoint:
print(f" Training step: {checkpoint['step']:,}")
if 'loss' in checkpoint:
print(f" Loss: {checkpoint['loss']:.4f}")
else:
# Direct state dict
model.load_state_dict(checkpoint)
print(f"✅ Loaded model state (direct)")
return checkpoint
if __name__ == "__main__":
import json
from pathlib import Path
import sys
# Configuration: Set to True to use Turkish tokenizer, False for tiktoken
USE_TURKISH_TOKENIZER = True # Change this to False for English text generation
# ===== CHECKPOINT LOADING =====
# Set this to the path of your trained checkpoint
# Example: CHECKPOINT_PATH = "./checkpoints/step_55000_expert_2.pt"
CHECKPOINT_PATH = None # Set to None to use random initialization
# You can also pass checkpoint path as command line argument
if len(sys.argv) > 1:
CHECKPOINT_PATH = sys.argv[1]
print(f"🔧 Using checkpoint from command line: {CHECKPOINT_PATH}")
# Example configuration - smaller model for testing
config_path = Path("config.json")
if config_path.exists():
with open(config_path) as f:
config = json.load(f)
print(f"✅ Loaded config from {config_path}")
args = ModelArgs(**config["model"])
else:
print("⚠️ config.json not found, using default ModelArgs")
args = ModelArgs()
# Initialize tokenizer
tokenizer_name = getattr(args, "tokenizer_name", "gpt2")
# Auto-detect Turkish tokenizer from config
use_turkish = (tokenizer_name.lower() == "turkish") or USE_TURKISH_TOKENIZER
tokenizer = get_tokenizer(
use_turkish=use_turkish,
tokenizer_name="gpt2" if use_turkish else tokenizer_name
)
# Update vocab size if using Turkish tokenizer
if use_turkish and isinstance(tokenizer, TurkishTokenizerWrapper):
if args.vocab_size != tokenizer.n_vocab:
print(f"⚠️ Config vocab_size ({args.vocab_size:,}) doesn't match tokenizer ({tokenizer.n_vocab:,})")
args.vocab_size = tokenizer.n_vocab
print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer")
# Initialize model
print("\n🚀 Initializing model...")
torch.manual_seed(123)
model = ismail(args)
# Load checkpoint if specified
if CHECKPOINT_PATH:
checkpoint_file = Path(CHECKPOINT_PATH)
if checkpoint_file.exists():
load_checkpoint(model, checkpoint_file)
else:
print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}")
print(" Using random initialization instead")
else:
print("ℹ️ No checkpoint specified, using random initialization")
model.eval()
# Example 1: Greedy generation (argmax)
print(f"\n{'='*60}")
print("EXAMPLE 1: GREEDY GENERATION (ARGMAX)")
print(f"{'='*60}")
# Use Turkish or English prompts based on tokenizer
if USE_TURKISH_TOKENIZER:
start_context = "Merhaba, ben"
else:
start_context = "Hello, I am"
print(f"\nInput: '{start_context}'")
token_ids = text_to_token_ids(start_context, tokenizer)
print(f"Token IDs shape: {token_ids.shape}")
generated_ids = generate_text_simple(
model=model,
idx=token_ids,
max_new_tokens=20,
context_size=args.max_seq_len
)
generated_text = token_ids_to_text(generated_ids, tokenizer)
print(f"\nGenerated: '{generated_text}'")
print(f"Total tokens: {generated_ids.shape[1]}")
# Example 2: Sampling with temperature
print(f"\n{'='*60}")
print("EXAMPLE 2: SAMPLING WITH TEMPERATURE")
print(f"{'='*60}")
if USE_TURKISH_TOKENIZER:
start_context = "Bir varmış bir yokmuş"
else:
start_context = "Once upon a time"
print(f"\nInput: '{start_context}'")
token_ids = text_to_token_ids(start_context, tokenizer)
# Generate with different temperatures
for temp in [0.5, 1.0, 1.5]:
print(f"\n--- Temperature: {temp} ---")
generated_ids = generate_text_with_sampling(
model=model,
idx=token_ids.clone(),
max_new_tokens=20,
context_size=args.max_seq_len,
temperature=temp
)
generated_text = token_ids_to_text(generated_ids, tokenizer)
print(f"Generated: '{generated_text}'")
# Example 3: Top-k sampling
print(f"\n{'='*60}")
print("EXAMPLE 3: TOP-K SAMPLING")
print(f"{'='*60}")
if USE_TURKISH_TOKENIZER:
start_context = "Yapay zekanın geleceği"
else:
start_context = "The future of AI is"
print(f"\nInput: '{start_context}'")
token_ids = text_to_token_ids(start_context, tokenizer)
generated_ids = generate_text_with_sampling(
model=model,
idx=token_ids,
max_new_tokens=30,
context_size=args.max_seq_len,
temperature=0.8,
top_k=50
)
generated_text = token_ids_to_text(generated_ids, tokenizer)
print(f"Generated: '{generated_text}'")
print(f"\n{'='*60}")
print("Generation examples completed!")
print(f"{'='*60}\n")