EAM-100M-Agentic-Kernel / training /agent_lightning_loop.py
saur7764's picture
Upload folder using huggingface_hub
c61a185 verified
import torch
import torch.nn as nn
from torch.optim import AdamW
class AgentLightningLoop:
"""
Implements Agent Lightning style training.
Supports both Supervised Fine-Tuning (SFT) and basic RL.
"""
def __init__(self, model, lr=1e-4):
self.model = model
self.optimizer = AdamW(model.parameters(), lr=lr)
self.criterion = nn.CrossEntropyLoss()
def sft_step(self, input_ids, targets):
"""Standard Supervised Fine-Tuning step."""
self.model.train()
logits, loss = self.model(input_ids, targets=targets)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def rl_optimize(self, log_probs, rewards):
"""Simple Policy Gradient (RL) step based on agent rewards."""
# log_probs: Tensor of log probabilities of the actions taken
# rewards: Tensor of rewards received
loss = -(log_probs * rewards).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def run_training_demo(model, tokenizer):
trainer = AgentLightningLoop(model)
# Mock training data: A simple goal -> thought -> action sequence
text = "<|goal|> Find files <|thought|> I should scan <|discover|>"
tokens = torch.tensor([tokenizer.encode(text)])
# Simple SFT: Predicting the next token
input_ids = tokens[:, :-1]
targets = tokens[:, 1:]
loss = trainer.sft_step(input_ids, targets)
print(f"Training Step Complete. Loss: {loss:.4f}")