webshop-hsl-seed123 / hsl_code_snapshot /trainer_model_agent.py
heendung's picture
Upload folder using huggingface_hub
d1c897a verified
import torch
class SingleSampleAgent:
"""Agent that generates text one sample at a time without padding (FSDP compatible)."""
def __init__(self, model, tokenizer, **gen_kwargs):
self.model = model
self.tokenizer = tokenizer
self.gen_kwargs = gen_kwargs
if getattr(self.tokenizer, "pad_token", None) is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(self, messages):
"""Generate a response for one or many conversations."""
if messages and isinstance(messages[0], dict):
conversations = [messages]
single = True
else:
conversations = messages
single = False
outputs = []
for conversation in conversations:
prompt = self.tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
"""
try:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if isinstance(self.model, FSDP):
with FSDP.summon_full_params(self.model, recurse=False):
output_ids = self.model.generate(**inputs, **self.gen_kwargs)
else:
output_ids = self.model.generate(**inputs, **self.gen_kwargs)
except ImportError:
"""
output_ids = self.model.generate(**inputs, **self.gen_kwargs)
output = self.tokenizer.decode(
output_ids[0, inputs["input_ids"].size(1):], skip_special_tokens=True
)
outputs.append(output)
return outputs[0] if single else outputs
class TrainerModelAgent:
"""Simple agent that generates text using a model and tokenizer."""
def __init__(self, model, tokenizer, **gen_kwargs):
self.model = model
self.tokenizer = tokenizer
self.gen_kwargs = gen_kwargs
if getattr(self.tokenizer, "pad_token", None) is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(self, messages):
"""Generate a response for one or many conversations."""
if messages and isinstance(messages[0], dict):
conversations = [messages]
single = True
else:
conversations = messages
single = False
prompts = [
self.tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=True)
for c in conversations
]
self.tokenizer.padding_side = "left"
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device)
output_ids = self.model.generate(**inputs, **self.gen_kwargs)
#if isinstance(inputs["input_ids"], torch.Tensor):
# lens = (inputs["attention_mask"] > 0).sum(dim=1).tolist()
#else:
# lens = [len(ids) for ids in inputs["input_ids"]]
outputs = self.tokenizer.batch_decode(
output_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True
)
self.tokenizer.padding_side = "right"
return outputs[0] if single else outputs