akkiisfrommars commited on
Commit
bae66f3
·
verified ·
1 Parent(s): f555625

Upload 2 files

Browse files
Files changed (2) hide show
  1. chat.py +81 -384
  2. modeling_cosmicfish.py +0 -6
chat.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- Chat interface for CosmicFish model directly from Hugging Face Hub.
3
- Automatically downloads and caches the model from HF Hub.
4
  """
5
 
6
  import os
@@ -11,23 +11,30 @@ import torch
11
  import numpy as np
12
  from termcolor import colored
13
  import logging
14
- import readline
15
  import re
16
  import textwrap
17
  import random
18
  from collections import defaultdict
19
  import json
20
 
21
- # Required imports for HF Hub
22
  try:
23
  from transformers import GPT2Tokenizer
24
- from huggingface_hub import hf_hub_download, snapshot_download
25
  HF_AVAILABLE = True
26
  except ImportError:
27
  HF_AVAILABLE = False
28
- print("Required libraries not available.")
29
- print("Install with: pip install transformers huggingface-hub")
30
- sys.exit(1)
 
 
 
 
 
 
 
 
31
 
32
  # Set up logging
33
  logging.basicConfig(
@@ -37,299 +44,10 @@ logging.basicConfig(
37
  )
38
  logger = logging.getLogger(__name__)
39
 
40
- # Default model repository
41
- DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-120M"
42
-
43
  # Default prompt template
44
  DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
45
 
46
 
47
- class CosmicConfig:
48
- """Configuration class for CosmicFish."""
49
-
50
- def __init__(self,
51
- vocab_size=50257,
52
- block_size=512,
53
- n_layer=12,
54
- n_head=16,
55
- n_embd=704,
56
- bias=True,
57
- dropout=0.0,
58
- n_query_groups=4,
59
- eps=1e-6,
60
- use_rotary=True,
61
- use_swiglu=True,
62
- use_qk_norm=False,
63
- use_gqa=True):
64
- self.vocab_size = vocab_size
65
- self.block_size = block_size
66
- self.n_layer = n_layer
67
- self.n_head = n_head
68
- self.n_embd = n_embd
69
- self.bias = bias
70
- self.dropout = dropout
71
- self.eps = eps
72
- self.use_rotary = use_rotary
73
- self.use_swiglu = use_swiglu
74
- self.use_qk_norm = use_qk_norm
75
- self.use_gqa = use_gqa
76
- self.n_query_groups = n_query_groups if use_gqa else n_head
77
- # Ensure n_head is divisible by n_query_groups
78
- assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
79
-
80
-
81
- class RMSNorm(torch.nn.Module):
82
- """Root Mean Square Normalization"""
83
-
84
- def __init__(self, dim, eps=1e-6):
85
- super().__init__()
86
- self.eps = eps
87
- self.weight = torch.nn.Parameter(torch.ones(dim))
88
-
89
- def forward(self, x):
90
- rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
91
- return self.weight * (x / rms)
92
-
93
-
94
- def precompute_freqs_cis(dim, end, theta=10000.0):
95
- """Precompute the frequency tensor for complex exponentials (cis)"""
96
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
97
- t = torch.arange(end, device=freqs.device)
98
- freqs = torch.outer(t, freqs)
99
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
100
- return freqs_cis
101
-
102
-
103
- def apply_rotary_emb(xq, xk, freqs_cis):
104
- """Apply rotary embeddings to input tensors"""
105
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
106
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
107
-
108
- seq_len = xq_.size(2)
109
- if freqs_cis.size(0) < seq_len:
110
- raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
111
-
112
- freqs_cis_seq = freqs_cis[:seq_len]
113
- xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
114
- xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
115
-
116
- return xq_out.type_as(xq), xk_out.type_as(xk)
117
-
118
-
119
- class GroupedQueryAttention(torch.nn.Module):
120
- """Grouped Query Attention (GQA) implementation"""
121
-
122
- def __init__(self, config):
123
- super().__init__()
124
- assert config.n_embd % config.n_head == 0
125
-
126
- head_dim = config.n_embd // config.n_head
127
- self.head_dim = head_dim
128
- self.n_head = config.n_head
129
- self.n_embd = config.n_embd
130
- self.n_query_groups = config.n_query_groups
131
-
132
- self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
133
- qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
134
-
135
- self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
136
- self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
137
-
138
- # Flash attention support
139
- self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
140
- if not self.flash:
141
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
142
- .view(1, 1, config.block_size, config.block_size))
143
-
144
- # Query-key normalization
145
- self.qk_norm = getattr(config, 'use_qk_norm', False)
146
- if self.qk_norm:
147
- self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
148
- self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
149
-
150
- def forward(self, x, freqs_cis=None):
151
- B, T, C = x.size()
152
- qkv = self.c_attn(x)
153
- head_dim = C // self.n_head
154
-
155
- q_size = self.n_head * head_dim
156
- k_size = self.kv_heads * head_dim
157
- v_size = self.kv_heads * head_dim
158
-
159
- q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
160
-
161
- q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
162
- k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
163
- v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
164
-
165
- # Repeat k and v if needed for GQA
166
- if self.kv_heads < self.n_head:
167
- repeats = self.n_head // self.kv_heads
168
- k = k.repeat_interleave(repeats, dim=1)
169
- v = v.repeat_interleave(repeats, dim=1)
170
-
171
- # Apply rotary embeddings
172
- if freqs_cis is not None:
173
- q, k = apply_rotary_emb(q, k, freqs_cis)
174
-
175
- # Apply query-key normalization
176
- if self.qk_norm:
177
- q = self.q_norm(q)
178
- k = self.k_norm(k)
179
-
180
- # Compute attention
181
- if self.flash:
182
- y = torch.nn.functional.scaled_dot_product_attention(
183
- q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
184
- )
185
- else:
186
- att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)))
187
- att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
188
- att = torch.nn.functional.softmax(att, dim=-1)
189
- y = att @ v
190
-
191
- y = y.transpose(1, 2).contiguous().view(B, T, C)
192
- y = self.c_proj(y)
193
- return y
194
-
195
-
196
- class Block(torch.nn.Module):
197
- """Transformer block"""
198
-
199
- def __init__(self, config):
200
- super().__init__()
201
- self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
202
- self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
203
- self.attn = GroupedQueryAttention(config)
204
-
205
- # MLP implementation based on configuration
206
- if config.use_swiglu:
207
- # SwiGLU MLP
208
- self.mlp = torch.nn.ModuleDict(dict(
209
- gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
210
- up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
211
- down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
212
- act=torch.nn.SiLU(),
213
- ))
214
- m = self.mlp
215
- self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
216
- else:
217
- # Traditional MLP
218
- self.mlp = torch.nn.ModuleDict(dict(
219
- c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
220
- c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
221
- act=torch.nn.GELU(),
222
- ))
223
- m = self.mlp
224
- self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
225
-
226
- def forward(self, x, freqs_cis=None):
227
- x = x + self.attn(self.ln_1(x), freqs_cis)
228
- x = x + self.mlpf(self.ln_2(x))
229
- return x
230
-
231
-
232
- class CosmicFish(torch.nn.Module):
233
- """
234
- CosmicFish model for inference only.
235
- Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
236
- """
237
-
238
- def __init__(self, config):
239
- super().__init__()
240
- self.config = config
241
-
242
- self.transformer = torch.nn.ModuleDict(dict(
243
- wte=torch.nn.Embedding(config.vocab_size, config.n_embd),
244
- h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
245
- ln_f=RMSNorm(config.n_embd, eps=config.eps),
246
- ))
247
-
248
- self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
249
-
250
- # Share weights between embedding and output
251
- self.transformer.wte.weight = self.lm_head.weight
252
-
253
- # Precompute rotary embedding frequencies
254
- if config.use_rotary:
255
- head_dim = config.n_embd // config.n_head
256
- self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
257
- else:
258
- self.freqs_cis = None
259
- self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
260
-
261
- def get_num_params(self, non_embedding=True):
262
- """Return the number of parameters in the model."""
263
- n_params = sum(p.numel() for p in self.parameters())
264
- if non_embedding and hasattr(self.transformer, 'wpe'):
265
- n_params -= self.transformer.wpe.weight.numel()
266
- return n_params
267
-
268
- def forward(self, idx, targets=None):
269
- """Forward pass through the model."""
270
- device = idx.device
271
- b, t = idx.size()
272
- assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
273
-
274
- # Get token embeddings
275
- tok_emb = self.transformer.wte(idx)
276
-
277
- # Handle positional embeddings
278
- if self.config.use_rotary:
279
- x = tok_emb
280
- freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
281
- else:
282
- pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
283
- pos_emb = self.transformer.wpe(pos)
284
- x = tok_emb + pos_emb
285
- freqs_cis = None
286
-
287
- # Apply transformer blocks
288
- for block in self.transformer.h:
289
- x = block(x, freqs_cis)
290
-
291
- # Apply final normalization
292
- x = self.transformer.ln_f(x)
293
-
294
- # Calculate outputs
295
- if targets is not None:
296
- logits = self.lm_head(x)
297
- loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
298
- else:
299
- # For inference, only compute logits for the last token
300
- logits = self.lm_head(x[:, [-1], :])
301
- loss = None
302
-
303
- return logits, loss
304
-
305
- @torch.no_grad()
306
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
- """
308
- Generate text by sampling from the model, token by token.
309
- """
310
- for _ in range(max_new_tokens):
311
- # Crop sequence to block size if needed
312
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
313
-
314
- # Forward pass
315
- logits, _ = self(idx_cond)
316
- logits = logits[:, -1, :] / temperature
317
-
318
- # Apply top-k sampling
319
- if top_k is not None:
320
- v, _ = torch.topk(logits, top_k)
321
- logits[logits < v[:, [-1]]] = -float('Inf')
322
-
323
- # Sample next token
324
- probs = torch.nn.functional.softmax(logits, dim=-1)
325
- idx_next = torch.multinomial(probs, num_samples=1)
326
-
327
- # Append to sequence
328
- idx = torch.cat((idx, idx_next), dim=1)
329
-
330
- return idx
331
-
332
-
333
  class RepetitionPenaltyLogitsProcessor:
334
  """Apply repetition penalty to prevent repeating tokens."""
335
 
@@ -346,7 +64,7 @@ class RepetitionPenaltyLogitsProcessor:
346
 
347
 
348
  class CosmicFishChatSession:
349
- """Chat session for CosmicFish model from Hugging Face Hub."""
350
 
351
  def __init__(self, model, tokenizer, config):
352
  """Initialize chat session with model and configuration."""
@@ -405,17 +123,11 @@ class CosmicFishChatSession:
405
  """Print a welcome message to the user."""
406
  welcome_text = f"""
