akkiisfrommars commited on
Commit
de2d09a
·
verified ·
1 Parent(s): b6bbee5

Upload chat1.py

Browse files
Files changed (1) hide show
  1. chat1.py +1122 -0
chat1.py ADDED
@@ -0,0 +1,1122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat interface for CosmicFish model directly from Hugging Face Hub.
3
+ Automatically downloads and caches the model from HF Hub.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import argparse
10
+ import torch
11
+ import numpy as np
12
+ from termcolor import colored
13
+ import logging
14
+ import readline # Enables arrow key history in terminal input
15
+ import re
16
+ import textwrap
17
+ import random
18
+ from collections import defaultdict
19
+ import json
20
+
21
+ # Required imports for HF Hub
22
+ try:
23
+ from transformers import GPT2Tokenizer
24
+ from huggingface_hub import hf_hub_download, snapshot_download
25
+ HF_AVAILABLE = True
26
+ except ImportError:
27
+ HF_AVAILABLE = False
28
+ print("❌ Required libraries not available.")
29
+ print("Install with: pip install transformers huggingface-hub")
30
+ sys.exit(1)
31
+
32
+ # Set up logging
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(levelname)s - %(message)s',
36
+ handlers=[logging.StreamHandler(sys.stdout)]
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Default model repository
41
+ DEFAULT_MODEL_REPO = "Mistyoz-AI/CosmicFish-120M"
42
+
43
+ # Default prompt template
44
+ DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
45
+
46
+
47
+ class CosmicConfig:
48
+ """Configuration class for CosmicFish."""
49
+
50
+ def __init__(self,
51
+ vocab_size=50257,
52
+ block_size=512,
53
+ n_layer=12,
54
+ n_head=16,
55
+ n_embd=704,
56
+ bias=True,
57
+ dropout=0.0, # Always 0 for inference
58
+ n_query_groups=4,
59
+ eps=1e-6,
60
+ use_rotary=True,
61
+ use_swiglu=True,
62
+ use_qk_norm=False,
63
+ use_gqa=True):
64
+ self.vocab_size = vocab_size
65
+ self.block_size = block_size
66
+ self.n_layer = n_layer
67
+ self.n_head = n_head
68
+ self.n_embd = n_embd
69
+ self.bias = bias
70
+ self.dropout = dropout
71
+ self.eps = eps
72
+ self.use_rotary = use_rotary
73
+ self.use_swiglu = use_swiglu
74
+ self.use_qk_norm = use_qk_norm
75
+ self.use_gqa = use_gqa
76
+ self.n_query_groups = n_query_groups if use_gqa else n_head
77
+ # Ensure n_head is divisible by n_query_groups
78
+ assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
79
+
80
+
81
+ class RMSNorm(torch.nn.Module):
82
+ """Root Mean Square Normalization"""
83
+
84
+ def __init__(self, dim, eps=1e-6):
85
+ super().__init__()
86
+ self.eps = eps
87
+ self.weight = torch.nn.Parameter(torch.ones(dim))
88
+
89
+ def forward(self, x):
90
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
91
+ return self.weight * (x / rms)
92
+
93
+
94
+ def precompute_freqs_cis(dim, end, theta=10000.0):
95
+ """Precompute the frequency tensor for complex exponentials (cis)"""
96
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
97
+ t = torch.arange(end, device=freqs.device)
98
+ freqs = torch.outer(t, freqs)
99
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
100
+ return freqs_cis
101
+
102
+
103
+ def apply_rotary_emb(xq, xk, freqs_cis):
104
+ """Apply rotary embeddings to input tensors"""
105
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
106
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
107
+
108
+ seq_len = xq_.size(2)
109
+ if freqs_cis.size(0) < seq_len:
110
+ raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
111
+
112
+ freqs_cis_seq = freqs_cis[:seq_len]
113
+ xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
114
+ xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
115
+
116
+ return xq_out.type_as(xq), xk_out.type_as(xk)
117
+
118
+
119
+ class GroupedQueryAttention(torch.nn.Module):
120
+ """Grouped Query Attention (GQA) implementation"""
121
+
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ assert config.n_embd % config.n_head == 0
125
+
126
+ head_dim = config.n_embd // config.n_head
127
+ self.head_dim = head_dim
128
+ self.n_head = config.n_head
129
+ self.n_embd = config.n_embd
130
+ self.n_query_groups = config.n_query_groups
131
+
132
+ self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
133
+ qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
134
+
135
+ self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
136
+ self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
137
+
138
+ # Flash attention support
139
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
140
+ if not self.flash:
141
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
142
+ .view(1, 1, config.block_size, config.block_size))
143
+
144
+ # Query-key normalization
145
+ self.qk_norm = getattr(config, 'use_qk_norm', False)
146
+ if self.qk_norm:
147
+ self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
148
+ self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
149
+
150
+ def forward(self, x, freqs_cis=None):
151
+ B, T, C = x.size()
152
+ qkv = self.c_attn(x)
153
+ head_dim = C // self.n_head
154
+
155
+ q_size = self.n_head * head_dim
156
+ k_size = self.kv_heads * head_dim
157
+ v_size = self.kv_heads * head_dim
158
+
159
+ q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
160
+
161
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
162
+ k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
163
+ v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
164
+
165
+ # Repeat k and v if needed for GQA
166
+ if self.kv_heads < self.n_head:
167
+ repeats = self.n_head // self.kv_heads
168
+ k = k.repeat_interleave(repeats, dim=1)
169
+ v = v.repeat_interleave(repeats, dim=1)
170
+
171
+ # Apply rotary embeddings
172
+ if freqs_cis is not None:
173
+ q, k = apply_rotary_emb(q, k, freqs_cis)
174
+
175
+ # Apply query-key normalization
176
+ if self.qk_norm:
177
+ q = self.q_norm(q)
178
+ k = self.k_norm(k)
179
+
180
+ # Compute attention
181
+ if self.flash:
182
+ y = torch.nn.functional.scaled_dot_product_attention(
183
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
184
+ )
185
+ else:
186
+ att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)))
187
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
188
+ att = torch.nn.functional.softmax(att, dim=-1)
189
+ y = att @ v
190
+
191
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
192
+ y = self.c_proj(y)
193
+ return y
194
+
195
+
196
+ class Block(torch.nn.Module):
197
+ """Transformer block"""
198
+
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
202
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
203
+ self.attn = GroupedQueryAttention(config)
204
+
205
+ # MLP implementation based on configuration
206
+ if config.use_swiglu:
207
+ # SwiGLU MLP
208
+ self.mlp = torch.nn.ModuleDict(dict(
209
+ gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
210
+ up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
211
+ down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
212
+ act=torch.nn.SiLU(),
213
+ ))
214
+ m = self.mlp
215
+ self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
216
+ else:
217
+ # Traditional MLP
218
+ self.mlp = torch.nn.ModuleDict(dict(
219
+ c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
220
+ c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
221
+ act=torch.nn.GELU(),
222
+ ))
223
+ m = self.mlp
224
+ self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
225
+
226
+ def forward(self, x, freqs_cis=None):
227
+ x = x + self.attn(self.ln_1(x), freqs_cis)
228
+ x = x + self.mlpf(self.ln_2(x))
229
+ return x
230
+
231
+
232
+ class CosmicFish(torch.nn.Module):
233
+ """
234
+ CosmicFish model for inference only.
235
+ Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
236
+ """
237
+
238
+ def __init__(self, config):
239
+ super().__init__()
240
+ self.config = config
241
+
242
+ self.transformer = torch.nn.ModuleDict(dict(
243
+ wte=torch.nn.Embedding(config.vocab_size, config.n_embd),
244
+ h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
245
+ ln_f=RMSNorm(config.n_embd, eps=config.eps),
246
+ ))
247
+
248
+ self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
249
+
250
+ # Share weights between embedding and output
251
+ self.transformer.wte.weight = self.lm_head.weight
252
+
253
+ # Precompute rotary embedding frequencies
254
+ if config.use_rotary:
255
+ head_dim = config.n_embd // config.n_head
256
+ self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
257
+ else:
258
+ self.freqs_cis = None
259
+ self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
260
+
261
+ def get_num_params(self, non_embedding=True):
262
+ """Return the number of parameters in the model."""
263
+ n_params = sum(p.numel() for p in self.parameters())
264
+ if non_embedding and hasattr(self.transformer, 'wpe'):
265
+ n_params -= self.transformer.wpe.weight.numel()
266
+ return n_params
267
+
268
+ def forward(self, idx, targets=None):
269
+ """Forward pass through the model."""
270
+ device = idx.device
271
+ b, t = idx.size()
272
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
273
+
274
+ # Get token embeddings
275
+ tok_emb = self.transformer.wte(idx)
276
+
277
+ # Handle positional embeddings
278
+ if self.config.use_rotary:
279
+ x = tok_emb
280
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
281
+ else:
282
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
283
+ pos_emb = self.transformer.wpe(pos)
284
+ x = tok_emb + pos_emb
285
+ freqs_cis = None
286
+
287
+ # Apply transformer blocks
288
+ for block in self.transformer.h:
289
+ x = block(x, freqs_cis)
290
+
291
+ # Apply final normalization
292
+ x = self.transformer.ln_f(x)
293
+
294
+ # Calculate outputs
295
+ if targets is not None:
296
+ logits = self.lm_head(x)
297
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
298
+ else:
299
+ # For inference, only compute logits for the last token
300
+ logits = self.lm_head(x[:, [-1], :])
301
+ loss = None
302
+
303
+ return logits, loss
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Generate text by sampling from the model, token by token.
309
+ """
310
+ for _ in range(max_new_tokens):
311
+ # Crop sequence to block size if needed
312
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
313
+
314
+ # Forward pass
315
+ logits, _ = self(idx_cond)
316
+ logits = logits[:, -1, :] / temperature
317
+
318
+ # Apply top-k sampling
319
+ if top_k is not None:
320
+ v, _ = torch.topk(logits, top_k)
321
+ logits[logits < v[:, [-1]]] = -float('Inf')
322
+
323
+ # Sample next token
324
+ probs = torch.nn.functional.softmax(logits, dim=-1)
325
+ idx_next = torch.multinomial(probs, num_samples=1)
326
+
327
+ # Append to sequence
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx
331
+
332
+
333
+ class RepetitionPenaltyLogitsProcessor:
334
+ """Apply repetition penalty to prevent repeating tokens."""
335
+
336
+ def __init__(self, penalty=1.2):
337
+ self.penalty = penalty
338
+
339
+ def __call__(self, input_ids, scores):
340
+ """Apply repetition penalty to logits where input_ids is already seen."""
341
+ score = torch.gather(scores, 1, input_ids)
342
+ # If score > 0, penalize by dividing; if score < 0, penalize by multiplying
343
+ score = torch.where(score > 0, score / self.penalty, score * self.penalty)
344
+ scores.scatter_(1, input_ids, score)
345
+ return scores
346
+
347
+
348
+ class CosmicFishChatSession:
349
+ """Chat session for CosmicFish model from Hugging Face Hub."""
350
+
351
+ def __init__(self, model, tokenizer, config):
352
+ """Initialize chat session with model and configuration."""
353
+ self.model = model
354
+ self.tokenizer = tokenizer
355
+ self.config = config
356
+ self.device = next(model.parameters()).device
357
+ self.history = []
358
+ self.history_tokens = []
359
+ self.max_history_tokens = config.max_history_tokens
360
+ self.prompt_template = config.prompt_template
361
+ self.human_prefix = config.human_prefix
362
+ self.assistant_prefix = config.assistant_prefix
363
+ self.end_of_turn = config.end_of_turn
364
+ self.block_size = config.block_size
365
+ self.debug_mode = config.debug_mode
366
+ self.repetition_penalty = config.repetition_penalty
367
+ self.min_tokens_to_generate = config.min_tokens_to_generate
368
+ self.max_retries = 20
369
+
370
+ self.fallback_responses = [
371
+ "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?",
372
+ "That's a topic I can provide information about. What specific aspects would you like to know?",
373
+ "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.",
374
+ "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?",
375
+ "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?"
376
+ ]
377
+
378
+ self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?"
379
+
380
+ # For token counting
381
+ self.total_prompt_tokens = 0
382
+ self.total_generated_tokens = 0
383
+
384
+ # End markers for live generation
385
+ self.end_markers = [
386
+ f"{self.human_prefix}",
387
+ "Human:",
388
+ "\nHuman:",
389
+ "\nH:",
390
+ "H:",
391
+ "<|endoftext|>",
392
+ "Below is a conversation",
393
+ "\nA:",
394
+ "A:",
395
+ "</s>",
396
+ "User:",
397
+ "\nUser:"
398
+ ]
399
+
400
+ # Print welcome message
401
+ if config.display_welcome:
402
+ self._print_welcome_message()
403
+
404
+ def _print_welcome_message(self):
405
+ """Print a welcome message to the user."""
406
+ welcome_text = f"""
407
+ {'=' * 80}
408
+ Welcome to CosmicFish chat interface (Hugging Face Hub)
409
+
410
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model loaded from HF Hub.
411
+ CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
412
+
413
+ Model: {DEFAULT_MODEL_REPO}
414
+
415
+ Type your prompts and CosmicFish will respond.
416
+
417
+ Special commands:
418
+ - /help: Show this help message
419
+ - /clear: Clear the conversation history
420
+ - /exit or /quit: Exit the chat
421
+ - /stats: Show token usage statistics
422
+ - /save [filename]: Save the conversation
423
+ - /load [filename]: Load a conversation
424
+ - /temp [value]: Set temperature (between 0.1 and 2.0)
425
+ - /penalty [value]: Set repetition penalty (1.0-2.0)
426
+ - /debug: Toggle debug mode
427
+ {'=' * 80}
428
+ """
429
+ print(colored(welcome_text, 'cyan'))
430
+
431
+ def _format_prompt(self, user_input):
432
+ """Format the complete prompt with history and current input."""
433
+ # Start with the template
434
+ formatted_prompt = self.prompt_template
435
+
436
+ # Add conversation history
437
+ for entry in self.history:
438
+ role, text = entry
439
+ if role == "human":
440
+ formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}"
441
+ else: # assistant
442
+ formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}"
443
+
444
+ # Add the current user input
445
+ formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}"
446
+
447
+ return formatted_prompt
448
+
449
+ def _tokenize(self, text):
450
+ """Tokenize text and return token IDs."""
451
+ return self.tokenizer.encode(text)
452
+
453
+ def _update_history(self, user_input, response):
454
+ """Update conversation history."""
455
+ # Add to text history
456
+ self.history.append(("human", user_input))
457
+ self.history.append(("assistant", response))
458
+
459
+ # Update token history for context window management
460
+ user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}")
461
+ response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}")
462
+
463
+ self.history_tokens.extend(user_tokens)
464
+ self.history_tokens.extend(response_tokens)
465
+
466
+ # Track token usage
467
+ self.total_prompt_tokens += len(user_tokens)
468
+ self.total_generated_tokens += len(response_tokens)
469
+
470
+ # Trim history if it gets too long
471
+ self._trim_history_if_needed()
472
+
473
+ def _trim_history_if_needed(self):
474
+ """Trim history to fit within the context window."""
475
+ if len(self.history_tokens) > self.max_history_tokens:
476
+ # Remove oldest turns until we're under the limit
477
+ while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2:
478
+ # Remove oldest human and assistant turn
479
+ self.history = self.history[2:]
480
+
481
+ # Find token boundary for the removed turns
482
+ user_turn = self.history[0][1]
483
+ assistant_turn = self.history[1][1]
484
+ user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}"))
485
+ assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}"))
486
+
487
+ # Update token history
488
+ self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:]
489
+
490
+ def _should_stop_generation(self, text):
491
+ """Check if generation should stop based on end markers."""
492
+ for marker in self.end_markers:
493
+ if marker in text:
494
+ return True
495
+ return False
496
+
497
+ def _clean_token_text(self, text):
498
+ """Clean token text by fixing encoding issues."""
499
+ # Fix the specific issue with �� -> '
500
+ text = text.replace('��', "'")
501
+ return text
502
+
503
+ def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
504
+ """Custom generate function with repetition penalty and optional live generation."""
505
+ model = self.model
506
+ device = self.device
507
+
508
+ # Ensure model is in eval mode
509
+ model.eval()
510
+
511
+ # Initialize sequence with input_ids
512
+ generated = input_ids.clone()
513
+
514
+ # Initialize live text buffer
515
+ live_buffer = ""
516
+
517
+ # Create repetition penalty processor
518
+ rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty)
519
+
520
+ # Counter for generated tokens
521
+ tokens_generated = 0
522
+ min_tokens = self.min_tokens_to_generate
523
+
524
+ # EOT token ID
525
+ eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256
526
+
527
+ # Generate tokens one at a time
528
+ for _ in range(max_new_tokens):
529
+ # Get only the last block_size tokens if context is too long
530
+ if generated.size(1) > self.block_size:
531
+ context = generated[:, -self.block_size:]
532
+ else:
533
+ context = generated
534
+
535
+ # Forward pass for next token prediction
536
+ with torch.no_grad():
537
+ logits, _ = model(context)
538
+
539
+ # Get logits for the next token (last position)
540
+ next_token_logits = logits[:, -1, :]
541
+
542
+ # Apply temperature
543
+ next_token_logits = next_token_logits / temperature
544
+
545
+ # Apply repetition penalty
546
+ if penalty > 1.0:
547
+ next_token_logits = rep_processor(context, next_token_logits)
548
+
549
+ # Optional top-k sampling
550
+ if top_k is not None:
551
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
552
+ next_token_logits[indices_to_remove] = float('-inf')
553
+
554
+ # Convert logits to probabilities
555
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
556
+
557
+ # Sample next token
558
+ next_token = torch.multinomial(probs, num_samples=1)
559
+
560
+ # Check if the next token is EOT and break immediately if so
561
+ if next_token.item() == eot_token_id:
562
+ if live:
563
+ yield "", live_buffer, True
564
+ break
565
+
566
+ # Append next token to generated sequence
567
+ generated = torch.cat((generated, next_token), dim=1)
568
+ tokens_generated += 1
569
+
570
+ # If live generation, decode and yield the token
571
+ if live:
572
+ # Decode the next token
573
+ next_token_text = self.tokenizer.decode([next_token.item()])
574
+ # Clean the token text to fix encoding issues
575
+ next_token_text = self._clean_token_text(next_token_text)
576
+ live_buffer += next_token_text
577
+
578
+ # Check if we've hit an end marker in the buffer
579
+ eot_marker_pos = live_buffer.find("<|endoftext|>")
580
+ if eot_marker_pos != -1:
581
+ # Cut off at the EOT marker
582
+ live_buffer = live_buffer[:eot_marker_pos]
583
+ yield "", live_buffer, True
584
+ break
585
+
586
+ # Check other end markers
587
+ should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer)
588
+ yield next_token_text, live_buffer, should_stop
589
+
590
+ if should_stop:
591
+ break
592
+
593
+ # For non-live generation, check if we should stop
594
+ elif tokens_generated >= min_tokens:
595
+ # Check for end markers in the recent generated tokens
596
+ recent_text = self.tokenizer.decode(generated[0, -20:].tolist())
597
+ if self._should_stop_generation(recent_text):
598
+ break
599
+
600
+ # Check if we generated any tokens at all
601
+ if tokens_generated == 0 and not live:
602
+ if self.debug_mode:
603
+ print(colored("\n[No tokens generated in this attempt]", "red"))
604
+ return None
605
+
606
+ if not live:
607
+ return generated
608
+
609
+ def generate_response(self, user_input):
610
+ """Generate a response to the user input."""
611
+ # Format the complete prompt
612
+ prompt = self._format_prompt(user_input)
613
+
614
+ # Tokenize the prompt
615
+ input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device)
616
+
617
+ # Ensure we don't exceed the model's context length
618
+ if input_ids.size(1) > self.block_size:
619
+ # If too long, keep the beginning part with the instruction template and trim the middle
620
+ instruction_tokens = self._tokenize(self.prompt_template)
621
+ # Keep the instruction and the most recent conversation that will fit
622
+ keep_from_beginning = len(instruction_tokens)
623
+ keep_from_end = self.block_size - keep_from_beginning
624
+
625
+ # Combine beginning and end, ensuring we don't exceed array bounds
626
+ if keep_from_end < 0:
627
+ # If instruction alone is too long, trim it (shouldn't happen with reasonable templates)
628
+ input_ids = input_ids[:, :self.block_size]
629
+ else:
630
+ # Keep instruction and most recent conversation
631
+ input_ids = torch.cat([
632
+ input_ids[:, :keep_from_beginning],
633
+ input_ids[:, -(keep_from_end):]
634
+ ], dim=1)
635
+
636
+ # Track generation start time
637
+ start_time = time.time()
638
+
639
+ # Always use live generation
640
+ return self._generate_live_response(input_ids, user_input, start_time)
641
+
642
+ def _generate_live_response(self, input_ids, user_input, start_time):
643
+ """Generate response with live token-by-token output."""
644
+ # Initialize for live generation
645
+ live_text = ""
646
+ tokens_generated = 0
647
+ retry_count = 0
648
+
649
+ # Keep trying until we get a valid response or exhaust retries
650
+ while retry_count <= self.max_retries:
651
+ if retry_count > 0:
652
+ # Calculate temperature for this retry
653
+ if retry_count % 2 == 0:
654
+ # Even retries: increase temperature
655
+ temp_adjustment = min(0.2 * (retry_count // 2), 0.8)
656
+ current_temp = min(self.config.temperature + temp_adjustment, 1.8)
657
+ else:
658
+ # Odd retries: decrease temperature
659
+ temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4)
660
+ current_temp = max(self.config.temperature - temp_adjustment, 0.2)
661
+
662
+ if self.debug_mode:
663
+ print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow"))
664
+ else:
665
+ current_temp = self.config.temperature
666
+
667
+ # Reset for this attempt
668
+ live_text = ""
669
+ tokens_generated = 0
670
+ generation_failed = False
671
+
672
+ # Try to generate with current settings
673
+ try:
674
+ # Generate with live output
675
+ for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty(
676
+ input_ids,
677
+ max_new_tokens=self.config.max_new_tokens,
678
+ temperature=current_temp,
679
+ top_k=self.config.top_k,
680
+ penalty=self.repetition_penalty,
681
+ live=True
682
+ ):
683
+ # If we should stop but there's a token, this is the last one
684
+ if should_stop:
685
+ # Update with the final clean buffer (will have EOT removed if present)
686
+ live_text = live_buffer
687
+ break
688
+
689
+ # Otherwise add the token and continue
690
+ if token_text:
691
+ live_text += token_text
692
+ tokens_generated += 1
693
+ yield token_text, live_text, False
694
+
695
+ # Check if we got a valid response
696
+ if not live_text or len(live_text.strip()) < 10:
697
+ if self.debug_mode:
698
+ print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow"))
699
+ generation_failed = True
700
+ retry_count += 1
701
+ # Clear any partial output
702
+ if retry_count <= self.max_retries:
703
+ print("\r" + " " * 80 + "\r", end="") # Clear the line
704
+ else:
705
+ # We got a valid response, stop retrying
706
+ break
707
+
708
+ except Exception as e:
709
+ if self.debug_mode:
710
+ print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red"))
711
+ generation_failed = True
712
+ retry_count += 1
713
+
714
+ # If we still failed after all retries, use the failure message
715
+ if generation_failed or not live_text or len(live_text.strip()) < 10:
716
+ live_text = self.generation_failure_message
717
+ if self.debug_mode:
718
+ print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red"))
719
+
720
+ # Calculate time taken and metrics
721
+ time_taken = time.time() - start_time
722
+ tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0
723
+
724
+ # Update history
725
+ self._update_history(user_input, live_text)
726
+
727
+ # Log generation stats
728
+ logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)")
729
+
730
+ # Final yield of the complete response
731
+ yield "", live_text, True
732
+
733
+ def execute_command(self, command):
734
+ """Execute a special command prefixed with /."""
735
+ command = command.strip()
736
+
737
+ if command == '/help':
738
+ self._print_welcome_message()
739
+ return True
740
+
741
+ elif command == '/clear':
742
+ self.history = []
743
+ self.history_tokens = []
744
+ print(colored("Conversation history cleared.", 'yellow'))
745
+ return True
746
+
747
+ elif command in ['/exit', '/quit']:
748
+ print(colored("Goodbye!", 'cyan'))
749
+ return False # Signal to exit the chat loop
750
+
751
+ elif command == '/stats':
752
+ prompt_tokens = self.total_prompt_tokens
753
+ generated_tokens = self.total_generated_tokens
754
+ total_tokens = prompt_tokens + generated_tokens
755
+
756
+ stats = f"""
757
+ Token usage statistics:
758
+ - Prompt tokens: {prompt_tokens}
759
+ - Generated tokens: {generated_tokens}
760
+ - Total tokens: {total_tokens}
761
+ - Current history length: {len(self.history_tokens)} tokens
762
+ - Current repetition penalty: {self.repetition_penalty}
763
+ - Current temperature: {self.config.temperature}
764
+ - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
765
+ - Source: {DEFAULT_MODEL_REPO}
766
+ """
767
+ print(colored(stats, 'yellow'))
768
+ return True
769
+
770
+ elif command == '/debug':
771
+ self.debug_mode = not self.debug_mode
772
+ self.config.debug_mode = self.debug_mode # Sync with config
773
+ mode = "enabled" if self.debug_mode else "disabled"
774
+ print(colored(f"Debug mode {mode}", 'yellow'))
775
+ return True
776
+
777
+ elif command.startswith('/penalty '):
778
+ try:
779
+ penalty = float(command[9:].strip())
780
+ if 1.0 <= penalty <= 2.0:
781
+ self.repetition_penalty = penalty
782
+ print(colored(f"Repetition penalty set to {penalty}", 'yellow'))
783
+ else:
784
+ print(colored("Repetition penalty should be between 1.0 and 2.0", 'red'))
785
+ except ValueError:
786
+ print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red'))
787
+ return True
788
+
789
+ elif command.startswith('/temp '):
790
+ try:
791
+ temp = float(command[6:].strip())
792
+ if 0.1 <= temp <= 2.0:
793
+ self.config.temperature = temp
794
+ print(colored(f"Temperature set to {temp}", 'yellow'))
795
+ else:
796
+ print(colored("Temperature should be between 0.1 and 2.0", 'red'))
797
+ except ValueError:
798
+ print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red'))
799
+ return True
800
+
801
+ elif command.startswith('/save '):
802
+ filename = command[6:].strip()
803
+ if not filename:
804
+ print(colored("Please specify a filename: /save <filename>", 'red'))
805
+ return True
806
+
807
+ try:
808
+ # Create conversations directory if it doesn't exist
809
+ os.makedirs('conversations', exist_ok=True)
810
+
811
+ # Add .txt extension if not present
812
+ if not filename.endswith('.txt'):
813
+ filename += '.txt'
814
+
815
+ filepath = os.path.join('conversations', filename)
816
+
817
+ with open(filepath, 'w', encoding='utf-8') as f:
818
+ for entry in self.history:
819
+ role, text = entry
820
+ prefix = self.human_prefix if role == "human" else self.assistant_prefix
821
+ f.write(f"{prefix}{text}{self.end_of_turn}")
822
+
823
+ print(colored(f"Conversation saved to {filepath}", 'green'))
824
+
825
+ except Exception as e:
826
+ print(colored(f"Error saving conversation: {str(e)}", 'red'))
827
+
828
+ return True
829
+
830
+ elif command.startswith('/load '):
831
+ filename = command[6:].strip()
832
+ if not filename:
833
+ print(colored("Please specify a filename: /load <filename>", 'red'))
834
+ return True
835
+
836
+ try:
837
+ # Add .txt extension if not present
838
+ if not filename.endswith('.txt'):
839
+ filename += '.txt'
840
+
841
+ filepath = os.path.join('conversations', filename)
842
+
843
+ if not os.path.exists(filepath):
844
+ print(colored(f"File not found: {filepath}", 'red'))
845
+ return True
846
+
847
+ with open(filepath, 'r', encoding='utf-8') as f:
848
+ content = f.read()
849
+
850
+ # Parse conversation turns
851
+ self.history = []
852
+ self.history_tokens = []
853
+
854
+ # Split by end of turn marker
855
+ turns = content.split(self.end_of_turn)
856
+ for turn in turns:
857
+ turn = turn.strip()
858
+ if not turn:
859
+ continue
860
+
861
+ if turn.startswith(self.human_prefix):
862
+ text = turn[len(self.human_prefix):].strip()
863
+ self.history.append(("human", text))
864
+ elif turn.startswith(self.assistant_prefix):
865
+ text = turn[len(self.assistant_prefix):].strip()
866
+ self.history.append(("assistant", text))
867
+
868
+ # Recalculate token counts
869
+ self.history_tokens = []
870
+ for entry in self.history:
871
+ role, text = entry
872
+ if role == "human":
873
+ self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}"))
874
+ else:
875
+ self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}"))
876
+
877
+ print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green'))
878
+
879
+ # Print the conversation
880
+ for i in range(0, len(self.history), 2):
881
+ if i < len(self.history):
882
+ user_text = self.history[i][1]
883
+ print(colored(f"\nYou: {user_text}", 'green'))
884
+
885
+ if i + 1 < len(self.history):
886
+ assistant_text = self.history[i + 1][1]
887
+ print(colored("CosmicFish: ", 'blue'), end="")
888
+ for line in assistant_text.split('\n'):
889
+ wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else ['']
890
+ for wrapped_line in wrapped_lines:
891
+ print(wrapped_line)
892
+
893
+ except Exception as e:
894
+ print(colored(f"Error loading conversation: {str(e)}", 'red'))
895
+
896
+ return True
897
+
898
+ else:
899
+ print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red'))
900
+ return True
901
+
902
+
903
+ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
904
+ """Download and load CosmicFish model from Hugging Face Hub"""
905
+ print(colored(f"🤗 Downloading CosmicFish from Hugging Face Hub: {model_repo}", "cyan"))
906
+
907
+ try:
908
+ # Download the model files to local cache
909
+ print("📥 Downloading model files...")
910
+ cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
911
+ print(f"✅ Model cached at: {cache_dir}")
912
+
913
+ # Load config
914
+ config_path = os.path.join(cache_dir, "config.json")
915
+ with open(config_path, "r") as f:
916
+ config_dict = json.load(f)
917
+
918
+ # Create CosmicConfig
919
+ config = CosmicConfig(
920
+ vocab_size=config_dict["vocab_size"],
921
+ block_size=config_dict["block_size"],
922
+ n_layer=config_dict["n_layer"],
923
+ n_head=config_dict["n_head"],
924
+ n_embd=config_dict["n_embd"],
925
+ bias=config_dict["bias"],
926
+ dropout=0.0, # Set to 0 for inference
927
+ eps=config_dict.get("eps", 1e-6),
928
+ use_rotary=config_dict["use_rotary"],
929
+ use_swiglu=config_dict["use_swiglu"],
930
+ use_gqa=config_dict["use_gqa"],
931
+ n_query_groups=config_dict["n_query_groups"],
932
+ use_qk_norm=config_dict.get("use_qk_norm", False)
933
+ )
934
+
935
+ # Create model
936
+ print("🧠 Creating model...")
937
+ model = CosmicFish(config)
938
+
939
+ # Load weights
940
+ print("⚖️ Loading weights...")
941
+ weights_path = os.path.join(cache_dir, "pytorch_model.bin")
942
+ state_dict = torch.load(weights_path, map_location=device)
943
+ model.load_state_dict(state_dict)
944
+ model.to(device)
945
+ model.eval()
946
+
947
+ print(f"✅ Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
948
+ print(f"🎯 Device: {device}")
949
+ return model, config
950
+
951
+ except Exception as e:
952
+ print(colored(f"❌ Error downloading/loading model: {str(e)}", "red"))
953
+ print(colored("💡 Make sure you have internet connection and the model repo exists", "yellow"))
954
+ sys.exit(1)
955
+
956
+
957
+ def load_tokenizer():
958
+ """Load GPT-2 tokenizer"""
959
+ print("🔤 Loading GPT-2 tokenizer...")
960
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
961
+ print("✅ Tokenizer loaded")
962
+ return tokenizer
963
+
964
+
965
+ def main():
966
+ parser = argparse.ArgumentParser(description="Chat with CosmicFish model from Hugging Face Hub")
967
+
968
+ # Model parameters
969
+ parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
970
+ help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})")
971
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
972
+ help="Device to use (cuda or cpu)")
973
+
974
+ # Generation parameters
975
+ parser.add_argument("--temperature", type=float, default=0.7,
976
+ help="Temperature for sampling (default: 0.7)")
977
+ parser.add_argument("--max_tokens", type=int, default=512,
978
+ help="Maximum number of tokens to generate per response")
979
+ parser.add_argument("--min_tokens", type=int, default=10,
980
+ help="Minimum number of tokens to generate per response")
981
+ parser.add_argument("--top_k", type=int, default=40,
982
+ help="Top-k sampling (0 to disable)")
983
+ parser.add_argument("--repetition_penalty", type=float, default=1.2,
984
+ help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)")
985
+
986
+ # Chat parameters
987
+ parser.add_argument("--human_prefix", type=str, default="Human: ",
988
+ help="Prefix for human messages")
989
+ parser.add_argument("--assistant_prefix", type=str, default="Assistant: ",
990
+ help="Prefix for assistant messages")
991
+ parser.add_argument("--end_of_turn", type=str, default="\n\n",
992
+ help="Delimiter between conversation turns")
993
+ parser.add_argument("--instruction", type=str,
994
+ default=DEFAULT_PROMPT_TEMPLATE,
995
+ help="Instruction prompt to prepend to the conversation")
996
+ parser.add_argument("--max_history", type=int, default=1024,
997
+ help="Maximum number of tokens to keep in history")
998
+
999
+ # UI parameters
1000
+ parser.add_argument("--no_welcome", action="store_true",
1001
+ help="Don't display the welcome message")
1002
+ parser.add_argument("--debug", action="store_true",
1003
+ help="Enable debug mode")
1004
+
1005
+ args = parser.parse_args()
1006
+
1007
+ # Configure device
1008
+ device = args.device
1009
+ if device == "cuda" and not torch.cuda.is_available():
1010
+ print(colored("⚠️ CUDA is not available, falling back to CPU", "yellow"))
1011
+ device = "cpu"
1012
+
1013
+ try:
1014
+ # Download and load the model from HF Hub
1015
+ model, model_config = download_cosmicfish_from_hub(args.model_repo, device)
1016
+
1017
+ # Load tokenizer
1018
+ tokenizer = load_tokenizer()
1019
+
1020
+ # Create a config object with all the necessary parameters
1021
+ class ChatConfig:
1022
+ def __init__(self, args, block_size):
1023
+ self.device = device
1024
+ self.temperature = args.temperature
1025
+ self.max_new_tokens = args.max_tokens
1026
+ self.min_tokens_to_generate = args.min_tokens
1027
+ self.top_k = args.top_k
1028
+ self.human_prefix = args.human_prefix
1029
+ self.assistant_prefix = args.assistant_prefix
1030
+ self.end_of_turn = args.end_of_turn
1031
+ self.prompt_template = args.instruction
1032
+ self.max_history_tokens = args.max_history
1033
+ self.display_welcome = not args.no_welcome
1034
+ self.block_size = block_size
1035
+ self.debug_mode = args.debug
1036
+ self.repetition_penalty = args.repetition_penalty
1037
+
1038
+ config = ChatConfig(args, model_config.block_size)
1039
+
1040
+ # Initialize chat session
1041
+ chat = CosmicFishChatSession(model, tokenizer, config)
1042
+
1043
+ # Main chat loop
1044
+ print(colored("\n🚀 CosmicFish initialized from Hugging Face Hub. Type your message (or /help for commands).\n", 'cyan'))
1045
+
1046
+ while True:
1047
+ try:
1048
+ # Get user input
1049
+ user_input = input(colored("You: ", 'green'))
1050
+
1051
+ # Check if it's a command
1052
+ if user_input.startswith('/'):
1053
+ # Execute command, continue loop if True, exit if False
1054
+ if not chat.execute_command(user_input):
1055
+ break
1056
+ continue
1057
+
1058
+ # Skip if empty input
1059
+ if not user_input.strip():
1060
+ continue
1061
+
1062
+ # Generate response using live generation
1063
+ live_buffer = ""
1064
+ final_response = None
1065
+
1066
+ # Use the generator version
1067
+ response_generator = chat.generate_response(user_input)
1068
+
1069
+ try:
1070
+ # First print the assistant prefix
1071
+ print(colored("CosmicFish: ", 'blue'), end="")
1072
+ sys.stdout.flush()
1073
+
1074
+ for token, live_text, is_done in response_generator:
1075
+ # If this is the final clean response
1076
+ if is_done:
1077
+ final_response = live_text
1078
+ # Print the final response directly if we didn't get any tokens yet
1079
+ if not live_buffer:
1080
+ print(final_response, end="")
1081
+ break
1082
+
1083
+ # If we have a token to display
1084
+ if token:
1085
+ # Check if token contains <|endoftext|> and remove it if present
1086
+ if "<|endoftext|>" in token:
1087
+ token = token.replace("<|endoftext|>", "")
1088
+ if token: # Only print if there's anything left
1089
+ print(token, end="", flush=True)
1090
+ break
1091
+
1092
+ # Display it
1093
+ print(token, end="", flush=True)
1094
+ live_buffer += token
1095
+
1096
+ except KeyboardInterrupt:
1097
+ # Allow user to interrupt generation
1098
+ print("\n[Generation interrupted]")
1099
+ final_response = "I was going to respond, but I'll stop here since you interrupted."
1100
+
1101
+ # Add an extra line for readability
1102
+ print()
1103
+
1104
+ except KeyboardInterrupt:
1105
+ print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.")
1106
+
1107
+ except Exception as e:
1108
+ print(colored(f"\nError: {str(e)}", 'red'))
1109
+ logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
1110
+
1111
+ except Exception as e:
1112
+ print(colored(f"Error setting up chat: {str(e)}", 'red'))
1113
+ logger.error(f"Error setting up chat: {str(e)}", exc_info=True)
1114
+ sys.exit(1)
1115
+
1116
+
1117
+ if __name__ == "__main__":
1118
+ try:
1119
+ main()
1120
+ except Exception as e:
1121
+ logger.error(f"Fatal error: {str(e)}", exc_info=True)
1122
+ sys.exit(1)