Spaces:
Sleeping
Sleeping
| import torch | |
| from model.nano_gpt import AgentGPT, Config | |
| from training.agent_lightning_loop import run_training_demo | |
| import tiktoken | |
| def main(): | |
| # 1. Use a slightly smaller config for CPU training demo | |
| config = Config() | |
| config.n_layer = 4 # Reduced from 10 | |
| config.n_head = 8 # Balanced for 256 embd | |
| config.n_embd = 256 # Reduced from 640 | |
| print(f"Initializing Mini-EAM for CPU training...") | |
| model = AgentGPT(config) | |
| # 2. Tokenizer | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| class TiktokenWrapper: | |
| def __init__(self, e): self.e = e | |
| def encode(self, t): return self.e.encode(t) | |
| def decode(self, i): return self.e.decode(i) | |
| tokenizer = TiktokenWrapper(enc) | |
| # 3. Run training demo with SIMULA-distilled data | |
| print("Loading SIMULA-distilled reasoning traces...") | |
| import json | |
| try: | |
| with open("training/distilled_reasoning_v1.json", "r") as f: | |
| distilled_data = json.load(f) | |
| print(f"Loaded {len(distilled_data)} high-quality traces from Qwen3.5-0.8B.") | |
| from training.agent_lightning_loop import AgentLightningLoop | |
| trainer = AgentLightningLoop(model) | |
| for trace in distilled_data: | |
| text = f"{trace['input']} {trace['reasoning']}" | |
| encoded = tokenizer.encode(text) | |
| # Clip to block_size + 1 (for targets) | |
| if len(encoded) > config.block_size + 1: | |
| encoded = encoded[:config.block_size + 1] | |
| tokens = torch.tensor([encoded]) | |
| input_ids = tokens[:, :-1] | |
| targets = tokens[:, 1:] | |
| loss = trainer.sft_step(input_ids, targets) | |
| print(f"Distillation Step Complete. Loss: {loss:.4f}") | |
| except FileNotFoundError: | |
| print("Distilled data not found. Running basic demo...") | |
| from training.agent_lightning_loop import run_training_demo | |
| run_training_demo(model, tokenizer) | |
| # 4. Load RRM-RL90K Distilled Data | |
| print("Loading RRM-RL90K reasoning trajectories...") | |
| try: | |
| with open("training/rrm_rl_trajectories.json", "r") as f: | |
| rrm_data = json.load(f) | |
| print(f"Loaded {len(rrm_data)} high-fidelity RL trajectories.") | |
| for trace in rrm_data: | |
| text = f"{trace['input']} {trace['reasoning']}" | |
| encoded = tokenizer.encode(text) | |
| if len(encoded) > config.block_size + 1: | |
| encoded = encoded[:config.block_size + 1] | |
| tokens = torch.tensor([encoded]) | |
| input_ids = tokens[:, :-1] | |
| targets = tokens[:, 1:] | |
| loss = trainer.sft_step(input_ids, targets) | |
| # print(f"RL Step Complete. Loss: {loss:.4f}") # Too much noise | |
| print("RL Trajectory Training Complete.") | |
| except FileNotFoundError: | |
| print("RRM data not found.") | |
| print("Training Script Verified.") | |
| if __name__ == "__main__": | |
| main() | |