ZhongRen11 commited on
Commit
c405313
·
verified ·
1 Parent(s): e948e6b

Upload 4 files

Browse files
VGT_Pro_Conv_Logic_Emergence/Logic-Orchestrator.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import random
5
+
6
+ # --- 1. 模型架构 (保持不变) ---
7
+ class VGTProModel(nn.Module):
8
+ def __init__(self, hidden_size=128):
9
+ super().__init__()
10
+ self.embedding = nn.Embedding(10, hidden_size)
11
+ self.reducer = nn.Conv1d(2 * hidden_size, hidden_size, kernel_size=1)
12
+ self.conv_process = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
13
+ self.output_proj = nn.Conv1d(hidden_size, 10, kernel_size=1)
14
+
15
+ def forward(self, x):
16
+ digits = x.shape[1] // 2
17
+ x_emb = self.embedding(x).transpose(1, 2)
18
+ h = torch.relu(self.reducer(torch.cat([x_emb[:, :, :digits], x_emb[:, :, digits:]], dim=1)))
19
+ h = F.pad(h, (0, 1))
20
+ for i in range(h.size(2) + 4):
21
+ dilation = 1 if i < 4 else (2 if i < 8 else 4)
22
+ h_residual = F.conv1d(h, self.conv_process.weight, self.conv_process.bias,
23
+ padding=dilation, dilation=dilation)
24
+ h = torch.relu(h_residual) + h
25
+ return self.output_proj(h).transpose(1, 2)
26
+
27
+ # --- 2. 基于 VGT-Pro 的逻辑运算单元 ---
28
+ class VGTLogicEngine:
29
+ def __init__(self, model_path, device):
30
+ self.device = device
31
+ checkpoint = torch.load(model_path, map_location=device)
32
+ self.model = VGTProModel().to(device)
33
+ self.model.load_state_dict(checkpoint['model_state_dict'])
34
+ self.model.eval()
35
+
36
+ def add(self, a, b):
37
+ """核心加法器:使用 VGT-Pro 计算 a + b"""
38
+ max_len = max(len(str(a)), len(str(b))) + 1
39
+ a_d = [int(d) for d in str(a).zfill(max_len)][::-1]
40
+ b_d = [int(d) for d in str(b).zfill(max_len)][::-1]
41
+ x_in = torch.tensor([a_d + b_d], dtype=torch.long).to(self.device)
42
+
43
+ with torch.no_grad():
44
+ logits = self.model(x_in)
45
+ preds = logits[0].argmax(dim=-1).cpu().tolist()
46
+
47
+ return sum(d * (10 ** i) for i, d in enumerate(preds))
48
+
49
+ def multiply(self, a, b):
50
+ """逻辑外推:通过重复累加实现乘法"""
51
+ res = 0
52
+ b_str = str(b)
53
+ for i, digit in enumerate(reversed(b_str)):
54
+ partial_sum = 0
55
+ # 这里的 A * single_digit 通过重复加法实现
56
+ for _ in range(int(digit)):
57
+ partial_sum = self.add(partial_sum, a)
58
+ # 处理位移 (左移 i 位)
59
+ shifted_partial = partial_sum * (10 ** i)
60
+ # 总和累加
61
+ res = self.add(res, shifted_partial)
62
+ return res
63
+
64
+ # --- 3. 测试脚本 ---
65
+ def verify_multiplication(model_path, samples=20, max_digits=4):
66
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+ engine = VGTLogicEngine(model_path, device)
68
+
69
+ print(f"✅ 加载逻辑引擎,开始【逻辑外推乘法】测试...")
70
+ print(f"{'题目':<25} | {'预期结果':<15} | {'模型结果':<15} | {'状态'}")
71
+ print("-" * 80)
72
+
73
+ correct = 0
74
+ for _ in range(samples):
75
+ # 乘法增长极快,我们测试 4 位数乘法 (结果可达 8 位)
76
+ a = random.randint(1, 10**max_digits - 1)
77
+ b = random.randint(1, 10**max_digits - 1)
78
+
79
+ true_prod = a * b
80
+ pred_prod = engine.multiply(a, b)
81
+
82
+ status = "✅" if true_prod == pred_prod else "❌"
83
+ if true_prod == pred_prod: correct += 1
84
+
85
+ print(f"{f'{a} x {b}':<25} | {true_prod:<15} | {pred_prod:<15} | {status}")
86
+
87
+ print("-" * 80)
88
+ print(f"测试完成!准确率: {(correct/samples)*100:.2f}%")
89
+
90
+ if __name__ == "__main__":
91
+ MODEL_PATH = "vgt_pro_logic_machine.pth"
92
+ # 测试 4 位数乘法,这会涉及几十次高精度加法调用
93
+ verify_multiplication(MODEL_PATH, samples=20, max_digits=4)
VGT_Pro_Conv_Logic_Emergence/vgt_vs_base_benchmark.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import random
6
+ import pandas as pd # 用于生成清晰的对比表格
7
+
8
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ # --- 实验超参数 (严格对齐论文) ---
11
+ MAX_DIGITS = 6 # 训练位数
12
+ HIDDEN_SIZE = 128
13
+ LR = 3e-4
14
+ TRAIN_STEPS = 10000
15
+ BATCH_SIZE = 64
16
+
17
+ # --- 1. 定义模型架构 ---
18
+ class AdditionModel(nn.Module):
19
+ def __init__(self, hidden_size):
20
+ super().__init__()
21
+ self.embedding = nn.Embedding(10, hidden_size)
22
+ self.reducer = nn.Conv1d(2 * hidden_size, hidden_size, kernel_size=1)
23
+ self.conv_process = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
24
+ self.output_proj = nn.Conv1d(hidden_size, 10, kernel_size=1)
25
+
26
+ def forward(self, x):
27
+ B, L = x.shape
28
+ digits = L // 2
29
+ x_emb = self.embedding(x).transpose(1, 2)
30
+ a_part = x_emb[:, :, :digits]; b_part = x_emb[:, :, digits:]
31
+ h = torch.relu(self.reducer(torch.cat([a_part, b_part], dim=1)))
32
+ h = nn.functional.pad(h, (0, 1))
33
+ for _ in range(h.size(2)): # 递归迭代
34
+ h = torch.relu(self.conv_process(h)) + h
35
+ return self.output_proj(h).transpose(1, 2), h
36
+
37
+ # --- 2. 训练函数 (通过 vgt_mode 参数控制变量) ---
38
+ def train_model(vgt_mode=True):
39
+ model = AdditionModel(HIDDEN_SIZE).to(DEVICE)
40
+ optimizer = optim.AdamW(model.parameters(), lr=LR)
41
+
42
+ mode_name = "VGT (With L2 Pressure)" if vgt_mode else "Base (No Constraint)"
43
+ print(f"\n--- Training {mode_name} ---")
44
+
45
+ for step in range(TRAIN_STEPS + 1):
46
+ model.train()
47
+ x, y = generate_batch(BATCH_SIZE)
48
+ optimizer.zero_grad()
49
+ logits, h_states = model(x)
50
+
51
+ # 基础交叉熵损失
52
+ loss_ce = nn.functional.cross_entropy(logits.reshape(-1, 10), y.reshape(-1))
53
+
54
+ if vgt_mode:
55
+ # 几何约束:Alpha 动态衰减 [cite: 22, 23]
56
+ alpha = max(1.0, 50.0 - (50.0 - 1.0) * (step / TRAIN_STEPS))
57
+ # 论文核心:隐藏状态的 L2 惩罚 [cite: 7, 22]
58
+ loss_l2 = torch.norm(h_states, p=2, dim=1).mean()
59
+ loss = loss_ce + alpha * 1e-4 * loss_l2
60
+ else:
61
+ loss = loss_ce # 普通训练模式
62
+
63
+ loss.backward()
64
+ optimizer.step()
65
+
66
+ if step % 2000 == 0:
67
+ std = model.output_proj.weight.std().item()
68
+ print(f"Step {step:5d} | CE Loss: {loss_ce.item():.4f} | Head Weight Std: {std:.4f}")
69
+
70
+ return model
71
+
72
+ # --- 3. 生成数据与评估 ---
73
+ def generate_batch(batch_size, digits=MAX_DIGITS):
74
+ x, y = [], []
75
+ for _ in range(batch_size):
76
+ a = random.randint(0, 10**digits - 1); b = random.randint(0, 10**digits - 1)
77
+ c = a + b
78
+ a_d = [int(d) for d in str(a).zfill(digits)][::-1]
79
+ b_d = [int(d) for d in str(b).zfill(digits)][::-1]
80
+ c_d = [int(d) for d in str(c).zfill(digits + 1)][::-1]
81
+ x.append(a_d + b_d); y.append(c_d)
82
+ return torch.tensor(x, dtype=torch.long).to(DEVICE), torch.tensor(y, dtype=torch.long).to(DEVICE)
83
+
84
+ def evaluate(model, digits):
85
+ model.eval()
86
+ correct = 0
87
+ num_tests = 500
88
+ with torch.no_grad():
89
+ for _ in range(num_tests):
90
+ a = random.randint(0, 10**digits - 1); b = random.randint(0, 10**digits - 1)
91
+ true_c = a + b
92
+ a_d = [int(d) for d in str(a).zfill(digits)][::-1]
93
+ b_d = [int(d) for d in str(b).zfill(digits)][::-1]
94
+ x_in = torch.tensor([a_d + b_d], dtype=torch.long).to(DEVICE)
95
+ logits, _ = model(x_in)
96
+ pred_digits = logits[0].argmax(dim=-1).cpu().tolist()
97
+ pred_c = sum(d * (10 ** i) for i, d in enumerate(pred_digits))
98
+ if pred_c == true_c: correct += 1
99
+ return (correct / num_tests) * 100
100
+
101
+ # --- 4. 执行对比实验 ---
102
+ if __name__ == "__main__":
103
+ # 训练两个模型
104
+ base_model = train_model(vgt_mode=False) # 对照组
105
+ vgt_model = train_model(vgt_mode=True) # 实验组(几何压力)
106
+
107
+ # 测试泛化能力
108
+ results = []
109
+ for d in [6, 12, 20]:
110
+ acc_base = evaluate(base_model, d)
111
+ acc_vgt = evaluate(vgt_model, d)
112
+ results.append({"Digits": d, "Base Accuracy (%)": acc_base, "VGT Accuracy (%)": acc_vgt})
113
+
114
+ # 输出对比表格
115
+ df = pd.DataFrame(results)
116
+ print("\n" + "="*50)
117
+ print("FINAL COMPARISON: BASE vs VGT (GEOMETRIC PRESSURE)")
118
+ print("="*50)
119
+ print(df.to_string(index=False))
120
+ print("="*50)