import os import sys import time import argparse import torch import numpy as np from termcolor import colored import logging import readline import re import textwrap import random from collections import defaultdict import tiktoken import json from safetensors.torch import load_file from modeling_hrm_cosmicfish import HRMCosmicFish, HRMCosmicFishConfig logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n" class RepetitionPenaltyLogitsProcessor: def __init__(self, penalty=1.2): self.penalty = penalty def __call__(self, input_ids, scores): score = torch.gather(scores, 1, input_ids) score = torch.where(score > 0, score / self.penalty, score * self.penalty) scores.scatter_(1, input_ids, score) return scores class ChatSession: def __init__(self, model, tokenizer, config): self.model = model self.tokenizer = tokenizer self.config = config self.device = config.device self.history = [] self.history_tokens = [] self.max_history_tokens = config.max_history_tokens self.prompt_template = config.prompt_template self.human_prefix = config.human_prefix self.assistant_prefix = config.assistant_prefix self.end_of_turn = config.end_of_turn self.block_size = config.block_size self.debug_mode = config.debug_mode self.repetition_penalty = config.repetition_penalty self.min_tokens_to_generate = config.min_tokens_to_generate self.hrm_forced_steps = None self.original_hrm_max_steps = self.model.config.hrm_max_steps self.max_retries = 20 self.fallback_responses = [ "I'd be happy to help with that. Could you provide more details?", "That's interesting. What specific aspects would you like to know about?", "I can help with that. Could you clarify what you're looking for?", "Let me help you with that. What particular information do you need?", "I understand. Could you be more specific about what you'd like to know?" ] self.generation_failure_message = "I'm having difficulty generating a response. Could you try rephrasing?" self.total_prompt_tokens = 0 self.total_generated_tokens = 0 self.total_hrm_steps_used = 0 self.end_markers = [ f"{self.human_prefix}", "Human:", "\nHuman:", "\nH:", "H:", "<|endoftext|>", "Below is a conversation", "\nA:", "A:", "", "User:", "\nUser:" ] if config.display_welcome: self._print_welcome_message() def _print_welcome_message(self): hrm_mode = f"auto (max {self.original_hrm_max_steps})" if self.hrm_forced_steps is None else str(self.hrm_forced_steps) print(colored(f""" {'=' * 80} Welcome to CosmicFish-HRM Model: {self.model.get_num_params() / 1e6:.1f}M parameters Max HRM Steps: {self.original_hrm_max_steps} | Current HRM Mode: {hrm_mode} Commands: /help /clear /exit /stats /save /load /temp [val] /penalty [val] /hrm [n|auto] /debug {'=' * 80} """, 'cyan')) def _format_prompt(self, user_input): formatted_prompt = self.prompt_template for entry in self.history: role, text = entry if role == "human": formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}" else: formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}" formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}" return formatted_prompt def _tokenize(self, text): return self.tokenizer.encode(text) def _update_history(self, user_input, response): self.history.append(("human", user_input)) self.history.append(("assistant", response)) user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}") response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}") self.history_tokens.extend(user_tokens) self.history_tokens.extend(response_tokens) self.total_prompt_tokens += len(user_tokens) self.total_generated_tokens += len(response_tokens) self._trim_history_if_needed() def _trim_history_if_needed(self): if len(self.history_tokens) > self.max_history_tokens: while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2: self.history = self.history[2:] user_turn = self.history[0][1] assistant_turn = self.history[1][1] user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}")) assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}")) self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:] def _should_stop_generation(self, text): for marker in self.end_markers: if marker in text: return True return False def _clean_token_text(self, text): return text.replace("<|endoftext|>", "") def _is_repetitive(self, tokens, window=10): if len(tokens) < window: return False recent = tokens[-window:] if len(set(recent)) < 3: return True for pattern_len in [2, 3, 4]: if len(recent) >= pattern_len * 2: pattern = tuple(recent[-pattern_len:]) prev_pattern = tuple(recent[-pattern_len*2:-pattern_len]) if pattern == prev_pattern: return True return False def _set_hrm_steps(self, steps): self.model.config.hrm_max_steps = steps self.model.hrm_core.config.hrm_max_steps = steps def _restore_hrm_steps(self): self.model.config.hrm_max_steps = self.original_hrm_max_steps self.model.hrm_core.config.hrm_max_steps = self.original_hrm_max_steps def generate_response(self, user_input): if self.hrm_forced_steps is not None: self._set_hrm_steps(self.hrm_forced_steps) try: full_prompt = self._format_prompt(user_input) prompt_tokens = self._tokenize(full_prompt) input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(self.device) if self.debug_mode: print(f"\n[DEBUG] Prompt tokens: {len(prompt_tokens)}") print(f"[DEBUG] HRM mode: {'auto' if self.hrm_forced_steps is None else self.hrm_forced_steps} (model max: {self.model.config.hrm_max_steps})") generated_tokens = [] accumulated_text = "" repetition_processor = RepetitionPenaltyLogitsProcessor(self.repetition_penalty) total_hrm_steps = 0 with torch.no_grad(): for step in range(self.config.max_new_tokens): context = input_ids[:, -self.block_size:] if input_ids.size(1) > self.block_size else input_ids logits, _, steps_taken, _ = self.model(context) total_hrm_steps += steps_taken.item() logits = logits[:, -1, :] / self.config.temperature logits = repetition_processor(context, logits) if self.config.top_k > 0: v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') probs = torch.nn.functional.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() == 50256: break token_text = self._clean_token_text(self.tokenizer.decode([next_token.item()])) generated_tokens.append(next_token.item()) accumulated_text += token_text if self._should_stop_generation(accumulated_text): for marker in self.end_markers: if marker in accumulated_text: accumulated_text = accumulated_text.split(marker)[0] break break if self._is_repetitive(generated_tokens): if self.debug_mode: print("\n[DEBUG] Detected repetition, stopping") break yield (token_text, accumulated_text, False) input_ids = torch.cat([input_ids, next_token], dim=1) if step < self.min_tokens_to_generate: continue final_response = accumulated_text.strip() for marker in self.end_markers: if final_response.endswith(marker.strip()): final_response = final_response[:-len(marker.strip())].strip() self.total_hrm_steps_used += total_hrm_steps if self.debug_mode: avg_steps = total_hrm_steps / len(generated_tokens) if generated_tokens else 0 print(f"\n[DEBUG] Generated {len(generated_tokens)} tokens | Total HRM steps: {total_hrm_steps} | Avg steps/token: {avg_steps:.1f}") self._update_history(user_input, final_response) yield (None, final_response, True) finally: if self.hrm_forced_steps is not None: self._restore_hrm_steps() def execute_command(self, command): command_lower = command.lower().strip() if command_lower in ['/exit', '/quit', '/q']: print(colored("Goodbye!", 'cyan')) return False elif command_lower == '/help': self._print_welcome_message() elif command_lower == '/clear': self.history = [] self.history_tokens = [] print(colored("Conversation history cleared.", 'yellow')) elif command_lower == '/stats': self._print_stats() elif command_lower == '/debug': self.debug_mode = not self.debug_mode print(colored(f"Debug mode {'enabled' if self.debug_mode else 'disabled'}.", 'yellow')) elif command_lower.startswith('/temp '): try: temp = float(command.split()[1]) if 0.1 <= temp <= 2.0: self.config.temperature = temp print(colored(f"Temperature set to {temp}", 'yellow')) else: print(colored("Temperature must be between 0.1 and 2.0", 'red')) except: print(colored("Usage: /temp [value]", 'red')) elif command_lower.startswith('/penalty '): try: penalty = float(command.split()[1]) if 1.0 <= penalty <= 2.0: self.repetition_penalty = penalty print(colored(f"Repetition penalty set to {penalty}", 'yellow')) else: print(colored("Penalty must be between 1.0 and 2.0", 'red')) except: print(colored("Usage: /penalty [value]", 'red')) elif command_lower.startswith('/hrm '): try: hrm_arg = command.split()[1].lower() if hrm_arg == 'auto': self.hrm_forced_steps = 8 print(colored(f"HRM mode set to AUTO (model will use up to {self.original_hrm_max_steps} steps)", 'yellow')) else: steps = int(hrm_arg) if 0 <= steps <= 9999: self.hrm_forced_steps = steps print(colored(f"HRM forced to {steps} step(s)", 'yellow')) if steps == 0: print(colored("Warning: HRM with 0 steps means no iterative reasoning!", 'red')) else: print(colored("HRM steps must be between 0 and 9999", 'red')) except: print(colored("Usage: /hrm [number] or /hrm auto", 'red')) elif command_lower.startswith('/save '): try: self._save_conversation(command.split(maxsplit=1)[1]) except: print(colored("Usage: /save [filename]", 'red')) elif command_lower.startswith('/load '): try: self._load_conversation(command.split(maxsplit=1)[1]) except: print(colored("Usage: /load [filename]", 'red')) else: print(colored(f"Unknown command: {command}", 'red')) print(colored("Type /help for available commands", 'yellow')) return True def _print_stats(self): avg_hrm = self.total_hrm_steps_used / self.total_generated_tokens if self.total_generated_tokens > 0 else 0 hrm_mode = "AUTO" if self.hrm_forced_steps is None else f"FORCED ({self.hrm_forced_steps})" print(colored(f""" {'=' * 60} CONVERSATION STATISTICS {'=' * 60} Prompt tokens: {self.total_prompt_tokens:,} Generated tokens: {self.total_generated_tokens:,} Total HRM steps: {self.total_hrm_steps_used:,} Avg HRM steps/tok: {avg_hrm:.2f} Turns: {len(self.history) // 2} History tokens: {len(self.history_tokens):,} Temperature: {self.config.temperature} Repetition penalty: {self.repetition_penalty} HRM mode: {hrm_mode} Model max HRM steps:{self.original_hrm_max_steps} Top-k: {self.config.top_k} {'=' * 60} """, 'cyan')) def _save_conversation(self, filename): try: with open(filename, 'w', encoding='utf-8') as f: f.write("HRM-CosmicFish Conversation\n") f.write(f"{'=' * 80}\n\n") for role, text in self.history: prefix = "Human: " if role == "human" else "Assistant: " f.write(f"{prefix}{text}\n\n") print(colored(f"Conversation saved to {filename}", 'green')) except Exception as e: print(colored(f"Error saving conversation: {e}", 'red')) def _load_conversation(self, filename): try: with open(filename, 'r', encoding='utf-8') as f: lines = f.read().split('\n') self.history = [] self.history_tokens = [] current_role = None current_text = [] for line in lines: if line.startswith('Human: '): if current_role and current_text: self.history.append((current_role, '\n'.join(current_text).strip())) current_role = 'human' current_text = [line[7:]] elif line.startswith('Assistant: '): if current_role and current_text: self.history.append((current_role, '\n'.join(current_text).strip())) current_role = 'assistant' current_text = [line[11:]] elif line.strip() and current_role: current_text.append(line) if current_role and current_text: self.history.append((current_role, '\n'.join(current_text).strip())) print(colored(f"Conversation loaded from {filename} ({len(self.history)//2} turns)", 'green')) except Exception as e: print(colored(f"Error loading conversation: {e}", 'red')) def main(): parser = argparse.ArgumentParser(description="Chat with CosmicFish-HRM model") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--temperature", type=float, default=0.5) parser.add_argument("--max_tokens", type=int, default=3000) parser.add_argument("--min_tokens", type=int, default=10) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--repetition_penalty", type=float, default=1.2) parser.add_argument("--human_prefix", type=str, default="Human: ") parser.add_argument("--assistant_prefix", type=str, default="Assistant: ") parser.add_argument("--end_of_turn", type=str, default="\n\n") parser.add_argument("--instruction", type=str, default=DEFAULT_PROMPT_TEMPLATE) parser.add_argument("--max_history", type=int, default=1024) parser.add_argument("--no_welcome", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() model_dir = os.path.dirname(os.path.abspath(__file__)) device = args.device if device == "cuda" and not torch.cuda.is_available(): print("CUDA not available, falling back to CPU") device = "cpu" print(f"Loading HRM-CosmicFish model from {model_dir}...") try: config_path = os.path.join(model_dir, "config.json") weights_path = os.path.join(model_dir, "model.safetensors") if not os.path.exists(config_path): raise FileNotFoundError(f"config.json not found in {model_dir}") if not os.path.exists(weights_path): raise FileNotFoundError(f"model.safetensors not found in {model_dir}") with open(config_path) as f: cfg = json.load(f) config = HRMCosmicFishConfig( vocab_size=cfg["vocab_size"], n_embd=cfg["n_embd"], block_size=cfg["block_size"], n_head=cfg["n_head"], n_kv_head=cfg["n_kv_head"], n_input_layers=cfg["n_input_layers"], n_output_layers=cfg["n_output_layers"], hrm_H_layers=cfg["hrm_H_layers"], hrm_L_layers=cfg["hrm_L_layers"], hrm_H_cycles=cfg["hrm_H_cycles"], hrm_L_cycles=cfg["hrm_L_cycles"], hrm_max_steps=cfg["hrm_max_steps"], hrm_exploration_prob=cfg["hrm_exploration_prob"], dropout=cfg["dropout"], bias=cfg["bias"], use_rotary=cfg["use_rotary"], use_gqa=cfg["use_gqa"], use_swiglu=cfg["use_swiglu"], eps=cfg["eps"], ) model = HRMCosmicFish(config) state_dict = load_file(weights_path, device=device) try: model.load_state_dict(state_dict) except RuntimeError as e: logger.warning(f"Strict loading failed: {e}, attempting flexible loading...") missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: logger.warning(f"Missing keys: {len(missing)}") if unexpected: logger.warning(f"Unexpected keys: {len(unexpected)}") model.to(device) model.eval() block_size = config.block_size print(f"Model loaded: {model.get_num_params() / 1e6:.2f}M parameters") print(f" Input blocks: {config.n_input_layers} | HRM: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps) | Output blocks: {config.n_output_layers}") except Exception as e: print(f"Error loading model: {str(e)}") return try: tokenizer = tiktoken.get_encoding("gpt2") except Exception as e: print(f"Error loading tokenizer: {str(e)}") return class ChatConfig: def __init__(self, args, block_size, device): self.device = device self.temperature = args.temperature self.max_new_tokens = args.max_tokens self.min_tokens_to_generate = args.min_tokens self.top_k = args.top_k self.human_prefix = args.human_prefix self.assistant_prefix = args.assistant_prefix self.end_of_turn = args.end_of_turn self.prompt_template = args.instruction self.max_history_tokens = args.max_history self.display_welcome = not args.no_welcome self.block_size = block_size self.debug_mode = args.debug self.repetition_penalty = args.repetition_penalty chat = ChatSession(model, tokenizer, ChatConfig(args, block_size, device)) print(colored("\nHRM-CosmicFish initialized. Type your message (or /help for commands).\n", 'cyan')) while True: try: user_input = input(colored("You: ", 'green')) if user_input.startswith('/'): if not chat.execute_command(user_input): break continue if not user_input.strip(): continue live_buffer = "" final_response = None response_generator = chat.generate_response(user_input) try: print(colored("CosmicFish: ", 'blue'), end="") sys.stdout.flush() for token, live_text, is_done in response_generator: if is_done: final_response = live_text if not live_buffer: print(final_response, end="") break if token: if "<|endoftext|>" in token: token = token.replace("<|endoftext|>", "") if token: print(token, end="", flush=True) break print(token, end="", flush=True) live_buffer += token except KeyboardInterrupt: print("\n[Generation interrupted]") print() except KeyboardInterrupt: print("\n\nKeyboard interrupt. Type /exit to quit or continue chatting.") except Exception as e: print(colored(f"\nError: {str(e)}", 'red')) logger.error(f"Error in chat loop: {str(e)}", exc_info=True) if __name__ == "__main__": try: main() except Exception as e: logger.error(f"Fatal error: {str(e)}", exc_info=True) sys.exit(1)