| | import torch |
| | import re |
| | from collections import defaultdict |
| | import os |
| | from typing import List, Dict, Any, Tuple |
| | from dataclasses import dataclass |
| | from .tensor_helper import TensorHelper, TensorConfig |
| | from verl import DataProto |
| | from verl.utils.tracking import Tracking |
| | import shutil |
| | import requests |
| |
|
| | @dataclass |
| | class GenerationConfig: |
| | max_turns: int |
| | max_start_length: int |
| | max_prompt_length: int |
| | max_response_length: int |
| | max_obs_length: int |
| | num_gpus: int |
| | no_think_rl: bool=False |
| | search_url: str = None |
| | topk: int = 3 |
| |
|
| | class LLMGenerationManager: |
| | def __init__( |
| | self, |
| | tokenizer, |
| | actor_rollout_wg, |
| | config: GenerationConfig, |
| | is_validation: bool = False, |
| | ): |
| | self.tokenizer = tokenizer |
| | self.actor_rollout_wg = actor_rollout_wg |
| | self.config = config |
| | self.is_validation = is_validation |
| |
|
| | self.tensor_fn = TensorHelper(TensorConfig( |
| | pad_token_id=tokenizer.pad_token_id, |
| | max_prompt_length=config.max_prompt_length, |
| | max_obs_length=config.max_obs_length, |
| | max_start_length=config.max_start_length |
| | )) |
| |
|
| | def _batch_tokenize(self, responses: List[str]) -> torch.Tensor: |
| | """Tokenize a batch of responses.""" |
| | return self.tokenizer( |
| | responses, |
| | add_special_tokens=False, |
| | return_tensors='pt', |
| | padding="longest" |
| | )['input_ids'] |
| |
|
| | def _postprocess_responses(self, responses: torch.Tensor) -> torch.Tensor: |
| | """Process responses to stop at search operation or answer operation.""" |
| | responses_str = self.tokenizer.batch_decode( |
| | responses, |
| | skip_special_tokens=True |
| | ) |
| |
|
| | responses_str = [resp.split('</search>')[0] + '</search>' |
| | if '</search>' in resp |
| | else resp.split('</answer>')[0] + '</answer>' |
| | if '</answer>' in resp |
| | else resp |
| | for resp in responses_str] |
| |
|
| | if self.config.no_think_rl: |
| | raise ValueError('stop') |
| | |
| | actions, _ = self.env.postprocess_predictions(responses_str) |
| | responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)] |
| | print("RESPONSES:", responses_str) |
| | responses = self._batch_tokenize(responses_str) |
| | return responses, responses_str |
| |
|
| | def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor: |
| | """Process next observations from environment.""" |
| | |
| | next_obs_ids = self.tokenizer( |
| | next_obs, |
| | padding='longest', |
| | return_tensors='pt', |
| | add_special_tokens=False, |
| | )['input_ids'] |
| |
|
| | if next_obs_ids.shape[1] > self.config.max_obs_length: |
| | print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}") |
| | next_obs_ids = next_obs_ids[:, :self.config.max_obs_length] |
| |
|
| | return next_obs_ids |
| |
|
| | def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor, |
| | next_obs_ids: torch.Tensor) -> Dict: |
| | """Update rolling state with new responses and observations.""" |
| | |
| | new_input_ids = self.tensor_fn.concatenate_with_padding([ |
| | rollings.batch['input_ids'], |
| | cur_responses, |
| | next_obs_ids |
| | ]) |
| | |
| | |
| | new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids) |
| | new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask) |
| |
|
| | |
| | effective_len = new_attention_mask.sum(dim=1).max() |
| | max_len = min(self.config.max_prompt_length, effective_len) |
| |
|
| | new_rollings = DataProto.from_dict({ |
| | 'input_ids': new_input_ids[:, -max_len:], |
| | 'position_ids': new_position_ids[:, -max_len:], |
| | 'attention_mask': new_attention_mask[:, -max_len:] |
| | }) |
| | new_rollings.meta_info.update(rollings.meta_info) |
| | |
| | return new_rollings |
| |
|
| | def _info_masked_concatenate_with_padding(self, |
| | prompt: torch.Tensor, |
| | prompt_with_mask: torch.Tensor, |
| | response: torch.Tensor, |
| | info: torch.Tensor = None, |
| | pad_to_left: bool = True |
| | ) -> torch.Tensor: |
| | """Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists.""" |
| | pad_id = self.tokenizer.pad_token_id |
| | tensors = [prompt, response] |
| | tensors_with_mask = [prompt_with_mask, response] |
| | if info is not None: |
| | tensors.append(info) |
| | info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) |
| | tensors_with_mask.append(info_mask) |
| | |
| | concatenated = torch.cat(tensors, dim=1) |
| | concatenated_with_info = torch.cat(tensors_with_mask, dim=1) |
| | mask = concatenated != pad_id if pad_to_left else concatenated == pad_id |
| | sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True) |
| | padded_tensor = concatenated.gather(1, sorted_indices) |
| | padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices) |
| |
|
| | return padded_tensor, padded_tensor_with_info |
| |
|
| | def _update_right_side(self, right_side: Dict, |
| | cur_responses: torch.Tensor, |
| | next_obs_ids: torch.Tensor = None) -> Dict: |
| | """Update right side state.""" |
| | if next_obs_ids != None: |
| | responses, responses_with_info_mask = self._info_masked_concatenate_with_padding( |
| | right_side['responses'], |
| | right_side['responses_with_info_mask'], |
| | cur_responses, |
| | next_obs_ids, |
| | pad_to_left=False |
| | ) |
| | else: |
| | responses, responses_with_info_mask = self._info_masked_concatenate_with_padding( |
| | right_side['responses'], |
| | right_side['responses_with_info_mask'], |
| | cur_responses, |
| | pad_to_left=False |
| | ) |
| | effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max() |
| | max_len = min(self.config.max_prompt_length, effective_len) |
| | |
| | return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]} |
| |
|
| | def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto: |
| | """ |
| | Wrapper for generation that handles multi-GPU padding requirements. |
| | if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch) |
| | if active_batch size is not divisible by num_gpus, pad with first sequence |
| | then remove padding from output |
| | """ |
| | num_gpus = self.config.num_gpus |
| | if num_gpus <= 1: |
| | return self.actor_rollout_wg.generate_sequences(active_batch) |
| | |
| | batch_size = active_batch.batch['input_ids'].shape[0] |
| | remainder = batch_size % num_gpus |
| | |
| | for key in active_batch.batch.keys(): |
| | active_batch.batch[key] = active_batch.batch[key].long() |
| | if remainder == 0: |
| | return self.actor_rollout_wg.generate_sequences(active_batch) |
| | |
| | |
| | padding_size = num_gpus - remainder |
| | padded_batch = {} |
| | |
| | for k, v in active_batch.batch.items(): |
| | |
| | pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1)) |
| | padded_batch[k] = torch.cat([v, pad_sequence], dim=0) |
| |
|
| | padded_active_batch = DataProto.from_dict(padded_batch) |
| | for key in padded_active_batch.batch.keys(): |
| | padded_active_batch.batch[key] = padded_active_batch.batch[key].long() |
| |
|
| | |
| | padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch) |
| |
|
| | |
| | trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()} |
| | |
| | |
| | if hasattr(padded_output, 'meta_info') and padded_output.meta_info: |
| | trimmed_meta = {} |
| | for k, v in padded_output.meta_info.items(): |
| | if isinstance(v, torch.Tensor): |
| | trimmed_meta[k] = v[:-padding_size] |
| | else: |
| | trimmed_meta[k] = v |
| | padded_output.meta_info = trimmed_meta |
| | |
| | padded_output.batch = trimmed_batch |
| | return padded_output |
| |
|
| | def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]: |
| | """Run main LLM generation loop.""" |
| | |
| | original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]} |
| | original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]} |
| | |
| | active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool) |
| | turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int) |
| | valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int) |
| | valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int) |
| | active_num_list = [active_mask.sum().item()] |
| | rollings = gen_batch |
| |
|
| | |
| | for step in range(self.config.max_turns): |
| | if not active_mask.sum(): |
| | break |
| | rollings.batch = self.tensor_fn.cut_to_effective_len( |
| | rollings.batch, |
| | keys=['input_ids', 'attention_mask', 'position_ids'] |
| | ) |
| | |
| | |
| | rollings_active = DataProto.from_dict({ |
| | k: v[active_mask] for k, v in rollings.batch.items() |
| | }) |
| | gen_output = self._generate_with_gpu_padding(rollings_active) |
| |
|
| | meta_info = gen_output.meta_info |
| | responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses']) |
| | responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask) |
| |
|
| | |
| | next_obs, dones, valid_action, is_search = self.execute_predictions( |
| | responses_str, self.tokenizer.pad_token, active_mask |
| | ) |
| | |
| | curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool) |
| | active_mask = active_mask * curr_active_mask |
| | active_num_list.append(active_mask.sum().item()) |
| | turns_stats[curr_active_mask] += 1 |
| | valid_action_stats += torch.tensor(valid_action, dtype=torch.int) |
| | valid_search_stats += torch.tensor(is_search, dtype=torch.int) |
| |
|
| | next_obs_ids = self._process_next_obs(next_obs) |
| | |
| | |
| | rollings = self._update_rolling_state( |
| | rollings, |
| | responses_ids, |
| | next_obs_ids |
| | ) |
| | original_right_side = self._update_right_side( |
| | original_right_side, |
| | responses_ids, |
| | next_obs_ids |
| | ) |
| | |
| | |
| | if active_mask.sum(): |
| | rollings.batch = self.tensor_fn.cut_to_effective_len( |
| | rollings.batch, |
| | keys=['input_ids', 'attention_mask', 'position_ids'] |
| | ) |
| |
|
| | |
| | rollings_active = DataProto.from_dict({ |
| | k: v[active_mask] for k, v in rollings.batch.items() |
| | }) |
| | gen_output = self._generate_with_gpu_padding(rollings_active) |
| |
|
| | meta_info = gen_output.meta_info |
| | responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses']) |
| | responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask) |
| |
|
| | |
| | _, dones, valid_action, is_search = self.execute_predictions( |
| | responses_str, self.tokenizer.pad_token, active_mask, do_search=False |
| | ) |
| |
|
| | curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool) |
| | active_mask = active_mask * curr_active_mask |
| | active_num_list.append(active_mask.sum().item()) |
| | valid_action_stats += torch.tensor(valid_action, dtype=torch.int) |
| | valid_search_stats += torch.tensor(is_search, dtype=torch.int) |
| | |
| |
|
| | original_right_side = self._update_right_side( |
| | original_right_side, |
| | responses_ids, |
| | ) |
| | |
| | meta_info['turns_stats'] = turns_stats.tolist() |
| | meta_info['active_mask'] = active_mask.tolist() |
| | meta_info['valid_action_stats'] = valid_action_stats.tolist() |
| | meta_info['valid_search_stats'] = valid_search_stats.tolist() |
| | |
| | print("ACTIVE_TRAJ_NUM:", active_num_list) |
| | |
| | return self._compose_final_output(original_left_side, original_right_side, meta_info) |
| |
|
| | def _compose_final_output(self, left_side: Dict, |
| | right_side: Dict, |
| | meta_info: Dict) -> Tuple[Dict, Dict]: |
| | """Compose final generation output.""" |
| | final_output = right_side.copy() |
| | final_output['prompts'] = left_side['input_ids'] |
| | |
| | |
| | final_output['input_ids'] = torch.cat([ |
| | left_side['input_ids'], |
| | right_side['responses'] |
| | ], dim=1) |
| | |
| | |
| | final_output['attention_mask'] = torch.cat([ |
| | self.tensor_fn.create_attention_mask(left_side['input_ids']), |
| | self.tensor_fn.create_attention_mask(final_output['responses']) |
| | ], dim=1) |
| | final_output['info_mask'] = torch.cat([ |
| | self.tensor_fn.create_attention_mask(left_side['input_ids']), |
| | self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask']) |
| | ], dim=1) |
| | |
| | final_output['position_ids'] = self.tensor_fn.create_position_ids( |
| | final_output['attention_mask'] |
| | ) |
| | |
| | final_output = DataProto.from_dict(final_output) |
| | final_output.meta_info.update(meta_info) |
| | |
| | return final_output |
| |
|
| | def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True) -> List[str]: |
| | """ |
| | Execute predictions across multiple environments. |
| | NOTE: the function is the actual `step` function in the environment |
| | NOTE penalty_for_invalid is not included in observation shown to the LLM |
| | |
| | Args: |
| | envs: List of environment instances |
| | predictions: List of action predictions |
| | pad_token: Token to use for padding |
| | |
| | Returns: |
| | List of observation strings |
| | """ |
| | cur_actions, contents = self.postprocess_predictions(predictions) |
| | next_obs, dones, valid_action, is_search = [], [], [], [] |
| | |
| | search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search'] |
| | if do_search: |
| | search_results = self.batch_search(search_queries) |
| | assert len(search_results) == sum([1 for action in cur_actions if action == 'search']) |
| | else: |
| | search_results = [''] * sum([1 for action in cur_actions if action == 'search']) |
| |
|
| | for i, (action, active) in enumerate(zip(cur_actions, active_mask)): |
| | |
| | if not active: |
| | next_obs.append('') |
| | dones.append(1) |
| | valid_action.append(0) |
| | is_search.append(0) |
| | else: |
| | if action == 'answer': |
| | next_obs.append('') |
| | dones.append(1) |
| | valid_action.append(1) |
| | is_search.append(0) |
| | elif action == 'search': |
| | next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n') |
| | dones.append(0) |
| | valid_action.append(1) |
| | is_search.append(1) |
| | else: |
| | next_obs.append(f'\nMy previous action is invalid. \ |
| | If I want to search, I should put the query between <search> and </search>. \ |
| | If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n') |
| | dones.append(0) |
| | valid_action.append(0) |
| | is_search.append(0) |
| | |
| | assert len(search_results) == 0 |
| | |
| | return next_obs, dones, valid_action, is_search |
| |
|
| | def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]: |
| | """ |
| | Process (text-based) predictions from llm into actions and validity flags. |
| | |
| | Args: |
| | predictions: List of raw predictions |
| | |
| | Returns: |
| | Tuple of (actions list, validity flags list) |
| | """ |
| | actions = [] |
| | contents = [] |
| | |
| | for prediction in predictions: |
| | if isinstance(prediction, str): |
| | pattern = r'<(search|answer)>(.*?)</\1>' |
| | match = re.search(pattern, prediction, re.DOTALL) |
| | if match: |
| | content = match.group(2).strip() |
| | action = match.group(1) |
| | else: |
| | content = '' |
| | action = None |
| | else: |
| | raise ValueError(f"Invalid prediction type: {type(prediction)}") |
| | |
| | actions.append(action) |
| | contents.append(content) |
| | |
| | return actions, contents |
| |
|
| | def batch_search(self, queries: List[str] = None) -> str: |
| | """ |
| | Batchified search for queries. |
| | Args: |
| | queries: queries to call the search engine |
| | Returns: |
| | search results which is concatenated into a string |
| | """ |
| | results = self._batch_search(queries)['result'] |
| | |
| | return [self._passages2string(result) for result in results] |
| |
|
| | def _batch_search(self, queries): |
| | |
| | payload = { |
| | "queries": queries, |
| | "topk": self.config.topk, |
| | "return_scores": True |
| | } |
| | |
| | return requests.post(self.config.search_url, json=payload).json() |
| |
|
| | def _passages2string(self, retrieval_result): |
| | format_reference = '' |
| | for idx, doc_item in enumerate(retrieval_result): |
| | |
| | content = doc_item['document']['contents'] |
| | title = content.split("\n")[0] |
| | text = "\n".join(content.split("\n")[1:]) |
| | format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" |
| |
|
| | return format_reference |
| |
|