swarm-chat / tools /train_decoder.py
lk080424
虫巢-200M训练部署: npz+json替代pkl, 三区循环训练, 4454QA数据
358ab64
Raw
History Blame Contribute Delete
2.6 kB
#!/usr/bin/env python3
"""Decoder训练脚本 — 用QA数据训练语言解码层"""
import sys, json, os, time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from core.brain import Brain
def main():
print("=" * 50)
print("Decoder训练开始")
print("=" * 50)
# 1. 初始化Brain
print("\n1. 初始化Brain...")
brain = Brain()
print(f" 解码层vocab: {brain._decoder.vocab_size if brain._decoder else 'None'}")
# 2. 加载QA数据
qa_path = '/home/admin/swarm_product/data/qa_training.json'
with open(qa_path) as f:
qa_data = json.load(f)
print(f" QA数据: {len(qa_data)}条")
# 3. 训练前测试
print("\n2. 训练前生成测试:")
from core.semantic_encoder import get_encoder
encoder = get_encoder()
test_q = "什么是人工智能"
vec = encoder.encode(test_q)
if vec is not None:
result = brain.forward(vec)
motor = result.get('motor', np.zeros(300))
text = brain._decoder.decode(motor[:300], temperature=0.8)
print(f" Q: {test_q}")
print(f" 生成: {text}")
# 4. 训练 — 分多轮,每轮后测试
epochs_list = [5, 10, 20, 30, 50]
for target_epochs in epochs_list:
print(f"\n3. 训练到 {target_epochs} epochs...")
t0 = time.time()
result = brain.train_decoder(qa_data, epochs=1, lr=0.01)
elapsed = time.time() - t0
print(f" loss={result.get('final_loss', 0):.4f}, "
f"n={result.get('n_trained', 0)}, "
f"耗时={elapsed:.1f}s")
# 测试生成
test_qs = ["什么是人工智能", "你好", "Python是什么"]
print(f" 生成测试:")
for q in test_qs:
vec = encoder.encode(q)
if vec is not None:
r = brain.forward(vec)
motor = r.get('motor', np.zeros(300))
text = brain._decoder.decode(motor[:300], temperature=0.5)
print(f" {q}{text}")
# 5. 最终保存
print("\n4. 训练完成,权重已自动保存")
# 6. 综合测试
print("\n5. 综合测试:")
test_cases = [
"什么是人工智能",
"你好",
"Python是什么",
"怎么学习编程",
"什么是虫群"
]
for q in test_cases:
reply = brain.chat(q)
print(f" Q: {q}")
print(f" A: {reply.get('text', '')[:80]}")
print(f" mode: {reply.get('mode', '')}")
import numpy as np
if __name__ == '__main__':
main()