|
|
import torch |
|
|
import logging |
|
|
from typing import Dict, List |
|
|
from transformers import GenerationConfig |
|
|
|
|
|
from larm.data.interactions.base_interaction import ( |
|
|
InteractionConfig, |
|
|
InteractionManager, |
|
|
InteractionDataProto |
|
|
) |
|
|
|
|
|
|
|
|
class SingleTurnInteractionManager(InteractionManager): |
|
|
def __init__( |
|
|
self, |
|
|
tokenizer, |
|
|
actor_rollout_wg, |
|
|
config: InteractionConfig, |
|
|
is_validation: bool = False, |
|
|
): |
|
|
super().__init__( |
|
|
tokenizer, actor_rollout_wg, config, is_validation |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
im_end_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False) |
|
|
if isinstance(im_end_ids, list) and len(im_end_ids) == 1: |
|
|
eos_id = im_end_ids[0] |
|
|
else: |
|
|
eos_id = self.tokenizer.eos_token_id |
|
|
except Exception: |
|
|
eos_id = self.tokenizer.eos_token_id |
|
|
|
|
|
self.generation_config = GenerationConfig( |
|
|
do_sample=self.config.do_sample, |
|
|
max_new_tokens=self.config.max_response_length, |
|
|
temperature=self.config.temperature, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=eos_id |
|
|
) |
|
|
|
|
|
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor: |
|
|
"""Tokenize a batch of responses.""" |
|
|
return self.tokenizer( |
|
|
responses, |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt', |
|
|
padding="longest" |
|
|
)['input_ids'] |
|
|
|
|
|
def _info_masked_concatenate_with_padding(self, |
|
|
prompt: torch.Tensor, |
|
|
prompt_with_mask: torch.Tensor, |
|
|
response: torch.Tensor, |
|
|
info: torch.Tensor = None, |
|
|
pad_to_left: bool = True |
|
|
) -> torch.Tensor: |
|
|
"""Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists.""" |
|
|
pad_id = self.tokenizer.pad_token_id |
|
|
tensors = [prompt, response] |
|
|
tensors_with_mask = [prompt_with_mask, response] |
|
|
if info is not None: |
|
|
tensors.append(info) |
|
|
info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) |
|
|
tensors_with_mask.append(info_mask) |
|
|
|
|
|
concatenated = torch.cat(tensors, dim=1) |
|
|
concatenated_with_info = torch.cat(tensors_with_mask, dim=1) |
|
|
mask = concatenated != pad_id if pad_to_left else concatenated == pad_id |
|
|
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True) |
|
|
padded_tensor = concatenated.gather(1, sorted_indices) |
|
|
padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices) |
|
|
|
|
|
return padded_tensor, padded_tensor_with_info |
|
|
|
|
|
def _update_right_side( |
|
|
self, right_side: Dict, |
|
|
cur_responses: torch.Tensor, |
|
|
next_obs_ids: torch.Tensor = None |
|
|
) -> Dict: |
|
|
"""Update right side state.""" |
|
|
if next_obs_ids != None: |
|
|
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding( |
|
|
right_side['responses'], |
|
|
right_side['responses_with_info_mask'], |
|
|
cur_responses, |
|
|
next_obs_ids, |
|
|
pad_to_left=False |
|
|
) |
|
|
else: |
|
|
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding( |
|
|
right_side['responses'], |
|
|
right_side['responses_with_info_mask'], |
|
|
cur_responses, |
|
|
pad_to_left=False |
|
|
) |
|
|
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max() |
|
|
max_len = min(self.config.max_prompt_length, effective_len) |
|
|
|
|
|
return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]} |
|
|
|
|
|
def _log_generation_input(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, sample_idx: int = 0): |
|
|
"""Log the actual input tokens passed to model.generate().""" |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vision_token_ids = [151652, 151653, 151654, 151655] |
|
|
|
|
|
|
|
|
vision_token_names = ["<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>", "<|vision_pad|>"] |
|
|
for token_name in vision_token_names: |
|
|
try: |
|
|
token_id = self.tokenizer.encode(token_name, add_special_tokens=False) |
|
|
if isinstance(token_id, list) and len(token_id) > 0: |
|
|
if token_id[0] not in vision_token_ids: |
|
|
vision_token_ids.append(token_id[0]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_ids = input_ids[sample_idx] |
|
|
sample_mask = attention_mask[sample_idx] |
|
|
|
|
|
|
|
|
valid_tokens = sample_ids[sample_mask.bool()].tolist() |
|
|
|
|
|
|
|
|
unique_tokens = set(valid_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vision_tokens_present = set(valid_tokens) & set(vision_token_ids) |
|
|
has_vision = len(vision_tokens_present) > 0 |
|
|
|
|
|
|
|
|
vision_range_tokens = [t for t in valid_tokens if 151650 <= t <= 151660] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info(f"[MODEL.GENERATE INPUT] Sample {sample_idx}") |
|
|
logging.info(f"Input length: {len(valid_tokens)} tokens") |
|
|
logging.info(f"Batch shape: {input_ids.shape}") |
|
|
|
|
|
if has_vision: |
|
|
logging.info(f"✓ Contains vision tokens: {vision_tokens_present}") |
|
|
else: |
|
|
logging.info("ℹ️ No vision tokens (text-only)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("=" * 80) |
|
|
except Exception as e: |
|
|
logging.warning(f"Failed to log generation input: {e}") |
|
|
|
|
|
def run_agent_loop(self, gen_batch: InteractionDataProto) -> InteractionDataProto: |
|
|
|
|
|
initial_input_ids = gen_batch.batch["input_ids"] |
|
|
has_pixels = "pixel_values" in gen_batch.batch |
|
|
|
|
|
if has_pixels: |
|
|
original_left_side = {'input_ids': initial_input_ids} |
|
|
else: |
|
|
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]} |
|
|
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]} |
|
|
|
|
|
|
|
|
rollings = gen_batch |
|
|
if has_pixels: |
|
|
|
|
|
rollings_active = {k: v for k, v in rollings.batch.items()} |
|
|
else: |
|
|
rollings.batch = self.tensor_fn.cut_to_effective_len( |
|
|
rollings.batch, |
|
|
keys=['input_ids', 'attention_mask'] |
|
|
) |
|
|
rollings_active = {k: v for k, v in rollings.batch.items()} |
|
|
|
|
|
|
|
|
self._log_generation_input(rollings_active["input_ids"], rollings_active["attention_mask"]) |
|
|
|
|
|
|
|
|
if "image_paths" in gen_batch.no_tensor_batch: |
|
|
image_paths = gen_batch.no_tensor_batch["image_paths"] |
|
|
if any(path is not None for path in image_paths): |
|
|
logging.info(f"[Image Paths in Batch] {image_paths}") |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"input_ids": rollings_active["input_ids"], |
|
|
"attention_mask": rollings_active["attention_mask"], |
|
|
"generation_config": self.generation_config, |
|
|
} |
|
|
if "pixel_values" in rollings_active: |
|
|
gen_kwargs["pixel_values"] = rollings_active["pixel_values"] |
|
|
if "image_grid_thw" in rollings_active: |
|
|
gen_kwargs["image_grid_thw"] = rollings_active["image_grid_thw"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_output = self.actor_rollout_wg.generate(**gen_kwargs) |
|
|
responses_ids = gen_output[:, rollings_active["input_ids"].size(1):] |
|
|
|
|
|
eos_id = getattr(self, "chat_eos_token_id", self.tokenizer.eos_token_id) |
|
|
responses_ids = self.tensor_fn.erase_after_first_eos(responses_ids, eos_id) |
|
|
|
|
|
|
|
|
original_right_side = self._update_right_side(original_right_side, responses_ids, next_obs_ids=None) |
|
|
|
|
|
|
|
|
return self._compose_final_output(original_left_side, original_right_side) |
|
|
|
|
|
def _compose_final_output( |
|
|
self, left_side: Dict, |
|
|
right_side: Dict, |
|
|
) -> InteractionDataProto: |
|
|
"""Compose final generation output.""" |
|
|
|
|
|
final_output_batch = right_side.copy() |
|
|
final_output_batch['prompts'] = left_side['input_ids'] |
|
|
final_output_batch["responses"] = right_side['responses'] |
|
|
|
|
|
|
|
|
final_output_batch['input_ids'] = torch.cat([ |
|
|
left_side['input_ids'], |
|
|
right_side['responses'] |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
final_output_batch['attention_mask'] = torch.cat([ |
|
|
self.tensor_fn.create_attention_mask(left_side['input_ids']), |
|
|
self.tensor_fn.create_attention_mask(final_output_batch['responses']) |
|
|
], dim=1) |
|
|
|
|
|
final_output_batch['info_mask'] = torch.cat([ |
|
|
self.tensor_fn.create_attention_mask(left_side['input_ids']), |
|
|
self.tensor_fn.create_attention_mask(final_output_batch['responses_with_info_mask']) |
|
|
], dim=1) |
|
|
|
|
|
final_output = InteractionDataProto(batch=final_output_batch) |
|
|
|
|
|
return final_output |
|
|
|