lit69 commited on
Commit
20ee890
·
verified ·
1 Parent(s): ea08ed0

Delete chat_interface.py

Browse files
Files changed (1) hide show
  1. chat_interface.py +0 -440
chat_interface.py DELETED
@@ -1,440 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import json
6
- import argparse
7
- import sys
8
- import sentencepiece as spm
9
- import math
10
- from dataclasses import dataclass
11
-
12
- # --- Define the CORRECT Model Architecture (copied from train_llm.py) ---
13
- @dataclass
14
- class ModelConfig:
15
- vocab_size: int = 32000
16
- hidden_size: int = 512
17
- num_layers: int = 8
18
- num_attention_heads: int = 8
19
- num_key_value_heads: int = 2
20
- intermediate_size: int = 1365
21
- max_position_embeddings: int = 2048
22
- rms_norm_eps: float = 1e-6
23
- rope_theta: float = 10000.0
24
-
25
- class RMSNorm(nn.Module):
26
- def __init__(self, hidden_size, eps=1e-6):
27
- super().__init__()
28
- self.weight = nn.Parameter(torch.ones(hidden_size))
29
- self.variance_epsilon = eps
30
-
31
- def forward(self, hidden_states):
32
- input_dtype = hidden_states.dtype
33
- hidden_states = hidden_states.to(torch.float32)
34
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
35
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
36
- return self.weight * hidden_states.to(input_dtype)
37
-
38
- class RotaryEmbedding(nn.Module):
39
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
40
- super().__init__()
41
- self.dim = dim
42
- self.max_position_embeddings = max_position_embeddings
43
- self.base = base
44
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
45
- self.register_buffer("inv_freq", inv_freq, persistent=False)
46
-
47
- def forward(self, x, seq_len=None):
48
- if seq_len is None:
49
- seq_len = x.shape[-2]
50
- t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
51
- freqs = torch.outer(t, self.inv_freq)
52
- emb = torch.cat((freqs, freqs), dim=-1)
53
- cos = emb.cos()
54
- sin = emb.sin()
55
- return cos, sin
56
-
57
- def rotate_half(x):
58
- x1 = x[..., : x.shape[-1] // 2]
59
- x2 = x[..., x.shape[-1] // 2 :]
60
- return torch.cat((-x2, x1), dim=-1)
61
-
62
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
63
- if position_ids is not None:
64
- cos = cos[position_ids].unsqueeze(1)
65
- sin = sin[position_ids].unsqueeze(1)
66
- else:
67
- cos = cos[:q.shape[-2]].unsqueeze(0).unsqueeze(0)
68
- sin = sin[:q.shape[-2]].unsqueeze(0).unsqueeze(0)
69
-
70
- q_embed = (q * cos) + (rotate_half(q) * sin)
71
- k_embed = (k * cos) + (rotate_half(k) * sin)
72
- return q_embed, k_embed
73
-
74
- class SwiGLU(nn.Module):
75
- def __init__(self, hidden_size, intermediate_size):
76
- super().__init__()
77
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
78
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
79
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
80
-
81
- def forward(self, x):
82
- gate = self.gate_proj(x)
83
- up = self.up_proj(x)
84
- return self.down_proj(F.silu(gate) * up)
85
-
86
- class GroupedQueryAttention(nn.Module):
87
- def __init__(self, config):
88
- super().__init__()
89
- self.hidden_size = config.hidden_size
90
- self.num_heads = config.num_attention_heads
91
- self.num_key_value_heads = config.num_key_value_heads
92
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
93
- self.head_dim = self.hidden_size // self.num_heads
94
- self.max_position_embeddings = config.max_position_embeddings
95
-
96
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
97
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
98
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
99
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
100
-
101
- self.rotary_emb = RotaryEmbedding(
102
- self.head_dim,
103
- max_position_embeddings=self.max_position_embeddings,
104
- base=config.rope_theta,
105
- )
106
-
107
- def forward(self, hidden_states, attention_mask=None, position_ids=None):
108
- bsz, q_len, _ = hidden_states.size()
109
-
110
- query_states = self.q_proj(hidden_states)
111
- key_states = self.k_proj(hidden_states)
112
- value_states = self.v_proj(hidden_states)
113
-
114
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
115
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
116
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
117
-
118
- cos, sin = self.rotary_emb(value_states, seq_len=q_len)
119
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
120
-
121
- # Repeat k/v heads if n_kv_heads < n_heads
122
- key_states = torch.repeat_interleave(key_states, repeats=self.num_key_value_groups, dim=1)
123
- value_states = torch.repeat_interleave(value_states, repeats=self.num_key_value_groups, dim=1)
124
-
125
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
126
-
127
- if attention_mask is not None:
128
- # Convert from [batch_size, seq_len] to [batch_size, 1, 1, seq_len]
129
- expanded_mask = attention_mask[:, None, None, :].to(attn_weights.dtype)
130
- expanded_mask = (1.0 - expanded_mask) * torch.finfo(attn_weights.dtype).min
131
- attn_weights = attn_weights + expanded_mask
132
-
133
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
134
- attn_output = torch.matmul(attn_weights, value_states)
135
-
136
- attn_output = attn_output.transpose(1, 2).contiguous()
137
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
138
- attn_output = self.o_proj(attn_output)
139
-
140
- return attn_output
141
-
142
- class TransformerBlock(nn.Module):
143
- def __init__(self, config):
144
- super().__init__()
145
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
146
- self.self_attn = GroupedQueryAttention(config)
147
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
148
- self.mlp = SwiGLU(config.hidden_size, config.intermediate_size)
149
-
150
- def forward(self, hidden_states, attention_mask=None, position_ids=None):
151
- residual = hidden_states
152
- hidden_states = self.input_layernorm(hidden_states)
153
- hidden_states = self.self_attn(hidden_states, attention_mask, position_ids)
154
- hidden_states = residual + hidden_states
155
-
156
- residual = hidden_states
157
- hidden_states = self.post_attention_layernorm(hidden_states)
158
- hidden_states = self.mlp(hidden_states)
159
- hidden_states = residual + hidden_states
160
-
161
- return hidden_states
162
-
163
- class LLMModel(nn.Module): # REPLACED CustomTransformer with this class
164
- def __init__(self, config):
165
- super().__init__()
166
- self.config = config
167
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
168
- self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
169
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
170
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
171
-
172
- def forward(self, input_ids, attention_mask=None, position_ids=None):
173
- batch_size, seq_length = input_ids.shape
174
-
175
- if position_ids is None:
176
- position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
177
- position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
178
-
179
- # Create causal mask for generation
180
- if attention_mask is None:
181
- # Create a causal mask (lower triangular)
182
- causal_mask = torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device)
183
- causal_mask = torch.triu(causal_mask, diagonal=1)
184
- # [batch_size, 1, seq_len, seq_len]
185
- attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
186
- else:
187
- # If a padding mask is provided, convert it to the format we need
188
- # Assuming attention_mask is [batch_size, seq_len] with 1 for valid, 0 for pad
189
- attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float32).min
190
- attention_mask = attention_mask[:, None, None, :]
191
-
192
- hidden_states = self.embed_tokens(input_ids)
193
-
194
- for layer in self.layers:
195
- hidden_states = layer(hidden_states, attention_mask, position_ids)
196
-
197
- hidden_states = self.norm(hidden_states)
198
- logits = self.lm_head(hidden_states)
199
- return logits
200
-
201
- def generate(self, input_ids, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7):
202
- """Simplified generation logic."""
203
- self.eval()
204
- generated = input_ids.clone()
205
-
206
- for _ in range(max_new_tokens):
207
- # Get logits for the last token
208
- logits = self(generated)[:, -1, :] # shape: [batch_size, vocab_size]
209
-
210
- if do_sample:
211
- # Apply temperature
212
- logits = logits / temperature
213
- probs = torch.softmax(logits, dim=-1)
214
-
215
- # Top-p (nucleus) sampling
216
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
217
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
218
- sorted_indices_to_remove = cumulative_probs > top_p
219
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
220
- sorted_indices_to_remove[..., 0] = 0
221
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
222
- probs = probs.masked_fill(indices_to_remove, 0.0)
223
- probs = probs / probs.sum(dim=-1, keepdim=True)
224
-
225
- next_token = torch.multinomial(probs, num_samples=1)
226
- else:
227
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
228
-
229
- generated = torch.cat([generated, next_token], dim=-1)
230
-
231
- # Stop if EOS is generated
232
- if next_token.item() == 3: # EOS token ID from your tokenizer
233
- break
234
-
235
- return generated
236
- # --- End of Model Architecture ---
237
-
238
- def load_tokenizer(tokenizer_path):
239
- """Load the tokenizer and ensure it has an <UNK> token."""
240
- print(f"Debug: Attempting to load tokenizer from {tokenizer_path}")
241
- if not os.path.exists(tokenizer_path):
242
- print(f"Error: Tokenizer file {tokenizer_path} does not exist")
243
- return None
244
- try:
245
- sp = spm.SentencePieceProcessor()
246
- sp.load(tokenizer_path)
247
- if sp.unk_id() is None:
248
- print("Warning: No <UNK> token in tokenizer. Using default ID 0.")
249
- sp.add_unk_token(0)
250
- print(f"Debug: Tokenizer loaded successfully. Vocab size: {sp.vocab_size()}")
251
- return sp
252
- except Exception as e:
253
- print(f"Error loading tokenizer: {e}")
254
- return None
255
-
256
- def load_model(model_path, config_path, device='cpu'):
257
- """Load the model from checkpoint with detailed debugging."""
258
- print(f"Debug: Attempting to load model from {model_path}")
259
- print(f"Debug: Config path: {config_path}")
260
-
261
- if not os.path.exists(model_path):
262
- print(f"Error: Model file {model_path} does not exist")
263
- return None, None
264
- if not os.path.exists(config_path):
265
- print(f"Warning: Config file {config_path} not found. Using default config.")
266
-
267
- # Load config to get the correct parameters for the model
268
- config_dict = {
269
- 'vocab_size': 32000,
270
- 'hidden_size': 512,
271
- 'num_layers': 8,
272
- 'num_attention_heads': 8,
273
- 'num_key_value_heads': 2,
274
- 'intermediate_size': 1365,
275
- 'max_position_embeddings': 2048,
276
- 'rms_norm_eps': 1e-6,
277
- 'rope_theta': 10000.0
278
- }
279
- try:
280
- if os.path.exists(config_path):
281
- with open(config_path, 'r') as f:
282
- loaded_config = json.load(f)
283
- # Update our config dict with the loaded values
284
- for key in config_dict:
285
- if key in loaded_config:
286
- config_dict[key] = loaded_config[key]
287
- print(f"Debug: Config loaded: {config_dict}")
288
- except Exception as e:
289
- print(f"Warning: Failed to load config.json: {e}. Using default config.")
290
-
291
- # Create a ModelConfig object
292
- config = ModelConfig(**config_dict)
293
-
294
- try:
295
- print("Debug: Initializing LLMModel (correct architecture)")
296
- model = LLMModel(config) # Now using the CORRECT model class
297
- except Exception as e:
298
- print(f"Error initializing model: {e}")
299
- return None, None
300
-
301
- try:
302
- checkpoint = torch.load(model_path, map_location=device)
303
- print(f"Debug: Checkpoint type: {type(checkpoint)}")
304
- if isinstance(checkpoint, dict):
305
- if 'model_state_dict' in checkpoint:
306
- print("Debug: Loading from full checkpoint dict")
307
- model.load_state_dict(checkpoint['model_state_dict'], strict=False)
308
- else:
309
- print("Debug: Loading state dictionary directly")
310
- model.load_state_dict(checkpoint, strict=False)
311
- else:
312
- print("Debug: Loading full model object (not recommended)")
313
- model = checkpoint
314
- model.to(device)
315
- model.eval()
316
- print(f"Debug: Model loaded successfully on {device}")
317
- return model, config
318
- except Exception as e:
319
- print(f"Error loading model checkpoint: {e}")
320
- return None, None
321
-
322
- def preprocess_input(text, tokenizer, max_length=512):
323
- """Preprocess and tokenize input text, handling OOV tokens."""
324
- print(f"Debug: Preprocessing input: {text}")
325
- text = ' '.join(text.strip().split())
326
- if not text:
327
- return None, "Input is empty. Please provide a valid input."
328
-
329
- try:
330
- # Use add_bos=True and add_eos=True as your training likely did
331
- tokens = tokenizer.encode(text, out_type=int, add_bos=True, add_eos=True)
332
- print(f"Debug: Tokenized input: {tokens}")
333
- if len(tokens) > max_length:
334
- # Truncate from the middle or end? Let's truncate from the end, keeping BOS
335
- tokens = tokens[:max_length-1] + [tokenizer.eos_id()]
336
- # Ensure the input is the right length
337
- # For generation, we usually don't pad the input, we just use its actual length
338
- # The model's attention mask will handle the rest.
339
- return torch.tensor([tokens], dtype=torch.long), None
340
- except Exception as e:
341
- print(f"Tokenization error: {e}")
342
- return None, f"Failed to tokenize input: {text}. Please try again."
343
-
344
- def generate_response(model, tokenizer, input_tokens, max_new_tokens=100, device='cpu'):
345
- """Generate a response from the model."""
346
- print(f"Debug: Generating response with input tokens shape: {input_tokens.shape}")
347
- try:
348
- input_tokens = input_tokens.to(device)
349
- output_tokens = model.generate(input_tokens, max_new_tokens=max_new_tokens)
350
- # Decode the entire sequence, then remove the input part
351
- full_sequence = output_tokens[0].tolist()
352
- # Find the EOS token that was originally added during preprocessing
353
- input_length = input_tokens.shape[1]
354
- # The response is the part after the input
355
- response_tokens = full_sequence[input_length:]
356
- response = tokenizer.decode(response_tokens)
357
- print(f"Debug: Generated response: {response}")
358
- return response, None
359
- except Exception as e:
360
- print(f"Inference error: {e}")
361
- return None, "Failed to generate response. Please try again."
362
-
363
- def main():
364
- print("🚀 Initializing CoreX AI Chat Interface...")
365
-
366
- default_checkpoint_path = r"D:\checkpoints"
367
- default_tokenizer_path = r"D:\CoreX\tokenizer\corex_tok.model"
368
-
369
- parser = argparse.ArgumentParser(description="CoreX AI Chat Interface")
370
- parser.add_argument('--model_path', default=default_checkpoint_path, help="Path to model checkpoints")
371
- parser.add_argument('--tokenizer_path', default=default_tokenizer_path, help="Path to tokenizer")
372
- args = parser.parse_args()
373
-
374
- print("📁 Default paths:")
375
- print(f" Model: {args.model_path}")
376
- print(f" Tokenizer: {args.tokenizer_path}")
377
- print("✅ Using default paths...")
378
-
379
- print(f"Loading tokenizer from {args.tokenizer_path}...")
380
- tokenizer = load_tokenizer(args.tokenizer_path)
381
- if tokenizer is None:
382
- print("Failed to load tokenizer. Exiting.")
383
- return
384
-
385
- config_path = os.path.join(args.model_path, "config.json")
386
- model_path = os.path.join(args.model_path, "final_model.pt")
387
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
388
-
389
- print(f"Loading custom model from {args.model_path}...")
390
- model, config = load_model(model_path, config_path, device)
391
- if model is None:
392
- print("Failed to load model. Exiting.")
393
- return
394
-
395
- print(f"Model loaded successfully on {device}")
396
- print("🤖 AI Chat Interface")
397
- print("=" * 50)
398
- print("Type 'quit', 'exit', or 'bye' to end the conversation")
399
- print("Type 'clear' to clear the conversation history")
400
- print("Type 'help' for more commands")
401
- print("=" * 50)
402
-
403
- conversation_history = []
404
-
405
- while True:
406
- user_input = input("\n🧑 You: ").strip()
407
-
408
- if user_input.lower() in ['quit', 'exit', 'bye']:
409
- print("👋 Goodbye!")
410
- break
411
- elif user_input.lower() == 'clear':
412
- conversation_history = []
413
- print("🗑 Conversation history cleared.")
414
- continue
415
- elif user_input.lower() == 'help':
416
- print("Available commands:")
417
- print(" quit/exit/bye: End the conversation")
418
- print(" clear: Clear conversation history")
419
- print(" help: Show this help message")
420
- continue
421
-
422
- input_tokens, error = preprocess_input(user_input, tokenizer)
423
- if error:
424
- print(f"🤖 AI: {error}")
425
- with open("rejected_inputs.log", "a") as log_file:
426
- log_file.write(f"Rejected input: {user_input}\nError: {error}\n")
427
- continue
428
-
429
- conversation_history.append({"role": "user", "content": user_input})
430
-
431
- response, error = generate_response(model, tokenizer, input_tokens, max_new_tokens=100, device=device)
432
- if error:
433
- print(f"🤖 AI: {error}")
434
- continue
435
-
436
- conversation_history.append({"role": "assistant", "content": response})
437
- print(f"\n🤖 AI: {response}")
438
-
439
- if __name__ == "__main__":
440
- main()