model111 / larm /data /interactions /multiturn_interaction.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
import torch
from typing import Dict, List, Tuple
from transformers import GenerationConfig
import copy
from larm.data.interactions.base_interaction import (
InteractionDataProto,
InteractionConfig,
InteractionManager
)
class MultiTurnInteractionManager(InteractionManager):
def __init__(
self,
tokenizer,
actor_rollout_wg,
config: InteractionConfig,
is_validation: bool = False,
):
super().__init__(
tokenizer, actor_rollout_wg, config, is_validation
)
# generation configs for agent
# Prefer chat end token (<|im_end|>) if available for EOS
try:
im_end_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False)
if isinstance(im_end_ids, list) and len(im_end_ids) == 1:
eos_id = im_end_ids[0]
else:
eos_id = self.tokenizer.eos_token_id
except Exception:
eos_id = self.tokenizer.eos_token_id
self.generation_config = GenerationConfig(
do_sample=self.config.do_sample,
max_new_tokens=self.config.max_response_length,
temperature=self.config.temperature,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=eos_id
)
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
"""Tokenize a batch of responses."""
return self.tokenizer(
responses,
add_special_tokens=False,
return_tensors='pt',
padding="longest"
)['input_ids']
def _build_chat_history(self, rollings: Dict) -> List[Dict]:
init_prompts = rollings.get("init_prompts")
if init_prompts is None:
raise ValueError("")
inter_histories = rollings.get("inter_histories")
if inter_histories is None:
raise ValueError("")
chat_histories: List[List[Dict]] = []
for init_prompt, inter_history in zip(init_prompts, inter_histories):
chat_histories.append(init_prompt + inter_history)
return chat_histories
def _update_interaction_history(self, rollings: InteractionDataProto, responses: List[str], observations: List[str]) -> List[List[Dict]]:
inter_histories = copy.deepcopy(rollings.no_tensor_batch.get("inter_histories"))
assert len(inter_histories) == len(responses) == len(observations)
for inter_history, response, observation in zip(inter_histories, responses, observations):
assistant_info = {"role": "assistant", "content": response}
user_info = {"role": "user", "content": observation}
inter_history.append(assistant_info)
inter_history.append(user_info)
return inter_histories
def _postprocess_responses(self, responses: torch.Tensor, envs: List) -> torch.Tensor:
responses_str = self.tokenizer.batch_decode(
responses,
skip_special_tokens=True
)
processed_responses_str = []
for r, env in zip(responses_str, envs):
processed_r = env.preprocess_action(r)
processed_responses_str.append(processed_r)
responses = self._batch_tokenize(processed_responses_str)
return responses, processed_responses_str
def _example_level_pad(
self, responses_ids: torch.Tensor, responses_str: List[str], active_mask: torch.Tensor
) -> Tuple[torch.Tensor, List[str]]:
assert active_mask.sum() == responses_ids.shape[0]
# Create masked responses tensor
batch_size = active_mask.shape[0]
seq_len = responses_ids.shape[1]
padded_responses = torch.full(
(batch_size, seq_len), self.tokenizer.pad_token_id,
dtype=responses_ids.dtype, device=responses_ids.device
)
padded_responses[active_mask] = responses_ids
# Create masked response strings
padded_responses_str = [""] * batch_size
s = 0
for i, is_active in enumerate(active_mask):
if is_active:
padded_responses_str[i] = responses_str[s]
s += 1
return padded_responses, padded_responses_str
def run_agent_loop(self, gen_batch: InteractionDataProto) -> InteractionDataProto:
"""Run main LLM generation loop (conversation format)."""
assert "init_prompts" in gen_batch.no_tensor_batch
assert "envs" in gen_batch.no_tensor_batch
batch_size = len(gen_batch.no_tensor_batch["init_prompts"])
rollings = gen_batch
rollings.no_tensor_batch["inter_histories"] = [[] for _ in range(batch_size)]
active_mask = torch.ones(batch_size, dtype=torch.bool)
active_num_list = [active_mask.sum().item()]
for step in range(self.config.max_turns):
if not active_mask.sum():
break
mask_list = active_mask.tolist()
rollings_active = {
k: [item for item, keep in zip(v, mask_list) if keep]
for k, v in rollings.no_tensor_batch.items()
}
# use tokenizer to add chat template and encode text to tokens: input_ids, attention_mask
messages = self._build_chat_history(rollings_active)
self.tokenizer.padding_side = "left"
inputs = self.tokenizer.apply_chat_template(
messages, tokenize=True,
add_generation_prompt=True,
padding=True, return_tensors="pt", return_dict=True
)
# agent rollout
gen_output = self.actor_rollout_wg.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
generation_config=self.generation_config,
).to("cpu")
# postprocess
prompt_len = inputs["input_ids"].size(1)
responses = gen_output[:, prompt_len:]
responses = self.tensor_fn.erase_after_first_eos(responses, self.tokenizer.eos_token_id)
responses_ids, responses_str = self._postprocess_responses(responses, rollings_active["envs"])
all_responses_ids, all_responses_str = self._example_level_pad(responses_ids, responses_str, active_mask)
next_obs, dones = self._execute_predictions(rollings, all_responses_str, active_mask)
processed_obs = self._postprocess_observations(next_obs)
# post process interaction states
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
interaction_histories = self._update_interaction_history(rollings, all_responses_str, processed_obs)
rollings.no_tensor_batch["inter_histories"] = interaction_histories
# build final outputs
final_outputs = self._build_final_outputs(rollings)
return final_outputs
def _execute_predictions(self, rollings: InteractionDataProto, responses: List[str], active_mask: torch.Tensor) -> Tuple[List[str], List[str]]:
observations = []
dones = []
for response, env, is_active in zip(responses, rollings.no_tensor_batch["envs"], active_mask):
if is_active:
observation, _, done = env.step(response)
else:
observation = ""
done = True
observations.append(observation)
dones.append(done)
return observations, dones
def _postprocess_observations(self, observations: List[str]) -> List[str]:
self.tokenizer.padding_side = "right"
next_obs_ids = self._batch_tokenize(observations)
max_len = self.config.max_obs_length
if next_obs_ids.shape[1] > max_len:
print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {max_len}")
extra_text = "..."
extra_ids = self.tokenizer.encode(
extra_text, add_special_tokens=False, return_tensors="pt"
).to(next_obs_ids.device)
extra_len = extra_ids.shape[1]
new_obs_ids = []
for row in next_obs_ids:
valid_len = (row != self.tokenizer.pad_token_id).sum().item()
if valid_len > max_len:
truncated = row[: max_len - extra_len]
new_row = torch.cat([truncated, extra_ids.squeeze(0)], dim=0)
else:
new_row = row[:max_len]
new_obs_ids.append(new_row.unsqueeze(0))
next_obs_ids = torch.cat(new_obs_ids, dim=0)
observations = self.tokenizer.batch_decode(next_obs_ids, skip_special_tokens=True)
return observations
def _build_final_outputs(self, rollings: InteractionDataProto) -> InteractionDataProto:
init_prompts: List[List[Dict]] = rollings.no_tensor_batch["init_prompts"]
inter_histories: List[List[Dict]] = rollings.no_tensor_batch["inter_histories"]
output = InteractionDataProto()
output.no_tensor_batch["inter_histories"] = [
prompt + inter for prompt, inter in zip(init_prompts, inter_histories)
]
# ---------- prompts ----------
self.tokenizer.padding_side = "left"
prompt_ids = self.tokenizer.apply_chat_template(
init_prompts, tokenize=True,
add_generation_prompt=False,
padding=True, return_tensors="pt", return_dict=True
)
output.batch["prompts"] = prompt_ids["input_ids"]
prompt_attn_mask = prompt_ids["attention_mask"]
# ---------- responses ----------
self.tokenizer.padding_side = "right"
response_ids = self.tokenizer.apply_chat_template(
inter_histories,
tokenize=True,
padding=True,
return_assistant_tokens_mask=True,
add_generation_prompt=False,
return_tensors="pt", return_dict=True
)
output.batch["responses"] = response_ids["input_ids"]
response_attn_mask = response_ids["attention_mask"]
completion_info_mask = response_ids["assistant_masks"]
# ---------- input_ids ----------
output.batch["input_ids"] = torch.cat(
[prompt_ids["input_ids"], response_ids["input_ids"]], dim=1
)
output.batch["attention_mask"] = torch.cat(
[prompt_attn_mask, response_attn_mask], dim=1
)
# ---------- info_mask ----------
prompt_info_mask = torch.zeros(
prompt_ids["input_ids"].shape,
dtype=completion_info_mask.dtype,
device=completion_info_mask.device
)
output.batch["info_mask"] = torch.cat(
[prompt_info_mask, completion_info_mask], dim=1
)
self.tokenizer.padding_side = "left"
return output