407
  {'=' * 80}
408
- Welcome to CosmicFish!
409
 
410
- This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
411
  CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
412
 
413
- ⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
414
- small, it is more likely to give incorrect answers or hallucinate compared to
415
- larger models. Please verify important information from reliable sources.
416
-
417
- Model: {DEFAULT_MODEL_REPO}
418
-
419
  Type your prompts and CosmicFish will respond.
420
 
421
  Special commands:
@@ -499,19 +211,9 @@ Special commands:
499
  return False
500
 
501
  def _clean_token_text(self, text):
502
-
 
503
  text = text.replace('��', "'")
504
-
505
- text = text.replace('�', "'")
506
- text = text.replace('\ufffd', "'")
507
- text = text.replace('\uFFFD', "'")
508
-
509
- text = text.replace('’', "'")
510
- text = text.replace('“', "'")
511
- text = text.replace('�', "'")
512
- text = text.replace('â€"', "'")
513
- text = text.replace('â€"', "'")
514
-
515
  return text
516
 
517
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
@@ -776,7 +478,6 @@ Token usage statistics:
776
  - Current repetition penalty: {self.repetition_penalty}
777
  - Current temperature: {self.config.temperature}
778
  - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
779
- - Source: {DEFAULT_MODEL_REPO}
780
  """
781
  print(colored(stats, 'yellow'))
782
  return True
@@ -914,80 +615,76 @@ Token usage statistics:
914
  return True
915
 
916
 
917
- def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
918
- """Download and load CosmicFish model from Hugging Face Hub"""
919
- print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
920
-
921
- try:
922
- # Download the model files to local cache
923
- print("Downloading model files...")
924
- cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
925
- print(f"Model cached at: {cache_dir}")
926
-
927
- # Load config
928
- config_path = os.path.join(cache_dir, "config.json")
929
- with open(config_path, "r") as f:
930
- config_dict = json.load(f)
931
-
932
- # Create CosmicConfig
933
- config = CosmicConfig(
934
- vocab_size=config_dict["vocab_size"],
935
- block_size=config_dict["block_size"],
936
- n_layer=config_dict["n_layer"],
937
- n_head=config_dict["n_head"],
938
- n_embd=config_dict["n_embd"],
939
- bias=config_dict["bias"],
940
- dropout=0.0, # Set to 0 for inference
941
- eps=config_dict.get("eps", 1e-6),
942
- use_rotary=config_dict["use_rotary"],
943
- use_swiglu=config_dict["use_swiglu"],
944
- use_gqa=config_dict["use_gqa"],
945
- n_query_groups=config_dict["n_query_groups"],
946
- use_qk_norm=config_dict.get("use_qk_norm", False)
947
- )
948
-
949
- # Create model
950
- print("Creating model...")
951
- model = CosmicFish(config)
952
-
953
- # Load weights
954
- print("Loading weights...")
955
- weights_path = os.path.join(cache_dir, "pytorch_model.bin")
956
- state_dict = torch.load(weights_path, map_location=device)
957
- model.load_state_dict(state_dict)
958
- model.to(device)
959
- model.eval()
960
-
961
- print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
962
- print(f"Device: {device}")
963
- return model, config
964
-
965
- except Exception as e:
966
- print(colored(f"Error downloading/loading model: {str(e)}", "red"))
967
- print(colored("Make sure you have internet connection and the model repo exists", "yellow"))
968
- sys.exit(1)
969
 
970
 
971
  def load_tokenizer():
972
- print("Loading tokenizer...")
 
 
 
 
973
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
974
- print("Tokenizer loaded")
975
  return tokenizer
976
 
977
 
978
  def main():
979
- parser = argparse.ArgumentParser(description="Chat with CosmicFish model from Hugging Face Hub")
980
 
981
  # Model parameters
982
- parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
983
- help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})")
984
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
985
  help="Device to use (cuda or cpu)")
986
 
987
  # Generation parameters
988
- parser.add_argument("--temperature", type=float, default=0.7,
989
  help="Temperature for sampling (default: 0.7)")
990
- parser.add_argument("--max_tokens", type=int, default=512,
991
  help="Maximum number of tokens to generate per response")
992
  parser.add_argument("--min_tokens", type=int, default=10,
993
  help="Minimum number of tokens to generate per response")
@@ -1020,12 +717,12 @@ def main():
1020
  # Configure device
1021
  device = args.device
1022
  if device == "cuda" and not torch.cuda.is_available():
1023
- print(colored("CUDA is not available, falling back to CPU", "yellow"))
1024
  device = "cpu"
1025
 
1026
  try:
1027
- # Download and load the model from HF Hub
1028
- model, model_config = download_cosmicfish_from_hub(args.model_repo, device)
1029
 
1030
  # Load tokenizer
1031
  tokenizer = load_tokenizer()
@@ -1054,7 +751,7 @@ def main():
1054
  chat = CosmicFishChatSession(model, tokenizer, config)
1055
 
1056
  # Main chat loop
1057
- print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
1058
 
1059
  while True:
1060
  try:
@@ -1122,8 +819,8 @@ def main():
1122
  logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
1123
 
1124
  except Exception as e:
1125
- print(colored(f"Error setting up chat: {str(e)}", 'red'))
1126
- logger.error(f"Error setting up chat: {str(e)}", exc_info=True)
1127
  sys.exit(1)
1128
 
1129
 
 
1
  """
