Commit ·
18a94f8
1
Parent(s): a050405
Test Model
Browse files- Model_Architecture/generation.py +54 -1
- Model_Architecture/test_model.py +224 -0
Model_Architecture/generation.py
CHANGED
|
@@ -171,13 +171,54 @@ def get_tokenizer(use_turkish=False, tokenizer_name="gpt2"):
|
|
| 171 |
# EXAMPLE USAGE
|
| 172 |
#####################################
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if __name__ == "__main__":
|
| 175 |
import json
|
| 176 |
from pathlib import Path
|
|
|
|
| 177 |
|
| 178 |
# Configuration: Set to True to use Turkish tokenizer, False for tiktoken
|
| 179 |
USE_TURKISH_TOKENIZER = True # Change this to False for English text generation
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
# Example configuration - smaller model for testing
|
| 182 |
config_path = Path("config.json")
|
| 183 |
if config_path.exists():
|
|
@@ -207,9 +248,21 @@ if __name__ == "__main__":
|
|
| 207 |
print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer")
|
| 208 |
|
| 209 |
# Initialize model
|
| 210 |
-
print("Initializing model...")
|
| 211 |
torch.manual_seed(123)
|
| 212 |
model = ismail(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
model.eval()
|
| 214 |
|
| 215 |
# Example 1: Greedy generation (argmax)
|
|
|
|
| 171 |
# EXAMPLE USAGE
|
| 172 |
#####################################
|
| 173 |
|
| 174 |
+
def load_checkpoint(model, checkpoint_path):
|
| 175 |
+
"""
|
| 176 |
+
Load a trained checkpoint into the model.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
model: The model instance
|
| 180 |
+
checkpoint_path: Path to the checkpoint file (.pt)
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
The loaded checkpoint dictionary with metadata
|
| 184 |
+
"""
|
| 185 |
+
print(f"\n📦 Loading checkpoint: {checkpoint_path}")
|
| 186 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 187 |
+
|
| 188 |
+
# Handle different checkpoint formats
|
| 189 |
+
if 'model_state_dict' in checkpoint:
|
| 190 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 191 |
+
print(f"✅ Loaded model state from checkpoint")
|
| 192 |
+
if 'step' in checkpoint:
|
| 193 |
+
print(f" Training step: {checkpoint['step']:,}")
|
| 194 |
+
if 'loss' in checkpoint:
|
| 195 |
+
print(f" Loss: {checkpoint['loss']:.4f}")
|
| 196 |
+
else:
|
| 197 |
+
# Direct state dict
|
| 198 |
+
model.load_state_dict(checkpoint)
|
| 199 |
+
print(f"✅ Loaded model state (direct)")
|
| 200 |
+
|
| 201 |
+
return checkpoint
|
| 202 |
+
|
| 203 |
+
|
| 204 |
if __name__ == "__main__":
|
| 205 |
import json
|
| 206 |
from pathlib import Path
|
| 207 |
+
import sys
|
| 208 |
|
| 209 |
# Configuration: Set to True to use Turkish tokenizer, False for tiktoken
|
| 210 |
USE_TURKISH_TOKENIZER = True # Change this to False for English text generation
|
| 211 |
|
| 212 |
+
# ===== CHECKPOINT LOADING =====
|
| 213 |
+
# Set this to the path of your trained checkpoint
|
| 214 |
+
# Example: CHECKPOINT_PATH = "./checkpoints/step_55000_expert_2.pt"
|
| 215 |
+
CHECKPOINT_PATH = None # Set to None to use random initialization
|
| 216 |
+
|
| 217 |
+
# You can also pass checkpoint path as command line argument
|
| 218 |
+
if len(sys.argv) > 1:
|
| 219 |
+
CHECKPOINT_PATH = sys.argv[1]
|
| 220 |
+
print(f"🔧 Using checkpoint from command line: {CHECKPOINT_PATH}")
|
| 221 |
+
|
| 222 |
# Example configuration - smaller model for testing
|
| 223 |
config_path = Path("config.json")
|
| 224 |
if config_path.exists():
|
|
|
|
| 248 |
print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer")
|
| 249 |
|
| 250 |
# Initialize model
|
| 251 |
+
print("\n🚀 Initializing model...")
|
| 252 |
torch.manual_seed(123)
|
| 253 |
model = ismail(args)
|
| 254 |
+
|
| 255 |
+
# Load checkpoint if specified
|
| 256 |
+
if CHECKPOINT_PATH:
|
| 257 |
+
checkpoint_file = Path(CHECKPOINT_PATH)
|
| 258 |
+
if checkpoint_file.exists():
|
| 259 |
+
load_checkpoint(model, checkpoint_file)
|
| 260 |
+
else:
|
| 261 |
+
print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}")
|
| 262 |
+
print(" Using random initialization instead")
|
| 263 |
+
else:
|
| 264 |
+
print("ℹ️ No checkpoint specified, using random initialization")
|
| 265 |
+
|
| 266 |
model.eval()
|
| 267 |
|
| 268 |
# Example 1: Greedy generation (argmax)
|
Model_Architecture/test_model.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Interactive script to test your trained ismAIl model.
|
| 4 |
+
Load a checkpoint and generate text with custom prompts.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import sys
|
| 11 |
+
from model import ismail, ModelArgs
|
| 12 |
+
from generation import (
|
| 13 |
+
generate_text_simple,
|
| 14 |
+
generate_text_with_sampling,
|
| 15 |
+
text_to_token_ids,
|
| 16 |
+
token_ids_to_text,
|
| 17 |
+
get_tokenizer,
|
| 18 |
+
load_checkpoint
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def interactive_generation(model, tokenizer, args):
|
| 23 |
+
"""Interactive mode: continuously prompt for text and generate responses."""
|
| 24 |
+
print("\n" + "="*60)
|
| 25 |
+
print("🎤 INTERACTIVE GENERATION MODE")
|
| 26 |
+
print("="*60)
|
| 27 |
+
print("Commands:")
|
| 28 |
+
print(" - Type your prompt and press Enter to generate")
|
| 29 |
+
print(" - Type 'quit' or 'exit' to stop")
|
| 30 |
+
print(" - Type 'params' to change generation parameters")
|
| 31 |
+
print("="*60 + "\n")
|
| 32 |
+
|
| 33 |
+
# Default generation parameters
|
| 34 |
+
temperature = 0.8
|
| 35 |
+
top_k = 50
|
| 36 |
+
max_tokens = 50
|
| 37 |
+
use_sampling = True
|
| 38 |
+
|
| 39 |
+
while True:
|
| 40 |
+
try:
|
| 41 |
+
prompt = input("\n💬 Prompt: ").strip()
|
| 42 |
+
|
| 43 |
+
if prompt.lower() in ['quit', 'exit', 'q']:
|
| 44 |
+
print("👋 Goodbye!")
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
if prompt.lower() == 'params':
|
| 48 |
+
print("\n⚙️ Current parameters:")
|
| 49 |
+
print(f" Temperature: {temperature}")
|
| 50 |
+
print(f" Top-k: {top_k}")
|
| 51 |
+
print(f" Max tokens: {max_tokens}")
|
| 52 |
+
print(f" Use sampling: {use_sampling}")
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
temp_input = input(f" New temperature (current: {temperature}): ").strip()
|
| 56 |
+
if temp_input:
|
| 57 |
+
temperature = float(temp_input)
|
| 58 |
+
|
| 59 |
+
topk_input = input(f" New top-k (current: {top_k}): ").strip()
|
| 60 |
+
if topk_input:
|
| 61 |
+
top_k = int(topk_input)
|
| 62 |
+
|
| 63 |
+
tokens_input = input(f" New max tokens (current: {max_tokens}): ").strip()
|
| 64 |
+
if tokens_input:
|
| 65 |
+
max_tokens = int(tokens_input)
|
| 66 |
+
|
| 67 |
+
sampling_input = input(f" Use sampling? (y/n, current: {'y' if use_sampling else 'n'}): ").strip()
|
| 68 |
+
if sampling_input:
|
| 69 |
+
use_sampling = sampling_input.lower() in ['y', 'yes', 't', 'true']
|
| 70 |
+
|
| 71 |
+
print("✅ Parameters updated!")
|
| 72 |
+
except ValueError as e:
|
| 73 |
+
print(f"❌ Invalid input: {e}")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
if not prompt:
|
| 77 |
+
print("⚠️ Empty prompt, try again")
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# Tokenize
|
| 81 |
+
token_ids = text_to_token_ids(prompt, tokenizer)
|
| 82 |
+
print(f"📝 Input tokens: {token_ids.shape[1]}")
|
| 83 |
+
|
| 84 |
+
# Generate
|
| 85 |
+
print("🤖 Generating...", end='', flush=True)
|
| 86 |
+
if use_sampling:
|
| 87 |
+
generated_ids = generate_text_with_sampling(
|
| 88 |
+
model=model,
|
| 89 |
+
idx=token_ids,
|
| 90 |
+
max_new_tokens=max_tokens,
|
| 91 |
+
context_size=args.max_seq_len,
|
| 92 |
+
temperature=temperature,
|
| 93 |
+
top_k=top_k
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
generated_ids = generate_text_simple(
|
| 97 |
+
model=model,
|
| 98 |
+
idx=token_ids,
|
| 99 |
+
max_new_tokens=max_tokens,
|
| 100 |
+
context_size=args.max_seq_len
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Decode
|
| 104 |
+
generated_text = token_ids_to_text(generated_ids, tokenizer)
|
| 105 |
+
print(f"\r🤖 Generated ({generated_ids.shape[1]} tokens):")
|
| 106 |
+
print(f"\n{generated_text}\n")
|
| 107 |
+
|
| 108 |
+
except KeyboardInterrupt:
|
| 109 |
+
print("\n\n👋 Interrupted. Goodbye!")
|
| 110 |
+
break
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"\n❌ Error: {e}")
|
| 113 |
+
import traceback
|
| 114 |
+
traceback.print_exc()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def batch_generation(model, tokenizer, args, prompts):
|
| 118 |
+
"""Generate text for a list of prompts."""
|
| 119 |
+
print("\n" + "="*60)
|
| 120 |
+
print("📋 BATCH GENERATION MODE")
|
| 121 |
+
print("="*60 + "\n")
|
| 122 |
+
|
| 123 |
+
for i, prompt in enumerate(prompts, 1):
|
| 124 |
+
print(f"\n--- Prompt {i}/{len(prompts)} ---")
|
| 125 |
+
print(f"Input: {prompt}")
|
| 126 |
+
|
| 127 |
+
token_ids = text_to_token_ids(prompt, tokenizer)
|
| 128 |
+
|
| 129 |
+
# Generate with sampling
|
| 130 |
+
generated_ids = generate_text_with_sampling(
|
| 131 |
+
model=model,
|
| 132 |
+
idx=token_ids,
|
| 133 |
+
max_new_tokens=50,
|
| 134 |
+
context_size=args.max_seq_len,
|
| 135 |
+
temperature=0.8,
|
| 136 |
+
top_k=50
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
generated_text = token_ids_to_text(generated_ids, tokenizer)
|
| 140 |
+
print(f"Output: {generated_text}\n")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
# Parse command line arguments
|
| 145 |
+
if len(sys.argv) < 2:
|
| 146 |
+
print("Usage: python test_model.py <checkpoint_path> [--interactive] [--prompts \"prompt1\" \"prompt2\" ...]")
|
| 147 |
+
print("\nExample:")
|
| 148 |
+
print(" python test_model.py checkpoints/step_55000_expert_2.pt --interactive")
|
| 149 |
+
print(" python test_model.py checkpoints/step_55000_expert_2.pt --prompts \"Merhaba\" \"Yapay zeka\"")
|
| 150 |
+
sys.exit(1)
|
| 151 |
+
|
| 152 |
+
checkpoint_path = sys.argv[1]
|
| 153 |
+
interactive_mode = '--interactive' in sys.argv or '-i' in sys.argv
|
| 154 |
+
|
| 155 |
+
# Extract prompts from command line
|
| 156 |
+
custom_prompts = []
|
| 157 |
+
if '--prompts' in sys.argv:
|
| 158 |
+
idx = sys.argv.index('--prompts')
|
| 159 |
+
custom_prompts = [arg for arg in sys.argv[idx+1:] if not arg.startswith('--')]
|
| 160 |
+
|
| 161 |
+
print("="*60)
|
| 162 |
+
print("🧠 ismAIl Model Testing Script")
|
| 163 |
+
print("="*60)
|
| 164 |
+
|
| 165 |
+
# Load config
|
| 166 |
+
config_path = Path("config.json")
|
| 167 |
+
if config_path.exists():
|
| 168 |
+
with open(config_path) as f:
|
| 169 |
+
config = json.load(f)
|
| 170 |
+
print(f"✅ Loaded config from {config_path}")
|
| 171 |
+
args = ModelArgs(**config["model"])
|
| 172 |
+
else:
|
| 173 |
+
print("❌ config.json not found!")
|
| 174 |
+
sys.exit(1)
|
| 175 |
+
|
| 176 |
+
# Initialize tokenizer
|
| 177 |
+
tokenizer_name = getattr(args, "tokenizer_name", "gpt2")
|
| 178 |
+
use_turkish = tokenizer_name.lower() == "turkish"
|
| 179 |
+
|
| 180 |
+
tokenizer = get_tokenizer(
|
| 181 |
+
use_turkish=use_turkish,
|
| 182 |
+
tokenizer_name="gpt2" if use_turkish else tokenizer_name
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Update vocab size if using Turkish tokenizer
|
| 186 |
+
if use_turkish:
|
| 187 |
+
from data import TurkishTokenizerWrapper
|
| 188 |
+
if isinstance(tokenizer, TurkishTokenizerWrapper):
|
| 189 |
+
if args.vocab_size != tokenizer.n_vocab:
|
| 190 |
+
print(f"⚠️ Updating vocab_size: {args.vocab_size:,} -> {tokenizer.n_vocab:,}")
|
| 191 |
+
args.vocab_size = tokenizer.n_vocab
|
| 192 |
+
|
| 193 |
+
# Initialize model
|
| 194 |
+
print("\n🚀 Initializing model...")
|
| 195 |
+
model = ismail(args)
|
| 196 |
+
|
| 197 |
+
# Load checkpoint
|
| 198 |
+
checkpoint_file = Path(checkpoint_path)
|
| 199 |
+
if checkpoint_file.exists():
|
| 200 |
+
load_checkpoint(model, checkpoint_file)
|
| 201 |
+
else:
|
| 202 |
+
print(f"❌ Checkpoint not found: {checkpoint_path}")
|
| 203 |
+
sys.exit(1)
|
| 204 |
+
|
| 205 |
+
model.eval()
|
| 206 |
+
|
| 207 |
+
# Run appropriate mode
|
| 208 |
+
if interactive_mode:
|
| 209 |
+
interactive_generation(model, tokenizer, args)
|
| 210 |
+
elif custom_prompts:
|
| 211 |
+
batch_generation(model, tokenizer, args, custom_prompts)
|
| 212 |
+
else:
|
| 213 |
+
# Default: use some Turkish prompts
|
| 214 |
+
default_prompts = [
|
| 215 |
+
"Merhaba, ben",
|
| 216 |
+
"Yapay zekanın geleceği",
|
| 217 |
+
"Bir varmış bir yokmuş",
|
| 218 |
+
"Türkiye'nin başkenti"
|
| 219 |
+
]
|
| 220 |
+
batch_generation(model, tokenizer, args, default_prompts)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
main()
|