File size: 1,940 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from typing import Dict
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from larm.data.utils.tensor_utils import TensorHelper, TensorConfig
@dataclass
class InteractionConfig:
max_turns: int
max_start_length: int
max_prompt_length: int
max_response_length: int
max_obs_length: int
do_sample: bool
temperature: float
@dataclass
class InteractionDataProto:
batch: Dict = field(default_factory=dict)
no_tensor_batch: Dict = field(default_factory=dict)
class InteractionManager(ABC):
def __init__(
self,
tokenizer,
actor_rollout_wg,
config: InteractionConfig,
is_validation: bool = False,
):
tokenizer = tokenizer.tokenizer # fix: processor --> tokenizer
self.tokenizer = tokenizer
self.tokenizer.padding_side = "left"
self.actor_rollout_wg = actor_rollout_wg
self.config = config
self.is_validation = is_validation
assert tokenizer.pad_token_id is not None
self.tensor_fn = TensorHelper(TensorConfig(
pad_token_id=tokenizer.pad_token_id,
max_prompt_length=config.max_prompt_length,
max_obs_length=config.max_obs_length,
max_start_length=config.max_start_length
))
# Prefer chat end token (<|im_end|>) as EOS for decoding termination if available
try:
im_end_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False)
if isinstance(im_end_ids, list) and len(im_end_ids) == 1:
self.chat_eos_token_id = im_end_ids[0]
else:
self.chat_eos_token_id = self.tokenizer.eos_token_id
except Exception:
self.chat_eos_token_id = self.tokenizer.eos_token_id
@abstractmethod
def run_agent_loop(self, gen_batch: InteractionDataProto) -> InteractionDataProto:
... |