Aluode commited on
Commit
30fb4f1
·
verified ·
1 Parent(s): 0ab65da

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +86 -0
  2. moire_chat3.py +205 -0
  3. moire_conv_trainer_v5.py +462 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from huggingface_hub import hf_hub_download
5
+ import sys
6
+
7
+ # Import your new v5 custom biological architecture
8
+ from moire_conv_trainer_v5 import MoireGPT, MoireGPTConfig
9
+
10
+ print("Downloading Moiré weights from HF Hub...")
11
+ # Points to your NEW HuggingFace repo and the Epoch 4 weights!
12
+ weights_path = hf_hub_download(repo_id="Aluode/MoireFormer137MillionP", filename="moire_phase2_ep4.pt")
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print(f"Running on device: {device}")
16
+
17
+ print("Initializing Moiré wave-field (137.9M)...")
18
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
19
+
20
+ # We use the 'xlarge' config from v5 (12 layers, 12 heads, 768 embd)
21
+ config = MoireGPTConfig(n_layer=12, n_head=12, n_embd=768)
22
+ model = MoireGPT(config)
23
+
24
+ # Load the weights into the field
25
+ state_dict = torch.load(weights_path, map_location=device, weights_only=True)
26
+ if 'model_state_dict' in state_dict:
27
+ state_dict = state_dict['model_state_dict']
28
+ model.load_state_dict(state_dict)
29
+ model.to(device)
30
+ model.eval()
31
+
32
+ def generate_text(prompt, max_new_tokens=80, temperature=0.7, top_k=50):
33
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
34
+
35
+ with torch.no_grad():
36
+ for _ in range(max_new_tokens):
37
+ if input_ids.size(1) > config.max_seq_len:
38
+ input_ids = input_ids[:, -config.max_seq_len:]
39
+
40
+ logits, _ = model(input_ids)
41
+ next_token_logits = logits[:, -1, :] / temperature
42
+
43
+ if top_k is not None:
44
+ v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
45
+ next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
46
+
47
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
48
+ next_token = torch.multinomial(probs, num_samples=1)
49
+
50
+ input_ids = torch.cat((input_ids, next_token), dim=1)
51
+
52
+ if next_token.item() == tokenizer.eos_token_id:
53
+ break
54
+
55
+ return tokenizer.decode(input_ids[0], skip_special_tokens=False)
56
+
57
+ def chat_interface(message, history):
58
+ prompt = ""
59
+ for msg in history:
60
+ if isinstance(msg, dict):
61
+ if msg.get("role") == "user":
62
+ prompt += f"User: {msg.get('content')}\n"
63
+ elif msg.get("role") == "assistant":
64
+ prompt += f"Bot: {msg.get('content')}\n"
65
+ elif isinstance(msg, (list, tuple)) and len(msg) == 2:
66
+ prompt += f"User: {msg[0]}\nBot: {msg[1]}\n"
67
+
68
+ prompt += f"User: {message}\nBot:"
69
+
70
+ full_response = generate_text(prompt)
71
+
72
+ # Strip the prompt out so the UI only shows the Bot's new reply
73
+ response_only = full_response[len(prompt):].strip()
74
+ return response_only
75
+
76
+ # Build the Gradio Web UI
77
+ demo = gr.ChatInterface(
78
+ fn=chat_interface,
79
+ title="MoireFormer (137.9M) - Phase-Interference AI",
80
+ description="This is a slightly larger MoireFormer which, instead of standard QKV dot-product attention, computes language via theoretical biological **Moiré wave-interference math**, proving AI can run on continuous geometric phase-space.",
81
+ examples=["Hi there!", "Can you tell me a story about a bunny and a turtle?", "Write a Python script."],
82
+ theme="soft"
83
+ )
84
+
85
+ if __name__ == "__main__":
86
+ demo.launch()
moire_chat3.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════════╗
3
+ ║ MOIRÉ CHAT — Interactive inference for any trained Moiré model ║
4
+ ║ ║
5
+ ║ Auto-detects model config from checkpoint, or specify manually. ║
6
+ ║ ║
7
+ ║ Usage: ║
8
+ ║ python moire_chat.py # uses defaults ║
9
+ ║ python moire_chat.py --weights moire_phase2_weights_ep4.pt --size xlarge ║
10
+ ╚══════════════════════════════════════════════════════════════════════════════╝
11
+ """
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import sys
16
+ import os
17
+ import argparse
18
+
19
+ # Import architecture — try both trainer versions
20
+ try:
21
+ from moire_conv_trainer_v5 import MoireGPT, MoireGPTConfig
22
+ except ImportError:
23
+ try:
24
+ from moire_conv_trainer_v3 import MoireGPT, MoireGPTConfig
25
+ except ImportError:
26
+ print("Error: Could not import MoireGPT.")
27
+ print("Make sure moire_conv_trainer_v4.py or v3 is in the same folder.")
28
+ sys.exit(1)
29
+
30
+
31
+ def load_model(args):
32
+ from transformers import AutoTokenizer
33
+ print("Loading tokenizer...")
34
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
35
+
36
+ # Try to load config from full checkpoint
37
+ config_dict = None
38
+ weights_path = args.weights
39
+
40
+ if args.checkpoint and os.path.exists(args.checkpoint):
41
+ print(f"Loading checkpoint {args.checkpoint}...")
42
+ ckpt = torch.load(args.checkpoint, map_location=args.device, weights_only=False)
43
+ if 'config' in ckpt:
44
+ config_dict = ckpt['config']
45
+ print(f" Config from checkpoint: {config_dict}")
46
+ weights_path = args.checkpoint # Will extract model_state below
47
+
48
+ # Build config
49
+ if config_dict:
50
+ config = MoireGPTConfig(
51
+ vocab_size=tokenizer.vocab_size,
52
+ n_layer=config_dict.get('n_layer', 4),
53
+ n_head=config_dict.get('n_head', 8),
54
+ n_embd=config_dict.get('n_embd', 256),
55
+ max_seq_len=config_dict.get('max_seq_len', 257),
56
+ gamma_slots=config_dict.get('gamma_slots', 8),
57
+ use_theta_gating=True,
58
+ )
59
+ else:
60
+ # Use size preset (Added xlarge!)
61
+ PRESETS = {
62
+ 'small': {'n_layer': 4, 'n_head': 8, 'n_embd': 256, 'max_seq_len': 129},
63
+ 'medium': {'n_layer': 6, 'n_head': 8, 'n_embd': 512, 'max_seq_len': 257},
64
+ 'large': {'n_layer': 8, 'n_head': 8, 'n_embd': 768, 'max_seq_len': 257},
65
+ 'xlarge': {'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'max_seq_len': 257},
66
+ }
67
+ p = PRESETS[args.size]
68
+ config = MoireGPTConfig(
69
+ vocab_size=tokenizer.vocab_size,
70
+ n_layer=p['n_layer'], n_head=p['n_head'], n_embd=p['n_embd'],
71
+ max_seq_len=p['max_seq_len'], gamma_slots=8, use_theta_gating=True,
72
+ )
73
+
74
+ print(f"Initializing Moiré model ({config.n_layer}L, {config.n_head}H, {config.n_embd}E)...")
75
+ model = MoireGPT(config)
76
+
77
+ # Load weights
78
+ print(f"Loading weights from {weights_path}...")
79
+ try:
80
+ state = torch.load(weights_path, map_location=args.device, weights_only=False)
81
+ if isinstance(state, dict) and 'model_state' in state:
82
+ model.load_state_dict(state['model_state'])
83
+ else:
84
+ model.load_state_dict(state)
85
+ except FileNotFoundError:
86
+ print(f"Error: {weights_path} not found!")
87
+ sys.exit(1)
88
+
89
+ model.to(args.device)
90
+
91
+ # # Only compress to bfloat16 if we are using the GPU!
92
+ # if args.device == 'cuda':
93
+ # model.bfloat16()
94
+
95
+ model.eval()
96
+
97
+ return model, tokenizer, config
98
+
99
+
100
+ def generate(model, tokenizer, config, prompt, max_tokens=80, temperature=0.7,
101
+ top_k=40, top_p=0.9, device='cuda'):
102
+ """Generate with top-k AND top-p (nucleus) sampling for better quality."""
103
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
104
+
105
+ print("Moiré: ", end="", flush=True)
106
+
107
+ for _ in range(max_tokens):
108
+ idx_cond = input_ids[:, -(config.max_seq_len - 1):]
109
+
110
+ with torch.no_grad():
111
+ logits, _ = model(idx_cond)
112
+
113
+ logits = logits[:, -1, :] / temperature
114
+
115
+ # Top-k filtering
116
+ if top_k is not None and top_k > 0:
117
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
118
+ logits[logits < v[:, [-1]]] = float('-inf')
119
+
120
+ # Top-p (nucleus) filtering
121
+ if top_p is not None and top_p < 1.0:
122
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
123
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
124
+ sorted_indices_to_remove = cumulative_probs > top_p
125
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
126
+ sorted_indices_to_remove[:, 0] = 0
127
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
128
+ logits[indices_to_remove] = float('-inf')
129
+
130
+ probs = F.softmax(logits, dim=-1)
131
+ next_token = torch.multinomial(probs, num_samples=1)
132
+ input_ids = torch.cat((input_ids, next_token), dim=1)
133
+
134
+ word = tokenizer.decode(next_token[0].tolist())
135
+ print(word, end="", flush=True)
136
+
137
+ # Stop at newline after "Bot:" response to prevent rambling
138
+ decoded_so_far = tokenizer.decode(input_ids[0].tolist())
139
+ if decoded_so_far.count('\n') > prompt.count('\n') + 2:
140
+ break
141
+
142
+ print()
143
+ return input_ids
144
+
145
+
146
+ def main():
147
+ parser = argparse.ArgumentParser(description="Moiré Chat Interface")
148
+ parser.add_argument('--weights', type=str, default='moire_conv_weights_final.pt',
149
+ help='Path to model weights (.pt)')
150
+ parser.add_argument('--checkpoint', type=str, default=None,
151
+ help='Path to full checkpoint (auto-detects config)')
152
+ parser.add_argument('--size', type=str, default='medium',
153
+ choices=['small', 'medium', 'large', 'xlarge'],
154
+ help='Model size if no checkpoint config available')
155
+ parser.add_argument('--device', type=str,
156
+ default='cuda' if torch.cuda.is_available() else 'cpu')
157
+ parser.add_argument('--temperature', type=float, default=0.7)
158
+ parser.add_argument('--max_tokens', type=int, default=80)
159
+ parser.add_argument('--mode', type=str, default='chat',
160
+ choices=['chat', 'complete'],
161
+ help='chat: formats as User/Bot. complete: raw completion')
162
+ args = parser.parse_args()
163
+
164
+ print(f"=== Moiré Attention Chat ===")
165
+ print(f"Device: {args.device.upper()}")
166
+ print()
167
+
168
+ model, tokenizer, config = load_model(args)
169
+
170
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
171
+ print(f"\n{'='*50}")
172
+ print(f"Moiré field ready. {n_params:.1f}M parameters.")
173
+ if args.mode == 'chat':
174
+ print(f"Chat mode: your input becomes 'User: ...' and model generates 'Bot: ...'")
175
+ else:
176
+ print(f"Completion mode: model continues your text directly.")
177
+ print(f"Temperature: {args.temperature} | Max tokens: {args.max_tokens}")
178
+ print(f"Type 'quit' to exit.")
179
+ print(f"{'='*50}\n")
180
+
181
+ while True:
182
+ try:
183
+ user_input = input("You: " if args.mode == 'chat' else "Prompt: ")
184
+ if user_input.lower().strip() in ['quit', 'exit']:
185
+ break
186
+ if not user_input.strip():
187
+ continue
188
+
189
+ if args.mode == 'chat':
190
+ prompt = f"User: {user_input}\nBot:"
191
+ else:
192
+ prompt = user_input
193
+
194
+ generate(model, tokenizer, config, prompt,
195
+ max_tokens=args.max_tokens,
196
+ temperature=args.temperature,
197
+ device=args.device)
198
+
199
+ except KeyboardInterrupt:
200
+ print("\nExiting...")
201
+ break
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()
moire_conv_trainer_v5.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════════╗
3
+ ║ MOIRÉ CONVERSATIONAL TRAINER v3 (Advanced Curriculums) ║
4
+ ║ ║
5
+ ║ Added new high-quality dataset loaders (Guanaco, TinyStories, FineWeb) ║
6
+ ║ to expand the semantic phase-space and cure hallucinations. ║
7
+ ╚══════════════════════════════════════════════════════════════════════════════╝
8
+ """
9
+ import random
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import math
14
+ import time
15
+ import os
16
+ import json
17
+ from typing import Optional
18
+ from dataclasses import dataclass
19
+
20
+ # ============================================================================
21
+ # 1. ARCHITECTURE
22
+ # ============================================================================
23
+
24
+ @dataclass
25
+ class MoireGPTConfig:
26
+ vocab_size: int = 50257
27
+ max_seq_len: int = 257
28
+ n_layer: int = 6
29
+ n_head: int = 8
30
+ n_embd: int = 512
31
+ gamma_slots: int = 8
32
+ dropout: float = 0.1
33
+ bias: bool = False
34
+ use_theta_gating: bool = True
35
+
36
+ @property
37
+ def head_dim(self):
38
+ return self.n_embd // self.n_head
39
+
40
+ class MoireAttention(nn.Module):
41
+ def __init__(self, config: MoireGPTConfig):
42
+ super().__init__()
43
+ self.config = config
44
+ self.n_head = config.n_head
45
+ self.head_dim = config.head_dim
46
+ self.n_embd = config.n_embd
47
+ self.gamma_slots = config.gamma_slots
48
+
49
+ self.q_proj = nn.Linear(config.n_embd, 2 * config.n_embd, bias=config.bias)
50
+ self.k_proj = nn.Linear(config.n_embd, 2 * config.n_embd, bias=config.bias)
51
+ self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
52
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
53
+
54
+ self.attn_dropout = nn.Dropout(config.dropout)
55
+ self.resid_dropout = nn.Dropout(config.dropout)
56
+
57
+ if config.use_theta_gating:
58
+ self.theta_offset = nn.Parameter(torch.randn(config.n_head) * 0.1)
59
+
60
+ self.scale = 1.0 / math.sqrt(config.head_dim)
61
+
62
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
63
+ B, T, C = x.shape
64
+
65
+ q_raw = self.q_proj(x)
66
+ k_raw = self.k_proj(x)
67
+ v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
68
+
69
+ q_amp, q_phase = q_raw.chunk(2, dim=-1)
70
+ k_amp, k_phase = k_raw.chunk(2, dim=-1)
71
+
72
+ q_amp = F.softplus(q_amp.view(B, T, self.n_head, self.head_dim).transpose(1, 2))
73
+ q_phase = q_phase.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
74
+ k_amp = F.softplus(k_amp.view(B, T, self.n_head, self.head_dim).transpose(1, 2))
75
+ k_phase = k_phase.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
76
+
77
+ # Optimized Interference
78
+ q_real = q_amp * torch.cos(q_phase)
79
+ q_imag = q_amp * torch.sin(q_phase)
80
+ k_real = k_amp * torch.cos(k_phase)
81
+ k_imag = k_amp * torch.sin(k_phase)
82
+
83
+ real_scores = torch.matmul(q_real, k_real.transpose(-1, -2))
84
+ imag_scores = torch.matmul(q_imag, k_imag.transpose(-1, -2))
85
+ scores = (real_scores + imag_scores) * self.scale
86
+
87
+ if self.config.use_theta_gating and T > self.gamma_slots:
88
+ positions = torch.arange(T, device=x.device, dtype=torch.float32)
89
+ cycle_ids = positions / self.gamma_slots
90
+ cycle_dist = cycle_ids.unsqueeze(0) - cycle_ids.unsqueeze(1)
91
+ theta_off = self.theta_offset.view(self.n_head, 1, 1)
92
+ theta_gate = torch.cos(theta_off * cycle_dist.unsqueeze(0))
93
+ scores = scores * theta_gate.unsqueeze(0)
94
+
95
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))
96
+ scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
97
+
98
+ if attention_mask is not None:
99
+ scores = scores + attention_mask
100
+
101
+ attn_weights = self.attn_dropout(F.softmax(scores, dim=-1))
102
+ out = self.resid_dropout(
103
+ self.out_proj(
104
+ torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(B, T, C)
105
+ )
106
+ )
107
+ return out
108
+
109
+ class MoireBlock(nn.Module):
110
+ def __init__(self, config: MoireGPTConfig):
111
+ super().__init__()
112
+ self.ln1 = nn.LayerNorm(config.n_embd)
113
+ self.attn = MoireAttention(config)
114
+ self.ln2 = nn.LayerNorm(config.n_embd)
115
+ self.mlp = nn.Sequential(
116
+ nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
117
+ nn.GELU(),
118
+ nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
119
+ nn.Dropout(config.dropout),
120
+ )
121
+
122
+ def forward(self, x, attention_mask=None):
123
+ x = x + self.attn(self.ln1(x), attention_mask)
124
+ x = x + self.mlp(self.ln2(x))
125
+ return x
126
+
127
+ class MoireGPT(nn.Module):
128
+ def __init__(self, config: MoireGPTConfig):
129
+ super().__init__()
130
+ self.config = config
131
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
132
+ self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)
133
+ self.drop = nn.Dropout(config.dropout)
134
+ self.blocks = nn.ModuleList([MoireBlock(config) for _ in range(config.n_layer)])
135
+ self.ln_f = nn.LayerNorm(config.n_embd)
136
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
137
+ self.tok_emb.weight = self.lm_head.weight
138
+ self.apply(self._init_weights)
139
+ n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
140
+ print(f"[Moiré GPT] {n_params/1e6:.1f}M parameters")
141
+
142
+ def _init_weights(self, module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
145
+ if module.bias is not None:
146
+ torch.nn.init.zeros_(module.bias)
147
+ elif isinstance(module, nn.Embedding):
148
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
149
+
150
+ def forward(self, input_ids, targets=None, attention_mask=None):
151
+ B, T = input_ids.shape
152
+ pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
153
+ x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
154
+ for block in self.blocks:
155
+ x = block(x, attention_mask)
156
+ logits = self.lm_head(self.ln_f(x))
157
+ loss = None
158
+ if targets is not None:
159
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)
160
+ return logits, loss
161
+
162
+ @torch.no_grad()
163
+ def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=40):
164
+ for _ in range(max_new_tokens):
165
+ idx_cond = input_ids[:, -self.config.max_seq_len:]
166
+ logits, _ = self(idx_cond)
167
+ logits = logits[:, -1, :] / temperature
168
+ if top_k is not None:
169
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
170
+ logits[logits < v[:, [-1]]] = float('-inf')
171
+ probs = F.softmax(logits, dim=-1)
172
+ next_token = torch.multinomial(probs, num_samples=1)
173
+ input_ids = torch.cat([input_ids, next_token], dim=1)
174
+ return input_ids
175
+
176
+
177
+ # ============================================================================
178
+ # 2. DATASET LOADERS (NEW CURRICULUMS ADDED)
179
+ # ============================================================================
180
+
181
+
182
+ def load_dataset_ultimate_mix(tokenizer, seq_len: int, max_chars_per=15_000_000):
183
+ """The Ultimate Curriculum: 1/3 Conversation, 1/3 Logic, 1/3 Facts"""
184
+ print("Loading Ultimate Mix (Guanaco + TinyStories + FineWeb)...")
185
+ from datasets import load_dataset
186
+
187
+ all_texts = []
188
+
189
+ # 1. Guanaco (Conversational / Persona)
190
+ print(" -> Fetching Guanaco...")
191
+ ds_g = load_dataset("timdettmers/openassistant-guanaco", split="train")
192
+ chars = 0
193
+ for row in ds_g:
194
+ text = row['text'].replace("### Human:", "User:").replace("### Assistant:", "Bot:")
195
+ all_texts.append(text)
196
+ chars += len(text)
197
+ if chars > max_chars_per: break
198
+
199
+ # 2. TinyStories (Grammar / Narrative Logic)
200
+ print(" -> Fetching TinyStories...")
201
+ ds_t = load_dataset("roneneldan/TinyStories", split="train")
202
+ chars = 0
203
+ for row in ds_t:
204
+ all_texts.append(row['text'])
205
+ chars += len(row['text'])
206
+ if chars > max_chars_per: break
207
+
208
+ # 3. FineWeb (Math / Science / Facts)
209
+ print(" -> Fetching FineWeb-Edu...")
210
+ ds_f = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)
211
+ chars = 0
212
+ for row in ds_f:
213
+ all_texts.append(row['text'])
214
+ chars += len(row['text'])
215
+ if chars > max_chars_per: break
216
+
217
+ # CRITICAL: Shuffle the documents so the wave-field learns everything simultaneously!
218
+ print(" -> Shuffling the multiverse...")
219
+ random.shuffle(all_texts)
220
+
221
+ # Join with an end-of-text token so thoughts don't bleed into each other
222
+ full_text = "\n\n<|endoftext|>\n\n".join(all_texts)
223
+ print(f"Total Mixed Corpus: {len(full_text):,} chars")
224
+
225
+ return _tokenize_text(full_text, tokenizer, seq_len)
226
+
227
+ def _tokenize_text(text: str, tokenizer, seq_len: int):
228
+ old_max = tokenizer.model_max_length
229
+ tokenizer.model_max_length = int(1e30)
230
+ chunk_size = 1_000_000
231
+ tokens = []
232
+ print("Tokenizing data...")
233
+ for i in range(0, len(text), chunk_size):
234
+ chunk = text[i:i + chunk_size]
235
+ tokens.extend(tokenizer.encode(chunk, add_special_tokens=False))
236
+ tokenizer.model_max_length = old_max
237
+ stride = seq_len // 2
238
+ sequences = []
239
+ for i in range(0, len(tokens) - seq_len, stride):
240
+ sequences.append(tokens[i:i + seq_len])
241
+ print(f"Created {len(sequences):,} training sequences.")
242
+ return torch.tensor(sequences, dtype=torch.long)
243
+
244
+ def load_dataset_guanaco(tokenizer, seq_len: int):
245
+ """High quality conversational flow."""
246
+ print("Loading OpenAssistant-Guanaco...")
247
+ from datasets import load_dataset
248
+ ds = load_dataset("timdettmers/openassistant-guanaco", split="train")
249
+ text_chunks = []
250
+ for row in ds:
251
+ text = row['text']
252
+ # Convert tags so the model builds on what it learned in Dolly
253
+ text = text.replace("### Human:", "User:")
254
+ text = text.replace("### Assistant:", "Bot:")
255
+ text_chunks.append(text)
256
+ full_text = "\n\n".join(text_chunks)
257
+ print(f"Total: {len(full_text):,} chars")
258
+ return _tokenize_text(full_text, tokenizer, seq_len)
259
+
260
+ def load_dataset_tinystories(tokenizer, seq_len: int, max_chars: int = 15_000_000):
261
+ """Logic, object permanence, and grammar."""
262
+ print("Loading TinyStories...")
263
+ from datasets import load_dataset
264
+ ds = load_dataset("roneneldan/TinyStories", split="train")
265
+ texts = []
266
+ current_chars = 0
267
+ for row in ds:
268
+ texts.append(row['text'])
269
+ current_chars += len(row['text'])
270
+ if current_chars > max_chars:
271
+ break
272
+ full_text = "\n\n<|endoftext|>\n\n".join(texts)
273
+ print(f"Total: {len(full_text):,} chars")
274
+ return _tokenize_text(full_text, tokenizer, seq_len)
275
+
276
+ def load_dataset_fineweb(tokenizer, seq_len: int, max_chars: int = 15_000_000):
277
+ """Hard factual data to separate phase-clumps."""
278
+ print("Loading FineWeb-Edu (Sample)...")
279
+ from datasets import load_dataset
280
+ ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)
281
+ texts = []
282
+ current_chars = 0
283
+ for row in ds:
284
+ texts.append(row['text'])
285
+ current_chars += len(row['text'])
286
+ if current_chars > max_chars:
287
+ break
288
+ full_text = "\n\n".join(texts)
289
+ print(f"Total: {len(full_text):,} chars")
290
+ return _tokenize_text(full_text, tokenizer, seq_len)
291
+
292
+ def load_dataset_mixed(tokenizer, seq_len: int):
293
+ # Keep the old mixed loader for legacy support
294
+ print("Loading mixed (Dolly + Wiki)...")
295
+ from datasets import load_dataset
296
+ all_text = []
297
+ ds = load_dataset("databricks/databricks-dolly-15k", split="train")
298
+ for row in ds:
299
+ user_text = row['instruction'].strip()
300
+ if row['context'].strip(): user_text += "\n" + row['context'].strip()
301
+ all_text.append(f"User: {user_text}\nBot: {row['response'].strip()}\n")
302
+ wiki = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
303
+ wiki_text = "\n".join([t for t in wiki['text'] if len(t.strip()) > 50])
304
+ all_text.append(wiki_text[:5_000_000])
305
+ return _tokenize_text("\n".join(all_text), tokenizer, seq_len)
306
+
307
+
308
+ # ============================================================================
309
+ # 3. TRAINING LOOP
310
+ # ============================================================================
311
+
312
+ def train(model, train_data, config, args):
313
+ device = args.device
314
+ model = model.to(device)
315
+ model.train()
316
+
317
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
318
+
319
+ # ADD THIS LINE: Initialize the AMP GradScaler
320
+ scaler = torch.amp.GradScaler('cuda')
321
+
322
+ n_batches = len(train_data) // args.batch_size
323
+ total_steps = args.epochs * n_batches
324
+ warmup_steps = min(200, total_steps // 10)
325
+
326
+ def lr_schedule(step):
327
+ if step < warmup_steps:
328
+ return step / warmup_steps
329
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
330
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
331
+
332
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
333
+
334
+ start_epoch = 0
335
+ global_step = 0
336
+ if args.resume:
337
+ if os.path.exists(args.resume):
338
+ print(f"Resuming weights from {args.resume}...")
339
+ checkpoint = torch.load(args.resume, map_location=device, weights_only=False)
340
+
341
+ # If we switch datasets, the optimizer momentum might be bad for the new data.
342
+ # We will load the weights, but NOT the optimizer/step state so it trains fresh
343
+ # on the new data curriculum!
344
+ if isinstance(checkpoint, dict) and 'model_state' in checkpoint:
345
+ model.load_state_dict(checkpoint['model_state'])
346
+ # ADD THIS: Load the optimizer momentum so it doesn't start from scratch!
347
+ if 'optimizer_state' in checkpoint:
348
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
349
+ print(" -> Optimizer momentum restored.")
350
+ else:
351
+ model.load_state_dict(checkpoint)
352
+
353
+ print(f" Weights loaded. Starting Phase 2 curriculum at Epoch 1.")
354
+ else:
355
+ print(f" Checkpoint {args.resume} not found, starting fresh.")
356
+
357
+ loss_history = []
358
+ t_start = time.time()
359
+
360
+ for epoch in range(start_epoch, args.epochs):
361
+ perm = torch.randperm(len(train_data))
362
+ train_data_shuffled = train_data[perm]
363
+
364
+ epoch_loss = 0.0
365
+ epoch_steps = 0
366
+
367
+ for i in range(0, len(train_data_shuffled) - args.batch_size, args.batch_size):
368
+ batch = train_data_shuffled[i:i + args.batch_size].to(device)
369
+
370
+ optimizer.zero_grad()
371
+
372
+ # 2. Wrap the forward pass in BFloat16 Autocast
373
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
374
+ logits, loss = model(batch[:, :-1], batch[:, 1:])
375
+
376
+ # 3. Scale the loss and backpropagate
377
+ scaler.scale(loss).backward()
378
+
379
+ # Unscale before clipping gradients
380
+ scaler.unscale_(optimizer)
381
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
382
+
383
+ # 4. Step optimizer and scaler
384
+ scaler.step(optimizer)
385
+ scaler.update()
386
+ scheduler.step()
387
+
388
+ loss_history.append(loss.item())
389
+ epoch_loss += loss.item()
390
+ epoch_steps += 1
391
+ global_step += 1
392
+
393
+ if global_step % args.log_every == 0:
394
+ elapsed = time.time() - t_start
395
+ print(f" Epoch {epoch+1}/{args.epochs} | Step {global_step:6d} | "
396
+ f"Loss: {loss.item():.4f} | LR: {scheduler.get_last_lr()[0]:.2e} | {elapsed:.0f}s")
397
+
398
+ avg_epoch = epoch_loss / max(epoch_steps, 1)
399
+ print(f"=== Epoch {epoch+1} Complete | Avg Loss: {avg_epoch:.4f} ===")
400
+
401
+ # Save checkpoint
402
+ if (epoch + 1) % args.save_every == 0 or (epoch + 1) == args.epochs:
403
+ ckpt_path = f'moire_phase2_ep{epoch+1}.pt'
404
+ torch.save({
405
+ 'model_state': model.state_dict(),
406
+ 'optimizer_state': optimizer.state_dict(),
407
+ 'config': {
408
+ 'n_layer': config.n_layer, 'n_head': config.n_head,
409
+ 'n_embd': config.n_embd, 'max_seq_len': config.max_seq_len,
410
+ }
411
+ }, ckpt_path)
412
+
413
+ weights_path = f'moire_phase2_weights_ep{epoch+1}.pt'
414
+ torch.save(model.state_dict(), weights_path)
415
+ print(f" Saved: {weights_path}")
416
+
417
+ torch.save(model.state_dict(), 'moire_phase2_weights_final.pt')
418
+ print(f"Training complete! Final weights saved.")
419
+
420
+
421
+ def main():
422
+ import argparse
423
+ parser = argparse.ArgumentParser()
424
+ parser.add_argument('--size', type=str, default='large', choices=['small', 'medium', 'large', 'xlarge'])
425
+ parser.add_argument('--epochs', type=int, default=10)
426
+ parser.add_argument('--batch_size', type=int, default=2)
427
+ parser.add_argument('--lr', type=float, default=1e-4) # Lower LR for finetuning
428
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
429
+ parser.add_argument('--log_every', type=int, default=100)
430
+ parser.add_argument('--save_every', type=int, default=2)
431
+ parser.add_argument('--dataset', type=str, default='ultimate',
432
+ choices=['mixed', 'guanaco', 'tinystories', 'fineweb', 'ultimate'])
433
+ parser.add_argument('--resume', type=str, default=None)
434
+ args = parser.parse_args()
435
+
436
+ # Model size presets
437
+ SIZE_PRESETS = {
438
+ 'small': {'n_layer': 4, 'n_head': 8, 'n_embd': 256},
439
+ 'medium': {'n_layer': 6, 'n_head': 8, 'n_embd': 512},
440
+ 'large': {'n_layer': 8, 'n_head': 8, 'n_embd': 768}, # 104.9M params
441
+ 'xlarge': {'n_layer': 12, 'n_head': 12, 'n_embd': 768}, # ~151M params (Tad bigger!)
442
+ }
443
+ p = SIZE_PRESETS[args.size]
444
+ config = MoireGPTConfig(n_layer=p['n_layer'], n_head=p['n_head'], n_embd=p['n_embd'])
445
+
446
+ from transformers import AutoTokenizer
447
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
448
+
449
+ LOADERS = {
450
+ 'mixed': load_dataset_mixed,
451
+ 'guanaco': load_dataset_guanaco,
452
+ 'tinystories': load_dataset_tinystories,
453
+ 'fineweb': load_dataset_fineweb,
454
+ 'ultimate': load_dataset_ultimate_mix,
455
+ }
456
+ train_data = LOADERS[args.dataset](tokenizer, config.max_seq_len)
457
+
458
+ model = MoireGPT(config)
459
+ train(model, train_data, config, args)
460
+
461
+ if __name__ == "__main__":
462
+ main()