File size: 1,940 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

from typing import Dict
from dataclasses import dataclass, field
from abc import ABC, abstractmethod

from larm.data.utils.tensor_utils import TensorHelper, TensorConfig


@dataclass
class InteractionConfig:
    max_turns: int
    max_start_length: int   
    max_prompt_length: int   
    max_response_length: int
    max_obs_length: int
    do_sample: bool
    temperature: float  

@dataclass
class InteractionDataProto:
    batch: Dict = field(default_factory=dict)
    no_tensor_batch: Dict = field(default_factory=dict)

class InteractionManager(ABC):
    
    def __init__(
        self,
        tokenizer,
        actor_rollout_wg,
        config: InteractionConfig,
        is_validation: bool = False,
    ):
        tokenizer = tokenizer.tokenizer # fix: processor --> tokenizer
        self.tokenizer = tokenizer
        self.tokenizer.padding_side = "left" 
        self.actor_rollout_wg = actor_rollout_wg
        self.config = config
        self.is_validation = is_validation
        
        assert tokenizer.pad_token_id is not None
        self.tensor_fn = TensorHelper(TensorConfig(
            pad_token_id=tokenizer.pad_token_id,
            max_prompt_length=config.max_prompt_length,
            max_obs_length=config.max_obs_length,
            max_start_length=config.max_start_length
        ))

        # Prefer chat end token (<|im_end|>) as EOS for decoding termination if available
        try:
            im_end_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False)
            if isinstance(im_end_ids, list) and len(im_end_ids) == 1:
                self.chat_eos_token_id = im_end_ids[0]
            else:
                self.chat_eos_token_id = self.tokenizer.eos_token_id
        except Exception:
            self.chat_eos_token_id = self.tokenizer.eos_token_id
    
    @abstractmethod
    def run_agent_loop(self, gen_batch: InteractionDataProto) -> InteractionDataProto:
        ...