akkiisfrommars commited on
Commit
44b04f6
·
verified ·
1 Parent(s): bae66f3

Upload chat.py

Browse files
Files changed (1) hide show
  1. chat.py +382 -84
chat.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Chat interface for the released CosmicFish model from Hugging Face.
3
- Compatible with the HF-format release while maintaining all original features.
4
- """
5
-
6
  import os
7
  import sys
8
  import time
@@ -11,30 +6,23 @@ import torch
11
  import numpy as np
12
  from termcolor import colored
13
  import logging
14
- import readline # Enables arrow key history in terminal input
15
  import re
16
  import textwrap
17
  import random
18
  from collections import defaultdict
19
  import json
20
 
21
- # Try to import from transformers, fallback to local
22
  try:
23
  from transformers import GPT2Tokenizer
 
24
  HF_AVAILABLE = True
25
  except ImportError:
26
  HF_AVAILABLE = False
27
- print(" Transformers not available. Install with: pip install transformers")
28
-
29
- # Import the model classes - try both locations
30
- try:
31
- from modeling_cosmicfish import CosmicFish, CosmicConfig
32
- except ImportError:
33
- try:
34
- from model import CosmicFish, CosmicConfig
35
- except ImportError:
36
- print("❌ CosmicFish model classes not found. Make sure modeling_cosmicfish.py or model.py is available.")
37
- sys.exit(1)
38
 
39
  # Set up logging
40
  logging.basicConfig(
@@ -44,10 +32,299 @@ logging.basicConfig(
44
  )
45
  logger = logging.getLogger(__name__)
46
 
 
 
 
47
  # Default prompt template
48
  DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  class RepetitionPenaltyLogitsProcessor:
52
  """Apply repetition penalty to prevent repeating tokens."""
53
 
@@ -64,7 +341,7 @@ class RepetitionPenaltyLogitsProcessor:
64
 
65
 
66
  class CosmicFishChatSession:
67
- """Chat session for the released CosmicFish model."""
68
 
69
  def __init__(self, model, tokenizer, config):
70
  """Initialize chat session with model and configuration."""
@@ -123,11 +400,17 @@ class CosmicFishChatSession:
123
  """Print a welcome message to the user."""
124
  welcome_text = f"""
125
  {'=' * 80}
126
- Welcome to CosmicFish chat interface (Hugging Face Release)
127
 
128
- This is a {self.model.get_num_params() / 1e6:.1f}M parameter model.
129
  CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
130
 
 
 
 
 
 
 
131
  Type your prompts and CosmicFish will respond.
132
 
133
  Special commands:
@@ -211,9 +494,19 @@ Special commands:
211
  return False
212
 
213
  def _clean_token_text(self, text):
214
- """Clean token text by fixing encoding issues."""
215
- # Fix the specific issue with �� -> '
216
  text = text.replace('��', "'")
 
 
 
 
 
 
 
 
 
 
 
217
  return text
218
 
219
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
@@ -478,6 +771,7 @@ Token usage statistics:
478
  - Current repetition penalty: {self.repetition_penalty}
479
  - Current temperature: {self.config.temperature}
480
  - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
 
