| """ |
| Chat interface for the released CosmicFish model from Hugging Face. |
| Compatible with the HF-format release while maintaining all original features. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import argparse |
| import torch |
| import numpy as np |
| from termcolor import colored |
| import logging |
| import readline |
| import re |
| import textwrap |
| import random |
| from collections import defaultdict |
| import json |
|
|
| |
| try: |
| from transformers import GPT2Tokenizer |
| HF_AVAILABLE = True |
| except ImportError: |
| HF_AVAILABLE = False |
| print("❌ Transformers not available. Install with: pip install transformers") |
|
|
| |
| try: |
| from modeling_cosmicfish import CosmicFish, CosmicConfig |
| except ImportError: |
| try: |
| from model import CosmicFish, CosmicConfig |
| except ImportError: |
| print("❌ CosmicFish model classes not found. Make sure modeling_cosmicfish.py or model.py is available.") |
| sys.exit(1) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[logging.StreamHandler(sys.stdout)] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n" |
|
|
|
|
| class RepetitionPenaltyLogitsProcessor: |
| """Apply repetition penalty to prevent repeating tokens.""" |
|
|
| def __init__(self, penalty=1.2): |
| self.penalty = penalty |
|
|
| def __call__(self, input_ids, scores): |
| """Apply repetition penalty to logits where input_ids is already seen.""" |
| score = torch.gather(scores, 1, input_ids) |
| |
| score = torch.where(score > 0, score / self.penalty, score * self.penalty) |
| scores.scatter_(1, input_ids, score) |
| return scores |
|
|
|
|
| class CosmicFishChatSession: |
| """Chat session for the released CosmicFish model.""" |
|
|
| def __init__(self, model, tokenizer, config): |
| """Initialize chat session with model and configuration.""" |
| self.model = model |
| self.tokenizer = tokenizer |
| self.config = config |
| self.device = next(model.parameters()).device |
| self.history = [] |
| self.history_tokens = [] |
| self.max_history_tokens = config.max_history_tokens |
| self.prompt_template = config.prompt_template |
| self.human_prefix = config.human_prefix |
| self.assistant_prefix = config.assistant_prefix |
| self.end_of_turn = config.end_of_turn |
| self.block_size = config.block_size |
| self.debug_mode = config.debug_mode |
| self.repetition_penalty = config.repetition_penalty |
| self.min_tokens_to_generate = config.min_tokens_to_generate |
| self.max_retries = 20 |
|
|
| self.fallback_responses = [ |
| "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?", |
| "That's a topic I can provide information about. What specific aspects would you like to know?", |
| "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.", |
| "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?", |
| "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?" |
| ] |
|
|
| self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?" |
|
|
| |
| self.total_prompt_tokens = 0 |
| self.total_generated_tokens = 0 |
|
|
| |
| self.end_markers = [ |
| f"{self.human_prefix}", |
| "Human:", |
| "\nHuman:", |
| "\nH:", |
| "H:", |
| "<|endoftext|>", |
| "Below is a conversation", |
| "\nA:", |
| "A:", |
| "</s>", |
| "User:", |
| "\nUser:" |
| ] |
|
|
| |
| if config.display_welcome: |
| self._print_welcome_message() |
|
|
| def _print_welcome_message(self): |
| """Print a welcome message to the user.""" |
| welcome_text = f""" |
| {'=' * 80} |
| Welcome to CosmicFish chat interface (Hugging Face Release) |
| |
| This is a {self.model.get_num_params() / 1e6:.1f}M parameter model. |
| CosmicFish features advanced architecture with RoPE, GQA, SwiGLU, and RMSNorm. |
| |
| Type your prompts and CosmicFish will respond. |
| |
| Special commands: |
| - /help: Show this help message |
| - /clear: Clear the conversation history |
| - /exit or /quit: Exit the chat |
| - /stats: Show token usage statistics |
| - /save [filename]: Save the conversation |
| - /load [filename]: Load a conversation |
| - /temp [value]: Set temperature (between 0.1 and 2.0) |
| - /penalty [value]: Set repetition penalty (1.0-2.0) |
| - /debug: Toggle debug mode |
| {'=' * 80} |
| """ |
| print(colored(welcome_text, 'cyan')) |
|
|
| def _format_prompt(self, user_input): |
| """Format the complete prompt with history and current input.""" |
| |
| formatted_prompt = self.prompt_template |
|
|
| |
| for entry in self.history: |
| role, text = entry |
| if role == "human": |
| formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}" |
| else: |
| formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}" |
|
|
| |
| formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}" |
|
|
| return formatted_prompt |
|
|
| def _tokenize(self, text): |
| """Tokenize text and return token IDs.""" |
| return self.tokenizer.encode(text) |
|
|
| def _update_history(self, user_input, response): |
| """Update conversation history.""" |
| |
| self.history.append(("human", user_input)) |
| self.history.append(("assistant", response)) |
|
|
| |
| user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}") |
| response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}") |
|
|
| self.history_tokens.extend(user_tokens) |
| self.history_tokens.extend(response_tokens) |
|
|
| |
| self.total_prompt_tokens += len(user_tokens) |
| self.total_generated_tokens += len(response_tokens) |
|
|
| |
| self._trim_history_if_needed() |
|
|
| def _trim_history_if_needed(self): |
| """Trim history to fit within the context window.""" |
| if len(self.history_tokens) > self.max_history_tokens: |
| |
| while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2: |
| |
| self.history = self.history[2:] |
|
|
| |
| user_turn = self.history[0][1] |
| assistant_turn = self.history[1][1] |
| user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}")) |
| assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}")) |
|
|
| |
| self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:] |
|
|
| def _should_stop_generation(self, text): |
| """Check if generation should stop based on end markers.""" |
| for marker in self.end_markers: |
| if marker in text: |
| return True |
| return False |
|
|
| def _clean_token_text(self, text): |
| """Clean token text by fixing encoding issues.""" |
| |
| text = text.replace('��', "'") |
| return text |
|
|
| def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False): |
| """Custom generate function with repetition penalty and optional live generation.""" |
| model = self.model |
| device = self.device |
|
|
| |
| model.eval() |
|
|
| |
| generated = input_ids.clone() |
|
|
| |
| live_buffer = "" |
|
|
| |
| rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty) |
|
|
| |
| tokens_generated = 0 |
| min_tokens = self.min_tokens_to_generate |
|
|
| |
| eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256 |
|
|
| |
| for _ in range(max_new_tokens): |
| |
| if generated.size(1) > self.block_size: |
| context = generated[:, -self.block_size:] |
| else: |
| context = generated |
|
|
| |
| with torch.no_grad(): |
| logits, _ = model(context) |
|
|
| |
| next_token_logits = logits[:, -1, :] |
|
|
| |
| next_token_logits = next_token_logits / temperature |
|
|
| |
| if penalty > 1.0: |
| next_token_logits = rep_processor(context, next_token_logits) |
|
|
| |
| if top_k is not None: |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
| next_token_logits[indices_to_remove] = float('-inf') |
|
|
| |
| probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
|
| |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| |
| if next_token.item() == eot_token_id: |
| if live: |
| yield "", live_buffer, True |
| break |
|
|
| |
| generated = torch.cat((generated, next_token), dim=1) |
| tokens_generated += 1 |
|
|
| |
| if live: |
| |
| next_token_text = self.tokenizer.decode([next_token.item()]) |
| |
| next_token_text = self._clean_token_text(next_token_text) |
| live_buffer += next_token_text |
|
|
| |
| eot_marker_pos = live_buffer.find("<|endoftext|>") |
| if eot_marker_pos != -1: |
| |
| live_buffer = live_buffer[:eot_marker_pos] |
| yield "", live_buffer, True |
| break |
|
|
| |
| should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer) |
| yield next_token_text, live_buffer, should_stop |
|
|
| if should_stop: |
| break |
|
|
| |
| elif tokens_generated >= min_tokens: |
| |
| recent_text = self.tokenizer.decode(generated[0, -20:].tolist()) |
| if self._should_stop_generation(recent_text): |
| break |
|
|
| |
| if tokens_generated == 0 and not live: |
| if self.debug_mode: |
| print(colored("\n[No tokens generated in this attempt]", "red")) |
| return None |
|
|
| if not live: |
| return generated |
|
|
| def generate_response(self, user_input): |
| """Generate a response to the user input.""" |
| |
| prompt = self._format_prompt(user_input) |
|
|
| |
| input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device) |
|
|
| |
| if input_ids.size(1) > self.block_size: |
| |
| instruction_tokens = self._tokenize(self.prompt_template) |
| |
| keep_from_beginning = len(instruction_tokens) |
| keep_from_end = self.block_size - keep_from_beginning |
|
|
| |
| if keep_from_end < 0: |
| |
| input_ids = input_ids[:, :self.block_size] |
| else: |
| |
| input_ids = torch.cat([ |
| input_ids[:, :keep_from_beginning], |
| input_ids[:, -(keep_from_end):] |
| ], dim=1) |
|
|
| |
| start_time = time.time() |
|
|
| |
| return self._generate_live_response(input_ids, user_input, start_time) |
|
|
| def _generate_live_response(self, input_ids, user_input, start_time): |
| """Generate response with live token-by-token output.""" |
| |
| live_text = "" |
| tokens_generated = 0 |
| retry_count = 0 |
|
|
| |
| while retry_count <= self.max_retries: |
| if retry_count > 0: |
| |
| if retry_count % 2 == 0: |
| |
| temp_adjustment = min(0.2 * (retry_count // 2), 0.8) |
| current_temp = min(self.config.temperature + temp_adjustment, 1.8) |
| else: |
| |
| temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4) |
| current_temp = max(self.config.temperature - temp_adjustment, 0.2) |
|
|
| if self.debug_mode: |
| print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow")) |
| else: |
| current_temp = self.config.temperature |
|
|
| |
| live_text = "" |
| tokens_generated = 0 |
| generation_failed = False |
|
|
| |
| try: |
| |
| for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty( |
| input_ids, |
| max_new_tokens=self.config.max_new_tokens, |
| temperature=current_temp, |
| top_k=self.config.top_k, |
| penalty=self.repetition_penalty, |
| live=True |
| ): |
| |
| if should_stop: |
| |
| live_text = live_buffer |
| break |
|
|
| |
| if token_text: |
| live_text += token_text |
| tokens_generated += 1 |
| yield token_text, live_text, False |
|
|
| |
| if not live_text or len(live_text.strip()) < 10: |
| if self.debug_mode: |
| print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow")) |
| generation_failed = True |
| retry_count += 1 |
| |
| if retry_count <= self.max_retries: |
| print("\r" + " " * 80 + "\r", end="") |
| else: |
| |
| break |
|
|
| except Exception as e: |
| if self.debug_mode: |
| print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red")) |
| generation_failed = True |
| retry_count += 1 |
|
|
| |
| if generation_failed or not live_text or len(live_text.strip()) < 10: |
| live_text = self.generation_failure_message |
| if self.debug_mode: |
| print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red")) |
|
|
| |
| time_taken = time.time() - start_time |
| tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0 |
|
|
| |
| self._update_history(user_input, live_text) |
|
|
| |
| logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)") |
|
|
| |
| yield "", live_text, True |
|
|
| def execute_command(self, command): |
| """Execute a special command prefixed with /.""" |
| command = command.strip() |
|
|
| if command == '/help': |
| self._print_welcome_message() |
| return True |
|
|
| elif command == '/clear': |
| self.history = [] |
| self.history_tokens = [] |
| print(colored("Conversation history cleared.", 'yellow')) |
| return True |
|
|
| elif command in ['/exit', '/quit']: |
| print(colored("Goodbye!", 'cyan')) |
| return False |
|
|
| elif command == '/stats': |
| prompt_tokens = self.total_prompt_tokens |
| generated_tokens = self.total_generated_tokens |
| total_tokens = prompt_tokens + generated_tokens |
|
|
| stats = f""" |
| Token usage statistics: |
| - Prompt tokens: {prompt_tokens} |
| - Generated tokens: {generated_tokens} |
| - Total tokens: {total_tokens} |
| - Current history length: {len(self.history_tokens)} tokens |
| - Current repetition penalty: {self.repetition_penalty} |
| - Current temperature: {self.config.temperature} |
| - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters) |
| """ |
| print(colored(stats, 'yellow')) |
| return True |
|
|
| elif command == '/debug': |
| self.debug_mode = not self.debug_mode |
| self.config.debug_mode = self.debug_mode |
| mode = "enabled" if self.debug_mode else "disabled" |
| print(colored(f"Debug mode {mode}", 'yellow')) |
| return True |
|
|
| elif command.startswith('/penalty '): |
| try: |
| penalty = float(command[9:].strip()) |
| if 1.0 <= penalty <= 2.0: |
| self.repetition_penalty = penalty |
| print(colored(f"Repetition penalty set to {penalty}", 'yellow')) |
| else: |
| print(colored("Repetition penalty should be between 1.0 and 2.0", 'red')) |
| except ValueError: |
| print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red')) |
| return True |
|
|
| elif command.startswith('/temp '): |
| try: |
| temp = float(command[6:].strip()) |
| if 0.1 <= temp <= 2.0: |
| self.config.temperature = temp |
| print(colored(f"Temperature set to {temp}", 'yellow')) |
| else: |
| print(colored("Temperature should be between 0.1 and 2.0", 'red')) |
| except ValueError: |
| print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red')) |
| return True |
|
|
| elif command.startswith('/save '): |
| filename = command[6:].strip() |
| if not filename: |
| print(colored("Please specify a filename: /save <filename>", 'red')) |
| return True |
|
|
| try: |
| |
| os.makedirs('conversations', exist_ok=True) |
|
|
| |
| if not filename.endswith('.txt'): |
| filename += '.txt' |
|
|
| filepath = os.path.join('conversations', filename) |
|
|
| with open(filepath, 'w', encoding='utf-8') as f: |
| for entry in self.history: |
| role, text = entry |
| prefix = self.human_prefix if role == "human" else self.assistant_prefix |
| f.write(f"{prefix}{text}{self.end_of_turn}") |
|
|
| print(colored(f"Conversation saved to {filepath}", 'green')) |
|
|
| except Exception as e: |
| print(colored(f"Error saving conversation: {str(e)}", 'red')) |
|
|
| return True |
|
|
| elif command.startswith('/load '): |
| filename = command[6:].strip() |
| if not filename: |
| print(colored("Please specify a filename: /load <filename>", 'red')) |
| return True |
|
|
| try: |
| |
| if not filename.endswith('.txt'): |
| filename += '.txt' |
|
|
| filepath = os.path.join('conversations', filename) |
|
|
| if not os.path.exists(filepath): |
| print(colored(f"File not found: {filepath}", 'red')) |
| return True |
|
|
| with open(filepath, 'r', encoding='utf-8') as f: |
| content = f.read() |
|
|
| |
| self.history = [] |
| self.history_tokens = [] |
|
|
| |
| turns = content.split(self.end_of_turn) |
| for turn in turns: |
| turn = turn.strip() |
| if not turn: |
| continue |
|
|
| if turn.startswith(self.human_prefix): |
| text = turn[len(self.human_prefix):].strip() |
| self.history.append(("human", text)) |
| elif turn.startswith(self.assistant_prefix): |
| text = turn[len(self.assistant_prefix):].strip() |
| self.history.append(("assistant", text)) |
|
|
| |
| self.history_tokens = [] |
| for entry in self.history: |
| role, text = entry |
| if role == "human": |
| self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}")) |
| else: |
| self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}")) |
|
|
| print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green')) |
|
|
| |
| for i in range(0, len(self.history), 2): |
| if i < len(self.history): |
| user_text = self.history[i][1] |
| print(colored(f"\nYou: {user_text}", 'green')) |
|
|
| if i + 1 < len(self.history): |
| assistant_text = self.history[i + 1][1] |
| print(colored("CosmicFish: ", 'blue'), end="") |
| for line in assistant_text.split('\n'): |
| wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else [''] |
| for wrapped_line in wrapped_lines: |
| print(wrapped_line) |
|
|
| except Exception as e: |
| print(colored(f"Error loading conversation: {str(e)}", 'red')) |
|
|
| return True |
|
|
| else: |
| print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red')) |
| return True |
|
|
|
|
| def load_cosmicfish_model(model_dir, device='cpu'): |
| """Load CosmicFish model from HF-format directory""" |
| print(f"Loading CosmicFish model from {model_dir}...") |
|
|
| |
| config_path = os.path.join(model_dir, "config.json") |
| if not os.path.exists(config_path): |
| raise FileNotFoundError(f"config.json not found in {model_dir}") |
|
|
| with open(config_path, "r") as f: |
| config_dict = json.load(f) |
|
|
| |
| config = CosmicConfig( |
| vocab_size=config_dict["vocab_size"], |
| block_size=config_dict["block_size"], |
| n_layer=config_dict["n_layer"], |
| n_head=config_dict["n_head"], |
| n_embd=config_dict["n_embd"], |
| bias=config_dict["bias"], |
| dropout=0.0, |
| eps=config_dict.get("eps", 1e-6), |
| use_rotary=config_dict["use_rotary"], |
| use_swiglu=config_dict["use_swiglu"], |
| use_gqa=config_dict["use_gqa"], |
| n_query_groups=config_dict["n_query_groups"], |
| use_qk_norm=config_dict.get("use_qk_norm", False) |
| ) |
|
|
| |
| model = CosmicFish(config) |
|
|
| |
| weights_path = os.path.join(model_dir, "pytorch_model.bin") |
| if not os.path.exists(weights_path): |
| raise FileNotFoundError(f"pytorch_model.bin not found in {model_dir}") |
|
|
| state_dict = torch.load(weights_path, map_location=device) |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
|
|
| print(f"✅ Model loaded: {model.get_num_params() / 1e6:.1f}M parameters") |
| return model, config |
|
|
|
|
| def load_tokenizer(): |
| """Load GPT-2 tokenizer""" |
| if not HF_AVAILABLE: |
| raise ImportError("transformers library required. Install with: pip install transformers") |
|
|
| print("Loading GPT-2 tokenizer...") |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| print("✅ Tokenizer loaded") |
| return tokenizer |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Chat with the released CosmicFish model") |
|
|
| |
| parser.add_argument("--model_dir", type=str, default="./cosmicfish-hf-release", |
| help="Path to the HF-format model directory") |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device to use (cuda or cpu)") |
|
|
| |
| parser.add_argument("--temperature", type=float, default=0.6, |
| help="Temperature for sampling (default: 0.7)") |
| parser.add_argument("--max_tokens", type=int, default=1024, |
| help="Maximum number of tokens to generate per response") |
| parser.add_argument("--min_tokens", type=int, default=10, |
| help="Minimum number of tokens to generate per response") |
| parser.add_argument("--top_k", type=int, default=40, |
| help="Top-k sampling (0 to disable)") |
| parser.add_argument("--repetition_penalty", type=float, default=1.2, |
| help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)") |
|
|
| |
| parser.add_argument("--human_prefix", type=str, default="Human: ", |
| help="Prefix for human messages") |
| parser.add_argument("--assistant_prefix", type=str, default="Assistant: ", |
| help="Prefix for assistant messages") |
| parser.add_argument("--end_of_turn", type=str, default="\n\n", |
| help="Delimiter between conversation turns") |
| parser.add_argument("--instruction", type=str, |
| default=DEFAULT_PROMPT_TEMPLATE, |
| help="Instruction prompt to prepend to the conversation") |
| parser.add_argument("--max_history", type=int, default=1024, |
| help="Maximum number of tokens to keep in history") |
|
|
| |
| parser.add_argument("--no_welcome", action="store_true", |
| help="Don't display the welcome message") |
| parser.add_argument("--debug", action="store_true", |
| help="Enable debug mode") |
|
|
| args = parser.parse_args() |
|
|
| |
| device = args.device |
| if device == "cuda" and not torch.cuda.is_available(): |
| print("CUDA is not available, falling back to CPU") |
| device = "cpu" |
|
|
| try: |
| |
| model, model_config = load_cosmicfish_model(args.model_dir, device) |
|
|
| |
| tokenizer = load_tokenizer() |
|
|
| |
| class ChatConfig: |
| def __init__(self, args, block_size): |
| self.device = device |
| self.temperature = args.temperature |
| self.max_new_tokens = args.max_tokens |
| self.min_tokens_to_generate = args.min_tokens |
| self.top_k = args.top_k |
| self.human_prefix = args.human_prefix |
| self.assistant_prefix = args.assistant_prefix |
| self.end_of_turn = args.end_of_turn |
| self.prompt_template = args.instruction |
| self.max_history_tokens = args.max_history |
| self.display_welcome = not args.no_welcome |
| self.block_size = block_size |
| self.debug_mode = args.debug |
| self.repetition_penalty = args.repetition_penalty |
|
|
| config = ChatConfig(args, model_config.block_size) |
|
|
| |
| chat = CosmicFishChatSession(model, tokenizer, config) |
|
|
| |
| print(colored("\nCosmicFish initialized. Type your message (or /help for commands).\n", 'cyan')) |
|
|
| while True: |
| try: |
| |
| user_input = input(colored("You: ", 'green')) |
|
|
| |
| if user_input.startswith('/'): |
| |
| if not chat.execute_command(user_input): |
| break |
| continue |
|
|
| |
| if not user_input.strip(): |
| continue |
|
|
| |
| live_buffer = "" |
| final_response = None |
|
|
| |
| response_generator = chat.generate_response(user_input) |
|
|
| try: |
| |
| print(colored("CosmicFish: ", 'blue'), end="") |
| sys.stdout.flush() |
|
|
| for token, live_text, is_done in response_generator: |
| |
| if is_done: |
| final_response = live_text |
| |
| if not live_buffer: |
| print(final_response, end="") |
| break |
|
|
| |
| if token: |
| |
| if "<|endoftext|>" in token: |
| token = token.replace("<|endoftext|>", "") |
| if token: |
| print(token, end="", flush=True) |
| break |
|
|
| |
| print(token, end="", flush=True) |
| live_buffer += token |
|
|
| except KeyboardInterrupt: |
| |
| print("\n[Generation interrupted]") |
| final_response = "I was going to respond, but I'll stop here since you interrupted." |
|
|
| |
| print() |
|
|
| except KeyboardInterrupt: |
| print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.") |
|
|
| except Exception as e: |
| print(colored(f"\nError: {str(e)}", 'red')) |
| logger.error(f"Error in chat loop: {str(e)}", exc_info=True) |
|
|
| except Exception as e: |
| print(colored(f"Error loading model: {str(e)}", 'red')) |
| logger.error(f"Error loading model: {str(e)}", exc_info=True) |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| main() |
| except Exception as e: |
| logger.error(f"Fatal error: {str(e)}", exc_info=True) |
| sys.exit(1) |
|
|