updated chat.py for better inference
Browse files
chat.py
CHANGED
|
@@ -11,7 +11,7 @@ import torch
|
|
| 11 |
import numpy as np
|
| 12 |
from termcolor import colored
|
| 13 |
import logging
|
| 14 |
-
import readline
|
| 15 |
import re
|
| 16 |
import textwrap
|
| 17 |
import random
|
|
@@ -25,7 +25,7 @@ try:
|
|
| 25 |
HF_AVAILABLE = True
|
| 26 |
except ImportError:
|
| 27 |
HF_AVAILABLE = False
|
| 28 |
-
print("
|
| 29 |
print("Install with: pip install transformers huggingface-hub")
|
| 30 |
sys.exit(1)
|
| 31 |
|
|
@@ -38,7 +38,7 @@ logging.basicConfig(
|
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
| 40 |
# Default model repository
|
| 41 |
-
DEFAULT_MODEL_REPO = "
|
| 42 |
|
| 43 |
# Default prompt template
|
| 44 |
DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
|
|
@@ -54,7 +54,7 @@ class CosmicConfig:
|
|
| 54 |
n_head=16,
|
| 55 |
n_embd=704,
|
| 56 |
bias=True,
|
| 57 |
-
dropout=0.0,
|
| 58 |
n_query_groups=4,
|
| 59 |
eps=1e-6,
|
| 60 |
use_rotary=True,
|
|
@@ -405,11 +405,15 @@ class CosmicFishChatSession:
|
|
| 405 |
"""Print a welcome message to the user."""
|
| 406 |
welcome_text = f"""
|
| 407 |
{'=' * 80}
|
| 408 |
-
Welcome to CosmicFish
|
| 409 |
|
| 410 |
-
This is a {self.model.get_num_params() / 1e6:.1f}M parameter model
|
| 411 |
CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
Model: {DEFAULT_MODEL_REPO}
|
| 414 |
|
| 415 |
Type your prompts and CosmicFish will respond.
|
|
@@ -495,9 +499,19 @@ Special commands:
|
|
| 495 |
return False
|
| 496 |
|
| 497 |
def _clean_token_text(self, text):
|
| 498 |
-
|
| 499 |
-
# Fix the specific issue with �� -> '
|
| 500 |
text = text.replace('��', "'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
return text
|
| 502 |
|
| 503 |
def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
|
|
@@ -902,13 +916,13 @@ Token usage statistics:
|
|
| 902 |
|
| 903 |
def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
| 904 |
"""Download and load CosmicFish model from Hugging Face Hub"""
|
| 905 |
-
print(colored(f"
|
| 906 |
|
| 907 |
try:
|
| 908 |
# Download the model files to local cache
|
| 909 |
-
print("
|
| 910 |
cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
|
| 911 |
-
print(f"
|
| 912 |
|
| 913 |
# Load config
|
| 914 |
config_path = os.path.join(cache_dir, "config.json")
|
|
@@ -933,32 +947,31 @@ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
|
| 933 |
)
|
| 934 |
|
| 935 |
# Create model
|
| 936 |
-
print("
|
| 937 |
model = CosmicFish(config)
|
| 938 |
|
| 939 |
# Load weights
|
| 940 |
-
print("
|
| 941 |
weights_path = os.path.join(cache_dir, "pytorch_model.bin")
|
| 942 |
state_dict = torch.load(weights_path, map_location=device)
|
| 943 |
model.load_state_dict(state_dict)
|
| 944 |
model.to(device)
|
| 945 |
model.eval()
|
| 946 |
|
| 947 |
-
print(f"
|
| 948 |
-
print(f"
|
| 949 |
return model, config
|
| 950 |
|
| 951 |
except Exception as e:
|
| 952 |
-
print(colored(f"
|
| 953 |
-
print(colored("
|
| 954 |
sys.exit(1)
|
| 955 |
|
| 956 |
|
| 957 |
def load_tokenizer():
|
| 958 |
-
"
|
| 959 |
-
print("🔤 Loading GPT-2 tokenizer...")
|
| 960 |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 961 |
-
print("
|
| 962 |
return tokenizer
|
| 963 |
|
| 964 |
|
|
@@ -1007,7 +1020,7 @@ def main():
|
|
| 1007 |
# Configure device
|
| 1008 |
device = args.device
|
| 1009 |
if device == "cuda" and not torch.cuda.is_available():
|
| 1010 |
-
print(colored("
|
| 1011 |
device = "cpu"
|
| 1012 |
|
| 1013 |
try:
|
|
@@ -1041,7 +1054,7 @@ def main():
|
|
| 1041 |
chat = CosmicFishChatSession(model, tokenizer, config)
|
| 1042 |
|
| 1043 |
# Main chat loop
|
| 1044 |
-
print(colored("\
|
| 1045 |
|
| 1046 |
while True:
|
| 1047 |
try:
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
from termcolor import colored
|
| 13 |
import logging
|
| 14 |
+
import readline
|
| 15 |
import re
|
| 16 |
import textwrap
|
| 17 |
import random
|
|
|
|
| 25 |
HF_AVAILABLE = True
|
| 26 |
except ImportError:
|
| 27 |
HF_AVAILABLE = False
|
| 28 |
+
print("Required libraries not available.")
|
| 29 |
print("Install with: pip install transformers huggingface-hub")
|
| 30 |
sys.exit(1)
|
| 31 |
|
|
|
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
| 40 |
# Default model repository
|
| 41 |
+
DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-120M"
|
| 42 |
|
| 43 |
# Default prompt template
|
| 44 |
DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
|
|
|
|
| 54 |
n_head=16,
|
| 55 |
n_embd=704,
|
| 56 |
bias=True,
|
| 57 |
+
dropout=0.0,
|
| 58 |
n_query_groups=4,
|
| 59 |
eps=1e-6,
|
| 60 |
use_rotary=True,
|
|
|
|
| 405 |
"""Print a welcome message to the user."""
|
| 406 |
welcome_text = f"""
|
| 407 |
{'=' * 80}
|
| 408 |
+
Welcome to CosmicFish!
|
| 409 |
|
| 410 |
+
This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
|
| 411 |
CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
|
| 412 |
|
| 413 |
+
⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
|
| 414 |
+
small, it is more likely to give incorrect answers or hallucinate compared to
|
| 415 |
+
larger models. Please verify important information from reliable sources.
|
| 416 |
+
|
| 417 |
Model: {DEFAULT_MODEL_REPO}
|
| 418 |
|
| 419 |
Type your prompts and CosmicFish will respond.
|
|
|
|
| 499 |
return False
|
| 500 |
|
| 501 |
def _clean_token_text(self, text):
|
| 502 |
+
|
|
|
|
| 503 |
text = text.replace('��', "'")
|
| 504 |
+
|
| 505 |
+
text = text.replace('�', "'")
|
| 506 |
+
text = text.replace('\ufffd', "'")
|
| 507 |
+
text = text.replace('\uFFFD', "'")
|
| 508 |
+
|
| 509 |
+
text = text.replace('’', "'")
|
| 510 |
+
text = text.replace('“', "'")
|
| 511 |
+
text = text.replace('�', "'")
|
| 512 |
+
text = text.replace('â€"', "'")
|
| 513 |
+
text = text.replace('â€"', "'")
|
| 514 |
+
|
| 515 |
return text
|
| 516 |
|
| 517 |
def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
|
|
|
|
| 916 |
|
| 917 |
def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
| 918 |
"""Download and load CosmicFish model from Hugging Face Hub"""
|
| 919 |
+
print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
|
| 920 |
|
| 921 |
try:
|
| 922 |
# Download the model files to local cache
|
| 923 |
+
print("Downloading model files...")
|
| 924 |
cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None)
|
| 925 |
+
print(f"Model cached at: {cache_dir}")
|
| 926 |
|
| 927 |
# Load config
|
| 928 |
config_path = os.path.join(cache_dir, "config.json")
|
|
|
|
| 947 |
)
|
| 948 |
|
| 949 |
# Create model
|
| 950 |
+
print("Creating model...")
|
| 951 |
model = CosmicFish(config)
|
| 952 |
|
| 953 |
# Load weights
|
| 954 |
+
print("Loading weights...")
|
| 955 |
weights_path = os.path.join(cache_dir, "pytorch_model.bin")
|
| 956 |
state_dict = torch.load(weights_path, map_location=device)
|
| 957 |
model.load_state_dict(state_dict)
|
| 958 |
model.to(device)
|
| 959 |
model.eval()
|
| 960 |
|
| 961 |
+
print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters")
|
| 962 |
+
print(f"Device: {device}")
|
| 963 |
return model, config
|
| 964 |
|
| 965 |
except Exception as e:
|
| 966 |
+
print(colored(f"Error downloading/loading model: {str(e)}", "red"))
|
| 967 |
+
print(colored("Make sure you have internet connection and the model repo exists", "yellow"))
|
| 968 |
sys.exit(1)
|
| 969 |
|
| 970 |
|
| 971 |
def load_tokenizer():
|
| 972 |
+
print("Loading tokenizer...")
|
|
|
|
| 973 |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 974 |
+
print("Tokenizer loaded")
|
| 975 |
return tokenizer
|
| 976 |
|
| 977 |
|
|
|
|
| 1020 |
# Configure device
|
| 1021 |
device = args.device
|
| 1022 |
if device == "cuda" and not torch.cuda.is_available():
|
| 1023 |
+
print(colored("CUDA is not available, falling back to CPU", "yellow"))
|
| 1024 |
device = "cpu"
|
| 1025 |
|
| 1026 |
try:
|
|
|
|
| 1054 |
chat = CosmicFishChatSession(model, tokenizer, config)
|
| 1055 |
|
| 1056 |
# Main chat loop
|
| 1057 |
+
print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
|
| 1058 |
|
| 1059 |
while True:
|
| 1060 |
try:
|