2
+ Chat interface for the released CosmicFish model from Hugging Face.
3
+ Compatible with the HF-format release while maintaining all original features.
4
  """
5
 
6
  import os
 
11
  import numpy as np
12
  from termcolor import colored
13
  import logging
14
+ import readline # Enables arrow key history in terminal input
15
  import re
16
  import textwrap
17
  import random
18
  from collections import defaultdict
19
  import json
20
 
21
+ # Try to import from transformers, fallback to local
22
  try:
23
  from transformers import GPT2Tokenizer
 
24
  HF_AVAILABLE = True
25
  except ImportError:
26
  HF_AVAILABLE = False
27
+ print(" Transformers not available. Install with: pip install transformers")
28
+
29
+ # Import the model classes - try both locations
30
+ try:
31
+ from modeling_cosmicfish import CosmicFish, CosmicConfig
32
+ except ImportError:
33
+ try:
34
+ from model import CosmicFish, CosmicConfig
35
+ except ImportError:
36
+ print("❌ CosmicFish model classes not found. Make sure modeling_cosmicfish.py or model.py is available.")
37
+ sys.exit(1)
38
 
39
  # Set up logging
40
  logging.basicConfig(
 
44
  )
45
  logger = logging.getLogger(__name__)
46
 
 
 
 
47
  # Default prompt template
48
  DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  class RepetitionPenaltyLogitsProcessor:
52
  """Apply repetition penalty to prevent repeating tokens."""
53
 
 
64
 
65
 
66
  class CosmicFishChatSession:
67
+ """Chat session for the released CosmicFish model."""
68
 
69
  def __init__(self, model, tokenizer, config):
70
  """Initialize chat session with model and configuration."""
 
123
  """Print a welcome message to the user."""
124
  welcome_text = f"""
