ikaganacar commited on
Commit
18a94f8
·
1 Parent(s): a050405

Test Model

Browse files
Model_Architecture/generation.py CHANGED
@@ -171,13 +171,54 @@ def get_tokenizer(use_turkish=False, tokenizer_name="gpt2"):
171
  # EXAMPLE USAGE
172
  #####################################
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  if __name__ == "__main__":
175
  import json
176
  from pathlib import Path
 
177
 
178
  # Configuration: Set to True to use Turkish tokenizer, False for tiktoken
179
  USE_TURKISH_TOKENIZER = True # Change this to False for English text generation
180
 
 
 
 
 
 
 
 
 
 
 
181
  # Example configuration - smaller model for testing
182
  config_path = Path("config.json")
183
  if config_path.exists():
@@ -207,9 +248,21 @@ if __name__ == "__main__":
207
  print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer")
208
 
209
  # Initialize model
210
- print("Initializing model...")
211
  torch.manual_seed(123)
212
  model = ismail(args)
 
 
 
 
 
 
 
 
 
 
 
 
213
  model.eval()
214
 
215
  # Example 1: Greedy generation (argmax)
 
171
  # EXAMPLE USAGE
172
  #####################################
173
 
174
+ def load_checkpoint(model, checkpoint_path):
175
+ """
176
+ Load a trained checkpoint into the model.
177
+
178
+ Args:
179
+ model: The model instance
180
+ checkpoint_path: Path to the checkpoint file (.pt)
181
+
182
+ Returns:
183
+ The loaded checkpoint dictionary with metadata
184
+ """
185
+ print(f"\n📦 Loading checkpoint: {checkpoint_path}")
186
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
187
+
188
+ # Handle different checkpoint formats
189
+ if 'model_state_dict' in checkpoint:
190
+ model.load_state_dict(checkpoint['model_state_dict'])
191
+ print(f"✅ Loaded model state from checkpoint")
192
+ if 'step' in checkpoint:
193
+ print(f" Training step: {checkpoint['step']:,}")
194
+ if 'loss' in checkpoint:
195
+ print(f" Loss: {checkpoint['loss']:.4f}")
196
+ else:
197
+ # Direct state dict
198
+ model.load_state_dict(checkpoint)
199
+ print(f"✅ Loaded model state (direct)")
200
+
201
+ return checkpoint
202
+
203
+
204
  if __name__ == "__main__":
205
  import json
206
  from pathlib import Path
207
+ import sys
208
 
209
  # Configuration: Set to True to use Turkish tokenizer, False for tiktoken
210
  USE_TURKISH_TOKENIZER = True # Change this to False for English text generation
211
 
212
+ # ===== CHECKPOINT LOADING =====
213
+ # Set this to the path of your trained checkpoint
214
+ # Example: CHECKPOINT_PATH = "./checkpoints/step_55000_expert_2.pt"
215
+ CHECKPOINT_PATH = None # Set to None to use random initialization
216
+
217
+ # You can also pass checkpoint path as command line argument
218
+ if len(sys.argv) > 1:
219
+ CHECKPOINT_PATH = sys.argv[1]
220
+ print(f"🔧 Using checkpoint from command line: {CHECKPOINT_PATH}")
221
+
222
  # Example configuration - smaller model for testing
223
  config_path = Path("config.json")
224
  if config_path.exists():
 
248
  print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer")
249
 
250
  # Initialize model
251
+ print("\n🚀 Initializing model...")
252
  torch.manual_seed(123)
253
  model = ismail(args)
254
+
255
+ # Load checkpoint if specified
256
+ if CHECKPOINT_PATH:
257
+ checkpoint_file = Path(CHECKPOINT_PATH)
258
+ if checkpoint_file.exists():
259
+ load_checkpoint(model, checkpoint_file)
260
+ else:
261
+ print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}")
262
+ print(" Using random initialization instead")
263
+ else:
264
+ print("ℹ️ No checkpoint specified, using random initialization")
265
+
266
  model.eval()
267
 
268
  # Example 1: Greedy generation (argmax)
