akkiisfrommars commited on
Commit
2a3f301
·
verified ·
1 Parent(s): 95f1d58

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +48 -32
chat.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import os
2
  import sys
3
  import time
@@ -24,6 +29,15 @@ except ImportError:
24
  print("Install with: pip install transformers huggingface-hub")
25
  sys.exit(1)
26
 
 
 
 
 
 
 
 
 
 
27
  # Set up logging
28
  logging.basicConfig(
29
  level=logging.INFO,
@@ -392,24 +406,16 @@ class CosmicFishChatSession:
392
  "\nUser:"
393
  ]
394
 
395
- # Print welcome message
396
  if config.display_welcome:
397
  self._print_welcome_message()
398
 
399
  def _print_welcome_message(self):
400
- """Print a welcome message to the user."""
401
  welcome_text = f"""
402
  {'=' * 80}
403
- Welcome to CosmicFish!
404
-
405
- This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
406
- CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
407
-
408
- ⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
409
- small, it is more likely to give incorrect answers or hallucinate compared to
410
- larger models. Please verify important information from reliable sources.
411
 
412
- Model: {DEFAULT_MODEL_REPO}
 
413
 
414
  Type your prompts and CosmicFish will respond.
415
 
@@ -423,6 +429,14 @@ Special commands:
423
  - /temp [value]: Set temperature (between 0.1 and 2.0)
424
  - /penalty [value]: Set repetition penalty (1.0-2.0)
425
  - /debug: Toggle debug mode
 
 
 
 
 
 
 
 
426
  {'=' * 80}
427
  """
428
  print(colored(welcome_text, 'cyan'))
@@ -494,19 +508,15 @@ Special commands:
494
  return False
495
 
496
  def _clean_token_text(self, text):
497
-
498
  text = text.replace('��', "'")
499
-
500
  text = text.replace('�', "'")
501
  text = text.replace('\ufffd', "'")
502
  text = text.replace('\uFFFD', "'")
503
-
504
- text = text.replace('’', "'")
505
- text = text.replace('“', "'")
506
- text = text.replace('�', "'")
507
- text = text.replace('â€"', "'")
508
- text = text.replace('â€"', "'")
509
-
510
  return text
511
 
512
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
@@ -772,6 +782,7 @@ Token usage statistics:
772
  - Current temperature: {self.config.temperature}
773
  - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
774
  - Source: {DEFAULT_MODEL_REPO}
 
775
  """
776
  print(colored(stats, 'yellow'))
777
  return True
@@ -910,7 +921,7 @@ Token usage statistics:
910
 
911
 
912
  def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
913
- """Download and load CosmicFish model from Hugging Face Hub"""
914
  print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
915
 
916
  try:
@@ -945,10 +956,19 @@ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
945
  print("Creating model...")
946
  model = CosmicFish(config)
947
 
948
- # Load weights
949
- print("Loading weights...")
950
- weights_path = os.path.join(cache_dir, "pytorch_model.bin")
951
- state_dict = torch.load(weights_path, map_location=device)
 
 
 
 
 
 
 
 
 
952
  model.load_state_dict(state_dict)
953
  model.to(device)
954
  model.eval()
@@ -964,14 +984,12 @@ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
964
 
965
 
966
  def load_tokenizer():
967
- print("Loading tokenizer...")
968
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
969
- print("Tokenizer loaded")
970
  return tokenizer
971
 
972
 
973
  def main():
974
- parser = argparse.ArgumentParser(description="Chat with CosmicFish model from Hugging Face Hub")
975
 
976
  # Model parameters
977
  parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
@@ -982,7 +1000,7 @@ def main():
982
  # Generation parameters
983
  parser.add_argument("--temperature", type=float, default=0.7,
984
  help="Temperature for sampling (default: 0.7)")
985
- parser.add_argument("--max_tokens", type=int, default=512,
986
  help="Maximum number of tokens to generate per response")
987
  parser.add_argument("--min_tokens", type=int, default=10,
988
  help="Minimum number of tokens to generate per response")
@@ -1049,7 +1067,7 @@ def main():
1049
  chat = CosmicFishChatSession(model, tokenizer, config)
1050
 
1051
  # Main chat loop
1052
- print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
1053
 
1054
  while True:
1055
  try:
@@ -1087,8 +1105,6 @@ def main():
1087
  if not live_buffer:
1088
  print(final_response, end="")
1089
  break
1090
-
1091
- # If we have a token to display
1092
  if token:
1093
  # Check if token contains <|endoftext|> and remove it if present
1094
  if "<|endoftext|>" in token:
 
1
+ """
2
+ Chat interface for CosmicFish model downloaded from Hugging Face Hub.
3
+ Uses safetensors format only for secure model loading.
4
+ """
5
+
6
  import os
7
  import sys
8
  import time
 
29
  print("Install with: pip install transformers huggingface-hub")
30
  sys.exit(1)
31
 
32
+ # Required for safetensors
33
+ try:
34
+ from safetensors.torch import load_file
35
+ SAFETENSORS_AVAILABLE = True
36
+ except ImportError:
37
+ SAFETENSORS_AVAILABLE = False
38
+ print("Safetensors not available. Install with: pip install safetensors")
39
+ sys.exit(1)
40
+
41
  # Set up logging
42
  logging.basicConfig(
43
  level=logging.INFO,
 
406
  "\nUser:"
407
  ]
408
 
 
409
  if config.display_welcome:
410
  self._print_welcome_message()
411
 
412
  def _print_welcome_message(self):
 
413
  welcome_text = f"""
