akkiisfrommars commited on
Commit
33b3450
·
verified ·
1 Parent(s): 73a9e06

updated chat.py for better inference

Browse files
Files changed (1) hide show
  1. chat.py +35 -22
chat.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import numpy as np
12
  from termcolor import colored
13
  import logging
14
- import readline # Enables arrow key history in terminal input
15
  import re
16
  import textwrap
17
  import random
@@ -25,7 +25,7 @@ try:
25
  HF_AVAILABLE = True
26
  except ImportError:
27
  HF_AVAILABLE = False
28
- print("Required libraries not available.")
29
  print("Install with: pip install transformers huggingface-hub")
30
  sys.exit(1)
31
 
@@ -38,7 +38,7 @@ logging.basicConfig(
38
  logger = logging.getLogger(__name__)
39
 
40
  # Default model repository
41
- DEFAULT_MODEL_REPO = "Mistyoz-AI/CosmicFish-120M"
42
 
43
  # Default prompt template
44
  DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
@@ -54,7 +54,7 @@ class CosmicConfig:
54
  n_head=16,
55
  n_embd=704,
56
  bias=True,
57
- dropout=0.0, # Always 0 for inference
58
  n_query_groups=4,
59
  eps=1e-6,
60
  use_rotary=True,
@@ -405,11 +405,15 @@ class CosmicFishChatSession:
405
  """Print a welcome message to the user."""
406
  welcome_text = f"""
407
  {'=' * 80}
408
- Welcome to CosmicFish chat interface (Hugging Face Hub)
409
 
410
- This is a {self.model.get_num_params() / 1e6:.1f}M parameter model loaded from HF Hub.
411
  CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
412
 
 
 
 
 
413
  Model: {DEFAULT_MODEL_REPO}
414
 
415
  Type your prompts and CosmicFish will respond.
@@ -495,9 +499,19 @@ Special commands:
495
  return False
496
 
497
  def _clean_token_text(self, text):
498
- """Clean token text by fixing encoding issues."""
499
- # Fix the specific issue with �� -> '
500
  text = text.replace('��', "'")
 
 
 
 
 
 
 
 
 
 
 
501
  return text
502
 
503
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
@@ -902,13 +916,13 @@ Token usage statistics:
902
 
903
  def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
904
  """Download and load CosmicFish model from Hugging Face Hub"""
905
- print(colored(f"🤗 Downloading CosmicFish from Hugging Face Hub: {model_repo}", "cyan"))
906
 
907
  try:
908
  # Download the model files to local cache
909
- print("📥 Downloading model files...")
910
  cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
911
- print(f"Model cached at: {cache_dir}")
912
 
913
  # Load config
914
  config_path = os.path.join(cache_dir, "config.json")
@@ -933,32 +947,31 @@ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
933
  )
934
 
935
  # Create model
936
- print("🧠 Creating model...")
937
  model = CosmicFish(config)
938
 
939
  # Load weights
940
- print("⚖️ Loading weights...")
941
  weights_path = os.path.join(cache_dir, "pytorch_model.bin")
942
  state_dict = torch.load(weights_path, map_location=device)
943
  model.load_state_dict(state_dict)
944
  model.to(device)
945
  model.eval()
946
 
947
- print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
948
- print(f"🎯 Device: {device}")
949
  return model, config
950
 
951
  except Exception as e:
952
- print(colored(f"Error downloading/loading model: {str(e)}", "red"))
953
- print(colored("💡 Make sure you have internet connection and the model repo exists", "yellow"))
954
  sys.exit(1)
955
 
956
 
957
  def load_tokenizer():
958
- """Load GPT-2 tokenizer"""
959
- print("🔤 Loading GPT-2 tokenizer...")
960
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
961
- print("Tokenizer loaded")
962
  return tokenizer
963
 
964
 
@@ -1007,7 +1020,7 @@ def main():
1007
  # Configure device
1008
  device = args.device
1009
  if device == "cuda" and not torch.cuda.is_available():
1010
- print(colored("⚠️ CUDA is not available, falling back to CPU", "yellow"))
1011
  device = "cpu"
1012
 
1013
  try:
@@ -1041,7 +1054,7 @@ def main():
1041
  chat = CosmicFishChatSession(model, tokenizer, config)
1042
 
1043
  # Main chat loop
1044
- print(colored("\n🚀 CosmicFish initialized from Hugging Face Hub. Type your message (or /help for commands).\n", 'cyan'))
1045
 
1046
  while True:
1047
  try:
 
11
  import numpy as np
12
  from termcolor import colored
13
  import logging
14
+ import readline
15
  import re
16
  import textwrap
17
  import random
 
25
  HF_AVAILABLE = True
26
  except ImportError:
27
  HF_AVAILABLE = False
28
+ print("Required libraries not available.")
29
  print("Install with: pip install transformers huggingface-hub")
30
  sys.exit(1)
31
 
 
38
  logger = logging.getLogger(__name__)
39
 
40
  # Default model repository
41
+ DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-120M"
42
 
43
  # Default prompt template
44
  DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
 
54
  n_head=16,
55
  n_embd=704,
56
  bias=True,
57
+ dropout=0.0,
58
  n_query_groups=4,
59
  eps=1e-6,
60
  use_rotary=True,
 
405
  """Print a welcome message to the user."""
406
  welcome_text = f"""
407
  {'=' * 80}
408
+ Welcome to CosmicFish!
409
 
410
+ This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
411
  CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
412
 
413
+ ⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
414
+ small, it is more likely to give incorrect answers or hallucinate compared to
415
+ larger models. Please verify important information from reliable sources.
416
+
417
  Model: {DEFAULT_MODEL_REPO}
418
 
419
  Type your prompts and CosmicFish will respond.
 
499
  return False
500
 
501
  def _clean_token_text(self, text):
502
+
 
503
  text = text.replace('��', "'")
504
+
505
+ text = text.replace('�', "'")
506
+ text = text.replace('\ufffd', "'")
507
+ text = text.replace('\uFFFD', "'")
508
+
509
+ text = text.replace('’', "'")
510
+ text = text.replace('“', "'")
511
+ text = text.replace('�', "'")
512
+ text = text.replace('â€"', "'")
513
+ text = text.replace('â€"', "'")
514
+
515
  return text
516
 
517
  def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
 
916
 
917
  def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
918
  """Download and load CosmicFish model from Hugging Face Hub"""
919
+ print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
920
 
921
  try:
922
  # Download the model files to local cache
923
+ print("Downloading model files...")
924
  cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
925
+ print(f"Model cached at: {cache_dir}")
926
 
927
  # Load config
928
  config_path = os.path.join(cache_dir, "config.json")
 
947
  )
