Update chat.py
Browse files
chat.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import time
|
|
@@ -24,6 +29,15 @@ except ImportError:
|
|
| 24 |
print("Install with: pip install transformers huggingface-hub")
|
| 25 |
sys.exit(1)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Set up logging
|
| 28 |
logging.basicConfig(
|
| 29 |
level=logging.INFO,
|
|
@@ -392,24 +406,16 @@ class CosmicFishChatSession:
|
|
| 392 |
"\nUser:"
|
| 393 |
]
|
| 394 |
|
| 395 |
-
# Print welcome message
|
| 396 |
if config.display_welcome:
|
| 397 |
self._print_welcome_message()
|
| 398 |
|
| 399 |
def _print_welcome_message(self):
|
| 400 |
-
"""Print a welcome message to the user."""
|
| 401 |
welcome_text = f"""
|
| 402 |
{'=' * 80}
|
| 403 |
-
Welcome to CosmicFish
|
| 404 |
-
|
| 405 |
-
This is a {self.model.get_num_params() / 1e6:.1f}M parameter model made by MistyozAI.
|
| 406 |
-
CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm.
|
| 407 |
-
|
| 408 |
-
⚠️ DISCLAIMER: Since this {self.model.get_num_params() / 1e6:.1f}M parameter model is relatively
|
| 409 |
-
small, it is more likely to give incorrect answers or hallucinate compared to
|
| 410 |
-
larger models. Please verify important information from reliable sources.
|
| 411 |
|
| 412 |
-
|
|
|
|
| 413 |
|
| 414 |
Type your prompts and CosmicFish will respond.
|
| 415 |
|
|
@@ -423,6 +429,14 @@ Special commands:
|
|
| 423 |
- /temp [value]: Set temperature (between 0.1 and 2.0)
|
| 424 |
- /penalty [value]: Set repetition penalty (1.0-2.0)
|
| 425 |
- /debug: Toggle debug mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
{'=' * 80}
|
| 427 |
"""
|
| 428 |
print(colored(welcome_text, 'cyan'))
|
|
@@ -494,19 +508,15 @@ Special commands:
|
|
| 494 |
return False
|
| 495 |
|
| 496 |
def _clean_token_text(self, text):
|
| 497 |
-
|
| 498 |
text = text.replace('��', "'")
|
| 499 |
-
|
| 500 |
text = text.replace('�', "'")
|
| 501 |
text = text.replace('\ufffd', "'")
|
| 502 |
text = text.replace('\uFFFD', "'")
|
| 503 |
-
|
| 504 |
-
text = text.replace('
|
| 505 |
-
text = text.replace('
|
| 506 |
-
text = text.replace('
|
| 507 |
-
text = text.replace('
|
| 508 |
-
text = text.replace('â€"', "'")
|
| 509 |
-
|
| 510 |
return text
|
| 511 |
|
| 512 |
def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
|
|
@@ -772,6 +782,7 @@ Token usage statistics:
|
|
| 772 |
- Current temperature: {self.config.temperature}
|
| 773 |
- Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
|
| 774 |
- Source: {DEFAULT_MODEL_REPO}
|
|
|
|
| 775 |
"""
|
| 776 |
print(colored(stats, 'yellow'))
|
| 777 |
return True
|
|
@@ -910,7 +921,7 @@ Token usage statistics:
|
|
| 910 |
|
| 911 |
|
| 912 |
def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
| 913 |
-
"""Download and load CosmicFish model from Hugging Face Hub"""
|
| 914 |
print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
|
| 915 |
|
| 916 |
try:
|
|
@@ -945,10 +956,19 @@ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
|
| 945 |
print("Creating model...")
|
| 946 |
model = CosmicFish(config)
|
| 947 |
|
| 948 |
-
# Load weights
|
| 949 |
-
print("Loading weights...")
|
| 950 |
-
|
| 951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
model.load_state_dict(state_dict)
|
| 953 |
model.to(device)
|
| 954 |
model.eval()
|
|
@@ -964,14 +984,12 @@ def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
|
| 964 |
|
| 965 |
|
| 966 |
def load_tokenizer():
|
| 967 |
-
print("Loading tokenizer...")
|
| 968 |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 969 |
-
print("Tokenizer loaded")
|
| 970 |
return tokenizer
|
| 971 |
|
| 972 |
|
| 973 |
def main():
|
| 974 |
-
parser = argparse.ArgumentParser(description="Chat with CosmicFish
|
| 975 |
|
| 976 |
# Model parameters
|
| 977 |
parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
|
|
@@ -982,7 +1000,7 @@ def main():
|
|
| 982 |
# Generation parameters
|
| 983 |
parser.add_argument("--temperature", type=float, default=0.7,
|
| 984 |
help="Temperature for sampling (default: 0.7)")
|
| 985 |
-
parser.add_argument("--max_tokens", type=int, default=
|
| 986 |
help="Maximum number of tokens to generate per response")
|
| 987 |
parser.add_argument("--min_tokens", type=int, default=10,
|
| 988 |
help="Minimum number of tokens to generate per response")
|
|
@@ -1049,7 +1067,7 @@ def main():
|
|
| 1049 |
chat = CosmicFishChatSession(model, tokenizer, config)
|
| 1050 |
|
| 1051 |
# Main chat loop
|
| 1052 |
-
print(colored("\nCosmicFish initialized! Type your message (or /help for commands).\n", 'cyan'))
|
| 1053 |
|
| 1054 |
while True:
|
| 1055 |
try:
|
|
@@ -1087,8 +1105,6 @@ def main():
|
|
| 1087 |
if not live_buffer:
|
| 1088 |
print(final_response, end="")
|
| 1089 |
break
|
| 1090 |
-
|
| 1091 |
-
# If we have a token to display
|
| 1092 |
if token:
|
| 1093 |
# Check if token contains <|endoftext|> and remove it if present
|
| 1094 |
if "<|endoftext|>" in token:
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat interface for CosmicFish model downloaded from Hugging Face Hub.
|
| 3 |
+
Uses safetensors format only for secure model loading.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
import time
|
|
|
|
| 29 |
print("Install with: pip install transformers huggingface-hub")
|
| 30 |
sys.exit(1)
|
| 31 |
|
| 32 |
+
# Required for safetensors
|
| 33 |
+
try:
|
| 34 |
+
from safetensors.torch import load_file
|
| 35 |
+
SAFETENSORS_AVAILABLE = True
|
| 36 |
+
except ImportError:
|
| 37 |
+
SAFETENSORS_AVAILABLE = False
|
| 38 |
+
print("Safetensors not available. Install with: pip install safetensors")
|
| 39 |
+
sys.exit(1)
|
| 40 |
+
|
| 41 |
# Set up logging
|
| 42 |
logging.basicConfig(
|
| 43 |
level=logging.INFO,
|
|
|
|
| 406 |
"\nUser:"
|
| 407 |
]
|
| 408 |
|
|
|
|
| 409 |
if config.display_welcome:
|
| 410 |
self._print_welcome_message()
|
| 411 |
|
| 412 |
def _print_welcome_message(self):
|
|
|
|
| 413 |
welcome_text = f"""
|
| 414 |
{'=' * 80}
|
| 415 |
+
Welcome to CosmicFish chat interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
This is a {self.model.get_num_params() / 1e6:.1f}M parameter model.
|
| 418 |
+
CosmicFish is an efficient LLM with an advanced architecture.
|
| 419 |
|
| 420 |
Type your prompts and CosmicFish will respond.
|
| 421 |
|
|
|
|
| 429 |
- /temp [value]: Set temperature (between 0.1 and 2.0)
|
| 430 |
- /penalty [value]: Set repetition penalty (1.0-2.0)
|
| 431 |
- /debug: Toggle debug mode
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
Note: CosmicFIsh may generate incorrect or fictional responses. Verify facts if needed.
|
| 435 |
+
|
| 436 |
+
Visit https://cosmicfish.ai for more info
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
Developed by Mistyoz AI (https://www.mistyoz.com)
|
| 440 |
{'=' * 80}
|
| 441 |
"""
|
| 442 |
print(colored(welcome_text, 'cyan'))
|
|
|
|
| 508 |
return False
|
| 509 |
|
| 510 |
def _clean_token_text(self, text):
|
|
|
|
| 511 |
text = text.replace('��', "'")
|
|
|
|
| 512 |
text = text.replace('�', "'")
|
| 513 |
text = text.replace('\ufffd', "'")
|
| 514 |
text = text.replace('\uFFFD', "'")
|
| 515 |
+
text = text.replace('’', "'")
|
| 516 |
+
text = text.replace('â€Å"', "'")
|
| 517 |
+
text = text.replace('�', "'")
|
| 518 |
+
text = text.replace('â€"', "'")
|
| 519 |
+
text = text.replace('â€"', "'")
|
|
|
|
|
|
|
| 520 |
return text
|
| 521 |
|
| 522 |
def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False):
|
|
|
|
| 782 |
- Current temperature: {self.config.temperature}
|
| 783 |
- Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters)
|
| 784 |
- Source: {DEFAULT_MODEL_REPO}
|
| 785 |
+
- Format: Safetensors (secure)
|
| 786 |
"""
|
| 787 |
print(colored(stats, 'yellow'))
|
| 788 |
return True
|
|
|
|
| 921 |
|
| 922 |
|
| 923 |
def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'):
|
| 924 |
+
"""Download and load CosmicFish model from Hugging Face Hub (safetensors only)"""
|
| 925 |
print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan"))
|
| 926 |
|
| 927 |
try:
|
|
|
|
| 956 |
print("Creating model...")
|
| 957 |
model = CosmicFish(config)
|
| 958 |
|
| 959 |
+
# Load weights from safetensors ONLY
|
| 960 |
+
print("Loading weights from safetensors...")
|
| 961 |
+
safetensors_path = os.path.join(cache_dir, "model.safetensors")
|
| 962 |
+
|
| 963 |
+
if not os.path.exists(safetensors_path):
|
| 964 |
+
raise FileNotFoundError(f"model.safetensors not found in {cache_dir}. This model requires safetensors format.")
|
| 965 |
+
|
| 966 |
+
state_dict = load_file(safetensors_path)
|
| 967 |
+
|
| 968 |
+
# Handle weight sharing: lm_head.weight shares with transformer.wte.weight
|
| 969 |
+
if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
|
| 970 |
+
state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
|
| 971 |
+
|
| 972 |
model.load_state_dict(state_dict)
|
| 973 |
model.to(device)
|
| 974 |
model.eval()
|
|
|
|
| 984 |
|
| 985 |
|
| 986 |
def load_tokenizer():
|
|
|
|
| 987 |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
|
|
| 988 |
return tokenizer
|
| 989 |
|
| 990 |
|
| 991 |
def main():
|
| 992 |
+
parser = argparse.ArgumentParser(description="Chat with CosmicFish")
|
| 993 |
|
| 994 |
# Model parameters
|
| 995 |
parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO,
|
|
|
|
| 1000 |
# Generation parameters
|
| 1001 |
parser.add_argument("--temperature", type=float, default=0.7,
|
| 1002 |
help="Temperature for sampling (default: 0.7)")
|
| 1003 |
+
parser.add_argument("--max_tokens", type=int, default=1024,
|
| 1004 |
help="Maximum number of tokens to generate per response")
|
| 1005 |
parser.add_argument("--min_tokens", type=int, default=10,
|
| 1006 |
help="Minimum number of tokens to generate per response")
|
|
|
|
| 1067 |
chat = CosmicFishChatSession(model, tokenizer, config)
|
| 1068 |
|
| 1069 |
# Main chat loop
|
| 1070 |
+
print(colored("\nCosmicFish initialized from Hugging Face! Type your message (or /help for commands).\n", 'cyan'))
|
| 1071 |
|
| 1072 |
while True:
|
| 1073 |
try:
|
|
|
|
| 1105 |
if not live_buffer:
|
| 1106 |
print(final_response, end="")
|
| 1107 |
break
|
|
|
|
|
|
|
| 1108 |
if token:
|
| 1109 |
# Check if token contains <|endoftext|> and remove it if present
|
| 1110 |
if "<|endoftext|>" in token:
|