Spaces:
Paused
Paused
| from abc import ABC | |
| from src.core.engine import ModelEngine | |
| class BaseAgent(ABC): | |
| def __init__(self, engine: ModelEngine, role: str): | |
| self.engine = engine | |
| self.role = role | |
| def generate(self, prompt: str, **kwargs): | |
| asset = self.engine.load_model(self.role) | |
| model, tokenizer = asset['model'], asset['tokenizer'] | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| gen_kwargs = self.engine.config.generation.copy() | |
| gen_kwargs.update(kwargs) | |
| outputs = model.generate( | |
| **inputs, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| **gen_kwargs | |
| ) | |
| return tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() | |