ZhongRen11's picture
Upload 3 files
8263663 verified
# -*- coding: utf-8 -*-
"""
Model: VGT-Conv (Vector-Gravity Transformer logic in Conv1D)
Theory: "Logic is squeezed out by geometric pressure" (Wang, 2024)
Task: 6-digit addition training -> 12-digit zero-shot generalization
"""
import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
import json
# --- 环境设置 ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SAVE_DIR = "VGT_Conv_Logic_Emergence"
os.makedirs(SAVE_DIR, exist_ok=True)
# --- 论文核心超参数 ---
MAX_DIGITS = 6 # 训练位数
HIDDEN_SIZE = 128 # 论文 d_model [cite: 18]
BATCH_SIZE = 64 # 论文 batch [cite: 20]
LR = 3e-4 # 论文 lr [cite: 20]
TRAIN_STEPS = 10000 # 压力衰减周期
def generate_batch(batch_size, max_digits=MAX_DIGITS):
x, y = [], []
for _ in range(batch_size):
a = random.randint(0, 10**max_digits - 1)
b = random.randint(0, 10**max_digits - 1)
c = a + b
a_d = [int(d) for d in str(a).zfill(max_digits)][::-1]
b_d = [int(d) for d in str(b).zfill(max_digits)][::-1]
c_d = [int(d) for d in str(c).zfill(max_digits + 1)][::-1]
x.append(a_d + b_d)
y.append(c_d)
return torch.tensor(x, dtype=torch.long).to(DEVICE), \
torch.tensor(y, dtype=torch.long).to(DEVICE)
class VGTConv(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.embedding = nn.Embedding(10, hidden_size)
# 移除 Norm 层以保持几何压力 [cite: 16]
self.reducer = nn.Conv1d(2 * hidden_size, hidden_size, kernel_size=1)
self.conv_process = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
self.output_proj = nn.Conv1d(hidden_size, 10, kernel_size=1)
def forward(self, x):
B, L = x.shape
digits = L // 2
x_emb = self.embedding(x).transpose(1, 2)
a_part = x_emb[:, :, :digits]
b_part = x_emb[:, :, digits:]
h = torch.relu(self.reducer(torch.cat([a_part, b_part], dim=1)))
# 扩展一位用于进位预测
h = nn.functional.pad(h, (0, 1))
# 递归迭代模拟长程进位传播
num_iters = h.size(2)
for _ in range(num_iters):
h = torch.relu(self.conv_process(h)) + h
logits = self.output_proj(h).transpose(1, 2)
return logits, h
def train_and_export():
model = VGTConv(HIDDEN_SIZE).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR)
history = []
print(f"Starting VGT training on {DEVICE}...")
for step in range(TRAIN_STEPS + 1):
model.train()
x, y = generate_batch(BATCH_SIZE)
# 几何脚手架:Alpha 从 50 线性衰减到 1
alpha = max(1.0, 50.0 - (50.0 - 1.0) * (step / TRAIN_STEPS))
optimizer.zero_grad()
logits, h_states = model(x)
# 1. 任务损失 (Cross Entropy)
loss_ce = nn.functional.cross_entropy(logits.reshape(-1, 10), y.reshape(-1))
# 2. 几何约束损失 (L2 Norm Penalty)
# 迫使模型在压缩流形上构建逻辑 [cite: 118]
loss_l2 = torch.norm(h_states, p=2, dim=1).mean()
total_loss = loss_ce + alpha * 1e-4 * loss_l2
total_loss.backward()
optimizer.step()
if step % 500 == 0:
# 监控权值极化:逻辑生成的宏观标志 [cite: 122]
std = model.output_proj.weight.std().item()
history.append({"step": step, "ce": loss_ce.item(), "weight_std": std})
print(f"Step {step:5d} | CE: {loss_ce.item():.4f} | Weight Std: {std:.4f} | Alpha: {alpha:.1f}")
# --- 保存模型与配置 ---
print("\nExporting model to Hugging Face format...")
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "pytorch_model.bin"))
config = {
"architecture": "VGT-Conv1D",
"hidden_size": HIDDEN_SIZE,
"train_max_digits": MAX_DIGITS,
"final_weight_polarization_std": model.output_proj.weight.std().item(),
"theory_reference": "The Geometric Origin of Logic (Wang, 2024)"
}
with open(os.path.join(SAVE_DIR, "config.json"), "w") as f:
json.dump(config, f, indent=4)
return model
def run_final_test(model):
model.eval()
print("\n--- Final Generalization Test ---")
for digits in [6, 12, 20]:
correct = 0
num_tests = 500
with torch.no_grad():
for _ in range(num_tests):
a = random.randint(0, 10**digits - 1)
b = random.randint(0, 10**digits - 1)
true_c = a + b
a_d = [int(d) for d in str(a).zfill(digits)][::-1]
b_d = [int(d) for d in str(b).zfill(digits)][::-1]
x_in = torch.tensor([a_d + b_d], dtype=torch.long).to(DEVICE)
logits, _ = model(x_in)
pred_digits = logits[0].argmax(dim=-1).cpu().tolist()
pred_c = sum(d * (10 ** i) for i, d in enumerate(pred_digits))
if pred_c == true_c: correct += 1
print(f"Accuracy on {digits:2d}-digit: {correct/num_tests*100:6.2f}%")
if __name__ == "__main__":
vgt_model = train_and_export()
run_final_test(vgt_model)