125
  {'=' * 80}
126
+ Welcome to CosmicFish chat interface (Hugging Face Release)
127
 
128
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model.
129
  CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
130
 
 
 
 
 
 
 
131
  Type your prompts and CosmicFish will respond.
132
 
133
  Special commands:
 
211
  return False
212
 
213
  def _clean_token_text(self, text):
214
+ """Clean token text by fixing encoding issues."""
215
+ # Fix the specific issue with �� -> '
216
  text = text.replace('��', "'")
 
 
 
 
 
 
 
 
 
 
 
217
  return text
218
 
219
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
 
478
  - Current repetition penalty: {self.repetition_penalty}
479
  - Current temperature: {self.config.temperature}
480
  - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
 
481
  """
482
  print(colored(stats, 'yellow'))
483
  return True
 
615
  return True
616
 
617
 
618
+ def load_cosmicfish_model(model_dir, device='cpu'):
619
+ """Load CosmicFish model from HF-format directory"""
620
+ print(f"Loading CosmicFish model from {model_dir}...")
621
+
622
+ # Load config
623
+ config_path = os.path.join(model_dir, "config.json")
624
+ if not os.path.exists(config_path):
625
+ raise FileNotFoundError(f"config.json not found in {model_dir}")
626
+
627
+ with open(config_path, "r") as f:
628
+ config_dict = json.load(f)
629
+
630
+ # Create CosmicConfig
631
+ config = CosmicConfig(
632
+ vocab_size=config_dict["vocab_size"],
633
+ block_size=config_dict["block_size"],
634
+ n_layer=config_dict["n_layer"],
635
+ n_head=config_dict["n_head"],
636
+ n_embd=config_dict["n_embd"],
637
+ bias=config_dict["bias"],
638
+ dropout=0.0, # Set to 0 for inference
639
+ eps=config_dict.get("eps", 1e-6),
640
+ use_rotary=config_dict["use_rotary"],
641
+ use_swiglu=config_dict["use_swiglu"],
642
+ use_gqa=config_dict["use_gqa"],
643
+ n_query_groups=config_dict["n_query_groups"],
644
+ use_qk_norm=config_dict.get("use_qk_norm", False)
645
+ )
646
+
647
+ # Create model
648
+ model = CosmicFish(config)
649
+
650
+ # Load weights
651
+ weights_path = os.path.join(model_dir, "pytorch_model.bin")
652
+ if not os.path.exists(weights_path):
653
+ raise FileNotFoundError(f"pytorch_model.bin not found in {model_dir}")
654
+
655
+ state_dict = torch.load(weights_path, map_location=device)
656
+ model.load_state_dict(state_dict)
657
+ model.to(device)
658
+ model.eval()
659
+
660
+ print(f"✅ Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
661
+ return model, config
 
 
 
 
 
 
 
 
662
 
663
 
664
  def load_tokenizer():
665
+ """Load GPT-2 tokenizer"""
666
+ if not HF_AVAILABLE:
667
+ raise ImportError("transformers library required. Install with: pip install transformers")
668
+
669
+ print("Loading GPT-2 tokenizer...")
670
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
671
+ print("Tokenizer loaded")
672
  return tokenizer
673
 
674
 
675
  def main():
676
+ parser = argparse.ArgumentParser(description="Chat with the released CosmicFish model")
677
 
678
  # Model parameters
679
+ parser.add_argument("--model_dir", type=str, default="./cosmicfish-hf-release",
680
+ help="Path to the HF-format model directory")
681
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
682
  help="Device to use (cuda or cpu)")
683
 
684
  # Generation parameters
685
+ parser.add_argument("--temperature", type=float, default=0.6,
686
  help="Temperature for sampling (default: 0.7)")
687
+ parser.add_argument("--max_tokens", type=int, default=1024,
688
  help="Maximum number of tokens to generate per response")
689
  parser.add_argument("--min_tokens", type=int, default=10,
690
  help="Minimum number of tokens to generate per response")
 
717
  # Configure device
718
  device = args.device
719
  if device == "cuda" and not torch.cuda.is_available():
720
+ print("CUDA is not available, falling back to CPU")
721
  device = "cpu"
722
 
723
  try:
724
+ # Load the model
725
+ model, model_config = load_cosmicfish_model(args.model_dir, device)
726
 
727
  # Load tokenizer
728
  tokenizer = load_tokenizer()
 
751
  chat = CosmicFishChatSession(model, tokenizer, config)
752
 
753
  # Main chat loop
754
+ print(colored("\nCosmicFish initialized. Type your message (or /help for commands).\n", 'cyan'))
755
 
756
  while True:
757
  try:
 
819
  logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
820
 
821
  except Exception as e:
822
+ print(colored(f"Error loading model: {str(e)}", 'red'))
823
+ logger.error(f"Error loading model: {str(e)}", exc_info=True)
824
  sys.exit(1)
825
 
826
 
modeling_cosmicfish.py CHANGED
@@ -1,9 +1,3 @@
1
- """
2
- CosmicFish Model - Inference Only Version
3
- Minimal implementation for loading and running inference with CosmicFish.
4
- Removes all training-specific code and optimizations.
5
- """
6
-
7
  import math
8
  import torch
9
  import torch.nn as nn
 
 
 
 
 
 
 
1
  import math
2
  import torch
3
  import torch.nn as nn