zz1358m's picture
Upload folder using huggingface_hub
7a92ec5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import math
from torch.distributions import Normal
from collections import defaultdict
import torch
from torch.nn.utils import clip_grad_norm_
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from collections import namedtuple
from transformers.models.gpt2 import GPT2LMHeadModel
from modules import grpo
from modules.projector import LatentPolicy
from modules.utils import get_position_ids_from_attention_mask
import copy
Outputs = namedtuple("Outputs", ["loss", "loss_explain_all", "inputs_embeds", "logits"])
Outputs_withkl = namedtuple("Outputs", ["loss", "loss_explain_all", "loss_kl", "inputs_embeds", "logits"])
Outputs_withmask = namedtuple("Outputs_withmask",
["loss", "attention_mask", "loss_explain_all", "inputs_embeds", "logits"])
MAX_N_LATENT = 8
class CoconutGPT_Same_Word_Embedding(nn.Module):
def __init__(
self,
base_causallm,
expainable_llm,
# for debug
tokenizer,
latent_token_id,
start_latent_id,
end_latent_id,
eos_token_id,
step_start_id,
c_thought,
configs,
):
super(CoconutGPT_Same_Word_Embedding, self).__init__()
self.gen_forward_cnt = 0
self.base_causallm = base_causallm
self.base_causallm.config.use_cache = True
self.expainable_llm = expainable_llm
# self.expainable_llm.config.use_cache = True
self.tokenizer = tokenizer
self.latent_token_id = latent_token_id
self.eos_token_id = eos_token_id
self.start_latent_id = start_latent_id
self.end_latent_id = end_latent_id
self.step_start_id = step_start_id
self.c_thought = c_thought
self.config = configs
if hasattr(self.config, "training_method"):
if self.config.training_method == 'only_expainable_llm':
for param in self.base_causallm.parameters():
param.requires_grad = False
elif self.config.training_method == 'only_base_causallm':
for param in self.expainable_llm.parameters():
param.requires_grad = False
elif self.config.training_method == 'full':
pass
elif self.config.training_method == 'freeze_backbone':
for param in self.base_causallm.parameters():
param.requires_grad = False
for param in self.expainable_llm.parameters():
param.requires_grad = False
else:
raise ValueError(f"not this training_method {self.config.training_method=}")
if isinstance(self.base_causallm, GPT2LMHeadModel):
self.embedding = self.base_causallm.transformer.get_input_embeddings()
print("is GPT")
else:
self.embedding = self.base_causallm.get_input_embeddings()
print("is not GPT")
def forward(self, input_ids, attention_mask, labels, position_ids, **kwargs):
logits = []
loss = 0.0
loss_explain_all = torch.tensor(0.0, device=input_ids.device)
c_thought_num = 1
latent_indices = (
input_ids == self.latent_token_id
).nonzero() # (num_latent_tokens_in_the_batch, 2)
latent_lists = [
[idx[1].item() for idx in latent_indices if idx[0] == i]
for i in range(input_ids.shape[0])
] # bs, num_latent_tokens_in_the_instance (difference across the batch)
max_n_latents = max([len(l) for l in latent_lists])
next_compute_range = (0, input_ids.shape[1])
inputs_embeds = self.embedding(input_ids)
if max_n_latents > 0:
next_compute_range = (0, latent_indices[:, 1].min().item())
# before the earliest latent token position
kv_cache = None
for pass_idx in range(max_n_latents):
if kv_cache == None:
# first forward pass
outputs = self.base_causallm(
inputs_embeds=inputs_embeds[
:, next_compute_range[0]: next_compute_range[1], :
],
attention_mask=attention_mask[
:, next_compute_range[0]: next_compute_range[1]
],
position_ids=position_ids[
:, next_compute_range[0]: next_compute_range[1]
],
output_hidden_states=True,
)
hidden_states_offset = 0
else:
# extract kv cache to reuse
past_key_values = [
(
k[:, :, : next_compute_range[0], :],
v[:, :, : next_compute_range[0], :],
)
for k, v in kv_cache
]
outputs = self.base_causallm(
inputs_embeds=inputs_embeds[
:, next_compute_range[0]: next_compute_range[1], :
],
attention_mask=attention_mask[:, : next_compute_range[1]],
position_ids=position_ids[
:, next_compute_range[0]: next_compute_range[1]
],
past_key_values=past_key_values,
output_hidden_states=True,
)
hidden_states_offset = next_compute_range[0]
# when we use kv_cache for the first k tokens
# in `outputs.hidden_states`, [0, k) will be skipped
# so we need to keep this offset to correctly use the last hidden states
logits.append(outputs.logits)
next_compute_range = (
next_compute_range[1],
(
input_ids.shape[1]
if pass_idx + 1 >= max_n_latents
else next_compute_range[1] + 1
),
)
hidden_states = outputs.hidden_states[
-1
] # Get the last layer hidden states
kv_cache = outputs.past_key_values
# feedback the continuous thoughts to the input_embeds
# first decide the positions to feedback
filling_indices = [
(instance_idx, mask_list[pass_idx])
for instance_idx, mask_list in enumerate(latent_lists)
if len(mask_list) > pass_idx
]
# to avoid in-place operations
# break down inputs_embeds (bs, len, hidden_size) into a list of list of 1-d tensors
tensor_list = [
[
inputs_embeds[batch_idx, pos, :]
for pos in range(inputs_embeds.shape[1])
]
for batch_idx in range(inputs_embeds.shape[0])
]
# replace some of them with continuous thoughts
for idx_pair in filling_indices:
batch_idx, token_idx = idx_pair
# replace it with the preceding last hidden states
tensor_list[batch_idx][token_idx] = hidden_states[
batch_idx, token_idx - 1 - hidden_states_offset, :
]
# assemble the new inputs_embeds
inputs_embeds = torch.stack(
[
torch.stack(tensor_list[batch_idx])
for batch_idx in range(inputs_embeds.shape[0])
]
)
# final pass
outputs = self.base_causallm(
inputs_embeds=inputs_embeds[
:, next_compute_range[0]: next_compute_range[1], :
],
attention_mask=attention_mask[:, : next_compute_range[1]],
position_ids=position_ids[:, next_compute_range[0]: next_compute_range[1]],
past_key_values=(
[
(
k[:, :, : next_compute_range[0], :],
v[:, :, : next_compute_range[0], :],
)
for k, v in kv_cache
]
if kv_cache
else None
),
output_hidden_states=True,
)
logits.append(outputs.logits)
self.gen_forward_cnt += max_n_latents + 1
logits = torch.cat(logits, dim=-2)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
if self.config.training_method == 'only_base_causallm' or self.config.training_method == 'full':
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if hasattr(self.config, 'visualize') and self.config.visualize:
debug_predictions = []
for debug_idx in range(0, len(latent_lists[0]), self.config.c_thought):
continuous_embeds = inputs_embeds[:, latent_lists[0][debug_idx: debug_idx + self.c_thought], :].to(
self.expainable_llm.device)
if hasattr(self.config, 'w_prompt') and self.config.w_prompt:
if hasattr(self.config, 'explain_mode') and self.config.explain_mode == 'v1_aug':
thought_idx = debug_idx // 2
if thought_idx != 2:
input_explain_input_embeds_pre_order_prompt_ids = self.tokenizer(
f'Step {thought_idx + 1} of the solution', add_special_tokens=False).input_ids
else:
input_explain_input_embeds_pre_order_prompt_ids = self.tokenizer(
f'Step 3 and all the remaining steps of the solution',
add_special_tokens=False).input_ids
bz = continuous_embeds.shape[0]
input_explain_input_embeds_pre_order_prompt_embeds = self.embedding(
torch.tensor(input_explain_input_embeds_pre_order_prompt_ids).to(
self.expainable_llm.device))[None, ...].repeat(bz, 1, 1)
continuous_embeds = torch.cat(
[input_explain_input_embeds_pre_order_prompt_embeds, continuous_embeds], dim=1)
debug_ids = torch.empty((1, 0), dtype=torch.long, device=self.expainable_llm.device)
while True:
if debug_ids.shape[0] != 0:
debug_embeds = torch.cat([continuous_embeds, self.embedding(debug_ids)], dim=1)
else:
debug_embeds = continuous_embeds
explainable_outputs = self.expainable_llm(
inputs_embeds=debug_embeds,
attention_mask=torch.ones(debug_embeds.shape[:2]).to(self.expainable_llm.device),
position_ids=torch.arange(1, debug_embeds.shape[1] + 1).unsqueeze(dim=0).to(
self.expainable_llm.device),
output_hidden_states=True,
)
debug_logits = explainable_outputs.logits[:, -1, :] / .98
probs = torch.softmax(debug_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
debug_ids = torch.cat([debug_ids, next_token], dim=1)
if torch.all(next_token == self.eos_token_id) or debug_ids.shape[-1] > 512: # 为 <eos>或者是'>>'结束
break
print(self.tokenizer.decode(debug_ids[0]))
debug_predictions.append(self.tokenizer.decode(debug_ids[0]))
if hasattr(self.config, 'visualize_jsonl') and self.config.visualize_jsonl != '':
save_jsonl_line(self.config.visualize_jsonl, {"predictiion": debug_predictions})
if hasattr(self.config, 'explain_mode') and self.config.explain_mode == 'v1_aug':
# print('explainable_ids_list' in kwargs)
if 'explainable_ids_list' in kwargs:
c_thought_num = len(latent_lists[0]) // self.c_thought
input_united_tokens = []
def safe_token_id(x):
return x[0] if isinstance(x, list) else x
start_token = safe_token_id(self.tokenizer.encode('<<', add_special_tokens=False)) # "<<"
end_token = safe_token_id(self.tokenizer.encode('>>', add_special_tokens=False)) # ">>"
separator_token = safe_token_id(self.tokenizer.encode('\n', add_special_tokens=False)) # "\n"
def trim_trailing_zeros(group):
while group and group[-1] == 0:
group.pop()
return group
def replace_llama_special_tokens(x, merged_token, end_token, separator_token):
out = []
for seq in x:
new_seq = []
for t in seq:
if t.item() == merged_token:
new_seq.extend([end_token, separator_token])
elif t.item() != 0 or len(new_seq) > 0:
new_seq.append(t.item())
out.append(torch.tensor(new_seq, device=x.device))
return out
if len(self.tokenizer.encode('>>\n', add_special_tokens=False)) == 1:
merge_token = self.tokenizer.encode('>>\n', add_special_tokens=False)[0]
kwargs['explainable_ids_list'] = copy.deepcopy(
replace_llama_special_tokens(kwargs['explainable_ids_list'], merge_token, end_token,
separator_token))
for j, seq in enumerate(kwargs['explainable_ids_list']):
i = 0
groups = []
while i < len(seq):
if seq[i] == start_token:
group = [start_token]
i += 1
while i < len(seq):
group.append(seq[i])
if seq[i] == end_token:
break
i += 1
group = trim_trailing_zeros(group)
groups.append(group)
else:
i += 1
print(len(groups))
if len(groups) < self.config.max_latent_stage:
input_ids_j = input_ids[j].tolist()
try:
start_idx = len(input_ids_j) - 1 - input_ids_j[::-1].index(self.end_latent_id)
except ValueError:
continue
try:
end_idx = input_ids_j.index(self.eos_token_id, start_idx + 1)
except ValueError:
end_idx = len(input_ids_j)
pseudo_thought = input_ids_j[start_idx + 1:end_idx]
if not pseudo_thought:
continue
if hasattr(self.config, 'format_pseudo_thought') and self.config.format_pseudo_thought:
tmp_num = self.tokenizer.decode(pseudo_thought).replace('### ', '')
pseudo_thought = self.tokenizer.encode(f'<<{tmp_num}={tmp_num}>>', add_special_tokens=False)
while len(groups) < c_thought_num:
groups.append(pseudo_thought)
input_united_groups = []
combined_group = []
group_count = 0
for group in groups:
group_count += 1
if group_count <= self.config.max_latent_stage - 1:
group = [-570] * self.c_thought + group + [self.eos_token_id]
cleaned_group = [int(x) if torch.is_tensor(x) else x for x in group]
input_united_groups.append(cleaned_group)
else:
if combined_group and combined_group[-1] == end_token and group[0] == start_token:
combined_group.append(separator_token)
combined_group.extend(group)
if combined_group:
final_group = [-570] * self.c_thought + combined_group + [self.eos_token_id]
cleaned_group = [int(x) if torch.is_tensor(x) else x for x in final_group]
input_united_groups.append(cleaned_group)
input_united_tokens.append(copy.deepcopy(input_united_groups))
## max pad len
bz = len(input_united_tokens)
if hasattr(self.config, 'packing') and self.config.packing == True:
pass
else:
for thought_idx in range(c_thought_num):
max_pad_len = max(len(input_united_tokens[bz_idx][thought_idx]) for bz_idx in range(bz))
max_pad_len += 1
for bz_idx in range(bz):
token_seq = input_united_tokens[bz_idx][thought_idx]
pad_len = max_pad_len - len(token_seq)
if pad_len > 0:
token_seq += [self.eos_token_id] * pad_len
input_united_tokens[bz_idx][thought_idx] = token_seq
print("there")
if hasattr(self.config, 'packing') and self.config.packing == True:
print("there1")
max_pad_len = 0
for bz_idx in range(bz):
for thought_idx in range(c_thought_num):
continuous_embeds = inputs_embeds[bz_idx,
latent_lists[bz_idx][self.c_thought * thought_idx]:latent_lists[bz_idx][
self.c_thought * thought_idx + 1] + 1,
:]
other_embeds = self.embedding(
torch.tensor(input_united_tokens[bz_idx][thought_idx][self.c_thought:]).to(
self.expainable_llm.device))
max_pad_len = max(max_pad_len, continuous_embeds.size(0) + other_embeds.size(
0)) # Compute the max length once
input_explain_input_embeds_batch = [[] for _ in range(c_thought_num)]
input_explain_attention_mask_batch = [[] for _ in range(c_thought_num)]
input_explain_position_ids_batch = [[] for _ in range(c_thought_num)]
input_explain_labels_batch = [[] for _ in range(c_thought_num)]
# Process CODI
for thought_idx in range(c_thought_num):
for bz_idx in range(bz):
continuous_embeds = inputs_embeds[bz_idx,
latent_lists[bz_idx][self.c_thought * thought_idx]:latent_lists[bz_idx][
self.c_thought * thought_idx + 1] + 1,
:]
other_embeds = self.embedding(
torch.tensor(input_united_tokens[bz_idx][thought_idx][self.c_thought:]).to(
self.expainable_llm.device))
input_explain_input_embeds_batch[thought_idx].append(
torch.cat([continuous_embeds, other_embeds], dim=0))
attention_eos_index = input_united_tokens[bz_idx][thought_idx].index(self.eos_token_id)
attention_explain_mask = torch.zeros(len(input_united_tokens[bz_idx][thought_idx]),
dtype=int)
attention_explain_mask[:attention_eos_index + 1] = 1
input_explain_attention_mask_batch[thought_idx].append(attention_explain_mask)
input_explain_position_ids_batch[thought_idx].append(
torch.arange(1, len(input_united_tokens[bz_idx][thought_idx]) + 1, dtype=int))
explain_labels = torch.tensor(input_united_tokens[bz_idx][thought_idx], dtype=int)
explain_labels_mask = (explain_labels != -570) & (explain_labels != self.eos_token_id)
explain_labels_mask[attention_eos_index] = True
explain_labels[~explain_labels_mask] = -100
input_explain_labels_batch[thought_idx].append(explain_labels)
# Pre-allocate padded tensors for the batch
input_explain_input_embeds_batch_tensor = torch.zeros(bz, c_thought_num, max_pad_len,
continuous_embeds.size(-1),
device=self.expainable_llm.device)
input_explain_attention_mask_batch_tensor = torch.zeros(bz, c_thought_num, max_pad_len,
device=self.expainable_llm.device)
input_explain_position_ids_batch_tensor = torch.zeros(bz, c_thought_num, max_pad_len,
device=self.expainable_llm.device)
input_explain_labels_batch_tensor = torch.full((bz, c_thought_num, max_pad_len), -100,
device=self.expainable_llm.device)
# Concatenate and pad the tensors
for bz_idx in range(bz):
for thought_idx in range(c_thought_num):
input_explain_input_embeds_batch_tensor[bz_idx, thought_idx,
:input_explain_input_embeds_batch[thought_idx][bz_idx].size(0)] = \
input_explain_input_embeds_batch[thought_idx][bz_idx]
input_explain_attention_mask_batch_tensor[bz_idx, thought_idx,
:input_explain_attention_mask_batch[thought_idx][bz_idx].size(0)] = \
input_explain_attention_mask_batch[thought_idx][bz_idx]
input_explain_position_ids_batch_tensor[bz_idx, thought_idx,
:input_explain_position_ids_batch[thought_idx][bz_idx].size(0)] = \
input_explain_position_ids_batch[thought_idx][bz_idx]
input_explain_labels_batch_tensor[bz_idx, thought_idx,
:input_explain_labels_batch[thought_idx][bz_idx].size(0)] = \
input_explain_labels_batch[thought_idx][bz_idx]
# Stack the padded tensors
input_explain_input_embeds_batch_tensor = input_explain_input_embeds_batch_tensor.view(bz, -1,
input_explain_input_embeds_batch_tensor.size(
-1))
input_explain_attention_mask_batch_tensor = input_explain_attention_mask_batch_tensor.view(bz, -1)
input_explain_position_ids_batch_tensor = input_explain_position_ids_batch_tensor.view(bz, -1)
input_explain_labels_batch_tensor = input_explain_labels_batch_tensor.view(bz, -1)
# Apply 4D attention mask preparation (if necessary)
input_explain_attention_mask_batch_tensor = prepare_4d_attention_mask(
input_explain_attention_mask_batch_tensor, dtype=self.expainable_llm.dtype)
# Forward pass
explainable_outputs = self.expainable_llm(
inputs_embeds=input_explain_input_embeds_batch_tensor,
attention_mask=input_explain_attention_mask_batch_tensor,
position_ids=input_explain_position_ids_batch_tensor.to(torch.long),
output_hidden_states=True,
)
explainable_logits = explainable_outputs.logits
effective_loss_num = float(
(input_explain_labels_batch_tensor != -100).sum(dim=1).bool().sum().item())
shift_explain_logits = explainable_logits[..., :-1, :].contiguous()
shift_explain_labels = input_explain_labels_batch_tensor[..., 1:].to(torch.long).contiguous()
loss_explain_fct = CrossEntropyLoss(reduction='sum')
loss_explain = loss_explain_fct(
shift_explain_logits.view(-1, shift_explain_logits.size(-1)), shift_explain_labels.view(-1)
)
loss_explain /= effective_loss_num
loss_explain_all += loss_explain
else:
print("there2")
for thought_idx in range(c_thought_num):
input_explain_input_embeds = []
input_explain_attention_mask, input_explain_position_ids, input_explain_labels = [], [], []
max_pad_len = -1
def extract_token_range(tensor, start_id=128000, end_id=128256):
try:
start_idx = (tensor == start_id).nonzero(as_tuple=True)[0][0].item()
end_idx = (tensor == end_id).nonzero(as_tuple=True)[0][0].item()
return tensor[start_idx:end_idx]
except IndexError:
print("start_id or end_id not in tensor")
return None
for bz_idx in range(bz):
latent_len = len(latent_lists[bz_idx])
start_idx = thought_idx * self.c_thought
end_idx = min(start_idx + self.c_thought, latent_len)
continuous_embeds = inputs_embeds[bz_idx, latent_lists[bz_idx][start_idx:end_idx], :]
other_embeds = self.embedding(
torch.tensor(input_united_tokens[bz_idx][thought_idx][self.c_thought:]).to(
self.expainable_llm.device))
input_explain_input_embeds.append(torch.cat([continuous_embeds, other_embeds], dim=0))
attention_eos_index = input_united_tokens[bz_idx][thought_idx].index(self.eos_token_id)
attention_explain_mask = torch.zeros(len(input_united_tokens[bz_idx][thought_idx]),
dtype=int)
attention_explain_mask[:attention_eos_index + 1] = 1
input_explain_attention_mask.append(attention_explain_mask)
input_explain_position_ids.append(
torch.arange(1, len(input_united_tokens[bz_idx][thought_idx]) + 1, dtype=int))
explain_labels = torch.tensor(input_united_tokens[bz_idx][thought_idx], dtype=int)
explain_labels_mask = (explain_labels != -570) & (explain_labels != self.eos_token_id)
explain_labels_mask[attention_eos_index] = True
explain_labels[~explain_labels_mask] = -100
input_explain_labels.append(explain_labels)
input_explain_input_embeds = torch.stack(input_explain_input_embeds)
input_explain_attention_mask = torch.stack(input_explain_attention_mask)
input_explain_position_ids = torch.stack(input_explain_position_ids)
input_explain_labels = torch.stack(input_explain_labels)
# print(input_explain_input_embeds.size())
explainable_outputs = self.expainable_llm(
inputs_embeds=input_explain_input_embeds.to(self.expainable_llm.device),
attention_mask=input_explain_attention_mask.to(self.expainable_llm.device),
position_ids=input_explain_position_ids.to(self.expainable_llm.device),
output_hidden_states=True,
)
if hasattr(self.config, "use_prj") and self.config.use_prj:
explainable_logits = self.base_causallm.lm_head(
self.projector2(explainable_outputs.hidden_states[-1]))
else:
explainable_logits = explainable_outputs.logits
if hasattr(self.config, "loss_level") and self.config.loss_level == 'token_level':
effective_token_count = (input_explain_labels != -100).sum()
else:
effective_token_count = float((input_explain_labels != -100).sum(dim=1).bool().sum().item())
shift_explain_logits = explainable_logits[..., :-1, :].contiguous()
shift_explain_labels = input_explain_labels[..., 1:].contiguous()
loss_explain_fct = CrossEntropyLoss(reduction='sum')
loss_explain = loss_explain_fct(
shift_explain_logits.view(-1, shift_explain_logits.size(-1)).to(self.expainable_llm.device),
shift_explain_labels.view(-1).to(self.expainable_llm.device)
)
loss_explain /= effective_token_count
loss_explain_all += loss_explain
if 'explainable_ids_list' in kwargs:
if loss is None:
loss = 0.0
# print(loss_explain_all)
loss += 1.0 * loss_explain_all / c_thought_num
return Outputs(loss=loss, loss_explain_all=loss_explain_all / c_thought_num, inputs_embeds=inputs_embeds,
logits=logits)
def train(self, mode: bool = True):
super().train(mode)
self.base_causallm.train(mode)
return self
def eval(self):
return self.train(False)
def generate(
self,
input_ids,
attention_mask, # attention_mask is not used
max_new_tokens=16,
output_embedding=False,
synced_gpus=False,
**kwargs
):
self.gen_forward_cnt = 0
assert input_ids.shape[0] == 1, "only support batch_size == 1 now"
tokens = input_ids[0].detach().tolist()
labels = input_ids.clone() # placeholder. not used.
outputs = self.forward(
input_ids,
torch.ones_like(input_ids, device=input_ids.device),
labels,
torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
).reshape(1, -1),
)
inputs_embeds = outputs.inputs_embeds
# get the first token using the current hidden state
next_token = torch.argmax(outputs.logits[0, -1]).item()
tokens.append(next_token)
new_token_embed = self.embedding(
torch.tensor(next_token, device=input_ids.device)
).view(1, 1, -1)
new_inputs_embeds = torch.cat((inputs_embeds, new_token_embed), dim=1)
# mean = new_inputs_embeds.mean(dim=-1)
# var = new_inputs_embeds.var(dim=-1, unbiased=False)
# print("mean:", mean.tolist())
# print("var:", var.tolist())
# get other tokens
for _ in range(max_new_tokens - 1):
outputs = self.base_causallm(inputs_embeds=new_inputs_embeds)
self.gen_forward_cnt += 1
next_token = torch.argmax(outputs.logits[0, -1]).item()
if next_token == self.eos_token_id:
break
tokens.append(next_token)
new_token_embed = self.embedding(
torch.tensor(next_token, device=input_ids.device)
).view(1, 1, -1)
new_inputs_embeds = torch.cat((new_inputs_embeds, new_token_embed), dim=1)
if synced_gpus:
# in FSDP, the number of forward pass need to be the same across devices
while (
self.gen_forward_cnt < max_new_tokens + MAX_N_LATENT
): # leave some room for latent tokens
self.gen_forward_cnt += 1
_ = self.base_causallm(inputs_embeds=new_inputs_embeds)
if output_embedding:
# for analysis purpose
return torch.tensor(tokens).view(1, -1), new_inputs_embeds
else:
return torch.tensor(tokens).view(1, -1), 0