414
  {'=' * 80}
415
+ Welcome to CosmicFish chat interface
 
 
 
 
 
 
 
416
 
417
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model.
418
+ CosmicFish is an efficient LLM with an advanced architecture.
419
 
420
  Type your prompts and CosmicFish will respond.
421
 
 
429
  - /temp [value]: Set temperature (between 0.1 and 2.0)
430
  - /penalty [value]: Set repetition penalty (1.0-2.0)
431
  - /debug: Toggle debug mode
432
+
433
+
434
+ Note: CosmicFIsh may generate incorrect or fictional responses. Verify facts if needed.
435
+
436
+ Visit https://cosmicfish.ai for more info
437
+
438
+
439
+ Developed by Mistyoz AI (https://www.mistyoz.com)
440
  {'=' * 80}
441
  """
442
  print(colored(welcome_text, 'cyan'))
 
508
  return False
509
 
510
  def _clean_token_text(self, text):
 
511
  text = text.replace('��', "'")
 
512
  text = text.replace('�', "'")
513
  text = text.replace('\ufffd', "'")
514
  text = text.replace('\uFFFD', "'")
515
+ text = text.replace('’', "'")
516
+ text = text.replace('â€Å"', "'")
517
+ text = text.replace('�', "'")
518
+ text = text.replace('â€"', "'")
519
+ text = text.replace('â€"', "'")
 
 
520
  return text
521
 
522
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
 
782
  - Current temperature: {self.config.temperature}
783
  - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
784
  - Source: {DEFAULT_MODEL_REPO}
785
+ - Format: Safetensors (secure)
786
  """
787
  print(colored(stats, 'yellow'))
788
  return True
 
921
 
922
 
923
  def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
924
+ """Download and load CosmicFish model from Hugging Face Hub (safetensors only)"""
925
  print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
926
 
927
  try:
 
956
  print("Creating model...")
957
  model = CosmicFish(config)
958
 
959
+ # Load weights from safetensors ONLY
960
+ print("Loading weights from safetensors...")
961
+ safetensors_path = os.path.join(cache_dir, "model.safetensors")
962
+
963
+ if not os.path.exists(safetensors_path):
964
+ raise FileNotFoundError(f"model.safetensors not found in {cache_dir}. This model requires safetensors format.")
965
+
966
+ state_dict = load_file(safetensors_path)
967
+
968
+ # Handle weight sharing: lm_head.weight shares with transformer.wte.weight
969
+ if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
970
+ state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
971
+
972
  model.load_state_dict(state_dict)
973
  model.to(device)
974
  model.eval()
 
984
 
985
 
986
  def load_tokenizer():
 
987
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 
988
  return tokenizer
989
 
990
 
991
  def main():
992
+ parser = argparse.ArgumentParser(description="Chat with CosmicFish")
993
 
994
  # Model parameters
995
  parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
 
1000
  # Generation parameters
1001
  parser.add_argument("--temperature", type=float, default=0.7,
1002
  help="Temperature for sampling (default: 0.7)")
1003
+ parser.add_argument("--max_tokens", type=int, default=1024,
1004
  help="Maximum number of tokens to generate per response")
1005
  parser.add_argument("--min_tokens", type=int, default=10,
1006
  help="Minimum number of tokens to generate per response")
 
1067
  chat = CosmicFishChatSession(model, tokenizer, config)
1068
 
1069
  # Main chat loop
1070
+ print(colored("\nCosmicFish initialized from Hugging Face! Type your message (or /help for commands).\n", 'cyan'))
1071
 
1072
  while True:
1073
  try:
 
1105
  if not live_buffer:
1106
  print(final_response, end="")
1107
  break
 
 
1108
  if token:
1109
  # Check if token contains <|endoftext|> and remove it if present
1110
  if "<|endoftext|>" in token: