ATP-Latent-Master / methods /sim_cot_rl_vae_end.py
zz1358m's picture
Upload folder using huggingface_hub
7a92ec5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import math
import re
from torch.distributions import Normal
from collections import defaultdict
import torch
from torch.nn.utils import clip_grad_norm_
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from collections import namedtuple
from transformers.models.gpt2 import GPT2LMHeadModel
from modules import grpo
from modules.projector import LatentPolicy
from modules.utils import get_position_ids_from_attention_mask
import copy
Outputs = namedtuple("Outputs", ["loss", "loss_explain_all", "inputs_embeds", "logits"])
Outputs_withkl = namedtuple("Outputs", ["loss", "loss_explain_all", "loss_kl", "inputs_embeds", "logits"])
Outputs_withmask = namedtuple("Outputs_withmask",
["loss", "attention_mask", "loss_explain_all", "inputs_embeds", "logits"])
MAX_N_LATENT = 8
from dataclasses import dataclass
from typing import Optional
from modules.projector import STOPPolicy
@dataclass
class LatentForwardResult:
inputs_embeds: torch.Tensor # (B, L_prompt, H)
attention_mask: torch.Tensor # (B, L_prompt)
latent_inputs_embeds: torch.Tensor # (B, L_latent, H)
@dataclass
class GenerationOutputs:
question_input_ids: torch.Tensor
question_attention_mask: torch.Tensor
latent_inputs_embeds: torch.Tensor
latent_attention_mask: torch.Tensor
pred_ids: torch.Tensor
class Coconut_RL_End_VAE(nn.Module):
def __init__(
self,
base_causallm,
expainable_llm,
latent_token_id,
start_latent_id,
end_latent_id,
eos_token_id,
tokenizer=None,
rl_config=None,
is_training=True,
):
super(Coconut_RL_End_VAE, self).__init__()
self.gen_forward_cnt = 0
self.base_causallm = base_causallm
self.latent_token_id = latent_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = eos_token_id
self.start_latent_id = start_latent_id
self.end_latent_id = end_latent_id
self.expainable_llm = expainable_llm
self.base_causallm.config.use_cache = True
self.expainable_llm.config.use_cache = True
self.end_head = STOPPolicy(
feature_size=self.base_causallm.config.hidden_size,
intermediate_size=self.base_causallm.config.hidden_size
)
self.latent_head = LatentPolicy(
feature_size=self.base_causallm.config.hidden_size,
intermediate_size=self.base_causallm.config.hidden_size,
deterministic=False
)
self.tokenizer = tokenizer
self.rl_config = rl_config
self.latent_sigma = rl_config.latent_sigma
self.rl_step_counter = 0
self.is_training = is_training
if self.expainable_llm is not None:
for param in self.expainable_llm.parameters():
param.requires_grad = False
if isinstance(self.base_causallm, GPT2LMHeadModel):
self.embedding = self.base_causallm.transformer.get_input_embeddings()
else:
self.embedding = self.base_causallm.get_input_embeddings()
self.replay_buffer = grpo.ReplayBuffer()
self.grpo_loss = grpo.GRPOLoss(rl_config=self.rl_config)
@property
def device(self):
return next(self.parameters()).device
def latent_forward_with_cache(self, input_ids, attention_mask, position_ids=None):
"""
只负责:对 batch 中所有 latent token 做并行采样,
复用 KV cache,并返回 latent embedding/logprob 等。
"""
bs, seq_len = input_ids.size()
device = input_ids.device
latent_mask = (input_ids == self.latent_token_id)
latent_positions = latent_mask.nonzero(as_tuple=False)
latent_lists = [
[idx[1].item() for idx in latent_positions if idx[0] == b]
for b in range(bs)
]
bol_position = latent_lists[0][0] - 1
eol_position = latent_lists[0][-1] + 1
max_n_latents = max((len(lst) for lst in latent_lists), default=0)
next_compute_range = (0, seq_len)
inputs_embeds = self.embedding(input_ids)
if max_n_latents > 0:
next_compute_range = (0, latent_positions[:, 1].min().item())
# stop_logprob = []
kv_cache = None
latent_samples_all = [[] for _ in range(bs)]
for pass_idx in range(max_n_latents if max_n_latents > 0 else 1):
slice_start, slice_end = next_compute_range
if kv_cache is None:
outputs = self.base_causallm(
inputs_embeds=inputs_embeds[:, slice_start:slice_end, :],
attention_mask=attention_mask[:, slice_start:slice_end],
position_ids=(
position_ids[:, slice_start:slice_end]
if position_ids is not None else None
),
output_hidden_states=True,
)
hidden_offset = 0
else:
past_key_values = [
(k[:, :, :slice_start, :], v[:, :, :slice_start, :])
for k, v in kv_cache
]
outputs = self.base_causallm(
inputs_embeds=inputs_embeds[:, slice_start:slice_end, :],
attention_mask=attention_mask[:, :slice_end],
position_ids=(
position_ids[:, slice_start:slice_end]
if position_ids is not None else None
),
past_key_values=past_key_values,
output_hidden_states=True,
)
hidden_offset = slice_start
kv_cache = outputs.past_key_values
hidden_states = outputs.hidden_states[-1]
next_compute_range = (
slice_end,
seq_len if pass_idx + 1 >= max_n_latents else slice_end + 1,
)
tensor_list = [
[inputs_embeds[b, pos, :] for pos in range(seq_len)]
for b in range(bs)
]
for b, latent_positions in enumerate(latent_lists):
if pass_idx >= len(latent_positions):
continue
token_idx = latent_positions[pass_idx]
vec = hidden_states[b, token_idx - 1 - hidden_offset, :]
mu, logvar = self.latent_head.forward(vec) # [dim*2]
# print(mu, logvar)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + std * eps
# print(self.is_training)
# dist = torch.distributions.Normal(mu, self.latent_sigma)
# logprob = dist.log_prob(sample).sum()
tensor_list[b][token_idx] = z
latent_samples_all[b].append(z)
inputs_embeds = torch.stack([torch.stack(tensor_list[b]) for b in range(bs)], dim=0)
pred_token_id = self.end_head(inputs_embeds[:, next_compute_range[0], :])
# print(torch.softmax(pred_token_id, dim=-1))
if self.is_training == True:
is_end = torch.multinomial(torch.softmax(pred_token_id, dim=-1), num_samples=1).squeeze(
-1).bool() # [batch_size]
else:
is_end = (pred_token_id[:, 1] > pred_token_id[:, 0])
# stop_logprob.append(torch.log_softmax(pred_token_id, dim=-1))
if pass_idx % self.rl_config.c_thought == 0:
attention_mask[is_end, bol_position + 1 + pass_idx:eol_position] = 0
if pass_idx >= max_n_latents: # or (attention_mask[:, bol_position + 1 + pass_idx] == 0).all():
break
self.gen_forward_cnt += max_n_latents + 1
latent_dim = inputs_embeds.size(-1)
latent_lengths = torch.tensor([len(lst) for lst in latent_lists],
device=device, dtype=torch.long)
max_latent_len = latent_lengths.max().item() if latent_lengths.numel() > 0 else 0
if max_latent_len > 0:
latent_inputs_embeds = torch.zeros(bs, max_latent_len, latent_dim, device=device)
for b in range(bs):
if latent_samples_all[b]:
latent_inputs_embeds[b, :len(latent_samples_all[b])] = torch.stack(latent_samples_all[b])
else:
latent_inputs_embeds = torch.zeros(bs, 0, latent_dim, device=device)
return LatentForwardResult(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
latent_inputs_embeds=latent_inputs_embeds,
)
@torch.no_grad()
def generate_with_latent(
self,
questions,
questions_mask,
max_new_tokens=16,
temperature=1.0,
synced_gpus=False,
):
self.eval()
device = questions.device
batch_size = questions.size(0)
latent_mask = (questions == self.latent_token_id)
latent_positions = latent_mask.nonzero(as_tuple=False)
latent_lists = [
[idx[1].item() for idx in latent_positions if idx[0] == b]
for b in range(batch_size)
]
bol_position = latent_lists[0][0] - 1
eol_position = latent_lists[0][-1] + 1
questions = questions[:, :eol_position + 1]
questions_mask = questions_mask[:, :eol_position + 1]
position_ids = get_position_ids_from_attention_mask(questions_mask)
latent_out = self.latent_forward_with_cache(
input_ids=questions,
attention_mask=questions_mask,
position_ids=position_ids,
)
attention_mask = latent_out.attention_mask
prompt_embeds = latent_out.inputs_embeds
latent_inputs_embeds = latent_out.latent_inputs_embeds
latent_attention_mask = attention_mask[:, bol_position + 1:eol_position]
# 直接用 prompt_embeds 做一次全量前向,拿到最后 logits
position_ids = get_position_ids_from_attention_mask(attention_mask)
outputs = self.base_causallm(
inputs_embeds=prompt_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
)
logits = outputs.logits[:, -1, :]
tokens = [[questions[i, -1].item()] for i in range(batch_size)]
unfinished = torch.ones(batch_size, dtype=torch.bool, device=device)
cur_inputs_embeds = prompt_embeds
cur_attention_mask = attention_mask
for _ in range(max_new_tokens):
if temperature != 1.0:
logits = logits / temperature
if self.is_training:
probs = torch.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
else:
next_tokens = torch.argmax(logits, dim=-1)
next_tokens = torch.where(
unfinished, next_tokens, torch.full_like(next_tokens, self.eos_token_id)
)
for i, tok in enumerate(next_tokens.tolist()):
tokens[i].append(tok)
next_token_embeds = self.embedding(next_tokens).unsqueeze(1)
cur_inputs_embeds = torch.cat([cur_inputs_embeds, next_token_embeds], dim=1)
cur_attention_mask = torch.cat(
[cur_attention_mask, torch.ones(batch_size, 1, dtype=cur_attention_mask.dtype, device=device)],
dim=1,
)
position_ids = get_position_ids_from_attention_mask(cur_attention_mask)
outputs = self.base_causallm(
inputs_embeds=cur_inputs_embeds,
attention_mask=cur_attention_mask,
position_ids=position_ids,
use_cache=False,
)
self.gen_forward_cnt += 1
logits = outputs.logits[:, -1, :]
unfinished &= next_tokens.ne(self.eos_token_id)
if not unfinished.any():
break
if synced_gpus:
while self.gen_forward_cnt < max_new_tokens + MAX_N_LATENT:
self.gen_forward_cnt += 1
_ = self.base_causallm(
inputs_embeds=cur_inputs_embeds[:, cur_pos - 1:cur_pos, :],
attention_mask=cur_attention_mask[:, :cur_pos],
past_key_values=cur_kv_cache,
use_cache=True,
)
max_len = max(len(seq) for seq in tokens)
pred_ids = torch.full(
(batch_size, max_len),
fill_value=self.pad_token_id,
dtype=torch.long,
device=device,
)
for i, seq in enumerate(tokens):
pred_ids[i, : len(seq)] = torch.tensor(seq, device=device)
return GenerationOutputs(
question_input_ids=questions[:, :bol_position + 1],
question_attention_mask=questions_mask[:, :bol_position + 1],
latent_inputs_embeds=latent_inputs_embeds,
latent_attention_mask=latent_attention_mask,
pred_ids=pred_ids,
)
def rl_step(self, batch, optimizer, writer=None, global_step=None):
rl_config = self.rl_config
device = self.device
questions = batch["input_ids"]
questions_mask = batch["attention_mask"]
answers = batch["labels"]
explainable_ids_list = batch["explainable_ids_list"]
self.replay_buffer.clear()
experience = self.rollout(
questions=questions,
questions_mask=questions_mask,
group_size=rl_config.group_size,
gt_answers=answers,
explainable_ids_list=explainable_ids_list,
)
self.replay_buffer.append(experience.to("cpu"))
log_dict = {
"train/rewards": experience.rewards.mean().item(),
"train/accuracies": experience.accuracies.mean().item(),
"train/n_latent_forward": experience.n_latent_forward.float().mean().item(),
}
if getattr(experience, "explain_logprob", None) is not None:
log_dict["train/explain_logprob"] = experience.explain_logprob.mean().item()
torch.cuda.empty_cache()
dataloader = DataLoader(
dataset=self.replay_buffer,
batch_size=rl_config.exp_batch_size,
shuffle=True,
collate_fn=grpo.join_experience_batch,
)
gradient_accumulation_steps = max(1, rl_config.gradient_accumulation_steps)
total_inner_steps = len(dataloader)
loop_sums = defaultdict(float)
inner_steps = 0
accum_counter = 0
optimizer.zero_grad()
last_grad_norm = 0.0
for exp in dataloader:
inner_steps += 1
accum_counter += 1
exp: grpo.Experience = exp.to(device)
latent_logprobs, answer_logprobs = self.get_logprobs(e=exp)
loss_dict = self.grpo_loss(
latent_logprobs=latent_logprobs,
answer_logprobs=answer_logprobs,
experience=exp,
)
raw_total_loss = loss_dict["total_loss"]
(raw_total_loss / gradient_accumulation_steps).backward()
for k, v in loss_dict.items():
value = v.item() if torch.is_tensor(v) else float(v)
loop_sums[f"train/{k}"] += value
need_update = (accum_counter % gradient_accumulation_steps == 0) or (inner_steps == total_inner_steps)
if need_update:
grad_norm = clip_grad_norm_(self.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
last_grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else float(grad_norm)
if inner_steps > 0:
for key in list(loop_sums.keys()):
log_dict[key] = loop_sums[key] / inner_steps
log_dict["train/grad_norm"] = last_grad_norm
return log_dict
def extract_equal_result(self, expression):
match = re.search(r'=(.*?)>>', expression)
if match:
return match.group(1).strip()
else:
return ''
@torch.no_grad()
def compute_explain_logprob(
self,
latent_inputs_embeds: torch.Tensor, # (B, L_latent, H)
n_latent_forward: torch.Tensor, # (B,)
pred_strings,
max_gen_len=16
):
if self.expainable_llm is None:
return None
device = self.expainable_llm.device
eos_token_id = self.eos_token_id
thought = self.rl_config.c_thought
max_stage = self.rl_config.max_latent_stage
tok = self.tokenizer # 或 self.expainable_llm 的 tokenizer
B, _, hidden = latent_inputs_embeds.size()
# 1) 收集所有可用的 latent prefix
prefix_list, owner = [], []
for b in range(B):
latent_valid = int(n_latent_forward[b].item())
stage_num = min(max_stage, latent_valid // thought)
for s in range(stage_num):
start = s * thought
end = start + thought
prefix = latent_inputs_embeds[b, start:end, :].to(device)
prefix_list.append(prefix)
owner.append(b)
N = len(prefix_list)
if N == 0:
return torch.zeros(B, device=device), torch.ones(B, device=device)
seq_embeds = torch.stack(prefix_list, dim=0).clone() # (N, thought, H)
attention_mask = torch.ones(N, thought, dtype=torch.long, device=device)
position_ids = torch.arange(thought, dtype=torch.long, device=device).unsqueeze(0).expand(N, -1)
total_logprob = torch.zeros(N, device=device)
total_tokens = torch.zeros(N, device=device)
gen_token_ids = [[] for _ in range(N)]
active_idx = torch.arange(N, device=device)
# embed_layer = self.expainable_llm.get_input_embeddings()
# 2) batched greedy decoding
for _ in range(max_gen_len):
outputs = self.expainable_llm(
inputs_embeds=seq_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
)
logits = outputs.logits[:, -1, :]
log_probs = torch.log_softmax(logits, dim=-1)
next_token = torch.argmax(logits, dim=-1)
step_logprob = log_probs.gather(-1, next_token.unsqueeze(-1)).squeeze(-1)
total_logprob[active_idx] += step_logprob
total_tokens[active_idx] += 1
for aidx, tok_id in zip(active_idx.tolist(), next_token.tolist()):
gen_token_ids[aidx].append(tok_id)
eos_mask = next_token.eq(eos_token_id)
keep_mask = ~eos_mask
if keep_mask.any():
kept_idx = active_idx[keep_mask]
next_embed = self.embedding(next_token[keep_mask]).unsqueeze(1)
seq_embeds = torch.cat([seq_embeds[keep_mask], next_embed], dim=1)
attention_mask = torch.cat(
[attention_mask[keep_mask],
torch.ones(len(kept_idx), 1, dtype=torch.long, device=device)],
dim=1
)
next_pos = position_ids[keep_mask][:, -1:] + 1
position_ids = torch.cat([position_ids[keep_mask], next_pos], dim=1)
active_idx = kept_idx
else:
break
# 3) 汇总到 batch 维度
explain_logprob_b = torch.zeros(B, device=device)
explain_coherence_b = torch.ones(B, device=device)
explain_count_b = torch.zeros(B, dtype=torch.long, device=device)
owner_t = torch.tensor(owner, dtype=torch.long, device=device)
valid_mask = total_tokens > 0
avg_logprob = torch.zeros_like(total_logprob)
avg_logprob[valid_mask] = total_logprob[valid_mask]
for idx in range(N):
b = owner_t[idx].item()
explain_logprob_b[b] += avg_logprob[idx]
explain_count_b[b] += 1
explain_logprob_b = torch.where(
explain_count_b > 0,
explain_logprob_b / explain_count_b.clamp_min(1),
torch.zeros_like(explain_logprob_b)
)
# 4) 解码并打印
preview_by_batch = [[] for _ in range(B)]
result_by_batch = [[] for _ in range(B)]
for idx, tlist in enumerate(gen_token_ids):
txt = tok.decode(tlist, skip_special_tokens=True) if tlist else ""
b = owner_t[idx].item()
result_of_eq = self.extract_equal_result(txt)
preview_by_batch[b].append(txt)
result_by_batch[b].append(result_of_eq)
for b in range(B):
total_steps = len(result_by_batch[b]) # eg. ["4", "3", "3"]
if total_steps == 0:
explain_coherence_b[b] = 1.0
continue
matched = 0
for i in range(len(result_by_batch[b]) - 1):
for j in range(i + 1, len(preview_by_batch[b])):
if result_by_batch[b][i] in preview_by_batch[b][j]:
matched += 1
break
last_result = result_by_batch[b][-1]
if last_result in pred_strings[b]:
matched += 1
explain_coherence_b[b] = matched / total_steps
# for b, seqs in enumerate(preview_by_batch):
# print(f"[sample {b}]", end=' ')
# for j, s in enumerate(seqs):
# print(f"step {j}: {s}", end=' ')
# print()
#
# for b, seqs in enumerate(result_by_batch):
# print(f"[sample {b}]", end=' ')
# for j, s in enumerate(seqs):
# print(f"step {j}: {s}", end=' ')
# print()
# 只返回 logprob
return explain_logprob_b, explain_coherence_b
@torch.no_grad()
def rollout(self, questions, questions_mask, gt_answers, group_size, explainable_ids_list=None):
rl_cfg = self.rl_config
group_size = group_size
# 扩展 group
group_questions = questions.repeat_interleave(group_size, dim=0)
group_masks = questions_mask.repeat_interleave(group_size, dim=0)
if explainable_ids_list is not None:
group_explain_ids = [copy.deepcopy(ids) for ids in explainable_ids_list for _ in range(group_size)]
else:
group_explain_ids = None
gen_out = self.generate_with_latent(
questions=group_questions.to(self.device),
questions_mask=group_masks.to(self.device),
max_new_tokens=rl_cfg.max_new_tokens,
temperature=rl_cfg.temperature,
)
# print(gen_out.pred_ids.tolist())
pred_strings = self.tokenizer.batch_decode(gen_out.pred_ids, skip_special_tokens=True)
n_latent_forward = gen_out.latent_attention_mask.sum(dim=1)
explain_logprob, explain_coherence = None, None
if self.expainable_llm is not None and group_explain_ids and rl_cfg.reason_reward_weight != 0:
explain_logprob, explain_coherence = self.compute_explain_logprob(
latent_inputs_embeds=gen_out.latent_inputs_embeds,
n_latent_forward=n_latent_forward,
pred_strings=pred_strings,
)
# print(explain_logprob.size())
rewards_list, acc_list, adv_list = [], [], []
for i in range(len(questions)):
start = i * group_size
end = (i + 1) * group_size
rewards, accuracies = self.get_group_rewards_and_acc(
pred_answers=pred_strings[start:end],
gt_answer=gt_answers[i],
explain_coherence=None if explain_coherence is None else explain_coherence[start:end],
n_latent_forward=n_latent_forward[start:end],
explain_logprob_chunk=None if explain_logprob is None else explain_logprob[start:end],
)
rewards_list.append(rewards)
acc_list.append(accuracies)
adv_list.append(grpo.group_advantages(rewards))
rewards = torch.cat(rewards_list, dim=0)
accuracies = torch.cat(acc_list, dim=0)
advantages = torch.cat(adv_list, dim=0)
# print(gen_out.latent_inputs_embeds.size())
# print(gen_out.latent_attention_mask.size())
# print(gen_out.question_input_ids.size())
# print(gen_out.question_attention_mask.size())
experience = grpo.Experience(
question_input_ids=gen_out.question_input_ids,
question_attention_mask=gen_out.question_attention_mask,
latent_inputs_embeds=gen_out.latent_inputs_embeds,
latent_attention_mask=gen_out.latent_attention_mask,
answer_input_ids=gen_out.pred_ids,
answer_attention_mask=gen_out.pred_ids.ne(self.pad_token_id).long(),
n_latent_forward=n_latent_forward.unsqueeze(1),
rewards=rewards,
accuracies=accuracies,
advantages=advantages,
)
if explain_logprob is not None:
experience.explain_logprob = explain_logprob
latent_logprobs, answer_logprobs = self.get_logprobs(experience)
# print(latent_logprobs.size())
# print(advantages.size())
experience.latent_logprobs = latent_logprobs
experience.answer_logprobs = answer_logprobs
return experience
def get_logprobs(self, e: grpo.Experience):
question_length = e.question_input_ids.shape[1]
latent_length = e.latent_inputs_embeds.shape[1]
answer_length = e.answer_input_ids.shape[1]
question_inputs_embeds = self.embedding(e.question_input_ids)
answer_inputs_embeds = self.embedding(e.answer_input_ids)
all_inputs_embeds = torch.cat(
[question_inputs_embeds, e.latent_inputs_embeds, answer_inputs_embeds], dim=1
)
all_attention_mask = torch.cat(
[e.question_attention_mask, e.latent_attention_mask, e.answer_attention_mask], dim=1
)
all_position_ids = get_position_ids_from_attention_mask(all_attention_mask)
all_outputs = self.base_causallm.forward(
inputs_embeds=all_inputs_embeds,
attention_mask=all_attention_mask,
position_ids=all_position_ids,
output_hidden_states=True,
)
# ===== Latent log_prob Gaussian =====
hidden = all_outputs.hidden_states[-1]
vec = hidden[:, question_length - 1: question_length + latent_length - 1, :]
mu, logvar = self.latent_head.forward(vec) # [dim*2]
std = torch.exp(0.5 * logvar)
dist = torch.distributions.Normal(mu, std)
latent_samples = e.latent_inputs_embeds[:, :latent_length, :]
pred_token_id = self.end_head(latent_samples)
# print(torch.softmax(pred_token_id, dim=-1))
for pass_idx in range(e.latent_attention_mask.size(1)):
if pass_idx % self.rl_config.c_thought:
pred_token_id[:, pass_idx, 1] = float('-inf')
pred_logprob = torch.log_softmax(pred_token_id, dim=-1)
eos_place = e.latent_attention_mask.sum(-1)[:, None]
# print(eos_place)
log_prob_fill = torch.cat((pred_logprob[:, :, 1], torch.zeros_like(pred_logprob[:, 0:1, 1])), dim=1)
latent_logprobs = dist.log_prob(latent_samples).sum(dim=-1) + pred_logprob[:, :, 0]
# print(latent_logprobs.tolist())
logits = all_outputs.logits
answer_logits = logits[:, question_length + latent_length - 1: -1, :]
answer_logprobs = F.log_softmax(answer_logits, dim=-1)
answer_logprobs = answer_logprobs.gather(dim=-1, index=e.answer_input_ids.unsqueeze(-1)).squeeze(-1)
answer_logprobs[:, 0] = log_prob_fill.gather(dim=1, index=eos_place).squeeze(-1)
# print(answer_logprobs.tolist())
return latent_logprobs, answer_logprobs
def get_group_rewards_and_acc(
self,
pred_answers,
gt_answer,
explain_coherence,
n_latent_forward,
explain_logprob_chunk=None,
):
rl_config = self.rl_config
group_size = len(pred_answers)
format = torch.zeros(group_size, 1, device=self.device)
accuracies = torch.zeros(group_size, 1, device=self.device)
for i, ans in enumerate(pred_answers):
gt_answer[gt_answer == -100] = self.pad_token_id
gt_string = self.tokenizer.decode(gt_answer, skip_special_tokens=True)
marker = "<|end-latent|>###"
gt_answer_string = gt_string[gt_string.rfind("###") + len("###"):]
idx = ans.rfind(marker)
if idx == -1:
pred_string = ""
else:
pred_string = ans[idx + len(marker):]
if len(pred_string.strip()) != 0:
format[i] = 1
accuracies[i] = self.verify_answer(gt_answer=gt_answer_string, pred_answer=pred_string)
print(
f"Ground_truth: {gt_answer_string} Answer {ans} Accuracy: {accuracies[i]} Format: {format[i]}")
if explain_logprob_chunk is not None:
print(f"Explain_prob: {explain_logprob_chunk[i]}")
if explain_coherence is not None:
print(f"Explain_coherence: {explain_coherence[i]}")
rewards = rl_config.accuracy_reward_weight * accuracies.detach().clone() + rl_config.format_reward_weight * format.detach().clone()
# if explain_logprob_chunk is not None:
# rewards = rewards + rl_config.reason_reward_weight * explain_logprob_chunk[:, None]
if explain_coherence is not None:
if rl_config.accuracy_reward_weight ==0:
rewards = rewards + rl_config.reason_reward_weight * explain_coherence[:, None]
else:
rewards = rewards + rl_config.reason_reward_weight * explain_coherence[:, None] * (accuracies.detach().clone())
return rewards, accuracies
def verify_answer(self, gt_answer: str, pred_answer: str) -> float:
# gt_answer = gt_answer.strip("\n ").rstrip(".").replace(",", "")
# pred_answer = pred_answer.strip("\n ").rstrip(".").rstrip(".\n").replace(",", "")
try: # some answers may be like '10.0' but predicted as '10'
gt_answer = float(gt_answer)
pred_answer = float(pred_answer)
return float(gt_answer == pred_answer)
except ValueError:
return float(self.is_equiv(gt_answer, pred_answer))
@torch.no_grad()
def eval_generation(self, dataloader, val='val'):
loop_sums = defaultdict(float)
inner_steps = 0
self.is_training = False
for step, batch in enumerate(dataloader):
inner_steps += 1
questions = batch["input_ids"]
questions_mask = batch["attention_mask"]
answers = batch["labels"]
explainable_ids_list = batch["explainable_ids_list"]
experience = self.rollout(
questions=questions,
questions_mask=questions_mask,
gt_answers=answers,
group_size=1,
explainable_ids_list=explainable_ids_list,
)
loop_sums[f"{val}/rewards"] += experience.rewards.mean().item()
loop_sums[f"{val}/accuracies"] += experience.accuracies.mean().item()
loop_sums[f"{val}/n_latent_forward"] += experience.n_latent_forward.float().mean().item()
if getattr(experience, "explain_logprob", None) is not None:
loop_sums[f"{val}/explain_logprob"] += experience.explain_logprob.mean().item()
torch.cuda.empty_cache()
print(f"Accuracy in Validation {loop_sums[f'{val}/accuracies'] * 100 / inner_steps}")
log_dict = {}
if inner_steps > 0:
for key in loop_sums:
log_dict[key] = loop_sums[key] / inner_steps
self.is_training = True
return log_dict
# -- evaluation ends --#
def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = self.strip_string(str1)
ss2 = self.strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def remove_boxed(self, s):
if "\\boxed " in s:
left = "\\boxed "
assert s[:len(left)] == left
return s[len(left):]
left = "\\boxed{"
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
def last_boxed_only_string(self, string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def fix_fracs(self, string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(self, string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(self, string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(self, string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(self, string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = self.remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = self.fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = self.fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = self.fix_a_slash_b(string)
return string