481
  """
482
  print(colored(stats, 'yellow'))
483
  return True
@@ -615,76 +909,80 @@ Token usage statistics:
615
  return True
616
 
617
 
618
- def load_cosmicfish_model(model_dir, device='cpu'):
619
- """Load CosmicFish model from HF-format directory"""
620
- print(f"Loading CosmicFish model from {model_dir}...")
621
-
622
- # Load config
623
- config_path = os.path.join(model_dir, "config.json")
624
- if not os.path.exists(config_path):
625
- raise FileNotFoundError(f"config.json not found in {model_dir}")
626
-
627
- with open(config_path, "r") as f:
628
- config_dict = json.load(f)
629
-
630
- # Create CosmicConfig
631
- config = CosmicConfig(
632
- vocab_size=config_dict["vocab_size"],
633
- block_size=config_dict["block_size"],
634
- n_layer=config_dict["n_layer"],
635
- n_head=config_dict["n_head"],
636
- n_embd=config_dict["n_embd"],
637
- bias=config_dict["bias"],
638
- dropout=0.0, # Set to 0 for inference
639
- eps=config_dict.get("eps", 1e-6),
640
- use_rotary=config_dict["use_rotary"],
641
- use_swiglu=config_dict["use_swiglu"],
642
- use_gqa=config_dict["use_gqa"],
643
- n_query_groups=config_dict["n_query_groups"],
644
- use_qk_norm=config_dict.get("use_qk_norm", False)
645
- )
646
-
647
- # Create model
648
- model = CosmicFish(config)
649
-
650
- # Load weights
651
- weights_path = os.path.join(model_dir, "pytorch_model.bin")
652
- if not os.path.exists(weights_path):
653
- raise FileNotFoundError(f"pytorch_model.bin not found in {model_dir}")
654
-
655
- state_dict = torch.load(weights_path, map_location=device)
656
- model.load_state_dict(state_dict)
657
- model.to(device)
658
- model.eval()
659
-
660
- print(f"✅ Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
661
- return model, config
662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
- def load_tokenizer():
665
- """Load GPT-2 tokenizer"""
666
- if not HF_AVAILABLE:
667
- raise ImportError("transformers library required. Install with: pip install transformers")
668
 
669
- print("Loading GPT-2 tokenizer...")
 
 
 
 
 
 
 
670
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
671
- print("Tokenizer loaded")
672
  return tokenizer
673
 
674
 
675
  def main():
676
- parser = argparse.ArgumentParser(description="Chat with the released CosmicFish model")
677
 
678
  # Model parameters
679
- parser.add_argument("--model_dir", type=str, default="./cosmicfish-hf-release",
680
- help="Path to the HF-format model directory")
681
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
682
  help="Device to use (cuda or cpu)")
683
 
684
  # Generation parameters
685
- parser.add_argument("--temperature", type=float, default=0.6,
686
  help="Temperature for sampling (default: 0.7)")
687
- parser.add_argument("--max_tokens", type=int, default=1024,
688
  help="Maximum number of tokens to generate per response")
689
  parser.add_argument("--min_tokens", type=int, default=10,
690
  help="Minimum number of tokens to generate per response")
@@ -717,12 +1015,12 @@ def main():
717
  # Configure device
718
  device = args.device
719
  if device == "cuda" and not torch.cuda.is_available():
720
- print("CUDA is not available, falling back to CPU")
721
  device = "cpu"
722
 
723
  try:
724
- # Load the model
725
- model, model_config = load_cosmicfish_model(args.model_dir, device)
726
 
727
  # Load tokenizer
728
  tokenizer = load_tokenizer()
@@ -751,7 +1049,7 @@ def main():
751
  chat = CosmicFishChatSession(model, tokenizer, config)
752
 
753
  # Main chat loop
754
- print(colored("\nCosmicFish initialized. Type your message (or /help for commands).\n", 'cyan'))
755
 
756
  while True:
757
  try:
@@ -819,8 +1117,8 @@ def main():
819
  logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
820
 
821
  except Exception as e:
822
- print(colored(f"Error loading model: {str(e)}", 'red'))
823
- logger.error(f"Error loading model: {str(e)}", exc_info=True)
824
  sys.exit(1)
825
 
826
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import time
 
6
  import numpy as np
7
  from termcolor import colored
8
  import logging
9
+ import readline
10
  import re
11
  import textwrap
12
  import random
13
  from collections import defaultdict
14
  import json
15
 
16
+ # Required imports for HF Hub
17
  try:
18
  from transformers import GPT2Tokenizer
19
+ from huggingface_hub import hf_hub_download, snapshot_download
20
  HF_AVAILABLE = True
21
  except ImportError:
22
  HF_AVAILABLE = False
23
+ print("Required libraries not available.")
24
+ print("Install with: pip install transformers huggingface-hub")
25
+ sys.exit(1)
 
 
 
 
 
 
 
 
26
 
27
  # Set up logging
28
  logging.basicConfig(
 
32
  )
33
  logger = logging.getLogger(__name__)
34
 
35
+ # Default model repository
36
+ DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-120M"
37
+
38
  # Default prompt template
39
  DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
40
 
41
 
42
+ class CosmicConfig:
43
+ """Configuration class for CosmicFish."""
44
+
45
+ def __init__(self,
46
+ vocab_size=50257,
47
+ block_size=512,
48
+ n_layer=12,
49
+ n_head=16,
50
+ n_embd=704,
51
+ bias=True,
52
+ dropout=0.0,
53
+ n_query_groups=4,
54
+ eps=1e-6,
55
+ use_rotary=True,
56
+ use_swiglu=True,
57
+ use_qk_norm=False,
58
+ use_gqa=True):
59
+ self.vocab_size = vocab_size
60
+ self.block_size = block_size
61
+ self.n_layer = n_layer
62
+ self.n_head = n_head
63
+ self.n_embd = n_embd
64
+ self.bias = bias
65
+ self.dropout = dropout
66
+ self.eps = eps
67
+ self.use_rotary = use_rotary
68
+ self.use_swiglu = use_swiglu
69
+ self.use_qk_norm = use_qk_norm
70
+ self.use_gqa = use_gqa
71
+ self.n_query_groups = n_query_groups if use_gqa else n_head
72
+ # Ensure n_head is divisible by n_query_groups
73
+ assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
74
+
75
+
76
+ class RMSNorm(torch.nn.Module):
77
+ """Root Mean Square Normalization"""
78
+
79
+ def __init__(self, dim, eps=1e-6):
80
+ super().__init__()
81
+ self.eps = eps
82
+ self.weight = torch.nn.Parameter(torch.ones(dim))
83
+
84
+ def forward(self, x):
85
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
86
+ return self.weight * (x / rms)
87
+
88
+
89
+ def precompute_freqs_cis(dim, end, theta=10000.0):
90
+ """Precompute the frequency tensor for complex exponentials (cis)"""
91
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
92
+ t = torch.arange(end, device=freqs.device)
93
+ freqs = torch.outer(t, freqs)
94
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
95
+ return freqs_cis
96
+
97
+
98
+ def apply_rotary_emb(xq, xk, freqs_cis):
99
+ """Apply rotary embeddings to input tensors"""
100
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
101
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
102
+
103
+ seq_len = xq_.size(2)
104
+ if freqs_cis.size(0) < seq_len:
105
+ raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
106
+
107
+ freqs_cis_seq = freqs_cis[:seq_len]
108
+ xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
109
+ xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
110
+
111
+ return xq_out.type_as(xq), xk_out.type_as(xk)
112
+
113
+
114
+ class GroupedQueryAttention(torch.nn.Module):
115
+ """Grouped Query Attention (GQA) implementation"""
116
+
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ assert config.n_embd % config.n_head == 0
120
+
121
+ head_dim = config.n_embd // config.n_head
122
+ self.head_dim = head_dim
123
+ self.n_head = config.n_head
124
+ self.n_embd = config.n_embd
125
+ self.n_query_groups = config.n_query_groups
126
+
127
+ self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
128
+ qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
129
+
130
+ self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
131
+ self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
132
+
133
+ # Flash attention support
134
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
135
+ if not self.flash:
136
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
137
+ .view(1, 1, config.block_size, config.block_size))
138
+
139
+ # Query-key normalization
140
+ self.qk_norm = getattr(config, 'use_qk_norm', False)
141
+ if self.qk_norm:
142
+ self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
143
+ self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
144
+
145
+ def forward(self, x, freqs_cis=None):
146
+ B, T, C = x.size()
147
+ qkv = self.c_attn(x)
148
+ head_dim = C // self.n_head
149
+
150
+ q_size = self.n_head * head_dim
151
+ k_size = self.kv_heads * head_dim
152
+ v_size = self.kv_heads * head_dim
153
+
154
+ q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
155
+
156
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
157
+ k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
158
+ v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
159
+
160
+ # Repeat k and v if needed for GQA
161
+ if self.kv_heads < self.n_head:
162
+ repeats = self.n_head // self.kv_heads
163
+ k = k.repeat_interleave(repeats, dim=1)
164
+ v = v.repeat_interleave(repeats, dim=1)
165
+
166
+ # Apply rotary embeddings
167
+ if freqs_cis is not None:
168
+ q, k = apply_rotary_emb(q, k, freqs_cis)
169
+
170
+ # Apply query-key normalization
171
+ if self.qk_norm:
172
+ q = self.q_norm(q)
173
+ k = self.k_norm(k)
174
+
175
+ # Compute attention
176
+ if self.flash:
177
+ y = torch.nn.functional.scaled_dot_product_attention(
178
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
179
+ )
180
+ else:
181
+ att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)))
182
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
183
+ att = torch.nn.functional.softmax(att, dim=-1)
184
+ y = att @ v
185
+
186
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
187
+ y = self.c_proj(y)
188
+ return y
189
+
190
+
191
+ class Block(torch.nn.Module):
192
+ """Transformer block"""
193
+
194
+ def __init__(self, config):
195
+ super().__init__()
196
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
197
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
198
+ self.attn = GroupedQueryAttention(config)
199
+
200
+ # MLP implementation based on configuration
201
+ if config.use_swiglu:
202
+ # SwiGLU MLP
203
+ self.mlp = torch.nn.ModuleDict(dict(
204
+ gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
205
+ up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
206
+ down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
207
+ act=torch.nn.SiLU(),
208
+ ))
209
+ m = self.mlp
210
+ self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
211
+ else:
212
+ # Traditional MLP
213
+ self.mlp = torch.nn.ModuleDict(dict(
214
+ c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
215
+ c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
216
+ act=torch.nn.GELU(),
217
+ ))
218
+ m = self.mlp
219
+ self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
220
+
221
+ def forward(self, x, freqs_cis=None):
222
+ x = x + self.attn(self.ln_1(x), freqs_cis)
223
+ x = x + self.mlpf(self.ln_2(x))
224
+ return x
225
+
226
+
227
+ class CosmicFish(torch.nn.Module):
228
+ """
229
+ CosmicFish model for inference only.
230
+ Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
231
+ """
232
+
233
+ def __init__(self, config):
234
+ super().__init__()
235
+ self.config = config
236
+
237
+ self.transformer = torch.nn.ModuleDict(dict(
238
+ wte=torch.nn.Embedding(config.vocab_size, config.n_embd),
239
+ h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
240
+ ln_f=RMSNorm(config.n_embd, eps=config.eps),
241
+ ))
242
+
243
+ self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
244
+
245
+ # Share weights between embedding and output
246
+ self.transformer.wte.weight = self.lm_head.weight
247
+
248
+ # Precompute rotary embedding frequencies
249
+ if config.use_rotary:
250
+ head_dim = config.n_embd // config.n_head
251
+ self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
252
+ else:
253
+ self.freqs_cis = None
254
+ self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
255
+
256
+ def get_num_params(self, non_embedding=True):
257
+ """Return the number of parameters in the model."""
258
+ n_params = sum(p.numel() for p in self.parameters())
259
+ if non_embedding and hasattr(self.transformer, 'wpe'):
260
+ n_params -= self.transformer.wpe.weight.numel()
261
+ return n_params
262
+
263
+ def forward(self, idx, targets=None):
264
+ """Forward pass through the model."""
265
+ device = idx.device
266
+ b, t = idx.size()
267
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
268
+
269
+ # Get token embeddings
270
+ tok_emb = self.transformer.wte(idx)
271
+
272
+ # Handle positional embeddings
273
+ if self.config.use_rotary:
274
+ x = tok_emb
275
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
276
+ else:
277
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
278
+ pos_emb = self.transformer.wpe(pos)
279
+ x = tok_emb + pos_emb
280
+ freqs_cis = None
281
+
282
+ # Apply transformer blocks
283
+ for block in self.transformer.h:
284
+ x = block(x, freqs_cis)
285
+
286
+ # Apply final normalization
287
+ x = self.transformer.ln_f(x)
288
+
289
+ # Calculate outputs
290
+ if targets is not None:
291
+ logits = self.lm_head(x)
292
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
293
+ else:
294
+ # For inference, only compute logits for the last token
295
+ logits = self.lm_head(x[:, [-1], :])
296
+ loss = None
297
+
298
+ return logits, loss
299
+
300
+ @torch.no_grad()
301
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
302
+ """
303
+ Generate text by sampling from the model, token by token.
304
+ """
305
+ for _ in range(max_new_tokens):
306
+ # Crop sequence to block size if needed
307
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
308
+
309
+ # Forward pass
310
+ logits, _ = self(idx_cond)
311
+ logits = logits[:, -1, :] / temperature
312
+
313
+ # Apply top-k sampling
314
+ if top_k is not None:
315
+ v, _ = torch.topk(logits, top_k)
316
+ logits[logits < v[:, [-1]]] = -float('Inf')
317
+
318
+ # Sample next token
319
+ probs = torch.nn.functional.softmax(logits, dim=-1)
320
+ idx_next = torch.multinomial(probs, num_samples=1)
321
+
322
+ # Append to sequence
323
+ idx = torch.cat((idx, idx_next), dim=1)
324
+
325
+ return idx
326
+
327
+
328
  class RepetitionPenaltyLogitsProcessor:
329
  """Apply repetition penalty to prevent repeating tokens."""
330
 
 
341
 
342
 
343
  class CosmicFishChatSession:
344
+ """Chat session for CosmicFish model from Hugging Face Hub."""
345
 
346
  def __init__(self, model, tokenizer, config):
347
  """Initialize chat session with model and configuration."""
 
400
  """Print a welcome message to the user."""
401
  welcome_text = f"""
