Spaces:
Paused
Paused
File size: 833 Bytes
2e91995 1804a7a 2e91995 1804a7a 2e91995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from abc import ABC
from src.core.engine import ModelEngine
class BaseAgent(ABC):
def __init__(self, engine: ModelEngine, role: str):
self.engine = engine
self.role = role
def generate(self, prompt: str, **kwargs):
asset = self.engine.load_model(self.role)
model, tokenizer = asset['model'], asset['tokenizer']
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
gen_kwargs = self.engine.config.generation.copy()
gen_kwargs.update(kwargs)
outputs = model.generate(
**inputs,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
**gen_kwargs
)
return tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
|