akkiisfrommars commited on
Commit
bf1f7b7
·
verified ·
1 Parent(s): d9ff3f5

Initial Commit

Browse files
README.md CHANGED
@@ -1,3 +1,110 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - text-generation
5
+ - causal-lm
6
+ - cosmicfish
7
+ - hrm
8
+ - adaptive-reasoning
9
+ - custom-architecture
10
+ language:
11
+ - en
12
  ---
13
+
14
+ # CosmicFish-HRM
15
+
16
+ CosmicFish-HRM is a compact 82.77M parameter causal language model built around a Hierarchical Reasoning Module (HRM) that dynamically allocates reasoning compute during inference. Developed at Mistyoz AI.
17
+
18
+ Rather than applying a fixed number of transformer layers to every input, CosmicFish-HRM iterates through high-level and low-level reasoning cycles and uses a learned halting head to decide when to stop. Harder inputs trigger deeper reasoning trajectories while simpler ones halt early.
19
+
20
+ ## Architecture
21
+
22
+ ```
23
+ Input Blocks (Transformer) -> HRM Core (H + L levels, variable steps) -> Output Blocks (Transformer) -> LM Head
24
+ ```
25
+
26
+ | Hyperparameter | Value |
27
+ |---|---|
28
+ | Parameters | 82.77M |
29
+ | Embedding dimension | 448 |
30
+ | Vocabulary size | 50,304 |
31
+ | Context length | 512 |
32
+ | Input layers | 6 |
33
+ | Output layers | 6 |
34
+ | Attention heads | 8 (4 KV, GQA) |
35
+ | HRM H-layers | 4 |
36
+ | HRM L-layers | 4 |
37
+ | Max HRM steps | 16 |
38
+
39
+ **Key components:**
40
+ - Grouped-Query Attention (GQA) with RoPE
41
+ - SwiGLU feedforward layers
42
+ - RMSNorm (pre-norm for I/O blocks, post-norm inside HRM)
43
+ - Learned halt/continue Q-head controlling per-input reasoning depth
44
+ - Step penalty in training loss encouraging efficient halting
45
+
46
+ ## Usage
47
+
48
+ This model uses a custom architecture and requires `trust_remote_code=True`.
49
+
50
+ ```python
51
+ import torch
52
+ import json
53
+ import tiktoken
54
+ from safetensors.torch import load_file
55
+ from modeling_hrm_cosmicfish import HRMCosmicFish, HRMCosmicFishConfig
56
+
57
+ with open("config.json") as f:
58
+ cfg = json.load(f)
59
+
60
+ config = HRMCosmicFishConfig(
61
+ vocab_size=cfg["vocab_size"],
62
+ n_embd=cfg["n_embd"],
63
+ block_size=cfg["block_size"],
64
+ n_head=cfg["n_head"],
65
+ n_kv_head=cfg["n_kv_head"],
66
+ n_input_layers=cfg["n_input_layers"],
67
+ n_output_layers=cfg["n_output_layers"],
68
+ hrm_H_layers=cfg["hrm_H_layers"],
69
+ hrm_L_layers=cfg["hrm_L_layers"],
70
+ hrm_H_cycles=cfg["hrm_H_cycles"],
71
+ hrm_L_cycles=cfg["hrm_L_cycles"],
72
+ hrm_max_steps=cfg["hrm_max_steps"],
73
+ dropout=0.0,
74
+ )
75
+
76
+ state_dict = load_file("model.safetensors")
77
+ model = HRMCosmicFish(config)
78
+ model.load_state_dict(state_dict)
79
+ model.eval()
80
+
81
+ tokenizer = tiktoken.get_encoding("gpt2")
82
+
83
+ prompt = "Artificial intelligence is"
84
+ tokens = tokenizer.encode(prompt)
85
+ idx = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
86
+
87
+ with torch.no_grad():
88
+ output = model.generate(idx, max_new_tokens=50, temperature=0.7, top_k=40)
89
+
90
+ print(tokenizer.decode(output[0].tolist()))
91
+ ```
92
+
93
+ ## Training
94
+
95
+ CosmicFish-HRM was trained on the 10B-token CosmicSet dataset spanning web text, Wikipedia, code, mathematics, and research papers. Training used cosine LR decay with linear warmup, bfloat16 mixed precision, and gradient clipping.
96
+
97
+ ## Citation
98
+
99
+ ```bibtex
100
+ @misc{cosmicfish-hrm,
101
+ title={CosmicFish-HRM: Adaptive Reasoning via Hierarchical Recurrent Mechanisms in Compact Language Models},
102
+ author={Venkat Akhil Lakkapragada},
103
+ year={2026},
104
+ howpublished={\url{https://huggingface.co/MistyozAI/CosmicFish-HRM}}
105
+ }
106
+ ```
107
+
108
+ ---
109
+
110
+ Mistyoz AI, Hyderabad
chat.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import math
5
+ import argparse
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from termcolor import colored
11
+ import logging
12
+ import readline
13
+ import re
14
+ import textwrap
15
+ import random
16
+ from collections import defaultdict
17
+ from dataclasses import dataclass
18
+ from typing import Optional
19
+
20
+ import json
21
+
22
+ try:
23
+ from safetensors.torch import load_file
24
+ except ImportError:
25
+ print("safetensors not installed. Run: pip install safetensors")
26
+ sys.exit(1)
27
+
28
+ try:
29
+ from huggingface_hub import snapshot_download
30
+ except ImportError:
31
+ print("huggingface_hub not installed. Run: pip install huggingface-hub")
32
+ sys.exit(1)
33
+
34
+ try:
35
+ from transformers import GPT2Tokenizer
36
+ except ImportError:
37
+ print("transformers not installed. Run: pip install transformers")
38
+ sys.exit(1)
39
+
40
+ HF_REPO = "MistyozAI/CosmicFish-HRM"
41
+
42
+
43
+ @dataclass
44
+ class HRMCosmicFishConfig:
45
+ vocab_size: int = 50304
46
+ n_embd: int = 448
47
+ block_size: int = 512
48
+ n_input_layers: int = 6
49
+ n_output_layers: int = 6
50
+ n_head: int = 8
51
+ hrm_H_layers: int = 4
52
+ hrm_L_layers: int = 4
53
+ hrm_H_cycles: int = 2
54
+ hrm_L_cycles: int = 2
55
+ hrm_max_steps: int = 16
56
+ hrm_exploration_prob: float = 0.1
57
+ dropout: float = 0.1
58
+ bias: bool = False
59
+ use_rotary: bool = True
60
+ use_gqa: bool = True
61
+ use_swiglu: bool = True
62
+ n_kv_head: int = 4
63
+ eps: float = 1e-5
64
+ forward_dtype: str = "bfloat16"
65
+
66
+
67
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
68
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
69
+ t = torch.arange(end, device=freqs.device)
70
+ freqs = torch.outer(t, freqs).float()
71
+ return torch.polar(torch.ones_like(freqs), freqs)
72
+
73
+
74
+ def apply_rotary_emb(xq, xk, freqs_cis):
75
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
76
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
77
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0)
78
+ freqs_cis = freqs_cis[:, :, :xq_.shape[2], :]
79
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
80
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
81
+ return xq_out.type_as(xq), xk_out.type_as(xk)
82
+
83
+
84
+ class RMSNorm(nn.Module):
85
+ def __init__(self, dim: int, eps: float = 1e-5):
86
+ super().__init__()
87
+ self.eps = eps
88
+ self.weight = nn.Parameter(torch.ones(dim))
89
+
90
+ def forward(self, x):
91
+ input_dtype = x.dtype
92
+ x = x.to(torch.float32)
93
+ variance = x.pow(2).mean(-1, keepdim=True)
94
+ x = x * torch.rsqrt(variance + self.eps)
95
+ return (self.weight * x).to(input_dtype)
96
+
97
+
98
+ class GroupedQueryAttention(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.n_head = config.n_head
102
+ self.n_kv_head = config.n_kv_head if config.use_gqa else config.n_head
103
+ self.head_dim = config.n_embd // config.n_head
104
+ self.n_embd = config.n_embd
105
+ self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
106
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias)
107
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias)
108
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
109
+ self.attn_dropout = nn.Dropout(config.dropout)
110
+ self.resid_dropout = nn.Dropout(config.dropout)
111
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
112
+
113
+ def forward(self, x, freqs_cis=None):
114
+ B, T, C = x.size()
115
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
116
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
117
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
118
+ if freqs_cis is not None:
119
+ q, k = apply_rotary_emb(q, k, freqs_cis)
120
+ if self.n_kv_head != self.n_head:
121
+ k = k.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
122
+ v = v.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
123
+ if self.flash:
124
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
125
+ dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
126
+ else:
127
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
128
+ att = att.masked_fill(torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool(), float('-inf'))
129
+ att = F.softmax(att, dim=-1)
130
+ att = self.attn_dropout(att)
131
+ y = att @ v
132
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
133
+ return self.resid_dropout(self.c_proj(y))
134
+
135
+
136
+ class MLP(nn.Module):
137
+ def __init__(self, config):
138
+ super().__init__()
139
+ hidden_dim = 4 * config.n_embd
140
+ if config.use_swiglu:
141
+ self.gate = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
142
+ self.up = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
143
+ self.down = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
144
+ self.act = nn.SiLU()
145
+ else:
146
+ self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
147
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
148
+ self.act = nn.GELU()
149
+ self.dropout = nn.Dropout(config.dropout)
150
+ self.use_swiglu = config.use_swiglu
151
+
152
+ def forward(self, x):
153
+ if self.use_swiglu:
154
+ return self.dropout(self.down(self.act(self.up(x)) * self.gate(x)))
155
+ return self.dropout(self.c_proj(self.act(self.c_fc(x))))
156
+
157
+
158
+ class TransformerBlock(nn.Module):
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
162
+ self.attn = GroupedQueryAttention(config)
163
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
164
+ self.mlp = MLP(config)
165
+
166
+ def forward(self, x, freqs_cis=None):
167
+ x = x + self.attn(self.ln_1(x), freqs_cis)
168
+ x = x + self.mlp(self.ln_2(x))
169
+ return x
170
+
171
+
172
+ class HRMReasoningBlock(nn.Module):
173
+ def __init__(self, config):
174
+ super().__init__()
175
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
176
+ self.attn = GroupedQueryAttention(config)
177
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
178
+ self.mlp = MLP(config)
179
+
180
+ def forward(self, x, freqs_cis=None):
181
+ x = self.ln_1(x + self.attn(x, freqs_cis))
182
+ x = self.ln_2(x + self.mlp(x))
183
+ return x
184
+
185
+
186
+ class HRMReasoningLevel(nn.Module):
187
+ def __init__(self, config, n_layers):
188
+ super().__init__()
189
+ self.layers = nn.ModuleList([HRMReasoningBlock(config) for _ in range(n_layers)])
190
+
191
+ def forward(self, hidden_states, input_injection, freqs_cis=None):
192
+ hidden_states = hidden_states + input_injection
193
+ for layer in self.layers:
194
+ hidden_states = layer(hidden_states, freqs_cis)
195
+ return hidden_states
196
+
197
+
198
+ class HRMCore(nn.Module):
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.config = config
202
+ self.H_level = HRMReasoningLevel(config, config.hrm_H_layers)
203
+ self.L_level = HRMReasoningLevel(config, config.hrm_L_layers)
204
+ self.H_init = nn.Parameter(torch.randn(config.n_embd) * 0.02)
205
+ self.L_init = nn.Parameter(torch.randn(config.n_embd) * 0.02)
206
+ self.q_head = nn.Linear(config.n_embd, 2, bias=True)
207
+ with torch.no_grad():
208
+ self.q_head.weight.zero_()
209
+ self.q_head.bias.fill_(-5.0)
210
+
211
+ def forward(self, x, freqs_cis=None, training=False):
212
+ B, T, C = x.size()
213
+ device = x.device
214
+ z_H = self.H_init.expand(B, T, C)
215
+ z_L = self.L_init.expand(B, T, C)
216
+ steps_taken = torch.zeros(B, dtype=torch.long, device=device)
217
+ halted = torch.zeros(B, dtype=torch.bool, device=device)
218
+ q_logits_list = []
219
+
220
+ for step in range(self.config.hrm_max_steps):
221
+ if halted.all():
222
+ break
223
+ with torch.set_grad_enabled(step == self.config.hrm_max_steps - 1):
224
+ for _h in range(self.config.hrm_H_cycles):
225
+ for _l in range(self.config.hrm_L_cycles):
226
+ z_L = self.L_level(z_L, z_H + x, freqs_cis)
227
+ z_H = self.H_level(z_H, z_L, freqs_cis)
228
+ q_input = z_H.mean(dim=1)
229
+ q_logits = self.q_head(q_input.float())
230
+ q_logits_list.append(q_logits)
231
+
232
+ if self.config.hrm_max_steps > 1:
233
+ q_halt = q_logits[:, 0]
234
+ q_continue = q_logits[:, 1]
235
+ if not training:
236
+ q_halt = q_halt + 0.35
237
+ should_halt = q_halt > q_continue
238
+ halted = halted | should_halt
239
+
240
+ steps_taken = torch.where(halted, steps_taken, steps_taken + 1)
241
+ if step == self.config.hrm_max_steps - 1:
242
+ halted = torch.ones_like(halted)
243
+
244
+ return z_H, steps_taken, (q_logits_list[-1] if q_logits_list else None)
245
+
246
+
247
+ class HRMCosmicFish(nn.Module):
248
+ def __init__(self, config):
249
+ super().__init__()
250
+ self.config = config
251
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
252
+
253
+ if config.use_rotary:
254
+ self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size)
255
+ else:
256
+ self.freqs_cis = None
257
+ self.wpe = nn.Embedding(config.block_size, config.n_embd)
258
+
259
+ self.drop = nn.Dropout(config.dropout)
260
+ self.input_blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_input_layers)])
261
+ self.hrm_core = HRMCore(config)
262
+ self.output_blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_output_layers)])
263
+ self.ln_f = RMSNorm(config.n_embd, eps=config.eps)
264
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
265
+ self.wte.weight = self.lm_head.weight
266
+
267
+ self.apply(self._init_weights)
268
+ for pn, p in self.named_parameters():
269
+ if pn.endswith('c_proj.weight') or pn.endswith('down.weight'):
270
+ total = config.n_input_layers + config.n_output_layers + config.hrm_H_layers + config.hrm_L_layers
271
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * total))
272
+
273
+ print(f"Model initialized with {self.get_num_params() / 1e6:.2f}M parameters")
274
+ print(f" Input blocks: {config.n_input_layers} layers")
275
+ print(f" HRM Core: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps)")
276
+ print(f" Output blocks: {config.n_output_layers} layers")
277
+
278
+ def _init_weights(self, module):
279
+ if isinstance(module, nn.Linear):
280
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
281
+ if module.bias is not None:
282
+ nn.init.zeros_(module.bias)
283
+ elif isinstance(module, nn.Embedding):
284
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
285
+
286
+ def get_num_params(self, non_embedding=True):
287
+ n_params = sum(p.numel() for p in self.parameters())
288
+ if non_embedding and hasattr(self, 'wpe'):
289
+ n_params -= self.wpe.weight.numel()
290
+ return n_params
291
+
292
+ def forward(self, idx, targets=None):
293
+ device = idx.device
294
+ B, T = idx.size()
295
+ x = self.wte(idx)
296
+
297
+ if self.config.use_rotary:
298
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
299
+ else:
300
+ pos = torch.arange(0, T, dtype=torch.long, device=device)
301
+ x = x + self.wpe(pos)
302
+ freqs_cis = None
303
+
304
+ x = self.drop(x)
305
+ for block in self.input_blocks:
306
+ x = block(x, freqs_cis)
307
+ x, steps_taken, q_logits = self.hrm_core(x, freqs_cis, training=self.training)
308
+ for block in self.output_blocks:
309
+ x = block(x, freqs_cis)
310
+ x = self.ln_f(x)
311
+ logits = self.lm_head(x)
312
+
313
+ loss = None
314
+ if targets is not None:
315
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
316
+ loss = loss + 0.01 * steps_taken.float().mean()
317
+
318
+ return logits, loss, steps_taken, q_logits
319
+
320
+ @torch.no_grad()
321
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
322
+ for _ in range(max_new_tokens):
323
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
324
+ logits, _, _, _ = self(idx_cond)
325
+ logits = logits[:, -1, :] / temperature
326
+ if top_k is not None:
327
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
328
+ logits[logits < v[:, [-1]]] = -float('Inf')
329
+ probs = F.softmax(logits, dim=-1)
330
+ idx_next = torch.multinomial(probs, num_samples=1)
331
+ idx = torch.cat((idx, idx_next), dim=1)
332
+ return idx
333
+
334
+ logging.basicConfig(
335
+ level=logging.INFO,
336
+ format='%(asctime)s - %(levelname)s - %(message)s',
337
+ handlers=[logging.StreamHandler(sys.stdout)]
338
+ )
339
+ logger = logging.getLogger(__name__)
340
+
341
+ DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
342
+
343
+
344
+ class RepetitionPenaltyLogitsProcessor:
345
+ def __init__(self, penalty=1.2):
346
+ self.penalty = penalty
347
+
348
+ def __call__(self, input_ids, scores):
349
+ score = torch.gather(scores, 1, input_ids)
350
+ score = torch.where(score > 0, score / self.penalty, score * self.penalty)
351
+ scores.scatter_(1, input_ids, score)
352
+ return scores
353
+
354
+
355
+ class ChatSession:
356
+ def __init__(self, model, tokenizer, config):
357
+ self.model = model
358
+ self.tokenizer = tokenizer
359
+ self.config = config
360
+ self.device = config.device
361
+ self.history = []
362
+ self.history_tokens = []
363
+ self.max_history_tokens = config.max_history_tokens
364
+ self.prompt_template = config.prompt_template
365
+ self.human_prefix = config.human_prefix
366
+ self.assistant_prefix = config.assistant_prefix
367
+ self.end_of_turn = config.end_of_turn
368
+ self.block_size = config.block_size
369
+ self.debug_mode = config.debug_mode
370
+ self.repetition_penalty = config.repetition_penalty
371
+ self.min_tokens_to_generate = config.min_tokens_to_generate
372
+
373
+ self.hrm_forced_steps = None
374
+ self.original_hrm_max_steps = self.model.config.hrm_max_steps
375
+
376
+ self.max_retries = 20
377
+
378
+ self.fallback_responses = [
379
+ "I'd be happy to help with that. Could you provide more details?",
380
+ "That's interesting. What specific aspects would you like to know about?",
381
+ "I can help with that. Could you clarify what you're looking for?",
382
+ "Let me help you with that. What particular information do you need?",
383
+ "I understand. Could you be more specific about what you'd like to know?"
384
+ ]
385
+
386
+ self.generation_failure_message = "I'm having difficulty generating a response. Could you try rephrasing?"
387
+
388
+ self.total_prompt_tokens = 0
389
+ self.total_generated_tokens = 0
390
+ self.total_hrm_steps_used = 0
391
+
392
+ self.end_markers = [
393
+ f"{self.human_prefix}",
394
+ "Human:",
395
+ "\nHuman:",
396
+ "\nH:",
397
+ "H:",
398
+ "<|endoftext|>",
399
+ "Below is a conversation",
400
+ "\nA:",
401
+ "A:",
402
+ "</s>",
403
+ "User:",
404
+ "\nUser:"
405
+ ]
406
+
407
+ if config.display_welcome:
408
+ self._print_welcome_message()
409
+
410
+ def _print_welcome_message(self):
411
+ hrm_mode = f"auto (max {self.original_hrm_max_steps})" if self.hrm_forced_steps is None else str(self.hrm_forced_steps)
412
+ print(colored(f"""
413
+ {'=' * 80}
414
+ Welcome to CosmicFish-HRM
415
+
416
+ Model: {self.model.get_num_params() / 1e6:.1f}M parameters
417
+ Max HRM Steps: {self.original_hrm_max_steps} | Current HRM Mode: {hrm_mode}
418
+
419
+ Commands: /help /clear /exit /stats /save /load
420
+ /temp [val] /penalty [val] /hrm [n|auto] /debug
421
+ {'=' * 80}
422
+ """, 'cyan'))
423
+
424
+ def _format_prompt(self, user_input):
425
+ formatted_prompt = self.prompt_template
426
+ for entry in self.history:
427
+ role, text = entry
428
+ if role == "human":
429
+ formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}"
430
+ else:
431
+ formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}"
432
+ formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}"
433
+ return formatted_prompt
434
+
435
+ def _tokenize(self, text):
436
+ return self.tokenizer.encode(text)
437
+
438
+ def _update_history(self, user_input, response):
439
+ self.history.append(("human", user_input))
440
+ self.history.append(("assistant", response))
441
+
442
+ user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}")
443
+ response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}")
444
+
445
+ self.history_tokens.extend(user_tokens)
446
+ self.history_tokens.extend(response_tokens)
447
+
448
+ self.total_prompt_tokens += len(user_tokens)
449
+ self.total_generated_tokens += len(response_tokens)
450
+
451
+ self._trim_history_if_needed()
452
+
453
+ def _trim_history_if_needed(self):
454
+ if len(self.history_tokens) > self.max_history_tokens:
455
+ while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2:
456
+ self.history = self.history[2:]
457
+ user_turn = self.history[0][1]
458
+ assistant_turn = self.history[1][1]
459
+ user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}"))
460
+ assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}"))
461
+ self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:]
462
+
463
+ def _should_stop_generation(self, text):
464
+ for marker in self.end_markers:
465
+ if marker in text:
466
+ return True
467
+ return False
468
+
469
+ def _clean_token_text(self, text):
470
+ return text.replace("<|endoftext|>", "")
471
+
472
+ def _is_repetitive(self, tokens, window=10):
473
+ if len(tokens) < window:
474
+ return False
475
+ recent = tokens[-window:]
476
+ if len(set(recent)) < 3:
477
+ return True
478
+ for pattern_len in [2, 3, 4]:
479
+ if len(recent) >= pattern_len * 2:
480
+ pattern = tuple(recent[-pattern_len:])
481
+ prev_pattern = tuple(recent[-pattern_len*2:-pattern_len])
482
+ if pattern == prev_pattern:
483
+ return True
484
+ return False
485
+
486
+ def _set_hrm_steps(self, steps):
487
+ self.model.config.hrm_max_steps = steps
488
+ self.model.hrm_core.config.hrm_max_steps = steps
489
+
490
+ def _restore_hrm_steps(self):
491
+ self.model.config.hrm_max_steps = self.original_hrm_max_steps
492
+ self.model.hrm_core.config.hrm_max_steps = self.original_hrm_max_steps
493
+
494
+ def generate_response(self, user_input):
495
+ if self.hrm_forced_steps is not None:
496
+ self._set_hrm_steps(self.hrm_forced_steps)
497
+
498
+ try:
499
+ full_prompt = self._format_prompt(user_input)
500
+ prompt_tokens = self._tokenize(full_prompt)
501
+ input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
502
+
503
+ if self.debug_mode:
504
+ print(f"\n[DEBUG] Prompt tokens: {len(prompt_tokens)}")
505
+ print(f"[DEBUG] HRM mode: {'auto' if self.hrm_forced_steps is None else self.hrm_forced_steps} (model max: {self.model.config.hrm_max_steps})")
506
+
507
+ generated_tokens = []
508
+ accumulated_text = ""
509
+ repetition_processor = RepetitionPenaltyLogitsProcessor(self.repetition_penalty)
510
+ total_hrm_steps = 0
511
+
512
+ with torch.no_grad():
513
+ for step in range(self.config.max_new_tokens):
514
+ context = input_ids[:, -self.block_size:] if input_ids.size(1) > self.block_size else input_ids
515
+
516
+ logits, _, steps_taken, _ = self.model(context)
517
+ total_hrm_steps += steps_taken.item()
518
+
519
+ logits = logits[:, -1, :] / self.config.temperature
520
+ logits = repetition_processor(context, logits)
521
+
522
+ if self.config.top_k > 0:
523
+ v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))
524
+ logits[logits < v[:, [-1]]] = float('-inf')
525
+
526
+ probs = torch.nn.functional.softmax(logits, dim=-1)
527
+ next_token = torch.multinomial(probs, num_samples=1)
528
+
529
+ if next_token.item() == 50256:
530
+ break
531
+
532
+ token_text = self._clean_token_text(self.tokenizer.decode([next_token.item()]))
533
+ generated_tokens.append(next_token.item())
534
+ accumulated_text += token_text
535
+
536
+ if self._should_stop_generation(accumulated_text):
537
+ for marker in self.end_markers:
538
+ if marker in accumulated_text:
539
+ accumulated_text = accumulated_text.split(marker)[0]
540
+ break
541
+ break
542
+
543
+ if self._is_repetitive(generated_tokens):
544
+ if self.debug_mode:
545
+ print("\n[DEBUG] Detected repetition, stopping")
546
+ break
547
+
548
+ yield (token_text, accumulated_text, False)
549
+
550
+ input_ids = torch.cat([input_ids, next_token], dim=1)
551
+
552
+ if step < self.min_tokens_to_generate:
553
+ continue
554
+
555
+ final_response = accumulated_text.strip()
556
+ for marker in self.end_markers:
557
+ if final_response.endswith(marker.strip()):
558
+ final_response = final_response[:-len(marker.strip())].strip()
559
+
560
+ self.total_hrm_steps_used += total_hrm_steps
561
+
562
+ if self.debug_mode:
563
+ avg_steps = total_hrm_steps / len(generated_tokens) if generated_tokens else 0
564
+ print(f"\n[DEBUG] Generated {len(generated_tokens)} tokens | Total HRM steps: {total_hrm_steps} | Avg steps/token: {avg_steps:.1f}")
565
+
566
+ self._update_history(user_input, final_response)
567
+ yield (None, final_response, True)
568
+
569
+ finally:
570
+ if self.hrm_forced_steps is not None:
571
+ self._restore_hrm_steps()
572
+
573
+ def execute_command(self, command):
574
+ command_lower = command.lower().strip()
575
+
576
+ if command_lower in ['/exit', '/quit', '/q']:
577
+ print(colored("Goodbye!", 'cyan'))
578
+ return False
579
+
580
+ elif command_lower == '/help':
581
+ self._print_welcome_message()
582
+
583
+ elif command_lower == '/clear':
584
+ self.history = []
585
+ self.history_tokens = []
586
+ print(colored("Conversation history cleared.", 'yellow'))
587
+
588
+ elif command_lower == '/stats':
589
+ self._print_stats()
590
+
591
+ elif command_lower == '/debug':
592
+ self.debug_mode = not self.debug_mode
593
+ print(colored(f"Debug mode {'enabled' if self.debug_mode else 'disabled'}.", 'yellow'))
594
+
595
+ elif command_lower.startswith('/temp '):
596
+ try:
597
+ temp = float(command.split()[1])
598
+ if 0.1 <= temp <= 2.0:
599
+ self.config.temperature = temp
600
+ print(colored(f"Temperature set to {temp}", 'yellow'))
601
+ else:
602
+ print(colored("Temperature must be between 0.1 and 2.0", 'red'))
603
+ except:
604
+ print(colored("Usage: /temp [value]", 'red'))
605
+
606
+ elif command_lower.startswith('/penalty '):
607
+ try:
608
+ penalty = float(command.split()[1])
609
+ if 1.0 <= penalty <= 2.0:
610
+ self.repetition_penalty = penalty
611
+ print(colored(f"Repetition penalty set to {penalty}", 'yellow'))
612
+ else:
613
+ print(colored("Penalty must be between 1.0 and 2.0", 'red'))
614
+ except:
615
+ print(colored("Usage: /penalty [value]", 'red'))
616
+
617
+ elif command_lower.startswith('/hrm '):
618
+ try:
619
+ hrm_arg = command.split()[1].lower()
620
+ if hrm_arg == 'auto':
621
+ self.hrm_forced_steps = 8
622
+ print(colored(f"HRM mode set to AUTO (model will use up to {self.original_hrm_max_steps} steps)", 'yellow'))
623
+ else:
624
+ steps = int(hrm_arg)
625
+ if 0 <= steps <= 9999:
626
+ self.hrm_forced_steps = steps
627
+ print(colored(f"HRM forced to {steps} step(s)", 'yellow'))
628
+ if steps == 0:
629
+ print(colored("Warning: HRM with 0 steps means no iterative reasoning!", 'red'))
630
+ else:
631
+ print(colored("HRM steps must be between 0 and 9999", 'red'))
632
+ except:
633
+ print(colored("Usage: /hrm [number] or /hrm auto", 'red'))
634
+
635
+ elif command_lower.startswith('/save '):
636
+ try:
637
+ self._save_conversation(command.split(maxsplit=1)[1])
638
+ except:
639
+ print(colored("Usage: /save [filename]", 'red'))
640
+
641
+ elif command_lower.startswith('/load '):
642
+ try:
643
+ self._load_conversation(command.split(maxsplit=1)[1])
644
+ except:
645
+ print(colored("Usage: /load [filename]", 'red'))
646
+
647
+ else:
648
+ print(colored(f"Unknown command: {command}", 'red'))
649
+ print(colored("Type /help for available commands", 'yellow'))
650
+
651
+ return True
652
+
653
+ def _print_stats(self):
654
+ avg_hrm = self.total_hrm_steps_used / self.total_generated_tokens if self.total_generated_tokens > 0 else 0
655
+ hrm_mode = "AUTO" if self.hrm_forced_steps is None else f"FORCED ({self.hrm_forced_steps})"
656
+ print(colored(f"""
657
+ {'=' * 60}
658
+ CONVERSATION STATISTICS
659
+ {'=' * 60}
660
+ Prompt tokens: {self.total_prompt_tokens:,}
661
+ Generated tokens: {self.total_generated_tokens:,}
662
+ Total HRM steps: {self.total_hrm_steps_used:,}
663
+ Avg HRM steps/tok: {avg_hrm:.2f}
664
+ Turns: {len(self.history) // 2}
665
+ History tokens: {len(self.history_tokens):,}
666
+
667
+ Temperature: {self.config.temperature}
668
+ Repetition penalty: {self.repetition_penalty}
669
+ HRM mode: {hrm_mode}
670
+ Model max HRM steps:{self.original_hrm_max_steps}
671
+ Top-k: {self.config.top_k}
672
+ {'=' * 60}
673
+ """, 'cyan'))
674
+
675
+ def _save_conversation(self, filename):
676
+ try:
677
+ with open(filename, 'w', encoding='utf-8') as f:
678
+ f.write("HRM-CosmicFish Conversation\n")
679
+ f.write(f"{'=' * 80}\n\n")
680
+ for role, text in self.history:
681
+ prefix = "Human: " if role == "human" else "Assistant: "
682
+ f.write(f"{prefix}{text}\n\n")
683
+ print(colored(f"Conversation saved to {filename}", 'green'))
684
+ except Exception as e:
685
+ print(colored(f"Error saving conversation: {e}", 'red'))
686
+
687
+ def _load_conversation(self, filename):
688
+ try:
689
+ with open(filename, 'r', encoding='utf-8') as f:
690
+ lines = f.read().split('\n')
691
+
692
+ self.history = []
693
+ self.history_tokens = []
694
+
695
+ current_role = None
696
+ current_text = []
697
+
698
+ for line in lines:
699
+ if line.startswith('Human: '):
700
+ if current_role and current_text:
701
+ self.history.append((current_role, '\n'.join(current_text).strip()))
702
+ current_role = 'human'
703
+ current_text = [line[7:]]
704
+ elif line.startswith('Assistant: '):
705
+ if current_role and current_text:
706
+ self.history.append((current_role, '\n'.join(current_text).strip()))
707
+ current_role = 'assistant'
708
+ current_text = [line[11:]]
709
+ elif line.strip() and current_role:
710
+ current_text.append(line)
711
+
712
+ if current_role and current_text:
713
+ self.history.append((current_role, '\n'.join(current_text).strip()))
714
+
715
+ print(colored(f"Conversation loaded from {filename} ({len(self.history)//2} turns)", 'green'))
716
+ except Exception as e:
717
+ print(colored(f"Error loading conversation: {e}", 'red'))
718
+
719
+
720
+ def main():
721
+ parser = argparse.ArgumentParser(description="Chat with CosmicFish-HRM model")
722
+
723
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
724
+ parser.add_argument("--temperature", type=float, default=0.5)
725
+ parser.add_argument("--max_tokens", type=int, default=3000)
726
+ parser.add_argument("--min_tokens", type=int, default=10)
727
+ parser.add_argument("--top_k", type=int, default=40)
728
+ parser.add_argument("--repetition_penalty", type=float, default=1.2)
729
+ parser.add_argument("--human_prefix", type=str, default="Human: ")
730
+ parser.add_argument("--assistant_prefix", type=str, default="Assistant: ")
731
+ parser.add_argument("--end_of_turn", type=str, default="\n\n")
732
+ parser.add_argument("--instruction", type=str, default=DEFAULT_PROMPT_TEMPLATE)
733
+ parser.add_argument("--max_history", type=int, default=1024)
734
+ parser.add_argument("--no_welcome", action="store_true")
735
+ parser.add_argument("--debug", action="store_true")
736
+
737
+ args = parser.parse_args()
738
+
739
+ device = args.device
740
+ if device == "cuda" and not torch.cuda.is_available():
741
+ print("CUDA not available, falling back to CPU")
742
+ device = "cpu"
743
+
744
+ print(f"Downloading CosmicFish-HRM from Hugging Face ({HF_REPO})...")
745
+ try:
746
+ cache_dir = snapshot_download(repo_id=HF_REPO)
747
+ logger.info(f"Model cached at: {cache_dir}")
748
+
749
+ config_path = os.path.join(cache_dir, "config.json")
750
+ weights_path = os.path.join(cache_dir, "model.safetensors")
751
+
752
+ if not os.path.exists(config_path):
753
+ raise FileNotFoundError(f"config.json not found in {cache_dir}")
754
+ if not os.path.exists(weights_path):
755
+ raise FileNotFoundError(f"model.safetensors not found in {cache_dir}")
756
+
757
+ with open(config_path) as f:
758
+ cfg = json.load(f)
759
+
760
+ config = HRMCosmicFishConfig(
761
+ vocab_size=cfg["vocab_size"],
762
+ n_embd=cfg["n_embd"],
763
+ block_size=cfg["block_size"],
764
+ n_head=cfg["n_head"],
765
+ n_kv_head=cfg["n_kv_head"],
766
+ n_input_layers=cfg["n_input_layers"],
767
+ n_output_layers=cfg["n_output_layers"],
768
+ hrm_H_layers=cfg["hrm_H_layers"],
769
+ hrm_L_layers=cfg["hrm_L_layers"],
770
+ hrm_H_cycles=cfg["hrm_H_cycles"],
771
+ hrm_L_cycles=cfg["hrm_L_cycles"],
772
+ hrm_max_steps=cfg["hrm_max_steps"],
773
+ hrm_exploration_prob=cfg["hrm_exploration_prob"],
774
+ dropout=0.0,
775
+ bias=cfg["bias"],
776
+ use_rotary=cfg["use_rotary"],
777
+ use_gqa=cfg["use_gqa"],
778
+ use_swiglu=cfg["use_swiglu"],
779
+ eps=cfg["eps"],
780
+ )
781
+
782
+ model = HRMCosmicFish(config)
783
+
784
+ state_dict = load_file(weights_path, device=device)
785
+
786
+ try:
787
+ model.load_state_dict(state_dict)
788
+ except RuntimeError as e:
789
+ logger.warning(f"Strict loading failed: {e}, attempting flexible loading...")
790
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
791
+ if missing:
792
+ logger.warning(f"Missing keys: {len(missing)}")
793
+ if unexpected:
794
+ logger.warning(f"Unexpected keys: {len(unexpected)}")
795
+
796
+ model.to(device)
797
+ model.eval()
798
+
799
+ block_size = config.block_size
800
+
801
+ print(f"Model loaded: {model.get_num_params() / 1e6:.2f}M parameters")
802
+ print(f" Input blocks: {config.n_input_layers} | HRM: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps) | Output blocks: {config.n_output_layers}")
803
+
804
+ except Exception as e:
805
+ print(f"Error loading model: {str(e)}")
806
+ return
807
+
808
+ try:
809
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
810
+ except Exception as e:
811
+ print(f"Error loading tokenizer: {str(e)}")
812
+ return
813
+
814
+ class ChatConfig:
815
+ def __init__(self, args, block_size, device):
816
+ self.device = device
817
+ self.temperature = args.temperature
818
+ self.max_new_tokens = args.max_tokens
819
+ self.min_tokens_to_generate = args.min_tokens
820
+ self.top_k = args.top_k
821
+ self.human_prefix = args.human_prefix
822
+ self.assistant_prefix = args.assistant_prefix
823
+ self.end_of_turn = args.end_of_turn
824
+ self.prompt_template = args.instruction
825
+ self.max_history_tokens = args.max_history
826
+ self.display_welcome = not args.no_welcome
827
+ self.block_size = block_size
828
+ self.debug_mode = args.debug
829
+ self.repetition_penalty = args.repetition_penalty
830
+
831
+ chat = ChatSession(model, tokenizer, ChatConfig(args, block_size, device))
832
+
833
+ print(colored("\nHRM-CosmicFish initialized. Type your message (or /help for commands).\n", 'cyan'))
834
+
835
+ while True:
836
+ try:
837
+ user_input = input(colored("You: ", 'green'))
838
+
839
+ if user_input.startswith('/'):
840
+ if not chat.execute_command(user_input):
841
+ break
842
+ continue
843
+
844
+ if not user_input.strip():
845
+ continue
846
+
847
+ live_buffer = ""
848
+ final_response = None
849
+
850
+ response_generator = chat.generate_response(user_input)
851
+
852
+ try:
853
+ print(colored("CosmicFish: ", 'blue'), end="")
854
+ sys.stdout.flush()
855
+
856
+ for token, live_text, is_done in response_generator:
857
+ if is_done:
858
+ final_response = live_text
859
+ if not live_buffer:
860
+ print(final_response, end="")
861
+ break
862
+
863
+ if token:
864
+ if "<|endoftext|>" in token:
865
+ token = token.replace("<|endoftext|>", "")
866
+ if token:
867
+ print(token, end="", flush=True)
868
+ break
869
+ print(token, end="", flush=True)
870
+ live_buffer += token
871
+
872
+ except KeyboardInterrupt:
873
+ print("\n[Generation interrupted]")
874
+
875
+ print()
876
+
877
+ except KeyboardInterrupt:
878
+ print("\n\nKeyboard interrupt. Type /exit to quit or continue chatting.")
879
+
880
+ except Exception as e:
881
+ print(colored(f"\nError: {str(e)}", 'red'))
882
+ logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
883
+
884
+
885
+ if __name__ == "__main__":
886
+ try:
887
+ main()
888
+ except Exception as e:
889
+ logger.error(f"Fatal error: {str(e)}", exc_info=True)
890
+ sys.exit(1)
chat_local.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import argparse
5
+ import torch
6
+ import numpy as np
7
+ from termcolor import colored
8
+ import logging
9
+ import readline
10
+ import re
11
+ import textwrap
12
+ import random
13
+ from collections import defaultdict
14
+ import tiktoken
15
+
16
+ import json
17
+ from safetensors.torch import load_file
18
+ from modeling_hrm_cosmicfish import HRMCosmicFish, HRMCosmicFishConfig
19
+
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(levelname)s - %(message)s',
23
+ handlers=[logging.StreamHandler(sys.stdout)]
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n"
28
+
29
+
30
+ class RepetitionPenaltyLogitsProcessor:
31
+ def __init__(self, penalty=1.2):
32
+ self.penalty = penalty
33
+
34
+ def __call__(self, input_ids, scores):
35
+ score = torch.gather(scores, 1, input_ids)
36
+ score = torch.where(score > 0, score / self.penalty, score * self.penalty)
37
+ scores.scatter_(1, input_ids, score)
38
+ return scores
39
+
40
+
41
+ class ChatSession:
42
+ def __init__(self, model, tokenizer, config):
43
+ self.model = model
44
+ self.tokenizer = tokenizer
45
+ self.config = config
46
+ self.device = config.device
47
+ self.history = []
48
+ self.history_tokens = []
49
+ self.max_history_tokens = config.max_history_tokens
50
+ self.prompt_template = config.prompt_template
51
+ self.human_prefix = config.human_prefix
52
+ self.assistant_prefix = config.assistant_prefix
53
+ self.end_of_turn = config.end_of_turn
54
+ self.block_size = config.block_size
55
+ self.debug_mode = config.debug_mode
56
+ self.repetition_penalty = config.repetition_penalty
57
+ self.min_tokens_to_generate = config.min_tokens_to_generate
58
+
59
+ self.hrm_forced_steps = None
60
+ self.original_hrm_max_steps = self.model.config.hrm_max_steps
61
+
62
+ self.max_retries = 20
63
+
64
+ self.fallback_responses = [
65
+ "I'd be happy to help with that. Could you provide more details?",
66
+ "That's interesting. What specific aspects would you like to know about?",
67
+ "I can help with that. Could you clarify what you're looking for?",
68
+ "Let me help you with that. What particular information do you need?",
69
+ "I understand. Could you be more specific about what you'd like to know?"
70
+ ]
71
+
72
+ self.generation_failure_message = "I'm having difficulty generating a response. Could you try rephrasing?"
73
+
74
+ self.total_prompt_tokens = 0
75
+ self.total_generated_tokens = 0
76
+ self.total_hrm_steps_used = 0
77
+
78
+ self.end_markers = [
79
+ f"{self.human_prefix}",
80
+ "Human:",
81
+ "\nHuman:",
82
+ "\nH:",
83
+ "H:",
84
+ "<|endoftext|>",
85
+ "Below is a conversation",
86
+ "\nA:",
87
+ "A:",
88
+ "</s>",
89
+ "User:",
90
+ "\nUser:"
91
+ ]
92
+
93
+ if config.display_welcome:
94
+ self._print_welcome_message()
95
+
96
+ def _print_welcome_message(self):
97
+ hrm_mode = f"auto (max {self.original_hrm_max_steps})" if self.hrm_forced_steps is None else str(self.hrm_forced_steps)
98
+ print(colored(f"""
99
+ {'=' * 80}
100
+ Welcome to CosmicFish-HRM
101
+
102
+ Model: {self.model.get_num_params() / 1e6:.1f}M parameters
103
+ Max HRM Steps: {self.original_hrm_max_steps} | Current HRM Mode: {hrm_mode}
104
+
105
+ Commands: /help /clear /exit /stats /save /load
106
+ /temp [val] /penalty [val] /hrm [n|auto] /debug
107
+ {'=' * 80}
108
+ """, 'cyan'))
109
+
110
+ def _format_prompt(self, user_input):
111
+ formatted_prompt = self.prompt_template
112
+ for entry in self.history:
113
+ role, text = entry
114
+ if role == "human":
115
+ formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}"
116
+ else:
117
+ formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}"
118
+ formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}"
119
+ return formatted_prompt
120
+
121
+ def _tokenize(self, text):
122
+ return self.tokenizer.encode(text)
123
+
124
+ def _update_history(self, user_input, response):
125
+ self.history.append(("human", user_input))
126
+ self.history.append(("assistant", response))
127
+
128
+ user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}")
129
+ response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}")
130
+
131
+ self.history_tokens.extend(user_tokens)
132
+ self.history_tokens.extend(response_tokens)
133
+
134
+ self.total_prompt_tokens += len(user_tokens)
135
+ self.total_generated_tokens += len(response_tokens)
136
+
137
+ self._trim_history_if_needed()
138
+
139
+ def _trim_history_if_needed(self):
140
+ if len(self.history_tokens) > self.max_history_tokens:
141
+ while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2:
142
+ self.history = self.history[2:]
143
+ user_turn = self.history[0][1]
144
+ assistant_turn = self.history[1][1]
145
+ user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}"))
146
+ assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}"))
147
+ self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:]
148
+
149
+ def _should_stop_generation(self, text):
150
+ for marker in self.end_markers:
151
+ if marker in text:
152
+ return True
153
+ return False
154
+
155
+ def _clean_token_text(self, text):
156
+ return text.replace("<|endoftext|>", "")
157
+
158
+ def _is_repetitive(self, tokens, window=10):
159
+ if len(tokens) < window:
160
+ return False
161
+ recent = tokens[-window:]
162
+ if len(set(recent)) < 3:
163
+ return True
164
+ for pattern_len in [2, 3, 4]:
165
+ if len(recent) >= pattern_len * 2:
166
+ pattern = tuple(recent[-pattern_len:])
167
+ prev_pattern = tuple(recent[-pattern_len*2:-pattern_len])
168
+ if pattern == prev_pattern:
169
+ return True
170
+ return False
171
+
172
+ def _set_hrm_steps(self, steps):
173
+ self.model.config.hrm_max_steps = steps
174
+ self.model.hrm_core.config.hrm_max_steps = steps
175
+
176
+ def _restore_hrm_steps(self):
177
+ self.model.config.hrm_max_steps = self.original_hrm_max_steps
178
+ self.model.hrm_core.config.hrm_max_steps = self.original_hrm_max_steps
179
+
180
+ def generate_response(self, user_input):
181
+ if self.hrm_forced_steps is not None:
182
+ self._set_hrm_steps(self.hrm_forced_steps)
183
+
184
+ try:
185
+ full_prompt = self._format_prompt(user_input)
186
+ prompt_tokens = self._tokenize(full_prompt)
187
+ input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
188
+
189
+ if self.debug_mode:
190
+ print(f"\n[DEBUG] Prompt tokens: {len(prompt_tokens)}")
191
+ print(f"[DEBUG] HRM mode: {'auto' if self.hrm_forced_steps is None else self.hrm_forced_steps} (model max: {self.model.config.hrm_max_steps})")
192
+
193
+ generated_tokens = []
194
+ accumulated_text = ""
195
+ repetition_processor = RepetitionPenaltyLogitsProcessor(self.repetition_penalty)
196
+ total_hrm_steps = 0
197
+
198
+ with torch.no_grad():
199
+ for step in range(self.config.max_new_tokens):
200
+ context = input_ids[:, -self.block_size:] if input_ids.size(1) > self.block_size else input_ids
201
+
202
+ logits, _, steps_taken, _ = self.model(context)
203
+ total_hrm_steps += steps_taken.item()
204
+
205
+ logits = logits[:, -1, :] / self.config.temperature
206
+ logits = repetition_processor(context, logits)
207
+
208
+ if self.config.top_k > 0:
209
+ v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))
210
+ logits[logits < v[:, [-1]]] = float('-inf')
211
+
212
+ probs = torch.nn.functional.softmax(logits, dim=-1)
213
+ next_token = torch.multinomial(probs, num_samples=1)
214
+
215
+ if next_token.item() == 50256:
216
+ break
217
+
218
+ token_text = self._clean_token_text(self.tokenizer.decode([next_token.item()]))
219
+ generated_tokens.append(next_token.item())
220
+ accumulated_text += token_text
221
+
222
+ if self._should_stop_generation(accumulated_text):
223
+ for marker in self.end_markers:
224
+ if marker in accumulated_text:
225
+ accumulated_text = accumulated_text.split(marker)[0]
226
+ break
227
+ break
228
+
229
+ if self._is_repetitive(generated_tokens):
230
+ if self.debug_mode:
231
+ print("\n[DEBUG] Detected repetition, stopping")
232
+ break
233
+
234
+ yield (token_text, accumulated_text, False)
235
+
236
+ input_ids = torch.cat([input_ids, next_token], dim=1)
237
+
238
+ if step < self.min_tokens_to_generate:
239
+ continue
240
+
241
+ final_response = accumulated_text.strip()
242
+ for marker in self.end_markers:
243
+ if final_response.endswith(marker.strip()):
244
+ final_response = final_response[:-len(marker.strip())].strip()
245
+
246
+ self.total_hrm_steps_used += total_hrm_steps
247
+
248
+ if self.debug_mode:
249
+ avg_steps = total_hrm_steps / len(generated_tokens) if generated_tokens else 0
250
+ print(f"\n[DEBUG] Generated {len(generated_tokens)} tokens | Total HRM steps: {total_hrm_steps} | Avg steps/token: {avg_steps:.1f}")
251
+
252
+ self._update_history(user_input, final_response)
253
+ yield (None, final_response, True)
254
+
255
+ finally:
256
+ if self.hrm_forced_steps is not None:
257
+ self._restore_hrm_steps()
258
+
259
+ def execute_command(self, command):
260
+ command_lower = command.lower().strip()
261
+
262
+ if command_lower in ['/exit', '/quit', '/q']:
263
+ print(colored("Goodbye!", 'cyan'))
264
+ return False
265
+
266
+ elif command_lower == '/help':
267
+ self._print_welcome_message()
268
+
269
+ elif command_lower == '/clear':
270
+ self.history = []
271
+ self.history_tokens = []
272
+ print(colored("Conversation history cleared.", 'yellow'))
273
+
274
+ elif command_lower == '/stats':
275
+ self._print_stats()
276
+
277
+ elif command_lower == '/debug':
278
+ self.debug_mode = not self.debug_mode
279
+ print(colored(f"Debug mode {'enabled' if self.debug_mode else 'disabled'}.", 'yellow'))
280
+
281
+ elif command_lower.startswith('/temp '):
282
+ try:
283
+ temp = float(command.split()[1])
284
+ if 0.1 <= temp <= 2.0:
285
+ self.config.temperature = temp
286
+ print(colored(f"Temperature set to {temp}", 'yellow'))
287
+ else:
288
+ print(colored("Temperature must be between 0.1 and 2.0", 'red'))
289
+ except:
290
+ print(colored("Usage: /temp [value]", 'red'))
291
+
292
+ elif command_lower.startswith('/penalty '):
293
+ try:
294
+ penalty = float(command.split()[1])
295
+ if 1.0 <= penalty <= 2.0:
296
+ self.repetition_penalty = penalty
297
+ print(colored(f"Repetition penalty set to {penalty}", 'yellow'))
298
+ else:
299
+ print(colored("Penalty must be between 1.0 and 2.0", 'red'))
300
+ except:
301
+ print(colored("Usage: /penalty [value]", 'red'))
302
+
303
+ elif command_lower.startswith('/hrm '):
304
+ try:
305
+ hrm_arg = command.split()[1].lower()
306
+ if hrm_arg == 'auto':
307
+ self.hrm_forced_steps = 8
308
+ print(colored(f"HRM mode set to AUTO (model will use up to {self.original_hrm_max_steps} steps)", 'yellow'))
309
+ else:
310
+ steps = int(hrm_arg)
311
+ if 0 <= steps <= 9999:
312
+ self.hrm_forced_steps = steps
313
+ print(colored(f"HRM forced to {steps} step(s)", 'yellow'))
314
+ if steps == 0:
315
+ print(colored("Warning: HRM with 0 steps means no iterative reasoning!", 'red'))
316
+ else:
317
+ print(colored("HRM steps must be between 0 and 9999", 'red'))
318
+ except:
319
+ print(colored("Usage: /hrm [number] or /hrm auto", 'red'))
320
+
321
+ elif command_lower.startswith('/save '):
322
+ try:
323
+ self._save_conversation(command.split(maxsplit=1)[1])
324
+ except:
325
+ print(colored("Usage: /save [filename]", 'red'))
326
+
327
+ elif command_lower.startswith('/load '):
328
+ try:
329
+ self._load_conversation(command.split(maxsplit=1)[1])
330
+ except:
331
+ print(colored("Usage: /load [filename]", 'red'))
332
+
333
+ else:
334
+ print(colored(f"Unknown command: {command}", 'red'))
335
+ print(colored("Type /help for available commands", 'yellow'))
336
+
337
+ return True
338
+
339
+ def _print_stats(self):
340
+ avg_hrm = self.total_hrm_steps_used / self.total_generated_tokens if self.total_generated_tokens > 0 else 0
341
+ hrm_mode = "AUTO" if self.hrm_forced_steps is None else f"FORCED ({self.hrm_forced_steps})"
342
+ print(colored(f"""
343
+ {'=' * 60}
344
+ CONVERSATION STATISTICS
345
+ {'=' * 60}
346
+ Prompt tokens: {self.total_prompt_tokens:,}
347
+ Generated tokens: {self.total_generated_tokens:,}
348
+ Total HRM steps: {self.total_hrm_steps_used:,}
349
+ Avg HRM steps/tok: {avg_hrm:.2f}
350
+ Turns: {len(self.history) // 2}
351
+ History tokens: {len(self.history_tokens):,}
352
+
353
+ Temperature: {self.config.temperature}
354
+ Repetition penalty: {self.repetition_penalty}
355
+ HRM mode: {hrm_mode}
356
+ Model max HRM steps:{self.original_hrm_max_steps}
357
+ Top-k: {self.config.top_k}
358
+ {'=' * 60}
359
+ """, 'cyan'))
360
+
361
+ def _save_conversation(self, filename):
362
+ try:
363
+ with open(filename, 'w', encoding='utf-8') as f:
364
+ f.write("HRM-CosmicFish Conversation\n")
365
+ f.write(f"{'=' * 80}\n\n")
366
+ for role, text in self.history:
367
+ prefix = "Human: " if role == "human" else "Assistant: "
368
+ f.write(f"{prefix}{text}\n\n")
369
+ print(colored(f"Conversation saved to {filename}", 'green'))
370
+ except Exception as e:
371
+ print(colored(f"Error saving conversation: {e}", 'red'))
372
+
373
+ def _load_conversation(self, filename):
374
+ try:
375
+ with open(filename, 'r', encoding='utf-8') as f:
376
+ lines = f.read().split('\n')
377
+
378
+ self.history = []
379
+ self.history_tokens = []
380
+
381
+ current_role = None
382
+ current_text = []
383
+
384
+ for line in lines:
385
+ if line.startswith('Human: '):
386
+ if current_role and current_text:
387
+ self.history.append((current_role, '\n'.join(current_text).strip()))
388
+ current_role = 'human'
389
+ current_text = [line[7:]]
390
+ elif line.startswith('Assistant: '):
391
+ if current_role and current_text:
392
+ self.history.append((current_role, '\n'.join(current_text).strip()))
393
+ current_role = 'assistant'
394
+ current_text = [line[11:]]
395
+ elif line.strip() and current_role:
396
+ current_text.append(line)
397
+
398
+ if current_role and current_text:
399
+ self.history.append((current_role, '\n'.join(current_text).strip()))
400
+
401
+ print(colored(f"Conversation loaded from {filename} ({len(self.history)//2} turns)", 'green'))
402
+ except Exception as e:
403
+ print(colored(f"Error loading conversation: {e}", 'red'))
404
+
405
+
406
+ def main():
407
+ parser = argparse.ArgumentParser(description="Chat with CosmicFish-HRM model")
408
+
409
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
410
+ parser.add_argument("--temperature", type=float, default=0.5)
411
+ parser.add_argument("--max_tokens", type=int, default=3000)
412
+ parser.add_argument("--min_tokens", type=int, default=10)
413
+ parser.add_argument("--top_k", type=int, default=40)
414
+ parser.add_argument("--repetition_penalty", type=float, default=1.2)
415
+ parser.add_argument("--human_prefix", type=str, default="Human: ")
416
+ parser.add_argument("--assistant_prefix", type=str, default="Assistant: ")
417
+ parser.add_argument("--end_of_turn", type=str, default="\n\n")
418
+ parser.add_argument("--instruction", type=str, default=DEFAULT_PROMPT_TEMPLATE)
419
+ parser.add_argument("--max_history", type=int, default=1024)
420
+ parser.add_argument("--no_welcome", action="store_true")
421
+ parser.add_argument("--debug", action="store_true")
422
+
423
+ args = parser.parse_args()
424
+
425
+ model_dir = os.path.dirname(os.path.abspath(__file__))
426
+
427
+ device = args.device
428
+ if device == "cuda" and not torch.cuda.is_available():
429
+ print("CUDA not available, falling back to CPU")
430
+ device = "cpu"
431
+
432
+ print(f"Loading HRM-CosmicFish model from {model_dir}...")
433
+ try:
434
+
435
+ config_path = os.path.join(model_dir, "config.json")
436
+ weights_path = os.path.join(model_dir, "model.safetensors")
437
+
438
+ if not os.path.exists(config_path):
439
+ raise FileNotFoundError(f"config.json not found in {model_dir}")
440
+ if not os.path.exists(weights_path):
441
+ raise FileNotFoundError(f"model.safetensors not found in {model_dir}")
442
+
443
+ with open(config_path) as f:
444
+ cfg = json.load(f)
445
+
446
+ config = HRMCosmicFishConfig(
447
+ vocab_size=cfg["vocab_size"],
448
+ n_embd=cfg["n_embd"],
449
+ block_size=cfg["block_size"],
450
+ n_head=cfg["n_head"],
451
+ n_kv_head=cfg["n_kv_head"],
452
+ n_input_layers=cfg["n_input_layers"],
453
+ n_output_layers=cfg["n_output_layers"],
454
+ hrm_H_layers=cfg["hrm_H_layers"],
455
+ hrm_L_layers=cfg["hrm_L_layers"],
456
+ hrm_H_cycles=cfg["hrm_H_cycles"],
457
+ hrm_L_cycles=cfg["hrm_L_cycles"],
458
+ hrm_max_steps=cfg["hrm_max_steps"],
459
+ hrm_exploration_prob=cfg["hrm_exploration_prob"],
460
+ dropout=cfg["dropout"],
461
+ bias=cfg["bias"],
462
+ use_rotary=cfg["use_rotary"],
463
+ use_gqa=cfg["use_gqa"],
464
+ use_swiglu=cfg["use_swiglu"],
465
+ eps=cfg["eps"],
466
+ )
467
+
468
+ model = HRMCosmicFish(config)
469
+
470
+ state_dict = load_file(weights_path, device=device)
471
+
472
+ try:
473
+ model.load_state_dict(state_dict)
474
+ except RuntimeError as e:
475
+ logger.warning(f"Strict loading failed: {e}, attempting flexible loading...")
476
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
477
+ if missing:
478
+ logger.warning(f"Missing keys: {len(missing)}")
479
+ if unexpected:
480
+ logger.warning(f"Unexpected keys: {len(unexpected)}")
481
+
482
+ model.to(device)
483
+ model.eval()
484
+
485
+ block_size = config.block_size
486
+
487
+ print(f"Model loaded: {model.get_num_params() / 1e6:.2f}M parameters")
488
+ print(f" Input blocks: {config.n_input_layers} | HRM: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps) | Output blocks: {config.n_output_layers}")
489
+
490
+ except Exception as e:
491
+ print(f"Error loading model: {str(e)}")
492
+ return
493
+
494
+ try:
495
+ tokenizer = tiktoken.get_encoding("gpt2")
496
+ except Exception as e:
497
+ print(f"Error loading tokenizer: {str(e)}")
498
+ return
499
+
500
+ class ChatConfig:
501
+ def __init__(self, args, block_size, device):
502
+ self.device = device
503
+ self.temperature = args.temperature
504
+ self.max_new_tokens = args.max_tokens
505
+ self.min_tokens_to_generate = args.min_tokens
506
+ self.top_k = args.top_k
507
+ self.human_prefix = args.human_prefix
508
+ self.assistant_prefix = args.assistant_prefix
509
+ self.end_of_turn = args.end_of_turn
510
+ self.prompt_template = args.instruction
511
+ self.max_history_tokens = args.max_history
512
+ self.display_welcome = not args.no_welcome
513
+ self.block_size = block_size
514
+ self.debug_mode = args.debug
515
+ self.repetition_penalty = args.repetition_penalty
516
+
517
+ chat = ChatSession(model, tokenizer, ChatConfig(args, block_size, device))
518
+
519
+ print(colored("\nHRM-CosmicFish initialized. Type your message (or /help for commands).\n", 'cyan'))
520
+
521
+ while True:
522
+ try:
523
+ user_input = input(colored("You: ", 'green'))
524
+
525
+ if user_input.startswith('/'):
526
+ if not chat.execute_command(user_input):
527
+ break
528
+ continue
529
+
530
+ if not user_input.strip():
531
+ continue
532
+
533
+ live_buffer = ""
534
+ final_response = None
535
+
536
+ response_generator = chat.generate_response(user_input)
537
+
538
+ try:
539
+ print(colored("CosmicFish: ", 'blue'), end="")
540
+ sys.stdout.flush()
541
+
542
+ for token, live_text, is_done in response_generator:
543
+ if is_done:
544
+ final_response = live_text
545
+ if not live_buffer:
546
+ print(final_response, end="")
547
+ break
548
+
549
+ if token:
550
+ if "<|endoftext|>" in token:
551
+ token = token.replace("<|endoftext|>", "")
552
+ if token:
553
+ print(token, end="", flush=True)
554
+ break
555
+ print(token, end="", flush=True)
556
+ live_buffer += token
557
+
558
+ except KeyboardInterrupt:
559
+ print("\n[Generation interrupted]")
560
+
561
+ print()
562
+
563
+ except KeyboardInterrupt:
564
+ print("\n\nKeyboard interrupt. Type /exit to quit or continue chatting.")
565
+
566
+ except Exception as e:
567
+ print(colored(f"\nError: {str(e)}", 'red'))
568
+ logger.error(f"Error in chat loop: {str(e)}", exc_info=True)
569
+
570
+
571
+ if __name__ == "__main__":
572
+ try:
573
+ main()
574
+ except Exception as e:
575
+ logger.error(f"Fatal error: {str(e)}", exc_info=True)
576
+ sys.exit(1)
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "cosmicfish_hrm",
3
+ "architectures": [
4
+ "HRMCosmicFish"
5
+ ],
6
+ "vocab_size": 50304,
7
+ "n_embd": 448,
8
+ "block_size": 512,
9
+ "n_head": 8,
10
+ "n_kv_head": 4,
11
+ "n_input_layers": 6,
12
+ "n_output_layers": 6,
13
+ "hrm_H_layers": 4,
14
+ "hrm_L_layers": 4,
15
+ "hrm_H_cycles": 2,
16
+ "hrm_L_cycles": 2,
17
+ "hrm_max_steps": 16,
18
+ "hrm_exploration_prob": 0.05,
19
+ "dropout": 0.1,
20
+ "bias": false,
21
+ "use_rotary": true,
22
+ "use_gqa": true,
23
+ "use_swiglu": true,
24
+ "eps": 1e-05,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.41.0",
27
+ "pad_token_id": 50256,
28
+ "bos_token_id": 50256,
29
+ "eos_token_id": 50256
30
+ }
example_usage.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import tiktoken
4
+ from safetensors.torch import load_file
5
+ from modeling_hrm_cosmicfish import HRMCosmicFish, HRMCosmicFishConfig
6
+
7
+
8
+ def load_model(model_dir, device="cpu"):
9
+ with open(f"{model_dir}/config.json") as f:
10
+ cfg = json.load(f)
11
+
12
+ config = HRMCosmicFishConfig(
13
+ vocab_size=cfg["vocab_size"],
14
+ n_embd=cfg["n_embd"],
15
+ block_size=cfg["block_size"],
16
+ n_head=cfg["n_head"],
17
+ n_kv_head=cfg["n_kv_head"],
18
+ n_input_layers=cfg["n_input_layers"],
19
+ n_output_layers=cfg["n_output_layers"],
20
+ hrm_H_layers=cfg["hrm_H_layers"],
21
+ hrm_L_layers=cfg["hrm_L_layers"],
22
+ hrm_H_cycles=cfg["hrm_H_cycles"],
23
+ hrm_L_cycles=cfg["hrm_L_cycles"],
24
+ hrm_max_steps=cfg["hrm_max_steps"],
25
+ dropout=0.0,
26
+ )
27
+
28
+ state_dict = load_file(f"{model_dir}/model.safetensors")
29
+ model = HRMCosmicFish(config)
30
+ model.load_state_dict(state_dict)
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ tokenizer = tiktoken.get_encoding("gpt2")
35
+ return model, tokenizer
36
+
37
+
38
+ def generate(model, tokenizer, prompt, device="cpu", max_new_tokens=100, temperature=0.7, top_k=40):
39
+ tokens = tokenizer.encode(prompt)
40
+ idx = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
41
+ with torch.no_grad():
42
+ output = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
43
+ return tokenizer.decode(output[0].tolist())
44
+
45
+
46
+ if __name__ == "__main__":
47
+ model, tokenizer = load_model(".")
48
+ prompts = [
49
+ "What is the capital of France?",
50
+ "What is artificial intelligence?",
51
+ "What does def fibonacci(n): do?",
52
+ ]
53
+ for prompt in prompts:
54
+ result = generate(model, tokenizer, prompt)
55
+ print(f"Prompt: {prompt}")
56
+ print(f"Output: {result}")
57
+ print()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45c13e18c3d876ed408db1ec62b2f940db4eadde0774905e760ed2c5933825a2
3
+ size 210627308
modeling_hrm_cosmicfish.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Dict
7
+
8
+
9
+ @dataclass
10
+ class HRMCosmicFishConfig:
11
+ vocab_size: int = 50304
12
+ n_embd: int = 448
13
+ block_size: int = 512
14
+
15
+ n_input_layers: int = 6
16
+ n_output_layers: int = 6
17
+ n_head: int = 8
18
+
19
+ hrm_H_layers: int = 4
20
+ hrm_L_layers: int = 4
21
+ hrm_H_cycles: int = 2
22
+ hrm_L_cycles: int = 2
23
+ hrm_max_steps: int = 16
24
+ hrm_exploration_prob: float = 0.1
25
+
26
+ dropout: float = 0.1
27
+ bias: bool = False
28
+
29
+ use_rotary: bool = True
30
+ use_gqa: bool = True
31
+ use_swiglu: bool = True
32
+ n_kv_head: int = 4
33
+
34
+ eps: float = 1e-5
35
+
36
+ forward_dtype: str = "bfloat16"
37
+
38
+
39
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
40
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
41
+ t = torch.arange(end, device=freqs.device)
42
+ freqs = torch.outer(t, freqs).float()
43
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
44
+ return freqs_cis
45
+
46
+
47
+ def apply_rotary_emb(xq, xk, freqs_cis):
48
+ # xq, xk: [B, n_heads, T, head_dim], freqs_cis: [T, head_dim/2]
49
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
50
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
51
+ freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0)
52
+ freqs_cis = freqs_cis[:, :, :xq_.shape[2], :]
53
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
54
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
55
+ return xq_out.type_as(xq), xk_out.type_as(xk)
56
+
57
+
58
+ class RMSNorm(nn.Module):
59
+ def __init__(self, dim: int, eps: float = 1e-5):
60
+ super().__init__()
61
+ self.eps = eps
62
+ self.weight = nn.Parameter(torch.ones(dim))
63
+
64
+ def forward(self, x):
65
+ input_dtype = x.dtype
66
+ x = x.to(torch.float32)
67
+ variance = x.pow(2).mean(-1, keepdim=True)
68
+ x = x * torch.rsqrt(variance + self.eps)
69
+ return (self.weight * x).to(input_dtype)
70
+
71
+
72
+ class GroupedQueryAttention(nn.Module):
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ assert config.n_embd % config.n_head == 0
76
+
77
+ self.n_head = config.n_head
78
+ self.n_kv_head = config.n_kv_head if config.use_gqa else config.n_head
79
+ self.head_dim = config.n_embd // config.n_head
80
+ self.n_embd = config.n_embd
81
+
82
+ self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
83
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias)
84
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias)
85
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
86
+
87
+ self.attn_dropout = nn.Dropout(config.dropout)
88
+ self.resid_dropout = nn.Dropout(config.dropout)
89
+
90
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
91
+
92
+ def forward(self, x, freqs_cis=None):
93
+ B, T, C = x.size()
94
+
95
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
96
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
97
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
98
+
99
+ if freqs_cis is not None:
100
+ q, k = apply_rotary_emb(q, k, freqs_cis)
101
+
102
+ if self.n_kv_head != self.n_head:
103
+ k = k.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
104
+ v = v.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
105
+
106
+ if self.flash:
107
+ y = torch.nn.functional.scaled_dot_product_attention(
108
+ q, k, v,
109
+ attn_mask=None,
110
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
111
+ is_causal=True
112
+ )
113
+ else:
114
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
115
+ att = att.masked_fill(torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool(), float('-inf'))
116
+ att = F.softmax(att, dim=-1)
117
+ att = self.attn_dropout(att)
118
+ y = att @ v
119
+
120
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
121
+ y = self.resid_dropout(self.c_proj(y))
122
+ return y
123
+
124
+
125
+ class MLP(nn.Module):
126
+ def __init__(self, config):
127
+ super().__init__()
128
+ hidden_dim = 4 * config.n_embd
129
+
130
+ if config.use_swiglu:
131
+ self.gate = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
132
+ self.up = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
133
+ self.down = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
134
+ self.act = nn.SiLU()
135
+ else:
136
+ self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
137
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
138
+ self.act = nn.GELU()
139
+
140
+ self.dropout = nn.Dropout(config.dropout)
141
+ self.use_swiglu = config.use_swiglu
142
+
143
+ def forward(self, x):
144
+ if self.use_swiglu:
145
+ return self.dropout(self.down(self.act(self.up(x)) * self.gate(x)))
146
+ else:
147
+ return self.dropout(self.c_proj(self.act(self.c_fc(x))))
148
+
149
+
150
+ class TransformerBlock(nn.Module):
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
154
+ self.attn = GroupedQueryAttention(config)
155
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
156
+ self.mlp = MLP(config)
157
+
158
+ def forward(self, x, freqs_cis=None):
159
+ x = x + self.attn(self.ln_1(x), freqs_cis)
160
+ x = x + self.mlp(self.ln_2(x))
161
+ return x
162
+
163
+
164
+ class HRMReasoningBlock(nn.Module):
165
+ def __init__(self, config):
166
+ super().__init__()
167
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
168
+ self.attn = GroupedQueryAttention(config)
169
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
170
+ self.mlp = MLP(config)
171
+
172
+ def forward(self, x, freqs_cis=None):
173
+ # Post-norm architecture for HRM
174
+ x = self.ln_1(x + self.attn(x, freqs_cis))
175
+ x = self.ln_2(x + self.mlp(x))
176
+ return x
177
+
178
+
179
+ class HRMReasoningLevel(nn.Module):
180
+ def __init__(self, config, n_layers):
181
+ super().__init__()
182
+ self.layers = nn.ModuleList([HRMReasoningBlock(config) for _ in range(n_layers)])
183
+
184
+ def forward(self, hidden_states, input_injection, freqs_cis=None):
185
+ hidden_states = hidden_states + input_injection
186
+ for layer in self.layers:
187
+ hidden_states = layer(hidden_states, freqs_cis)
188
+ return hidden_states
189
+
190
+
191
+ class HRMCore(nn.Module):
192
+ def __init__(self, config):
193
+ super().__init__()
194
+ self.config = config
195
+
196
+ self.H_level = HRMReasoningLevel(config, config.hrm_H_layers)
197
+ self.L_level = HRMReasoningLevel(config, config.hrm_L_layers)
198
+
199
+ self.H_init = nn.Parameter(torch.randn(config.n_embd) * 0.02)
200
+ self.L_init = nn.Parameter(torch.randn(config.n_embd) * 0.02)
201
+
202
+ self.q_head = nn.Linear(config.n_embd, 2, bias=True) # [halt, continue]
203
+
204
+ with torch.no_grad():
205
+ self.q_head.weight.zero_()
206
+ self.q_head.bias.fill_(-5.0) # Bias towards halting
207
+
208
+ def forward(self, x, freqs_cis=None, training=False):
209
+ B, T, C = x.size()
210
+ device = x.device
211
+
212
+ z_H = self.H_init.expand(B, T, C)
213
+ z_L = self.L_init.expand(B, T, C)
214
+
215
+ steps_taken = torch.zeros(B, dtype=torch.long, device=device)
216
+ halted = torch.zeros(B, dtype=torch.bool, device=device)
217
+
218
+ q_logits_list = []
219
+
220
+ for step in range(self.config.hrm_max_steps):
221
+ if halted.all():
222
+ break
223
+
224
+ with torch.set_grad_enabled(step == self.config.hrm_max_steps - 1):
225
+ for _h in range(self.config.hrm_H_cycles):
226
+ for _l in range(self.config.hrm_L_cycles):
227
+ z_L = self.L_level(z_L, z_H + x, freqs_cis)
228
+ z_H = self.H_level(z_H, z_L, freqs_cis)
229
+
230
+ q_input = z_H.mean(dim=1) # [B, n_embd]
231
+ q_logits = self.q_head(q_input.float()) # [B, 2]
232
+ q_logits_list.append(q_logits)
233
+
234
+ if self.config.hrm_max_steps > 1:
235
+ q_halt = q_logits[:, 0]
236
+ q_continue = q_logits[:, 1]
237
+
238
+ if not training:
239
+ q_halt = q_halt + 0.35 # tune this value (try 1.0, 2.0, 3.0)
240
+
241
+ should_halt = q_halt > q_continue
242
+
243
+ if training and torch.rand(1).item() < self.config.hrm_exploration_prob:
244
+ min_steps = torch.randint(2, self.config.hrm_max_steps + 1, (1,)).item()
245
+ should_halt = should_halt & (steps_taken >= min_steps)
246
+
247
+ halted = halted | should_halt
248
+
249
+ steps_taken = torch.where(halted, steps_taken, steps_taken + 1)
250
+
251
+ if step == self.config.hrm_max_steps - 1:
252
+ halted = torch.ones_like(halted)
253
+
254
+ output_q_logits = q_logits_list[-1] if q_logits_list else None
255
+ return z_H, steps_taken, output_q_logits
256
+
257
+
258
+ class HRMCosmicFish(nn.Module):
259
+ """
260
+ Architecture: Input Blocks → HRM Reasoning Core → Output Blocks → LM Head
261
+ """
262
+
263
+ def __init__(self, config):
264
+ super().__init__()
265
+ self.config = config
266
+
267
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
268
+
269
+ if config.use_rotary:
270
+ self.freqs_cis = precompute_freqs_cis(
271
+ config.n_embd // config.n_head,
272
+ config.block_size
273
+ )
274
+ else:
275
+ self.freqs_cis = None
276
+ self.wpe = nn.Embedding(config.block_size, config.n_embd)
277
+
278
+ self.drop = nn.Dropout(config.dropout)
279
+
280
+ self.input_blocks = nn.ModuleList([
281
+ TransformerBlock(config) for _ in range(config.n_input_layers)
282
+ ])
283
+
284
+ self.hrm_core = HRMCore(config)
285
+
286
+ self.output_blocks = nn.ModuleList([
287
+ TransformerBlock(config) for _ in range(config.n_output_layers)
288
+ ])
289
+
290
+ self.ln_f = RMSNorm(config.n_embd, eps=config.eps)
291
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
292
+
293
+ # Weight tying
294
+ self.wte.weight = self.lm_head.weight
295
+
296
+ self.apply(self._init_weights)
297
+
298
+ for pn, p in self.named_parameters():
299
+ if pn.endswith('c_proj.weight') or pn.endswith('down.weight'):
300
+ total_layers = config.n_input_layers + config.n_output_layers + config.hrm_H_layers + config.hrm_L_layers
301
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * total_layers))
302
+
303
+ print(f"Model initialized with {self.get_num_params() / 1e6:.2f}M parameters")
304
+ print(f" Input blocks: {config.n_input_layers} layers")
305
+ print(f" HRM Core: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps)")
306
+ print(f" Output blocks: {config.n_output_layers} layers")
307
+
308
+ def _init_weights(self, module):
309
+ if isinstance(module, nn.Linear):
310
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
311
+ if module.bias is not None:
312
+ torch.nn.init.zeros_(module.bias)
313
+ elif isinstance(module, nn.Embedding):
314
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
315
+
316
+ def get_num_params(self, non_embedding=True):
317
+ n_params = sum(p.numel() for p in self.parameters())
318
+ if non_embedding and hasattr(self, 'wpe'):
319
+ n_params -= self.wpe.weight.numel()
320
+ return n_params
321
+
322
+ def forward(self, idx, targets=None):
323
+ device = idx.device
324
+ B, T = idx.size()
325
+ assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
326
+
327
+ x = self.wte(idx)
328
+
329
+ if self.config.use_rotary:
330
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
331
+ else:
332
+ pos = torch.arange(0, T, dtype=torch.long, device=device)
333
+ x = x + self.wpe(pos)
334
+ freqs_cis = None
335
+
336
+ x = self.drop(x)
337
+
338
+ for block in self.input_blocks:
339
+ x = block(x, freqs_cis)
340
+
341
+ x, steps_taken, q_logits = self.hrm_core(x, freqs_cis, training=self.training)
342
+
343
+ for block in self.output_blocks:
344
+ x = block(x, freqs_cis)
345
+
346
+ x = self.ln_f(x)
347
+ logits = self.lm_head(x)
348
+
349
+ loss = None
350
+ if targets is not None:
351
+ task_loss = F.cross_entropy(
352
+ logits.view(-1, logits.size(-1)),
353
+ targets.view(-1),
354
+ ignore_index=-1
355
+ )
356
+ step_penalty = 0.01 * steps_taken.float().mean() # penalize using more steps
357
+ loss = task_loss + step_penalty
358
+
359
+ return logits, loss, steps_taken, q_logits
360
+
361
+ @torch.no_grad()
362
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
363
+ for _ in range(max_new_tokens):
364
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
365
+
366
+ logits, _, _, _ = self(idx_cond)
367
+ logits = logits[:, -1, :] / temperature
368
+
369
+ if top_k is not None:
370
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
371
+ logits[logits < v[:, [-1]]] = -float('Inf')
372
+
373
+ probs = F.softmax(logits, dim=-1)
374
+ idx_next = torch.multinomial(probs, num_samples=1)
375
+ idx = torch.cat((idx, idx_next), dim=1)
376
+
377
+ return idx
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>",
5
+ "pad_token": "<|endoftext|>"
6
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "GPT2Tokenizer",
3
+ "vocab_size": 50257,
4
+ "model_max_length": 512,
5
+ "bos_token": "<|endoftext|>",
6
+ "eos_token": "<|endoftext|>",
7
+ "unk_token": "<|endoftext|>",
8
+ "pad_token": "<|endoftext|>",
9
+ "add_prefix_space": false
10
+ }