Spaces:
Sleeping
Sleeping
File size: 1,619 Bytes
c61a185 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch
import torch.nn as nn
from torch.optim import AdamW
class AgentLightningLoop:
"""
Implements Agent Lightning style training.
Supports both Supervised Fine-Tuning (SFT) and basic RL.
"""
def __init__(self, model, lr=1e-4):
self.model = model
self.optimizer = AdamW(model.parameters(), lr=lr)
self.criterion = nn.CrossEntropyLoss()
def sft_step(self, input_ids, targets):
"""Standard Supervised Fine-Tuning step."""
self.model.train()
logits, loss = self.model(input_ids, targets=targets)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def rl_optimize(self, log_probs, rewards):
"""Simple Policy Gradient (RL) step based on agent rewards."""
# log_probs: Tensor of log probabilities of the actions taken
# rewards: Tensor of rewards received
loss = -(log_probs * rewards).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def run_training_demo(model, tokenizer):
trainer = AgentLightningLoop(model)
# Mock training data: A simple goal -> thought -> action sequence
text = "<|goal|> Find files <|thought|> I should scan <|discover|>"
tokens = torch.tensor([tokenizer.encode(text)])
# Simple SFT: Predicting the next token
input_ids = tokens[:, :-1]
targets = tokens[:, 1:]
loss = trainer.sft_step(input_ids, targets)
print(f"Training Step Complete. Loss: {loss:.4f}")
|