Model_Architecture/test_model.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Interactive script to test your trained ismAIl model.
4
+ Load a checkpoint and generate text with custom prompts.
5
+ """
6
+
7
+ import torch
8
+ import json
9
+ from pathlib import Path
10
+ import sys
11
+ from model import ismail, ModelArgs
12
+ from generation import (
13
+ generate_text_simple,
14
+ generate_text_with_sampling,
15
+ text_to_token_ids,
16
+ token_ids_to_text,
17
+ get_tokenizer,
18
+ load_checkpoint
19
+ )
20
+
21
+
22
+ def interactive_generation(model, tokenizer, args):
23
+ """Interactive mode: continuously prompt for text and generate responses."""
24
+ print("\n" + "="*60)
25
+ print("🎤 INTERACTIVE GENERATION MODE")
26
+ print("="*60)
27
+ print("Commands:")
28
+ print(" - Type your prompt and press Enter to generate")
29
+ print(" - Type 'quit' or 'exit' to stop")
30
+ print(" - Type 'params' to change generation parameters")
31
+ print("="*60 + "\n")
32
+
33
+ # Default generation parameters
34
+ temperature = 0.8
35
+ top_k = 50
36
+ max_tokens = 50
37
+ use_sampling = True
38
+
39
+ while True:
40
+ try:
41
+ prompt = input("\n💬 Prompt: ").strip()
42
+
43
+ if prompt.lower() in ['quit', 'exit', 'q']:
44
+ print("👋 Goodbye!")
45
+ break
46
+
47
+ if prompt.lower() == 'params':
48
+ print("\n⚙️ Current parameters:")
49
+ print(f" Temperature: {temperature}")
50
+ print(f" Top-k: {top_k}")
51
+ print(f" Max tokens: {max_tokens}")
52
+ print(f" Use sampling: {use_sampling}")
53
+
54
+ try:
55
+ temp_input = input(f" New temperature (current: {temperature}): ").strip()
56
+ if temp_input:
57
+ temperature = float(temp_input)
58
+
59
+ topk_input = input(f" New top-k (current: {top_k}): ").strip()
60
+ if topk_input:
61
+ top_k = int(topk_input)
62
+
63
+ tokens_input = input(f" New max tokens (current: {max_tokens}): ").strip()
64
+ if tokens_input:
65
+ max_tokens = int(tokens_input)
66
+
67
+ sampling_input = input(f" Use sampling? (y/n, current: {'y' if use_sampling else 'n'}): ").strip()
68
+ if sampling_input:
69
+ use_sampling = sampling_input.lower() in ['y', 'yes', 't', 'true']
70
+
71
+ print("✅ Parameters updated!")
72
+ except ValueError as e:
73
+ print(f"❌ Invalid input: {e}")
74
+ continue
75
+
76
+ if not prompt:
77
+ print("⚠️ Empty prompt, try again")
78
+ continue
79
+
80
+ # Tokenize
81
+ token_ids = text_to_token_ids(prompt, tokenizer)
82
+ print(f"📝 Input tokens: {token_ids.shape[1]}")
83
+
84
+ # Generate
85
+ print("🤖 Generating...", end='', flush=True)
86
+ if use_sampling:
87
+ generated_ids = generate_text_with_sampling(
88
+ model=model,
89
+ idx=token_ids,
90
+ max_new_tokens=max_tokens,
91
+ context_size=args.max_seq_len,
92
+ temperature=temperature,
93
+ top_k=top_k
94
+ )
95
+ else:
96
+ generated_ids = generate_text_simple(
97
+ model=model,
98
+ idx=token_ids,
99
+ max_new_tokens=max_tokens,
100
+ context_size=args.max_seq_len
101
+ )
102
+
103
+ # Decode
104
+ generated_text = token_ids_to_text(generated_ids, tokenizer)
105
+ print(f"\r🤖 Generated ({generated_ids.shape[1]} tokens):")
106
+ print(f"\n{generated_text}\n")
107
+
108
+ except KeyboardInterrupt:
109
+ print("\n\n👋 Interrupted. Goodbye!")
110
+ break
111
+ except Exception as e:
112
+ print(f"\n❌ Error: {e}")
113
+ import traceback
114
+ traceback.print_exc()
115
+
116
+
117
+ def batch_generation(model, tokenizer, args, prompts):
118
+ """Generate text for a list of prompts."""
119
+ print("\n" + "="*60)
120
+ print("📋 BATCH GENERATION MODE")
121
+ print("="*60 + "\n")
122
+
123
+ for i, prompt in enumerate(prompts, 1):
124
+ print(f"\n--- Prompt {i}/{len(prompts)} ---")
125
+ print(f"Input: {prompt}")
126
+
127
+ token_ids = text_to_token_ids(prompt, tokenizer)
128
+
129
+ # Generate with sampling
130
+ generated_ids = generate_text_with_sampling(
131
+ model=model,
132
+ idx=token_ids,
133
+ max_new_tokens=50,
134
+ context_size=args.max_seq_len,
135
+ temperature=0.8,
136
+ top_k=50
137
+ )
138
+
139
+ generated_text = token_ids_to_text(generated_ids, tokenizer)
140
+ print(f"Output: {generated_text}\n")
141
+
142
+
143
+ def main():
144
+ # Parse command line arguments
145
+ if len(sys.argv) < 2:
146
+ print("Usage: python test_model.py <checkpoint_path> [--interactive] [--prompts \"prompt1\" \"prompt2\" ...]")
147
+ print("\nExample:")
148
+ print(" python test_model.py checkpoints/step_55000_expert_2.pt --interactive")
149
+ print(" python test_model.py checkpoints/step_55000_expert_2.pt --prompts \"Merhaba\" \"Yapay zeka\"")
150
+ sys.exit(1)
151
+
152
+ checkpoint_path = sys.argv[1]
153
+ interactive_mode = '--interactive' in sys.argv or '-i' in sys.argv
154
+
155
+ # Extract prompts from command line
156
+ custom_prompts = []
157
+ if '--prompts' in sys.argv:
158
+ idx = sys.argv.index('--prompts')
159
+ custom_prompts = [arg for arg in sys.argv[idx+1:] if not arg.startswith('--')]
160
+
161
+ print("="*60)
162
+ print("🧠 ismAIl Model Testing Script")
163
+ print("="*60)
164
+
165
+ # Load config
166
+ config_path = Path("config.json")
167
+ if config_path.exists():
168
+ with open(config_path) as f:
169
+ config = json.load(f)
170
+ print(f"✅ Loaded config from {config_path}")
171
+ args = ModelArgs(**config["model"])
172
+ else:
173
+ print("❌ config.json not found!")
174
+ sys.exit(1)
175
+
176
+ # Initialize tokenizer
177
+ tokenizer_name = getattr(args, "tokenizer_name", "gpt2")
178
+ use_turkish = tokenizer_name.lower() == "turkish"
179
+
180
+ tokenizer = get_tokenizer(
181
+ use_turkish=use_turkish,
182
+ tokenizer_name="gpt2" if use_turkish else tokenizer_name
183
+ )
184
+
185
+ # Update vocab size if using Turkish tokenizer
186
+ if use_turkish:
187
+ from data import TurkishTokenizerWrapper
188
+ if isinstance(tokenizer, TurkishTokenizerWrapper):
189
+ if args.vocab_size != tokenizer.n_vocab:
190
+ print(f"⚠️ Updating vocab_size: {args.vocab_size:,} -> {tokenizer.n_vocab:,}")
191
+ args.vocab_size = tokenizer.n_vocab
192
+
193
+ # Initialize model
194
+ print("\n🚀 Initializing model...")
195
+ model = ismail(args)
196
+
197
+ # Load checkpoint
198
+ checkpoint_file = Path(checkpoint_path)
199
+ if checkpoint_file.exists():
200
+ load_checkpoint(model, checkpoint_file)
201
+ else:
202
+ print(f"❌ Checkpoint not found: {checkpoint_path}")
203
+ sys.exit(1)
204
+
205
+ model.eval()
206
+
207
+ # Run appropriate mode
208
+ if interactive_mode:
209
+ interactive_generation(model, tokenizer, args)
210
+ elif custom_prompts:
211
+ batch_generation(model, tokenizer, args, custom_prompts)
212
+ else:
213
+ # Default: use some Turkish prompts
214
+ default_prompts = [
215
+ "Merhaba, ben",
216
+ "Yapay zekanın geleceği",
217
+ "Bir varmış bir yokmuş",
218
+ "Türkiye'nin başkenti"
219
+ ]
220
+ batch_generation(model, tokenizer, args, default_prompts)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()