amanuelbyte commited on
Commit
e92b429
·
verified ·
1 Parent(s): 110bf70

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 250100,
3
+ "block_size": 516,
4
+ "d_model": 512,
5
+ "n_heads": 8,
6
+ "d_ff": 2048,
7
+ "dropout": 0.1,
8
+ "halt_max_steps": 8,
9
+ "ponder_loss_weight": 0.01,
10
+ "halt_bias_init": -2.2
11
+ }
hrm_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ # This cell contains the complete definition for the HRM-Text1 model architecture.
8
+ # It was separated in your original code, but needs to be defined before it can be used.
9
+
10
+ class RMSNorm(nn.Module):
11
+ def __init__(self, d_model, eps=1e-8):
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(d_model))
15
+ def forward(self, x):
16
+ return self.weight * (x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps))
17
+
18
+ class SwiGLUMuchPelu(nn.Module):
19
+ def __init__(self, d_model, d_ff, dropout=0.1):
20
+ super().__init__()
21
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
22
+ self.w2 = nn.Linear(d_model, d_ff, bias=False)
23
+ self.w3 = nn.Linear(d_ff, d_model, bias=False)
24
+ self.dropout = nn.Dropout(dropout)
25
+ def forward(self, x):
26
+ return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x)))
27
+
28
+ class HRMBlock(nn.Module):
29
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
30
+ super().__init__()
31
+ self.norm1 = RMSNorm(d_model)
32
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
33
+ self.norm2 = RMSNorm(d_model)
34
+ self.mlp = SwiGLUMuchPelu(d_model, d_ff, dropout)
35
+ self.dropout = nn.Dropout(dropout)
36
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
37
+ x_norm = self.norm1(x)
38
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
39
+ x = x + self.dropout(attn_out)
40
+ x = x + self.dropout(self.mlp(self.norm2(x)))
41
+ return x
42
+
43
+ class HRMInner(nn.Module):
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.H_module = HRMBlock(config["d_model"], config["n_heads"], config["d_ff"], config["dropout"])
47
+ self.L_module = HRMBlock(config["d_model"], config["n_heads"], config["d_ff"], config["dropout"])
48
+ def forward(self, z_H, z_L, attn_mask=None, key_padding_mask=None):
49
+ z_L_input = z_L + z_H
50
+ z_L_new = self.L_module(z_L_input, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
51
+ z_H_input = z_H + z_L_new
52
+ z_H_new = self.H_module(z_H_input, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
53
+ return z_H_new, z_L_new
54
+
55
+ class HRMText1(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.config = config
59
+ self.token_embeddings = nn.Embedding(config["vocab_size"], config["d_model"])
60
+ self.pos_embeddings = nn.Embedding(config["block_size"], config["d_model"]) # Positional embeddings
61
+ self.register_buffer("pos_ids", torch.arange(config["block_size"]).unsqueeze(0))
62
+ self.inner_model = HRMInner(config)
63
+ self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
64
+ self.halt_head = nn.Sequential(nn.Linear(config["d_model"], 1), nn.Sigmoid())
65
+ self.max_steps = config["halt_max_steps"]
66
+ self.ponder_loss_weight = config["ponder_loss_weight"]
67
+ self.gradient_checkpointing = False # Default to False for inference
68
+
69
+ with torch.no_grad():
70
+ self.halt_head[0].bias.fill_(config.get("halt_bias_init", -2.0))
71
+
72
+ def forward(self, input_ids, labels=None, attention_mask=None):
73
+ batch_size, seq_len = input_ids.shape
74
+ device = input_ids.device
75
+
76
+ z_L = self.token_embeddings(input_ids) + self.pos_embeddings(self.pos_ids[:, :seq_len])
77
+ z_H = torch.zeros_like(z_L)
78
+
79
+ key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
80
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)
81
+
82
+ halting_probs = torch.zeros((batch_size, seq_len, self.max_steps), device=device)
83
+ remainders = torch.ones((batch_size, seq_len), device=device)
84
+ total_z_H = torch.zeros_like(z_H)
85
+ n_updates = torch.zeros((batch_size, seq_len), device=device)
86
+
87
+ eps = 1e-6
88
+ for step in range(self.max_steps):
89
+ p_halt = self.halt_head(z_H).squeeze(-1)
90
+ p_halt = p_halt.clamp(eps, 1 - eps)
91
+ is_last_step = (step == self.max_steps - 1)
92
+
93
+ halt_now_prob = torch.ones_like(p_halt) if is_last_step else p_halt
94
+ contrib = remainders * halt_now_prob
95
+
96
+ halting_probs[:, :, step] = contrib
97
+ total_z_H += contrib.unsqueeze(-1) * z_H
98
+
99
+ remainders = remainders * (1 - p_halt) if not is_last_step else torch.zeros_like(remainders)
100
+
101
+ if not is_last_step:
102
+ n_updates += remainders
103
+
104
+ if torch.all(remainders < eps):
105
+ break
106
+
107
+ if self.training and self.gradient_checkpointing:
108
+ z_H, z_L = checkpoint(self.inner_model, z_H, z_L, attn_mask=causal_mask, key_padding_mask=key_padding_mask, use_reentrant=False)
109
+ else:
110
+ z_H, z_L = self.inner_model(z_H, z_L, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
111
+
112
+ logits = self.lm_head(total_z_H)
113
+ loss, ponder_loss, lm_loss = None, None, None
114
+ if labels is not None:
115
+ shift_logits = logits[..., :-1, :].contiguous()
116
+ shift_labels = labels[..., 1:].contiguous()
117
+ loss_fct = nn.CrossEntropyLoss()
118
+ lm_loss = loss_fct(shift_logits.view(-1, self.config["vocab_size"]), shift_labels.view(-1))
119
+ ponder_loss = torch.mean(n_updates)
120
+ loss = lm_loss + self.ponder_loss_weight * ponder_loss
121
+
122
+ return {"loss": loss, "logits": logits, "ponder_loss": ponder_loss, "lm_loss": lm_loss}
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f4fb7b97d938b5bca0e01750a4f3535cb01caaa867779beb0f5b597bbb3cbfd
3
+ size 1059061122
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<pad>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
3
+ size 4309802
tokenizer_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": [],
30
+ "clean_up_tokenization_spaces": false,
31
+ "eos_token": "</s>",
32
+ "extra_ids": 0,
33
+ "extra_special_tokens": {},
34
+ "legacy": true,
35
+ "model_max_length": 1000000000000000019884624838656,
36
+ "pad_token": "<pad>",
37
+ "sp_model_kwargs": {},
38
+ "tokenizer_class": "T5Tokenizer",
39
+ "unk_token": "<unk>",
40
+ "use_fast": false
41
+ }