| | |
| | |
| | import math |
| | from torch.distributions import Normal |
| | from collections import defaultdict |
| | import torch |
| | from torch.nn.utils import clip_grad_norm_ |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn import CrossEntropyLoss |
| | from torch.utils.data import DataLoader |
| | from collections import namedtuple |
| | from transformers.models.gpt2 import GPT2LMHeadModel |
| |
|
| | from modules import grpo |
| | from modules.projector import LatentPolicy |
| | from modules.utils import get_position_ids_from_attention_mask |
| | import copy |
| |
|
| | Outputs = namedtuple("Outputs", ["loss", "loss_explain_all", "inputs_embeds", "logits"]) |
| | Outputs_withkl = namedtuple("Outputs", ["loss", "loss_explain_all", "loss_kl", "inputs_embeds", "logits"]) |
| | Outputs_withmask = namedtuple("Outputs_withmask", |
| | ["loss", "attention_mask", "loss_explain_all", "inputs_embeds", "logits"]) |
| | MAX_N_LATENT = 8 |
| |
|
| |
|
| | class CoconutGPT_Same_Word_Embedding(nn.Module): |
| | def __init__( |
| | self, |
| | base_causallm, |
| | expainable_llm, |
| | |
| | tokenizer, |
| | latent_token_id, |
| | start_latent_id, |
| | end_latent_id, |
| | eos_token_id, |
| | step_start_id, |
| | c_thought, |
| | configs, |
| | ): |
| |
|
| | super(CoconutGPT_Same_Word_Embedding, self).__init__() |
| | self.gen_forward_cnt = 0 |
| | self.base_causallm = base_causallm |
| | self.base_causallm.config.use_cache = True |
| | self.expainable_llm = expainable_llm |
| | |
| | self.tokenizer = tokenizer |
| | self.latent_token_id = latent_token_id |
| | self.eos_token_id = eos_token_id |
| | self.start_latent_id = start_latent_id |
| | self.end_latent_id = end_latent_id |
| | self.step_start_id = step_start_id |
| | self.c_thought = c_thought |
| | self.config = configs |
| |
|
| | if hasattr(self.config, "training_method"): |
| | if self.config.training_method == 'only_expainable_llm': |
| | for param in self.base_causallm.parameters(): |
| | param.requires_grad = False |
| | elif self.config.training_method == 'only_base_causallm': |
| | for param in self.expainable_llm.parameters(): |
| | param.requires_grad = False |
| | elif self.config.training_method == 'full': |
| | pass |
| | elif self.config.training_method == 'freeze_backbone': |
| | for param in self.base_causallm.parameters(): |
| | param.requires_grad = False |
| |
|
| | for param in self.expainable_llm.parameters(): |
| | param.requires_grad = False |
| | else: |
| | raise ValueError(f"not this training_method {self.config.training_method=}") |
| |
|
| | if isinstance(self.base_causallm, GPT2LMHeadModel): |
| | self.embedding = self.base_causallm.transformer.get_input_embeddings() |
| | print("is GPT") |
| | else: |
| | self.embedding = self.base_causallm.get_input_embeddings() |
| | print("is not GPT") |
| |
|
| | def forward(self, input_ids, attention_mask, labels, position_ids, **kwargs): |
| | logits = [] |
| | loss = 0.0 |
| | loss_explain_all = torch.tensor(0.0, device=input_ids.device) |
| | c_thought_num = 1 |
| | latent_indices = ( |
| | input_ids == self.latent_token_id |
| | ).nonzero() |
| |
|
| | latent_lists = [ |
| | [idx[1].item() for idx in latent_indices if idx[0] == i] |
| | for i in range(input_ids.shape[0]) |
| | ] |
| |
|
| | max_n_latents = max([len(l) for l in latent_lists]) |
| |
|
| | next_compute_range = (0, input_ids.shape[1]) |
| | inputs_embeds = self.embedding(input_ids) |
| |
|
| | if max_n_latents > 0: |
| | next_compute_range = (0, latent_indices[:, 1].min().item()) |
| | |
| |
|
| | kv_cache = None |
| |
|
| | for pass_idx in range(max_n_latents): |
| |
|
| | if kv_cache == None: |
| | |
| | outputs = self.base_causallm( |
| | inputs_embeds=inputs_embeds[ |
| | :, next_compute_range[0]: next_compute_range[1], : |
| | ], |
| | attention_mask=attention_mask[ |
| | :, next_compute_range[0]: next_compute_range[1] |
| | ], |
| | position_ids=position_ids[ |
| | :, next_compute_range[0]: next_compute_range[1] |
| | ], |
| | output_hidden_states=True, |
| | ) |
| | hidden_states_offset = 0 |
| |
|
| | else: |
| | |
| | past_key_values = [ |
| | ( |
| | k[:, :, : next_compute_range[0], :], |
| | v[:, :, : next_compute_range[0], :], |
| | ) |
| | for k, v in kv_cache |
| | ] |
| |
|
| | outputs = self.base_causallm( |
| | inputs_embeds=inputs_embeds[ |
| | :, next_compute_range[0]: next_compute_range[1], : |
| | ], |
| | attention_mask=attention_mask[:, : next_compute_range[1]], |
| | position_ids=position_ids[ |
| | :, next_compute_range[0]: next_compute_range[1] |
| | ], |
| | past_key_values=past_key_values, |
| | output_hidden_states=True, |
| | ) |
| |
|
| | hidden_states_offset = next_compute_range[0] |
| | |
| | |
| | |
| |
|
| | logits.append(outputs.logits) |
| |
|
| | next_compute_range = ( |
| | next_compute_range[1], |
| | ( |
| | input_ids.shape[1] |
| | if pass_idx + 1 >= max_n_latents |
| | else next_compute_range[1] + 1 |
| | ), |
| | ) |
| |
|
| | hidden_states = outputs.hidden_states[ |
| | -1 |
| | ] |
| | kv_cache = outputs.past_key_values |
| |
|
| | |
| |
|
| | |
| | filling_indices = [ |
| | (instance_idx, mask_list[pass_idx]) |
| | for instance_idx, mask_list in enumerate(latent_lists) |
| | if len(mask_list) > pass_idx |
| | ] |
| |
|
| | |
| | |
| | tensor_list = [ |
| | [ |
| | inputs_embeds[batch_idx, pos, :] |
| | for pos in range(inputs_embeds.shape[1]) |
| | ] |
| | for batch_idx in range(inputs_embeds.shape[0]) |
| | ] |
| |
|
| | |
| | for idx_pair in filling_indices: |
| | batch_idx, token_idx = idx_pair |
| |
|
| | |
| | tensor_list[batch_idx][token_idx] = hidden_states[ |
| | batch_idx, token_idx - 1 - hidden_states_offset, : |
| | ] |
| |
|
| | |
| | inputs_embeds = torch.stack( |
| | [ |
| | torch.stack(tensor_list[batch_idx]) |
| | for batch_idx in range(inputs_embeds.shape[0]) |
| | ] |
| | ) |
| |
|
| | |
| | outputs = self.base_causallm( |
| | inputs_embeds=inputs_embeds[ |
| | :, next_compute_range[0]: next_compute_range[1], : |
| | ], |
| | attention_mask=attention_mask[:, : next_compute_range[1]], |
| | position_ids=position_ids[:, next_compute_range[0]: next_compute_range[1]], |
| | past_key_values=( |
| | [ |
| | ( |
| | k[:, :, : next_compute_range[0], :], |
| | v[:, :, : next_compute_range[0], :], |
| | ) |
| | for k, v in kv_cache |
| | ] |
| | if kv_cache |
| | else None |
| | ), |
| | output_hidden_states=True, |
| | ) |
| |
|
| | logits.append(outputs.logits) |
| |
|
| | self.gen_forward_cnt += max_n_latents + 1 |
| |
|
| | logits = torch.cat(logits, dim=-2) |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss_fct = CrossEntropyLoss() |
| | if self.config.training_method == 'only_base_causallm' or self.config.training_method == 'full': |
| | loss = loss_fct( |
| | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
| | ) |
| |
|
| | if hasattr(self.config, 'visualize') and self.config.visualize: |
| | debug_predictions = [] |
| |
|
| | for debug_idx in range(0, len(latent_lists[0]), self.config.c_thought): |
| |
|
| | continuous_embeds = inputs_embeds[:, latent_lists[0][debug_idx: debug_idx + self.c_thought], :].to( |
| | self.expainable_llm.device) |
| |
|
| | if hasattr(self.config, 'w_prompt') and self.config.w_prompt: |
| | if hasattr(self.config, 'explain_mode') and self.config.explain_mode == 'v1_aug': |
| | thought_idx = debug_idx // 2 |
| | if thought_idx != 2: |
| | input_explain_input_embeds_pre_order_prompt_ids = self.tokenizer( |
| | f'Step {thought_idx + 1} of the solution', add_special_tokens=False).input_ids |
| | else: |
| | input_explain_input_embeds_pre_order_prompt_ids = self.tokenizer( |
| | f'Step 3 and all the remaining steps of the solution', |
| | add_special_tokens=False).input_ids |
| | bz = continuous_embeds.shape[0] |
| | input_explain_input_embeds_pre_order_prompt_embeds = self.embedding( |
| | torch.tensor(input_explain_input_embeds_pre_order_prompt_ids).to( |
| | self.expainable_llm.device))[None, ...].repeat(bz, 1, 1) |
| | continuous_embeds = torch.cat( |
| | [input_explain_input_embeds_pre_order_prompt_embeds, continuous_embeds], dim=1) |
| | debug_ids = torch.empty((1, 0), dtype=torch.long, device=self.expainable_llm.device) |
| | while True: |
| | if debug_ids.shape[0] != 0: |
| | debug_embeds = torch.cat([continuous_embeds, self.embedding(debug_ids)], dim=1) |
| | else: |
| | debug_embeds = continuous_embeds |
| | explainable_outputs = self.expainable_llm( |
| | inputs_embeds=debug_embeds, |
| | attention_mask=torch.ones(debug_embeds.shape[:2]).to(self.expainable_llm.device), |
| | position_ids=torch.arange(1, debug_embeds.shape[1] + 1).unsqueeze(dim=0).to( |
| | self.expainable_llm.device), |
| | output_hidden_states=True, |
| | ) |
| | debug_logits = explainable_outputs.logits[:, -1, :] / .98 |
| | probs = torch.softmax(debug_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | debug_ids = torch.cat([debug_ids, next_token], dim=1) |
| |
|
| | if torch.all(next_token == self.eos_token_id) or debug_ids.shape[-1] > 512: |
| | break |
| |
|
| | print(self.tokenizer.decode(debug_ids[0])) |
| | debug_predictions.append(self.tokenizer.decode(debug_ids[0])) |
| |
|
| | if hasattr(self.config, 'visualize_jsonl') and self.config.visualize_jsonl != '': |
| | save_jsonl_line(self.config.visualize_jsonl, {"predictiion": debug_predictions}) |
| | if hasattr(self.config, 'explain_mode') and self.config.explain_mode == 'v1_aug': |
| | |
| | if 'explainable_ids_list' in kwargs: |
| | c_thought_num = len(latent_lists[0]) // self.c_thought |
| |
|
| | input_united_tokens = [] |
| |
|
| | def safe_token_id(x): |
| | return x[0] if isinstance(x, list) else x |
| |
|
| | start_token = safe_token_id(self.tokenizer.encode('<<', add_special_tokens=False)) |
| | end_token = safe_token_id(self.tokenizer.encode('>>', add_special_tokens=False)) |
| | separator_token = safe_token_id(self.tokenizer.encode('\n', add_special_tokens=False)) |
| |
|
| | def trim_trailing_zeros(group): |
| | while group and group[-1] == 0: |
| | group.pop() |
| | return group |
| |
|
| | def replace_llama_special_tokens(x, merged_token, end_token, separator_token): |
| | out = [] |
| | for seq in x: |
| | new_seq = [] |
| | for t in seq: |
| | if t.item() == merged_token: |
| | new_seq.extend([end_token, separator_token]) |
| | elif t.item() != 0 or len(new_seq) > 0: |
| | new_seq.append(t.item()) |
| | out.append(torch.tensor(new_seq, device=x.device)) |
| | return out |
| |
|
| | if len(self.tokenizer.encode('>>\n', add_special_tokens=False)) == 1: |
| | merge_token = self.tokenizer.encode('>>\n', add_special_tokens=False)[0] |
| | kwargs['explainable_ids_list'] = copy.deepcopy( |
| | replace_llama_special_tokens(kwargs['explainable_ids_list'], merge_token, end_token, |
| | separator_token)) |
| |
|
| | for j, seq in enumerate(kwargs['explainable_ids_list']): |
| | i = 0 |
| | groups = [] |
| | while i < len(seq): |
| | if seq[i] == start_token: |
| |
|
| | group = [start_token] |
| | i += 1 |
| | while i < len(seq): |
| | group.append(seq[i]) |
| | if seq[i] == end_token: |
| | break |
| | i += 1 |
| | group = trim_trailing_zeros(group) |
| | groups.append(group) |
| | else: |
| | i += 1 |
| | print(len(groups)) |
| | if len(groups) < self.config.max_latent_stage: |
| | input_ids_j = input_ids[j].tolist() |
| |
|
| | try: |
| | start_idx = len(input_ids_j) - 1 - input_ids_j[::-1].index(self.end_latent_id) |
| | except ValueError: |
| | continue |
| |
|
| | try: |
| | end_idx = input_ids_j.index(self.eos_token_id, start_idx + 1) |
| | except ValueError: |
| | end_idx = len(input_ids_j) |
| |
|
| | pseudo_thought = input_ids_j[start_idx + 1:end_idx] |
| |
|
| | if not pseudo_thought: |
| | continue |
| |
|
| | if hasattr(self.config, 'format_pseudo_thought') and self.config.format_pseudo_thought: |
| | tmp_num = self.tokenizer.decode(pseudo_thought).replace('### ', '') |
| | pseudo_thought = self.tokenizer.encode(f'<<{tmp_num}={tmp_num}>>', add_special_tokens=False) |
| |
|
| | while len(groups) < c_thought_num: |
| | groups.append(pseudo_thought) |
| |
|
| | input_united_groups = [] |
| | combined_group = [] |
| | group_count = 0 |
| |
|
| | for group in groups: |
| | group_count += 1 |
| | if group_count <= self.config.max_latent_stage - 1: |
| | group = [-570] * self.c_thought + group + [self.eos_token_id] |
| | cleaned_group = [int(x) if torch.is_tensor(x) else x for x in group] |
| | input_united_groups.append(cleaned_group) |
| | else: |
| | if combined_group and combined_group[-1] == end_token and group[0] == start_token: |
| | combined_group.append(separator_token) |
| | combined_group.extend(group) |
| |
|
| | if combined_group: |
| | final_group = [-570] * self.c_thought + combined_group + [self.eos_token_id] |
| | cleaned_group = [int(x) if torch.is_tensor(x) else x for x in final_group] |
| | input_united_groups.append(cleaned_group) |
| |
|
| | input_united_tokens.append(copy.deepcopy(input_united_groups)) |
| |
|
| | |
| | bz = len(input_united_tokens) |
| |
|
| | if hasattr(self.config, 'packing') and self.config.packing == True: |
| | pass |
| | else: |
| | for thought_idx in range(c_thought_num): |
| |
|
| | max_pad_len = max(len(input_united_tokens[bz_idx][thought_idx]) for bz_idx in range(bz)) |
| | max_pad_len += 1 |
| | for bz_idx in range(bz): |
| | token_seq = input_united_tokens[bz_idx][thought_idx] |
| | pad_len = max_pad_len - len(token_seq) |
| | if pad_len > 0: |
| | token_seq += [self.eos_token_id] * pad_len |
| | input_united_tokens[bz_idx][thought_idx] = token_seq |
| |
|
| | print("there") |
| | if hasattr(self.config, 'packing') and self.config.packing == True: |
| | print("there1") |
| | max_pad_len = 0 |
| | for bz_idx in range(bz): |
| | for thought_idx in range(c_thought_num): |
| | continuous_embeds = inputs_embeds[bz_idx, |
| | latent_lists[bz_idx][self.c_thought * thought_idx]:latent_lists[bz_idx][ |
| | self.c_thought * thought_idx + 1] + 1, |
| | :] |
| | other_embeds = self.embedding( |
| | torch.tensor(input_united_tokens[bz_idx][thought_idx][self.c_thought:]).to( |
| | self.expainable_llm.device)) |
| | max_pad_len = max(max_pad_len, continuous_embeds.size(0) + other_embeds.size( |
| | 0)) |
| |
|
| | input_explain_input_embeds_batch = [[] for _ in range(c_thought_num)] |
| | input_explain_attention_mask_batch = [[] for _ in range(c_thought_num)] |
| | input_explain_position_ids_batch = [[] for _ in range(c_thought_num)] |
| | input_explain_labels_batch = [[] for _ in range(c_thought_num)] |
| |
|
| | |
| | for thought_idx in range(c_thought_num): |
| | for bz_idx in range(bz): |
| | continuous_embeds = inputs_embeds[bz_idx, |
| | latent_lists[bz_idx][self.c_thought * thought_idx]:latent_lists[bz_idx][ |
| | self.c_thought * thought_idx + 1] + 1, |
| | :] |
| | other_embeds = self.embedding( |
| | torch.tensor(input_united_tokens[bz_idx][thought_idx][self.c_thought:]).to( |
| | self.expainable_llm.device)) |
| |
|
| | input_explain_input_embeds_batch[thought_idx].append( |
| | torch.cat([continuous_embeds, other_embeds], dim=0)) |
| |
|
| | attention_eos_index = input_united_tokens[bz_idx][thought_idx].index(self.eos_token_id) |
| | attention_explain_mask = torch.zeros(len(input_united_tokens[bz_idx][thought_idx]), |
| | dtype=int) |
| | attention_explain_mask[:attention_eos_index + 1] = 1 |
| | input_explain_attention_mask_batch[thought_idx].append(attention_explain_mask) |
| |
|
| | input_explain_position_ids_batch[thought_idx].append( |
| | torch.arange(1, len(input_united_tokens[bz_idx][thought_idx]) + 1, dtype=int)) |
| |
|
| | explain_labels = torch.tensor(input_united_tokens[bz_idx][thought_idx], dtype=int) |
| | explain_labels_mask = (explain_labels != -570) & (explain_labels != self.eos_token_id) |
| | explain_labels_mask[attention_eos_index] = True |
| | explain_labels[~explain_labels_mask] = -100 |
| | input_explain_labels_batch[thought_idx].append(explain_labels) |
| |
|
| | |
| | input_explain_input_embeds_batch_tensor = torch.zeros(bz, c_thought_num, max_pad_len, |
| | continuous_embeds.size(-1), |
| | device=self.expainable_llm.device) |
| | input_explain_attention_mask_batch_tensor = torch.zeros(bz, c_thought_num, max_pad_len, |
| | device=self.expainable_llm.device) |
| | input_explain_position_ids_batch_tensor = torch.zeros(bz, c_thought_num, max_pad_len, |
| | device=self.expainable_llm.device) |
| | input_explain_labels_batch_tensor = torch.full((bz, c_thought_num, max_pad_len), -100, |
| | device=self.expainable_llm.device) |
| |
|
| | |
| | for bz_idx in range(bz): |
| | for thought_idx in range(c_thought_num): |
| | input_explain_input_embeds_batch_tensor[bz_idx, thought_idx, |
| | :input_explain_input_embeds_batch[thought_idx][bz_idx].size(0)] = \ |
| | input_explain_input_embeds_batch[thought_idx][bz_idx] |
| | input_explain_attention_mask_batch_tensor[bz_idx, thought_idx, |
| | :input_explain_attention_mask_batch[thought_idx][bz_idx].size(0)] = \ |
| | input_explain_attention_mask_batch[thought_idx][bz_idx] |
| | input_explain_position_ids_batch_tensor[bz_idx, thought_idx, |
| | :input_explain_position_ids_batch[thought_idx][bz_idx].size(0)] = \ |
| | input_explain_position_ids_batch[thought_idx][bz_idx] |
| | input_explain_labels_batch_tensor[bz_idx, thought_idx, |
| | :input_explain_labels_batch[thought_idx][bz_idx].size(0)] = \ |
| | input_explain_labels_batch[thought_idx][bz_idx] |
| |
|
| | |
| | input_explain_input_embeds_batch_tensor = input_explain_input_embeds_batch_tensor.view(bz, -1, |
| | input_explain_input_embeds_batch_tensor.size( |
| | -1)) |
| | input_explain_attention_mask_batch_tensor = input_explain_attention_mask_batch_tensor.view(bz, -1) |
| | input_explain_position_ids_batch_tensor = input_explain_position_ids_batch_tensor.view(bz, -1) |
| | input_explain_labels_batch_tensor = input_explain_labels_batch_tensor.view(bz, -1) |
| |
|
| | |
| | input_explain_attention_mask_batch_tensor = prepare_4d_attention_mask( |
| | input_explain_attention_mask_batch_tensor, dtype=self.expainable_llm.dtype) |
| |
|
| | |
| | explainable_outputs = self.expainable_llm( |
| | inputs_embeds=input_explain_input_embeds_batch_tensor, |
| | attention_mask=input_explain_attention_mask_batch_tensor, |
| | position_ids=input_explain_position_ids_batch_tensor.to(torch.long), |
| | output_hidden_states=True, |
| | ) |
| |
|
| | explainable_logits = explainable_outputs.logits |
| | effective_loss_num = float( |
| | (input_explain_labels_batch_tensor != -100).sum(dim=1).bool().sum().item()) |
| |
|
| | shift_explain_logits = explainable_logits[..., :-1, :].contiguous() |
| | shift_explain_labels = input_explain_labels_batch_tensor[..., 1:].to(torch.long).contiguous() |
| | loss_explain_fct = CrossEntropyLoss(reduction='sum') |
| | loss_explain = loss_explain_fct( |
| | shift_explain_logits.view(-1, shift_explain_logits.size(-1)), shift_explain_labels.view(-1) |
| | ) |
| |
|
| | loss_explain /= effective_loss_num |
| | loss_explain_all += loss_explain |
| |
|
| | else: |
| |
|
| | print("there2") |
| | for thought_idx in range(c_thought_num): |
| | input_explain_input_embeds = [] |
| | input_explain_attention_mask, input_explain_position_ids, input_explain_labels = [], [], [] |
| | max_pad_len = -1 |
| |
|
| | def extract_token_range(tensor, start_id=128000, end_id=128256): |
| | try: |
| | start_idx = (tensor == start_id).nonzero(as_tuple=True)[0][0].item() |
| | end_idx = (tensor == end_id).nonzero(as_tuple=True)[0][0].item() |
| | return tensor[start_idx:end_idx] |
| | except IndexError: |
| | print("start_id or end_id not in tensor") |
| | return None |
| |
|
| | for bz_idx in range(bz): |
| | latent_len = len(latent_lists[bz_idx]) |
| | start_idx = thought_idx * self.c_thought |
| | end_idx = min(start_idx + self.c_thought, latent_len) |
| | continuous_embeds = inputs_embeds[bz_idx, latent_lists[bz_idx][start_idx:end_idx], :] |
| |
|
| | other_embeds = self.embedding( |
| | torch.tensor(input_united_tokens[bz_idx][thought_idx][self.c_thought:]).to( |
| | self.expainable_llm.device)) |
| | input_explain_input_embeds.append(torch.cat([continuous_embeds, other_embeds], dim=0)) |
| | attention_eos_index = input_united_tokens[bz_idx][thought_idx].index(self.eos_token_id) |
| | attention_explain_mask = torch.zeros(len(input_united_tokens[bz_idx][thought_idx]), |
| | dtype=int) |
| | attention_explain_mask[:attention_eos_index + 1] = 1 |
| | input_explain_attention_mask.append(attention_explain_mask) |
| | input_explain_position_ids.append( |
| | torch.arange(1, len(input_united_tokens[bz_idx][thought_idx]) + 1, dtype=int)) |
| | explain_labels = torch.tensor(input_united_tokens[bz_idx][thought_idx], dtype=int) |
| | explain_labels_mask = (explain_labels != -570) & (explain_labels != self.eos_token_id) |
| | explain_labels_mask[attention_eos_index] = True |
| | explain_labels[~explain_labels_mask] = -100 |
| | input_explain_labels.append(explain_labels) |
| |
|
| | input_explain_input_embeds = torch.stack(input_explain_input_embeds) |
| | input_explain_attention_mask = torch.stack(input_explain_attention_mask) |
| | input_explain_position_ids = torch.stack(input_explain_position_ids) |
| | input_explain_labels = torch.stack(input_explain_labels) |
| | |
| |
|
| | explainable_outputs = self.expainable_llm( |
| | inputs_embeds=input_explain_input_embeds.to(self.expainable_llm.device), |
| | attention_mask=input_explain_attention_mask.to(self.expainable_llm.device), |
| | position_ids=input_explain_position_ids.to(self.expainable_llm.device), |
| | output_hidden_states=True, |
| | ) |
| | if hasattr(self.config, "use_prj") and self.config.use_prj: |
| | explainable_logits = self.base_causallm.lm_head( |
| | self.projector2(explainable_outputs.hidden_states[-1])) |
| | else: |
| | explainable_logits = explainable_outputs.logits |
| |
|
| | if hasattr(self.config, "loss_level") and self.config.loss_level == 'token_level': |
| | effective_token_count = (input_explain_labels != -100).sum() |
| | else: |
| | effective_token_count = float((input_explain_labels != -100).sum(dim=1).bool().sum().item()) |
| |
|
| | shift_explain_logits = explainable_logits[..., :-1, :].contiguous() |
| | shift_explain_labels = input_explain_labels[..., 1:].contiguous() |
| | loss_explain_fct = CrossEntropyLoss(reduction='sum') |
| | loss_explain = loss_explain_fct( |
| | shift_explain_logits.view(-1, shift_explain_logits.size(-1)).to(self.expainable_llm.device), |
| | shift_explain_labels.view(-1).to(self.expainable_llm.device) |
| | ) |
| | loss_explain /= effective_token_count |
| | loss_explain_all += loss_explain |
| |
|
| | if 'explainable_ids_list' in kwargs: |
| | if loss is None: |
| | loss = 0.0 |
| | |
| | loss += 1.0 * loss_explain_all / c_thought_num |
| |
|
| | return Outputs(loss=loss, loss_explain_all=loss_explain_all / c_thought_num, inputs_embeds=inputs_embeds, |
| | logits=logits) |
| |
|
| | def train(self, mode: bool = True): |
| | super().train(mode) |
| | self.base_causallm.train(mode) |
| | return self |
| |
|
| | def eval(self): |
| | return self.train(False) |
| |
|
| | def generate( |
| | self, |
| | input_ids, |
| | attention_mask, |
| | max_new_tokens=16, |
| | output_embedding=False, |
| | synced_gpus=False, |
| | **kwargs |
| | ): |
| |
|
| | self.gen_forward_cnt = 0 |
| |
|
| | assert input_ids.shape[0] == 1, "only support batch_size == 1 now" |
| |
|
| | tokens = input_ids[0].detach().tolist() |
| |
|
| | labels = input_ids.clone() |
| | outputs = self.forward( |
| | input_ids, |
| | torch.ones_like(input_ids, device=input_ids.device), |
| | labels, |
| | torch.arange( |
| | 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device |
| | ).reshape(1, -1), |
| | ) |
| | inputs_embeds = outputs.inputs_embeds |
| |
|
| | |
| | next_token = torch.argmax(outputs.logits[0, -1]).item() |
| | tokens.append(next_token) |
| | new_token_embed = self.embedding( |
| | torch.tensor(next_token, device=input_ids.device) |
| | ).view(1, 1, -1) |
| | new_inputs_embeds = torch.cat((inputs_embeds, new_token_embed), dim=1) |
| | |
| | |
| | |
| | |
| | |
| | for _ in range(max_new_tokens - 1): |
| | outputs = self.base_causallm(inputs_embeds=new_inputs_embeds) |
| | self.gen_forward_cnt += 1 |
| | next_token = torch.argmax(outputs.logits[0, -1]).item() |
| | if next_token == self.eos_token_id: |
| | break |
| | tokens.append(next_token) |
| | new_token_embed = self.embedding( |
| | torch.tensor(next_token, device=input_ids.device) |
| | ).view(1, 1, -1) |
| | new_inputs_embeds = torch.cat((new_inputs_embeds, new_token_embed), dim=1) |
| |
|
| | if synced_gpus: |
| | |
| | while ( |
| | self.gen_forward_cnt < max_new_tokens + MAX_N_LATENT |
| | ): |
| | self.gen_forward_cnt += 1 |
| | _ = self.base_causallm(inputs_embeds=new_inputs_embeds) |
| |
|
| | if output_embedding: |
| | |
| | return torch.tensor(tokens).view(1, -1), new_inputs_embeds |
| |
|
| | else: |
| | return torch.tensor(tokens).view(1, -1), 0 |