import torch import torch.nn as nn from torch.optim import AdamW class AgentLightningLoop: """ Implements Agent Lightning style training. Supports both Supervised Fine-Tuning (SFT) and basic RL. """ def __init__(self, model, lr=1e-4): self.model = model self.optimizer = AdamW(model.parameters(), lr=lr) self.criterion = nn.CrossEntropyLoss() def sft_step(self, input_ids, targets): """Standard Supervised Fine-Tuning step.""" self.model.train() logits, loss = self.model(input_ids, targets=targets) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def rl_optimize(self, log_probs, rewards): """Simple Policy Gradient (RL) step based on agent rewards.""" # log_probs: Tensor of log probabilities of the actions taken # rewards: Tensor of rewards received loss = -(log_probs * rewards).mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def run_training_demo(model, tokenizer): trainer = AgentLightningLoop(model) # Mock training data: A simple goal -> thought -> action sequence text = "<|goal|> Find files <|thought|> I should scan <|discover|>" tokens = torch.tensor([tokenizer.encode(text)]) # Simple SFT: Predicting the next token input_ids = tokens[:, :-1] targets = tokens[:, 1:] loss = trainer.sft_step(input_ids, targets) print(f"Training Step Complete. Loss: {loss:.4f}")