saur7764's picture
Upload folder using huggingface_hub
c61a185 verified
import torch
from model.nano_gpt import AgentGPT, Config
from training.agent_lightning_loop import run_training_demo
import tiktoken
def main():
# 1. Use a slightly smaller config for CPU training demo
config = Config()
config.n_layer = 4 # Reduced from 10
config.n_head = 8 # Balanced for 256 embd
config.n_embd = 256 # Reduced from 640
print(f"Initializing Mini-EAM for CPU training...")
model = AgentGPT(config)
# 2. Tokenizer
enc = tiktoken.get_encoding("cl100k_base")
class TiktokenWrapper:
def __init__(self, e): self.e = e
def encode(self, t): return self.e.encode(t)
def decode(self, i): return self.e.decode(i)
tokenizer = TiktokenWrapper(enc)
# 3. Run training demo with SIMULA-distilled data
print("Loading SIMULA-distilled reasoning traces...")
import json
try:
with open("training/distilled_reasoning_v1.json", "r") as f:
distilled_data = json.load(f)
print(f"Loaded {len(distilled_data)} high-quality traces from Qwen3.5-0.8B.")
from training.agent_lightning_loop import AgentLightningLoop
trainer = AgentLightningLoop(model)
for trace in distilled_data:
text = f"{trace['input']} {trace['reasoning']}"
encoded = tokenizer.encode(text)
# Clip to block_size + 1 (for targets)
if len(encoded) > config.block_size + 1:
encoded = encoded[:config.block_size + 1]
tokens = torch.tensor([encoded])
input_ids = tokens[:, :-1]
targets = tokens[:, 1:]
loss = trainer.sft_step(input_ids, targets)
print(f"Distillation Step Complete. Loss: {loss:.4f}")
except FileNotFoundError:
print("Distilled data not found. Running basic demo...")
from training.agent_lightning_loop import run_training_demo
run_training_demo(model, tokenizer)
# 4. Load RRM-RL90K Distilled Data
print("Loading RRM-RL90K reasoning trajectories...")
try:
with open("training/rrm_rl_trajectories.json", "r") as f:
rrm_data = json.load(f)
print(f"Loaded {len(rrm_data)} high-fidelity RL trajectories.")
for trace in rrm_data:
text = f"{trace['input']} {trace['reasoning']}"
encoded = tokenizer.encode(text)
if len(encoded) > config.block_size + 1:
encoded = encoded[:config.block_size + 1]
tokens = torch.tensor([encoded])
input_ids = tokens[:, :-1]
targets = tokens[:, 1:]
loss = trainer.sft_step(input_ids, targets)
# print(f"RL Step Complete. Loss: {loss:.4f}") # Too much noise
print("RL Trajectory Training Complete.")
except FileNotFoundError:
print("RRM data not found.")
print("Training Script Verified.")
if __name__ == "__main__":
main()