402
  {'=' * 80}
403
+ Welcome to CosmicFish!
404
 
405
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
406
  CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
407
 
408
+ ⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
409
+ small, it is more likely to give incorrect answers or hallucinate compared to
410
+ larger models. Please verify important information from reliable sources.
411
+
412
+ Model: {DEFAULT_MODEL_REPO}
413
+
414
  Type your prompts and CosmicFish will respond.
415
 
416
  Special commands:
 
494
  return False
495
 
496
  def _clean_token_text(self, text):
497
+
 
498
  text = text.replace('��', "'")
499
+
500
+ text = text.replace('�', "'")
501
+ text = text.replace('\ufffd', "'")
502
+ text = text.replace('\uFFFD', "'")
503
+
504
+ text = text.replace('’', "'")
505
+ text = text.replace('“', "'")
506
+ text = text.replace('�', "'")
507
+ text = text.replace('â€"', "'")
508
+ text = text.replace('â€"', "'")
509
+
510
  return text
511
 
512
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
 
771
  - Current repetition penalty: {self.repetition_penalty}
772
  - Current temperature: {self.config.temperature}
773
  - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
774
+ - Source: {DEFAULT_MODEL_REPO}
775
  """
776
  print(colored(stats, 'yellow'))
777
  return True
 
909
  return True
910
 
911
 
912
+ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
913
+ """Download and load CosmicFish model from Hugging Face Hub"""
914
+ print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
915
 
916
+ try:
917
+ # Download the model files to local cache
918
+ print("Downloading model files...")
919
+ cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
920
+ print(f"Model cached at: {cache_dir}")
921
+
922
+ # Load config
923
+ config_path = os.path.join(cache_dir, "config.json")
924
+ with open(config_path, "r") as f:
925
+ config_dict = json.load(f)
926
+
927
+ # Create CosmicConfig
928
+ config = CosmicConfig(
929
+ vocab_size=config_dict["vocab_size"],
930
+ block_size=config_dict["block_size"],
931
+ n_layer=config_dict["n_layer"],
932
+ n_head=config_dict["n_head"],
933
+ n_embd=config_dict["n_embd"],
934
+ bias=config_dict["bias"],
935
+ dropout=0.0, # Set to 0 for inference
936
+ eps=config_dict.get("eps", 1e-6),
937
+ use_rotary=config_dict["use_rotary"],
938
+ use_swiglu=config_dict["use_swiglu"],
939
+ use_gqa=config_dict["use_gqa"],
940
+ n_query_groups=config_dict["n_query_groups"],
941
+ use_qk_norm=config_dict.get("use_qk_norm", False)
942
+ )
943
+
944
+ # Create model
945
+ print("Creating model...")
946
+ model = CosmicFish(config)
947
+
948
+ # Load weights
949
+ print("Loading weights...")
950
+ weights_path = os.path.join(cache_dir, "pytorch_model.bin")
951
+ state_dict = torch.load(weights_path, map_location=device)
952
+ model.load_state_dict(state_dict)
953
+ model.to(device)
954
+ model.eval()
955
 
956
+ print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
957
+ print(f"Device: {device}")
958
+ return model, config
 
959
 
960
+ except Exception as e:
961
+ print(colored(f"Error downloading/loading model: {str(e)}", "red"))
962
+ print(colored("Make sure you have internet connection and the model repo exists", "yellow"))
963
+ sys.exit(1)
964
+
965
+
966
+ def load_tokenizer():
967
+ print("Loading tokenizer...")
968
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
969
+ print("Tokenizer loaded")
970
  return tokenizer
971
 
972
 
973
  def main():
974
+ parser = argparse.ArgumentParser(description="Chat with CosmicFish model from Hugging Face Hub")
975
 
976
  # Model parameters
977
+ parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
978
+ help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})")
979
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
980
  help="Device to use (cuda or cpu)")
981
 
982
  # Generation parameters
983
+ parser.add_argument("--temperature", type=float, default=0.7,
984
  help="Temperature for sampling (default: 0.7)")
985
+ parser.add_argument("--max_tokens", type=int, default=512,
986
  help="Maximum number of tokens to generate per response")
987
  parser.add_argument("--min_tokens", type=int, default=10,
988
  help="Minimum number of tokens to generate per response")
 
1015
  # Configure device
1016
  device = args.device
1017
  if device == "cuda" and not torch.cuda.is_available():
1018
+ print(colored("CUDA is not available, falling back to CPU", "yellow"))
1019
  device = "cpu"
1020
 
1021
  try:
1022
+ # Download and load the model from HF Hub
1023
+ model, model_config = download_cosmicfish_from_hub(args.model_repo, device)
1024
 
1025
  # Load tokenizer
1026
  tokenizer = load_tokenizer()
 
1049
  chat = CosmicFishChatSession(model, tokenizer, config)
1050
 
1051
  # Main chat loop
1052
+ print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
1053
 
1054
  while True:
1055
  try:
 
1117
  logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
1118
 
1119
  except Exception as e:
1120
+ print(colored(f"Error setting up chat: {str(e)}", 'red'))
1121
+ logger.error(f"Error setting up chat: {str(e)}", exc_info=True)
1122
  sys.exit(1)
1123
 
1124