Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Decoder训练脚本 — 用QA数据训练语言解码层""" | |
| import sys, json, os, time | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) | |
| from core.brain import Brain | |
| def main(): | |
| print("=" * 50) | |
| print("Decoder训练开始") | |
| print("=" * 50) | |
| # 1. 初始化Brain | |
| print("\n1. 初始化Brain...") | |
| brain = Brain() | |
| print(f" 解码层vocab: {brain._decoder.vocab_size if brain._decoder else 'None'}") | |
| # 2. 加载QA数据 | |
| qa_path = '/home/admin/swarm_product/data/qa_training.json' | |
| with open(qa_path) as f: | |
| qa_data = json.load(f) | |
| print(f" QA数据: {len(qa_data)}条") | |
| # 3. 训练前测试 | |
| print("\n2. 训练前生成测试:") | |
| from core.semantic_encoder import get_encoder | |
| encoder = get_encoder() | |
| test_q = "什么是人工智能" | |
| vec = encoder.encode(test_q) | |
| if vec is not None: | |
| result = brain.forward(vec) | |
| motor = result.get('motor', np.zeros(300)) | |
| text = brain._decoder.decode(motor[:300], temperature=0.8) | |
| print(f" Q: {test_q}") | |
| print(f" 生成: {text}") | |
| # 4. 训练 — 分多轮,每轮后测试 | |
| epochs_list = [5, 10, 20, 30, 50] | |
| for target_epochs in epochs_list: | |
| print(f"\n3. 训练到 {target_epochs} epochs...") | |
| t0 = time.time() | |
| result = brain.train_decoder(qa_data, epochs=1, lr=0.01) | |
| elapsed = time.time() - t0 | |
| print(f" loss={result.get('final_loss', 0):.4f}, " | |
| f"n={result.get('n_trained', 0)}, " | |
| f"耗时={elapsed:.1f}s") | |
| # 测试生成 | |
| test_qs = ["什么是人工智能", "你好", "Python是什么"] | |
| print(f" 生成测试:") | |
| for q in test_qs: | |
| vec = encoder.encode(q) | |
| if vec is not None: | |
| r = brain.forward(vec) | |
| motor = r.get('motor', np.zeros(300)) | |
| text = brain._decoder.decode(motor[:300], temperature=0.5) | |
| print(f" {q} → {text}") | |
| # 5. 最终保存 | |
| print("\n4. 训练完成,权重已自动保存") | |
| # 6. 综合测试 | |
| print("\n5. 综合测试:") | |
| test_cases = [ | |
| "什么是人工智能", | |
| "你好", | |
| "Python是什么", | |
| "怎么学习编程", | |
| "什么是虫群" | |
| ] | |
| for q in test_cases: | |
| reply = brain.chat(q) | |
| print(f" Q: {q}") | |
| print(f" A: {reply.get('text', '')[:80]}") | |
| print(f" mode: {reply.get('mode', '')}") | |
| import numpy as np | |
| if __name__ == '__main__': | |
| main() | |