| import torch |
|
|
| class SingleSampleAgent: |
| """Agent that generates text one sample at a time without padding (FSDP compatible).""" |
|
|
| def __init__(self, model, tokenizer, **gen_kwargs): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.gen_kwargs = gen_kwargs |
| if getattr(self.tokenizer, "pad_token", None) is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def __call__(self, messages): |
| """Generate a response for one or many conversations.""" |
|
|
| if messages and isinstance(messages[0], dict): |
| conversations = [messages] |
| single = True |
| else: |
| conversations = messages |
| single = False |
|
|
| outputs = [] |
| for conversation in conversations: |
| prompt = self.tokenizer.apply_chat_template( |
| conversation, tokenize=False, add_generation_prompt=True |
| ) |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| |
| with torch.no_grad(): |
| """ |
| try: |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| if isinstance(self.model, FSDP): |
| with FSDP.summon_full_params(self.model, recurse=False): |
| output_ids = self.model.generate(**inputs, **self.gen_kwargs) |
| else: |
| output_ids = self.model.generate(**inputs, **self.gen_kwargs) |
| except ImportError: |
| """ |
| output_ids = self.model.generate(**inputs, **self.gen_kwargs) |
| |
| output = self.tokenizer.decode( |
| output_ids[0, inputs["input_ids"].size(1):], skip_special_tokens=True |
| ) |
| outputs.append(output) |
| |
| return outputs[0] if single else outputs |
|
|
|
|
| class TrainerModelAgent: |
| """Simple agent that generates text using a model and tokenizer.""" |
|
|
| def __init__(self, model, tokenizer, **gen_kwargs): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.gen_kwargs = gen_kwargs |
| if getattr(self.tokenizer, "pad_token", None) is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def __call__(self, messages): |
| """Generate a response for one or many conversations.""" |
|
|
| if messages and isinstance(messages[0], dict): |
| conversations = [messages] |
| single = True |
| else: |
| conversations = messages |
| single = False |
|
|
| prompts = [ |
| self.tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=True) |
| for c in conversations |
| ] |
| self.tokenizer.padding_side = "left" |
| inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device) |
| output_ids = self.model.generate(**inputs, **self.gen_kwargs) |
| |
| |
| |
| |
| |
|
|
| outputs = self.tokenizer.batch_decode( |
| output_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True |
| ) |
| |
| self.tokenizer.padding_side = "right" |
| return outputs[0] if single else outputs |