import os import sys import time import math import argparse import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from termcolor import colored import logging import readline import re import textwrap import random from collections import defaultdict from dataclasses import dataclass from typing import Optional import json try: from safetensors.torch import load_file except ImportError: print("safetensors not installed. Run: pip install safetensors") sys.exit(1) try: from huggingface_hub import snapshot_download except ImportError: print("huggingface_hub not installed. Run: pip install huggingface-hub") sys.exit(1) try: from transformers import GPT2Tokenizer except ImportError: print("transformers not installed. Run: pip install transformers") sys.exit(1) HF_REPO = "MistyozAI/CosmicFish-HRM" @dataclass class HRMCosmicFishConfig: vocab_size: int = 50304 n_embd: int = 448 block_size: int = 512 n_input_layers: int = 6 n_output_layers: int = 6 n_head: int = 8 hrm_H_layers: int = 4 hrm_L_layers: int = 4 hrm_H_cycles: int = 2 hrm_L_cycles: int = 2 hrm_max_steps: int = 16 hrm_exploration_prob: float = 0.1 dropout: float = 0.1 bias: bool = False use_rotary: bool = True use_gqa: bool = True use_swiglu: bool = True n_kv_head: int = 4 eps: float = 1e-5 forward_dtype: str = "bfloat16" def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() return torch.polar(torch.ones_like(freqs), freqs) def apply_rotary_emb(xq, xk, freqs_cis): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0) freqs_cis = freqs_cis[:, :, :xq_.shape[2], :] xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): input_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return (self.weight * x).to(input_dtype) class GroupedQueryAttention(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.n_head self.n_kv_head = config.n_kv_head if config.use_gqa else config.n_head self.head_dim = config.n_embd // config.n_head self.n_embd = config.n_embd self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias) self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.flash = hasattr(F, 'scaled_dot_product_attention') def forward(self, x, freqs_cis=None): B, T, C = x.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) if freqs_cis is not None: q, k = apply_rotary_emb(q, k, freqs_cis) if self.n_kv_head != self.n_head: k = k.repeat_interleave(self.n_head // self.n_kv_head, dim=1) v = v.repeat_interleave(self.n_head // self.n_kv_head, dim=1) if self.flash: y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True) else: att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) att = att.masked_fill(torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool(), float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_dropout(self.c_proj(y)) class MLP(nn.Module): def __init__(self, config): super().__init__() hidden_dim = 4 * config.n_embd if config.use_swiglu: self.gate = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) self.up = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) self.down = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) self.act = nn.SiLU() else: self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) self.act = nn.GELU() self.dropout = nn.Dropout(config.dropout) self.use_swiglu = config.use_swiglu def forward(self, x): if self.use_swiglu: return self.dropout(self.down(self.act(self.up(x)) * self.gate(x))) return self.dropout(self.c_proj(self.act(self.c_fc(x)))) class TransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) self.attn = GroupedQueryAttention(config) self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) self.mlp = MLP(config) def forward(self, x, freqs_cis=None): x = x + self.attn(self.ln_1(x), freqs_cis) x = x + self.mlp(self.ln_2(x)) return x class HRMReasoningBlock(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) self.attn = GroupedQueryAttention(config) self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) self.mlp = MLP(config) def forward(self, x, freqs_cis=None): x = self.ln_1(x + self.attn(x, freqs_cis)) x = self.ln_2(x + self.mlp(x)) return x class HRMReasoningLevel(nn.Module): def __init__(self, config, n_layers): super().__init__() self.layers = nn.ModuleList([HRMReasoningBlock(config) for _ in range(n_layers)]) def forward(self, hidden_states, input_injection, freqs_cis=None): hidden_states = hidden_states + input_injection for layer in self.layers: hidden_states = layer(hidden_states, freqs_cis) return hidden_states class HRMCore(nn.Module): def __init__(self, config): super().__init__() self.config = config self.H_level = HRMReasoningLevel(config, config.hrm_H_layers) self.L_level = HRMReasoningLevel(config, config.hrm_L_layers) self.H_init = nn.Parameter(torch.randn(config.n_embd) * 0.02) self.L_init = nn.Parameter(torch.randn(config.n_embd) * 0.02) self.q_head = nn.Linear(config.n_embd, 2, bias=True) with torch.no_grad(): self.q_head.weight.zero_() self.q_head.bias.fill_(-5.0) def forward(self, x, freqs_cis=None, training=False): B, T, C = x.size() device = x.device z_H = self.H_init.expand(B, T, C) z_L = self.L_init.expand(B, T, C) steps_taken = torch.zeros(B, dtype=torch.long, device=device) halted = torch.zeros(B, dtype=torch.bool, device=device) q_logits_list = [] for step in range(self.config.hrm_max_steps): if halted.all(): break with torch.set_grad_enabled(step == self.config.hrm_max_steps - 1): for _h in range(self.config.hrm_H_cycles): for _l in range(self.config.hrm_L_cycles): z_L = self.L_level(z_L, z_H + x, freqs_cis) z_H = self.H_level(z_H, z_L, freqs_cis) q_input = z_H.mean(dim=1) q_logits = self.q_head(q_input.float()) q_logits_list.append(q_logits) if self.config.hrm_max_steps > 1: q_halt = q_logits[:, 0] q_continue = q_logits[:, 1] if not training: q_halt = q_halt + 0.35 should_halt = q_halt > q_continue halted = halted | should_halt steps_taken = torch.where(halted, steps_taken, steps_taken + 1) if step == self.config.hrm_max_steps - 1: halted = torch.ones_like(halted) return z_H, steps_taken, (q_logits_list[-1] if q_logits_list else None) class HRMCosmicFish(nn.Module): def __init__(self, config): super().__init__() self.config = config self.wte = nn.Embedding(config.vocab_size, config.n_embd) if config.use_rotary: self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size) else: self.freqs_cis = None self.wpe = nn.Embedding(config.block_size, config.n_embd) self.drop = nn.Dropout(config.dropout) self.input_blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_input_layers)]) self.hrm_core = HRMCore(config) self.output_blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_output_layers)]) self.ln_f = RMSNorm(config.n_embd, eps=config.eps) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.wte.weight = self.lm_head.weight self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight') or pn.endswith('down.weight'): total = config.n_input_layers + config.n_output_layers + config.hrm_H_layers + config.hrm_L_layers nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * total)) print(f"Model initialized with {self.get_num_params() / 1e6:.2f}M parameters") print(f" Input blocks: {config.n_input_layers} layers") print(f" HRM Core: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps)") print(f" Output blocks: {config.n_output_layers} layers") def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def get_num_params(self, non_embedding=True): n_params = sum(p.numel() for p in self.parameters()) if non_embedding and hasattr(self, 'wpe'): n_params -= self.wpe.weight.numel() return n_params def forward(self, idx, targets=None): device = idx.device B, T = idx.size() x = self.wte(idx) if self.config.use_rotary: freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None else: pos = torch.arange(0, T, dtype=torch.long, device=device) x = x + self.wpe(pos) freqs_cis = None x = self.drop(x) for block in self.input_blocks: x = block(x, freqs_cis) x, steps_taken, q_logits = self.hrm_core(x, freqs_cis, training=self.training) for block in self.output_blocks: x = block(x, freqs_cis) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) loss = loss + 0.01 * steps_taken.float().mean() return logits, loss, steps_taken, q_logits @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] logits, _, _, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n" class RepetitionPenaltyLogitsProcessor: def __init__(self, penalty=1.2): self.penalty = penalty def __call__(self, input_ids, scores): score = torch.gather(scores, 1, input_ids) score = torch.where(score > 0, score / self.penalty, score * self.penalty) scores.scatter_(1, input_ids, score) return scores class ChatSession: def __init__(self, model, tokenizer, config): self.model = model self.tokenizer = tokenizer self.config = config self.device = config.device self.history = [] self.history_tokens = [] self.max_history_tokens = config.max_history_tokens self.prompt_template = config.prompt_template self.human_prefix = config.human_prefix self.assistant_prefix = config.assistant_prefix self.end_of_turn = config.end_of_turn self.block_size = config.block_size self.debug_mode = config.debug_mode self.repetition_penalty = config.repetition_penalty self.min_tokens_to_generate = config.min_tokens_to_generate self.hrm_forced_steps = None self.original_hrm_max_steps = self.model.config.hrm_max_steps self.max_retries = 20 self.fallback_responses = [ "I'd be happy to help with that. Could you provide more details?", "That's interesting. What specific aspects would you like to know about?", "I can help with that. Could you clarify what you're looking for?", "Let me help you with that. What particular information do you need?", "I understand. Could you be more specific about what you'd like to know?" ] self.generation_failure_message = "I'm having difficulty generating a response. Could you try rephrasing?" self.total_prompt_tokens = 0 self.total_generated_tokens = 0 self.total_hrm_steps_used = 0 self.end_markers = [ f"{self.human_prefix}", "Human:", "\nHuman:", "\nH:", "H:", "<|endoftext|>", "Below is a conversation", "\nA:", "A:", "", "User:", "\nUser:" ] if config.display_welcome: self._print_welcome_message() def _print_welcome_message(self): hrm_mode = f"auto (max {self.original_hrm_max_steps})" if self.hrm_forced_steps is None else str(self.hrm_forced_steps) print(colored(f""" {'=' * 80} Welcome to CosmicFish-HRM Model: {self.model.get_num_params() / 1e6:.1f}M parameters Max HRM Steps: {self.original_hrm_max_steps} | Current HRM Mode: {hrm_mode} Commands: /help /clear /exit /stats /save /load /temp [val] /penalty [val] /hrm [n|auto] /debug {'=' * 80} """, 'cyan')) def _format_prompt(self, user_input): formatted_prompt = self.prompt_template for entry in self.history: role, text = entry if role == "human": formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}" else: formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}" formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}" return formatted_prompt def _tokenize(self, text): return self.tokenizer.encode(text) def _update_history(self, user_input, response): self.history.append(("human", user_input)) self.history.append(("assistant", response)) user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}") response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}") self.history_tokens.extend(user_tokens) self.history_tokens.extend(response_tokens) self.total_prompt_tokens += len(user_tokens) self.total_generated_tokens += len(response_tokens) self._trim_history_if_needed() def _trim_history_if_needed(self): if len(self.history_tokens) > self.max_history_tokens: while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2: self.history = self.history[2:] user_turn = self.history[0][1] assistant_turn = self.history[1][1] user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}")) assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}")) self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:] def _should_stop_generation(self, text): for marker in self.end_markers: if marker in text: return True return False def _clean_token_text(self, text): return text.replace("<|endoftext|>", "") def _is_repetitive(self, tokens, window=10): if len(tokens) < window: return False recent = tokens[-window:] if len(set(recent)) < 3: return True for pattern_len in [2, 3, 4]: if len(recent) >= pattern_len * 2: pattern = tuple(recent[-pattern_len:]) prev_pattern = tuple(recent[-pattern_len*2:-pattern_len]) if pattern == prev_pattern: return True return False def _set_hrm_steps(self, steps): self.model.config.hrm_max_steps = steps self.model.hrm_core.config.hrm_max_steps = steps def _restore_hrm_steps(self): self.model.config.hrm_max_steps = self.original_hrm_max_steps self.model.hrm_core.config.hrm_max_steps = self.original_hrm_max_steps def generate_response(self, user_input): if self.hrm_forced_steps is not None: self._set_hrm_steps(self.hrm_forced_steps) try: full_prompt = self._format_prompt(user_input) prompt_tokens = self._tokenize(full_prompt) input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(self.device) if self.debug_mode: print(f"\n[DEBUG] Prompt tokens: {len(prompt_tokens)}") print(f"[DEBUG] HRM mode: {'auto' if self.hrm_forced_steps is None else self.hrm_forced_steps} (model max: {self.model.config.hrm_max_steps})") generated_tokens = [] accumulated_text = "" repetition_processor = RepetitionPenaltyLogitsProcessor(self.repetition_penalty) total_hrm_steps = 0 with torch.no_grad(): for step in range(self.config.max_new_tokens): context = input_ids[:, -self.block_size:] if input_ids.size(1) > self.block_size else input_ids logits, _, steps_taken, _ = self.model(context) total_hrm_steps += steps_taken.item() logits = logits[:, -1, :] / self.config.temperature logits = repetition_processor(context, logits) if self.config.top_k > 0: v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') probs = torch.nn.functional.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() == 50256: break token_text = self._clean_token_text(self.tokenizer.decode([next_token.item()])) generated_tokens.append(next_token.item()) accumulated_text += token_text if self._should_stop_generation(accumulated_text): for marker in self.end_markers: if marker in accumulated_text: accumulated_text = accumulated_text.split(marker)[0] break break if self._is_repetitive(generated_tokens): if self.debug_mode: print("\n[DEBUG] Detected repetition, stopping") break yield (token_text, accumulated_text, False) input_ids = torch.cat([input_ids, next_token], dim=1) if step < self.min_tokens_to_generate: continue final_response = accumulated_text.strip() for marker in self.end_markers: if final_response.endswith(marker.strip()): final_response = final_response[:-len(marker.strip())].strip() self.total_hrm_steps_used += total_hrm_steps if self.debug_mode: avg_steps = total_hrm_steps / len(generated_tokens) if generated_tokens else 0 print(f"\n[DEBUG] Generated {len(generated_tokens)} tokens | Total HRM steps: {total_hrm_steps} | Avg steps/token: {avg_steps:.1f}") self._update_history(user_input, final_response) yield (None, final_response, True) finally: if self.hrm_forced_steps is not None: self._restore_hrm_steps() def execute_command(self, command): command_lower = command.lower().strip() if command_lower in ['/exit', '/quit', '/q']: print(colored("Goodbye!", 'cyan')) return False elif command_lower == '/help': self._print_welcome_message() elif command_lower == '/clear': self.history = [] self.history_tokens = [] print(colored("Conversation history cleared.", 'yellow')) elif command_lower == '/stats': self._print_stats() elif command_lower == '/debug': self.debug_mode = not self.debug_mode print(colored(f"Debug mode {'enabled' if self.debug_mode else 'disabled'}.", 'yellow')) elif command_lower.startswith('/temp '): try: temp = float(command.split()[1]) if 0.1 <= temp <= 2.0: self.config.temperature = temp print(colored(f"Temperature set to {temp}", 'yellow')) else: print(colored("Temperature must be between 0.1 and 2.0", 'red')) except: print(colored("Usage: /temp [value]", 'red')) elif command_lower.startswith('/penalty '): try: penalty = float(command.split()[1]) if 1.0 <= penalty <= 2.0: self.repetition_penalty = penalty print(colored(f"Repetition penalty set to {penalty}", 'yellow')) else: print(colored("Penalty must be between 1.0 and 2.0", 'red')) except: print(colored("Usage: /penalty [value]", 'red')) elif command_lower.startswith('/hrm '): try: hrm_arg = command.split()[1].lower() if hrm_arg == 'auto': self.hrm_forced_steps = 8 print(colored(f"HRM mode set to AUTO (model will use up to {self.original_hrm_max_steps} steps)", 'yellow')) else: steps = int(hrm_arg) if 0 <= steps <= 9999: self.hrm_forced_steps = steps print(colored(f"HRM forced to {steps} step(s)", 'yellow')) if steps == 0: print(colored("Warning: HRM with 0 steps means no iterative reasoning!", 'red')) else: print(colored("HRM steps must be between 0 and 9999", 'red')) except: print(colored("Usage: /hrm [number] or /hrm auto", 'red')) elif command_lower.startswith('/save '): try: self._save_conversation(command.split(maxsplit=1)[1]) except: print(colored("Usage: /save [filename]", 'red')) elif command_lower.startswith('/load '): try: self._load_conversation(command.split(maxsplit=1)[1]) except: print(colored("Usage: /load [filename]", 'red')) else: print(colored(f"Unknown command: {command}", 'red')) print(colored("Type /help for available commands", 'yellow')) return True def _print_stats(self): avg_hrm = self.total_hrm_steps_used / self.total_generated_tokens if self.total_generated_tokens > 0 else 0 hrm_mode = "AUTO" if self.hrm_forced_steps is None else f"FORCED ({self.hrm_forced_steps})" print(colored(f""" {'=' * 60} CONVERSATION STATISTICS {'=' * 60} Prompt tokens: {self.total_prompt_tokens:,} Generated tokens: {self.total_generated_tokens:,} Total HRM steps: {self.total_hrm_steps_used:,} Avg HRM steps/tok: {avg_hrm:.2f} Turns: {len(self.history) // 2} History tokens: {len(self.history_tokens):,} Temperature: {self.config.temperature} Repetition penalty: {self.repetition_penalty} HRM mode: {hrm_mode} Model max HRM steps:{self.original_hrm_max_steps} Top-k: {self.config.top_k} {'=' * 60} """, 'cyan')) def _save_conversation(self, filename): try: with open(filename, 'w', encoding='utf-8') as f: f.write("HRM-CosmicFish Conversation\n") f.write(f"{'=' * 80}\n\n") for role, text in self.history: prefix = "Human: " if role == "human" else "Assistant: " f.write(f"{prefix}{text}\n\n") print(colored(f"Conversation saved to {filename}", 'green')) except Exception as e: print(colored(f"Error saving conversation: {e}", 'red')) def _load_conversation(self, filename): try: with open(filename, 'r', encoding='utf-8') as f: lines = f.read().split('\n') self.history = [] self.history_tokens = [] current_role = None current_text = [] for line in lines: if line.startswith('Human: '): if current_role and current_text: self.history.append((current_role, '\n'.join(current_text).strip())) current_role = 'human' current_text = [line[7:]] elif line.startswith('Assistant: '): if current_role and current_text: self.history.append((current_role, '\n'.join(current_text).strip())) current_role = 'assistant' current_text = [line[11:]] elif line.strip() and current_role: current_text.append(line) if current_role and current_text: self.history.append((current_role, '\n'.join(current_text).strip())) print(colored(f"Conversation loaded from {filename} ({len(self.history)//2} turns)", 'green')) except Exception as e: print(colored(f"Error loading conversation: {e}", 'red')) def main(): parser = argparse.ArgumentParser(description="Chat with CosmicFish-HRM model") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--temperature", type=float, default=0.5) parser.add_argument("--max_tokens", type=int, default=3000) parser.add_argument("--min_tokens", type=int, default=10) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--repetition_penalty", type=float, default=1.2) parser.add_argument("--human_prefix", type=str, default="Human: ") parser.add_argument("--assistant_prefix", type=str, default="Assistant: ") parser.add_argument("--end_of_turn", type=str, default="\n\n") parser.add_argument("--instruction", type=str, default=DEFAULT_PROMPT_TEMPLATE) parser.add_argument("--max_history", type=int, default=1024) parser.add_argument("--no_welcome", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() device = args.device if device == "cuda" and not torch.cuda.is_available(): print("CUDA not available, falling back to CPU") device = "cpu" print(f"Downloading CosmicFish-HRM from Hugging Face ({HF_REPO})...") try: cache_dir = snapshot_download(repo_id=HF_REPO) logger.info(f"Model cached at: {cache_dir}") config_path = os.path.join(cache_dir, "config.json") weights_path = os.path.join(cache_dir, "model.safetensors") if not os.path.exists(config_path): raise FileNotFoundError(f"config.json not found in {cache_dir}") if not os.path.exists(weights_path): raise FileNotFoundError(f"model.safetensors not found in {cache_dir}") with open(config_path) as f: cfg = json.load(f) config = HRMCosmicFishConfig( vocab_size=cfg["vocab_size"], n_embd=cfg["n_embd"], block_size=cfg["block_size"], n_head=cfg["n_head"], n_kv_head=cfg["n_kv_head"], n_input_layers=cfg["n_input_layers"], n_output_layers=cfg["n_output_layers"], hrm_H_layers=cfg["hrm_H_layers"], hrm_L_layers=cfg["hrm_L_layers"], hrm_H_cycles=cfg["hrm_H_cycles"], hrm_L_cycles=cfg["hrm_L_cycles"], hrm_max_steps=cfg["hrm_max_steps"], hrm_exploration_prob=cfg["hrm_exploration_prob"], dropout=0.0, bias=cfg["bias"], use_rotary=cfg["use_rotary"], use_gqa=cfg["use_gqa"], use_swiglu=cfg["use_swiglu"], eps=cfg["eps"], ) model = HRMCosmicFish(config) state_dict = load_file(weights_path, device=device) try: model.load_state_dict(state_dict) except RuntimeError as e: logger.warning(f"Strict loading failed: {e}, attempting flexible loading...") missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: logger.warning(f"Missing keys: {len(missing)}") if unexpected: logger.warning(f"Unexpected keys: {len(unexpected)}") model.to(device) model.eval() block_size = config.block_size print(f"Model loaded: {model.get_num_params() / 1e6:.2f}M parameters") print(f" Input blocks: {config.n_input_layers} | HRM: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps) | Output blocks: {config.n_output_layers}") except Exception as e: print(f"Error loading model: {str(e)}") return try: tokenizer = GPT2Tokenizer.from_pretrained("gpt2") except Exception as e: print(f"Error loading tokenizer: {str(e)}") return class ChatConfig: def __init__(self, args, block_size, device): self.device = device self.temperature = args.temperature self.max_new_tokens = args.max_tokens self.min_tokens_to_generate = args.min_tokens self.top_k = args.top_k self.human_prefix = args.human_prefix self.assistant_prefix = args.assistant_prefix self.end_of_turn = args.end_of_turn self.prompt_template = args.instruction self.max_history_tokens = args.max_history self.display_welcome = not args.no_welcome self.block_size = block_size self.debug_mode = args.debug self.repetition_penalty = args.repetition_penalty chat = ChatSession(model, tokenizer, ChatConfig(args, block_size, device)) print(colored("\nHRM-CosmicFish initialized. Type your message (or /help for commands).\n", 'cyan')) while True: try: user_input = input(colored("You: ", 'green')) if user_input.startswith('/'): if not chat.execute_command(user_input): break continue if not user_input.strip(): continue live_buffer = "" final_response = None response_generator = chat.generate_response(user_input) try: print(colored("CosmicFish: ", 'blue'), end="") sys.stdout.flush() for token, live_text, is_done in response_generator: if is_done: final_response = live_text if not live_buffer: print(final_response, end="") break if token: if "<|endoftext|>" in token: token = token.replace("<|endoftext|>", "") if token: print(token, end="", flush=True) break print(token, end="", flush=True) live_buffer += token except KeyboardInterrupt: print("\n[Generation interrupted]") print() except KeyboardInterrupt: print("\n\nKeyboard interrupt. Type /exit to quit or continue chatting.") except Exception as e: print(colored(f"\nError: {str(e)}", 'red')) logger.error(f"Error in chat loop: {str(e)}", exc_info=True) if __name__ == "__main__": try: main() except Exception as e: logger.error(f"Fatal error: {str(e)}", exc_info=True) sys.exit(1)