948
 
949
  # Create model
950
+ print("Creating model...")
951
  model = CosmicFish(config)
952
 
953
  # Load weights
954
+ print("Loading weights...")
955
  weights_path = os.path.join(cache_dir, "pytorch_model.bin")
956
  state_dict = torch.load(weights_path, map_location=device)
957
  model.load_state_dict(state_dict)
958
  model.to(device)
959
  model.eval()
960
 
961
+ print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
962
+ print(f"Device: {device}")
963
  return model, config
964
 
965
  except Exception as e:
966
+ print(colored(f"Error downloading/loading model: {str(e)}", "red"))
967
+ print(colored("Make sure you have internet connection and the model repo exists", "yellow"))
968
  sys.exit(1)
969
 
970
 
971
  def load_tokenizer():
972
+ print("Loading tokenizer...")
 
973
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
974
+ print("Tokenizer loaded")
975
  return tokenizer
976
 
977
 
 
1020
  # Configure device
1021
  device = args.device
1022
  if device == "cuda" and not torch.cuda.is_available():
1023
+ print(colored("CUDA is not available, falling back to CPU", "yellow"))
1024
  device = "cpu"
1025
 
1026
  try:
 
1054
  chat = CosmicFishChatSession(model, tokenizer, config)
1055
 
1056
  # Main chat loop
1057
+ print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
1058
 
1059
  while True:
1060
  try: