| | |
| | |
| | import math |
| | import re |
| |
|
| | from torch.distributions import Normal |
| | from collections import defaultdict |
| | import torch |
| | from torch.nn.utils import clip_grad_norm_ |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn import CrossEntropyLoss |
| | from torch.utils.data import DataLoader |
| | from collections import namedtuple |
| | from transformers.models.gpt2 import GPT2LMHeadModel |
| |
|
| | from modules import grpo |
| | from modules.projector import LatentPolicy |
| | from modules.utils import get_position_ids_from_attention_mask |
| | import copy |
| |
|
| | Outputs = namedtuple("Outputs", ["loss", "loss_explain_all", "inputs_embeds", "logits"]) |
| | Outputs_withkl = namedtuple("Outputs", ["loss", "loss_explain_all", "loss_kl", "inputs_embeds", "logits"]) |
| | Outputs_withmask = namedtuple("Outputs_withmask", |
| | ["loss", "attention_mask", "loss_explain_all", "inputs_embeds", "logits"]) |
| | MAX_N_LATENT = 8 |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| | from modules.projector import STOPPolicy |
| |
|
| |
|
| | @dataclass |
| | class LatentForwardResult: |
| | inputs_embeds: torch.Tensor |
| | attention_mask: torch.Tensor |
| | latent_inputs_embeds: torch.Tensor |
| |
|
| |
|
| | @dataclass |
| | class GenerationOutputs: |
| | question_input_ids: torch.Tensor |
| | question_attention_mask: torch.Tensor |
| | latent_inputs_embeds: torch.Tensor |
| | latent_attention_mask: torch.Tensor |
| | pred_ids: torch.Tensor |
| |
|
| |
|
| | class Coconut_RL_End_VAE(nn.Module): |
| | def __init__( |
| | self, |
| | base_causallm, |
| | expainable_llm, |
| | latent_token_id, |
| | start_latent_id, |
| | end_latent_id, |
| | eos_token_id, |
| | tokenizer=None, |
| | rl_config=None, |
| | is_training=True, |
| | ): |
| | super(Coconut_RL_End_VAE, self).__init__() |
| | self.gen_forward_cnt = 0 |
| | self.base_causallm = base_causallm |
| | self.latent_token_id = latent_token_id |
| | self.eos_token_id = eos_token_id |
| | self.pad_token_id = eos_token_id |
| | self.start_latent_id = start_latent_id |
| | self.end_latent_id = end_latent_id |
| | self.expainable_llm = expainable_llm |
| | self.base_causallm.config.use_cache = True |
| | self.expainable_llm.config.use_cache = True |
| | self.end_head = STOPPolicy( |
| | feature_size=self.base_causallm.config.hidden_size, |
| | intermediate_size=self.base_causallm.config.hidden_size |
| | ) |
| | self.latent_head = LatentPolicy( |
| | feature_size=self.base_causallm.config.hidden_size, |
| | intermediate_size=self.base_causallm.config.hidden_size, |
| | deterministic=False |
| | ) |
| | self.tokenizer = tokenizer |
| | self.rl_config = rl_config |
| | self.latent_sigma = rl_config.latent_sigma |
| | self.rl_step_counter = 0 |
| | self.is_training = is_training |
| | if self.expainable_llm is not None: |
| | for param in self.expainable_llm.parameters(): |
| | param.requires_grad = False |
| |
|
| | if isinstance(self.base_causallm, GPT2LMHeadModel): |
| | self.embedding = self.base_causallm.transformer.get_input_embeddings() |
| | else: |
| | self.embedding = self.base_causallm.get_input_embeddings() |
| |
|
| | self.replay_buffer = grpo.ReplayBuffer() |
| | self.grpo_loss = grpo.GRPOLoss(rl_config=self.rl_config) |
| |
|
| | @property |
| | def device(self): |
| | return next(self.parameters()).device |
| |
|
| | def latent_forward_with_cache(self, input_ids, attention_mask, position_ids=None): |
| | """ |
| | 只负责:对 batch 中所有 latent token 做并行采样, |
| | 复用 KV cache,并返回 latent embedding/logprob 等。 |
| | """ |
| | bs, seq_len = input_ids.size() |
| | device = input_ids.device |
| | latent_mask = (input_ids == self.latent_token_id) |
| | latent_positions = latent_mask.nonzero(as_tuple=False) |
| | latent_lists = [ |
| | [idx[1].item() for idx in latent_positions if idx[0] == b] |
| | for b in range(bs) |
| | ] |
| | bol_position = latent_lists[0][0] - 1 |
| | eol_position = latent_lists[0][-1] + 1 |
| | max_n_latents = max((len(lst) for lst in latent_lists), default=0) |
| |
|
| | next_compute_range = (0, seq_len) |
| | inputs_embeds = self.embedding(input_ids) |
| |
|
| | if max_n_latents > 0: |
| | next_compute_range = (0, latent_positions[:, 1].min().item()) |
| |
|
| | |
| | kv_cache = None |
| | latent_samples_all = [[] for _ in range(bs)] |
| |
|
| | for pass_idx in range(max_n_latents if max_n_latents > 0 else 1): |
| | slice_start, slice_end = next_compute_range |
| |
|
| | if kv_cache is None: |
| | outputs = self.base_causallm( |
| | inputs_embeds=inputs_embeds[:, slice_start:slice_end, :], |
| | attention_mask=attention_mask[:, slice_start:slice_end], |
| | position_ids=( |
| | position_ids[:, slice_start:slice_end] |
| | if position_ids is not None else None |
| | ), |
| | output_hidden_states=True, |
| | ) |
| | hidden_offset = 0 |
| | else: |
| | past_key_values = [ |
| | (k[:, :, :slice_start, :], v[:, :, :slice_start, :]) |
| | for k, v in kv_cache |
| | ] |
| | outputs = self.base_causallm( |
| | inputs_embeds=inputs_embeds[:, slice_start:slice_end, :], |
| | attention_mask=attention_mask[:, :slice_end], |
| | position_ids=( |
| | position_ids[:, slice_start:slice_end] |
| | if position_ids is not None else None |
| | ), |
| | past_key_values=past_key_values, |
| | output_hidden_states=True, |
| | ) |
| | hidden_offset = slice_start |
| |
|
| | kv_cache = outputs.past_key_values |
| | hidden_states = outputs.hidden_states[-1] |
| |
|
| | next_compute_range = ( |
| | slice_end, |
| | seq_len if pass_idx + 1 >= max_n_latents else slice_end + 1, |
| | ) |
| |
|
| | tensor_list = [ |
| | [inputs_embeds[b, pos, :] for pos in range(seq_len)] |
| | for b in range(bs) |
| | ] |
| |
|
| | for b, latent_positions in enumerate(latent_lists): |
| | if pass_idx >= len(latent_positions): |
| | continue |
| | token_idx = latent_positions[pass_idx] |
| | vec = hidden_states[b, token_idx - 1 - hidden_offset, :] |
| |
|
| | mu, logvar = self.latent_head.forward(vec) |
| | |
| | std = torch.exp(0.5 * logvar) |
| | eps = torch.randn_like(std) |
| | z = mu + std * eps |
| | |
| | |
| | |
| |
|
| | tensor_list[b][token_idx] = z |
| | latent_samples_all[b].append(z) |
| |
|
| | inputs_embeds = torch.stack([torch.stack(tensor_list[b]) for b in range(bs)], dim=0) |
| |
|
| | pred_token_id = self.end_head(inputs_embeds[:, next_compute_range[0], :]) |
| | |
| | if self.is_training == True: |
| | is_end = torch.multinomial(torch.softmax(pred_token_id, dim=-1), num_samples=1).squeeze( |
| | -1).bool() |
| | else: |
| | is_end = (pred_token_id[:, 1] > pred_token_id[:, 0]) |
| | |
| | if pass_idx % self.rl_config.c_thought == 0: |
| | attention_mask[is_end, bol_position + 1 + pass_idx:eol_position] = 0 |
| | if pass_idx >= max_n_latents: |
| | break |
| |
|
| | self.gen_forward_cnt += max_n_latents + 1 |
| |
|
| | latent_dim = inputs_embeds.size(-1) |
| | latent_lengths = torch.tensor([len(lst) for lst in latent_lists], |
| | device=device, dtype=torch.long) |
| | max_latent_len = latent_lengths.max().item() if latent_lengths.numel() > 0 else 0 |
| |
|
| | if max_latent_len > 0: |
| | latent_inputs_embeds = torch.zeros(bs, max_latent_len, latent_dim, device=device) |
| | for b in range(bs): |
| | if latent_samples_all[b]: |
| | latent_inputs_embeds[b, :len(latent_samples_all[b])] = torch.stack(latent_samples_all[b]) |
| | else: |
| | latent_inputs_embeds = torch.zeros(bs, 0, latent_dim, device=device) |
| |
|
| | return LatentForwardResult( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | latent_inputs_embeds=latent_inputs_embeds, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate_with_latent( |
| | self, |
| | questions, |
| | questions_mask, |
| | max_new_tokens=16, |
| | temperature=1.0, |
| | synced_gpus=False, |
| | ): |
| | self.eval() |
| | device = questions.device |
| | batch_size = questions.size(0) |
| |
|
| | latent_mask = (questions == self.latent_token_id) |
| | latent_positions = latent_mask.nonzero(as_tuple=False) |
| | latent_lists = [ |
| | [idx[1].item() for idx in latent_positions if idx[0] == b] |
| | for b in range(batch_size) |
| | ] |
| | bol_position = latent_lists[0][0] - 1 |
| | eol_position = latent_lists[0][-1] + 1 |
| | questions = questions[:, :eol_position + 1] |
| | questions_mask = questions_mask[:, :eol_position + 1] |
| | position_ids = get_position_ids_from_attention_mask(questions_mask) |
| |
|
| | latent_out = self.latent_forward_with_cache( |
| | input_ids=questions, |
| | attention_mask=questions_mask, |
| | position_ids=position_ids, |
| | ) |
| | attention_mask = latent_out.attention_mask |
| | prompt_embeds = latent_out.inputs_embeds |
| | latent_inputs_embeds = latent_out.latent_inputs_embeds |
| | latent_attention_mask = attention_mask[:, bol_position + 1:eol_position] |
| |
|
| | |
| | position_ids = get_position_ids_from_attention_mask(attention_mask) |
| | outputs = self.base_causallm( |
| | inputs_embeds=prompt_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | use_cache=False, |
| | ) |
| | logits = outputs.logits[:, -1, :] |
| |
|
| | tokens = [[questions[i, -1].item()] for i in range(batch_size)] |
| | unfinished = torch.ones(batch_size, dtype=torch.bool, device=device) |
| | cur_inputs_embeds = prompt_embeds |
| | cur_attention_mask = attention_mask |
| |
|
| | for _ in range(max_new_tokens): |
| | if temperature != 1.0: |
| | logits = logits / temperature |
| | if self.is_training: |
| | probs = torch.softmax(logits, dim=-1) |
| | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) |
| | else: |
| | next_tokens = torch.argmax(logits, dim=-1) |
| |
|
| | next_tokens = torch.where( |
| | unfinished, next_tokens, torch.full_like(next_tokens, self.eos_token_id) |
| | ) |
| | for i, tok in enumerate(next_tokens.tolist()): |
| | tokens[i].append(tok) |
| |
|
| | next_token_embeds = self.embedding(next_tokens).unsqueeze(1) |
| | cur_inputs_embeds = torch.cat([cur_inputs_embeds, next_token_embeds], dim=1) |
| | cur_attention_mask = torch.cat( |
| | [cur_attention_mask, torch.ones(batch_size, 1, dtype=cur_attention_mask.dtype, device=device)], |
| | dim=1, |
| | ) |
| |
|
| | position_ids = get_position_ids_from_attention_mask(cur_attention_mask) |
| | outputs = self.base_causallm( |
| | inputs_embeds=cur_inputs_embeds, |
| | attention_mask=cur_attention_mask, |
| | position_ids=position_ids, |
| | use_cache=False, |
| | ) |
| | self.gen_forward_cnt += 1 |
| | logits = outputs.logits[:, -1, :] |
| |
|
| | unfinished &= next_tokens.ne(self.eos_token_id) |
| | if not unfinished.any(): |
| | break |
| |
|
| | if synced_gpus: |
| | while self.gen_forward_cnt < max_new_tokens + MAX_N_LATENT: |
| | self.gen_forward_cnt += 1 |
| | _ = self.base_causallm( |
| | inputs_embeds=cur_inputs_embeds[:, cur_pos - 1:cur_pos, :], |
| | attention_mask=cur_attention_mask[:, :cur_pos], |
| | past_key_values=cur_kv_cache, |
| | use_cache=True, |
| | ) |
| |
|
| | max_len = max(len(seq) for seq in tokens) |
| | pred_ids = torch.full( |
| | (batch_size, max_len), |
| | fill_value=self.pad_token_id, |
| | dtype=torch.long, |
| | device=device, |
| | ) |
| | for i, seq in enumerate(tokens): |
| | pred_ids[i, : len(seq)] = torch.tensor(seq, device=device) |
| |
|
| | return GenerationOutputs( |
| | question_input_ids=questions[:, :bol_position + 1], |
| | question_attention_mask=questions_mask[:, :bol_position + 1], |
| | latent_inputs_embeds=latent_inputs_embeds, |
| | latent_attention_mask=latent_attention_mask, |
| | pred_ids=pred_ids, |
| | ) |
| |
|
| | def rl_step(self, batch, optimizer, writer=None, global_step=None): |
| | rl_config = self.rl_config |
| | device = self.device |
| | questions = batch["input_ids"] |
| | questions_mask = batch["attention_mask"] |
| | answers = batch["labels"] |
| | explainable_ids_list = batch["explainable_ids_list"] |
| |
|
| | self.replay_buffer.clear() |
| |
|
| | experience = self.rollout( |
| | questions=questions, |
| | questions_mask=questions_mask, |
| | group_size=rl_config.group_size, |
| | gt_answers=answers, |
| | explainable_ids_list=explainable_ids_list, |
| | ) |
| | self.replay_buffer.append(experience.to("cpu")) |
| |
|
| | log_dict = { |
| | "train/rewards": experience.rewards.mean().item(), |
| | "train/accuracies": experience.accuracies.mean().item(), |
| | "train/n_latent_forward": experience.n_latent_forward.float().mean().item(), |
| | } |
| | if getattr(experience, "explain_logprob", None) is not None: |
| | log_dict["train/explain_logprob"] = experience.explain_logprob.mean().item() |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | dataloader = DataLoader( |
| | dataset=self.replay_buffer, |
| | batch_size=rl_config.exp_batch_size, |
| | shuffle=True, |
| | collate_fn=grpo.join_experience_batch, |
| | ) |
| |
|
| | gradient_accumulation_steps = max(1, rl_config.gradient_accumulation_steps) |
| | total_inner_steps = len(dataloader) |
| |
|
| | loop_sums = defaultdict(float) |
| | inner_steps = 0 |
| | accum_counter = 0 |
| | optimizer.zero_grad() |
| | last_grad_norm = 0.0 |
| |
|
| | for exp in dataloader: |
| | inner_steps += 1 |
| | accum_counter += 1 |
| | exp: grpo.Experience = exp.to(device) |
| |
|
| | latent_logprobs, answer_logprobs = self.get_logprobs(e=exp) |
| | loss_dict = self.grpo_loss( |
| | latent_logprobs=latent_logprobs, |
| | answer_logprobs=answer_logprobs, |
| | experience=exp, |
| | ) |
| |
|
| | raw_total_loss = loss_dict["total_loss"] |
| | (raw_total_loss / gradient_accumulation_steps).backward() |
| |
|
| | for k, v in loss_dict.items(): |
| | value = v.item() if torch.is_tensor(v) else float(v) |
| | loop_sums[f"train/{k}"] += value |
| |
|
| | need_update = (accum_counter % gradient_accumulation_steps == 0) or (inner_steps == total_inner_steps) |
| | if need_update: |
| | grad_norm = clip_grad_norm_(self.parameters(), max_norm=1.0) |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | last_grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else float(grad_norm) |
| |
|
| | if inner_steps > 0: |
| | for key in list(loop_sums.keys()): |
| | log_dict[key] = loop_sums[key] / inner_steps |
| | log_dict["train/grad_norm"] = last_grad_norm |
| |
|
| | return log_dict |
| |
|
| | def extract_equal_result(self, expression): |
| | match = re.search(r'=(.*?)>>', expression) |
| | if match: |
| | return match.group(1).strip() |
| | else: |
| | return '' |
| |
|
| | @torch.no_grad() |
| | def compute_explain_logprob( |
| | self, |
| | latent_inputs_embeds: torch.Tensor, |
| | n_latent_forward: torch.Tensor, |
| | pred_strings, |
| | max_gen_len=16 |
| | ): |
| | if self.expainable_llm is None: |
| | return None |
| |
|
| | device = self.expainable_llm.device |
| | eos_token_id = self.eos_token_id |
| | thought = self.rl_config.c_thought |
| | max_stage = self.rl_config.max_latent_stage |
| | tok = self.tokenizer |
| |
|
| | B, _, hidden = latent_inputs_embeds.size() |
| |
|
| | |
| | prefix_list, owner = [], [] |
| | for b in range(B): |
| | latent_valid = int(n_latent_forward[b].item()) |
| | stage_num = min(max_stage, latent_valid // thought) |
| | for s in range(stage_num): |
| | start = s * thought |
| | end = start + thought |
| | prefix = latent_inputs_embeds[b, start:end, :].to(device) |
| | prefix_list.append(prefix) |
| | owner.append(b) |
| |
|
| | N = len(prefix_list) |
| | if N == 0: |
| | return torch.zeros(B, device=device), torch.ones(B, device=device) |
| |
|
| | seq_embeds = torch.stack(prefix_list, dim=0).clone() |
| | attention_mask = torch.ones(N, thought, dtype=torch.long, device=device) |
| | position_ids = torch.arange(thought, dtype=torch.long, device=device).unsqueeze(0).expand(N, -1) |
| |
|
| | total_logprob = torch.zeros(N, device=device) |
| | total_tokens = torch.zeros(N, device=device) |
| |
|
| | gen_token_ids = [[] for _ in range(N)] |
| | active_idx = torch.arange(N, device=device) |
| | |
| |
|
| | |
| | for _ in range(max_gen_len): |
| | outputs = self.expainable_llm( |
| | inputs_embeds=seq_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | ) |
| | logits = outputs.logits[:, -1, :] |
| | log_probs = torch.log_softmax(logits, dim=-1) |
| |
|
| | next_token = torch.argmax(logits, dim=-1) |
| | step_logprob = log_probs.gather(-1, next_token.unsqueeze(-1)).squeeze(-1) |
| |
|
| | total_logprob[active_idx] += step_logprob |
| | total_tokens[active_idx] += 1 |
| |
|
| | for aidx, tok_id in zip(active_idx.tolist(), next_token.tolist()): |
| | gen_token_ids[aidx].append(tok_id) |
| |
|
| | eos_mask = next_token.eq(eos_token_id) |
| | keep_mask = ~eos_mask |
| | if keep_mask.any(): |
| | kept_idx = active_idx[keep_mask] |
| | next_embed = self.embedding(next_token[keep_mask]).unsqueeze(1) |
| | seq_embeds = torch.cat([seq_embeds[keep_mask], next_embed], dim=1) |
| |
|
| | attention_mask = torch.cat( |
| | [attention_mask[keep_mask], |
| | torch.ones(len(kept_idx), 1, dtype=torch.long, device=device)], |
| | dim=1 |
| | ) |
| | next_pos = position_ids[keep_mask][:, -1:] + 1 |
| | position_ids = torch.cat([position_ids[keep_mask], next_pos], dim=1) |
| |
|
| | active_idx = kept_idx |
| | else: |
| | break |
| |
|
| | |
| | explain_logprob_b = torch.zeros(B, device=device) |
| | explain_coherence_b = torch.ones(B, device=device) |
| | explain_count_b = torch.zeros(B, dtype=torch.long, device=device) |
| | owner_t = torch.tensor(owner, dtype=torch.long, device=device) |
| |
|
| | valid_mask = total_tokens > 0 |
| | avg_logprob = torch.zeros_like(total_logprob) |
| | avg_logprob[valid_mask] = total_logprob[valid_mask] |
| |
|
| | for idx in range(N): |
| | b = owner_t[idx].item() |
| | explain_logprob_b[b] += avg_logprob[idx] |
| | explain_count_b[b] += 1 |
| |
|
| | explain_logprob_b = torch.where( |
| | explain_count_b > 0, |
| | explain_logprob_b / explain_count_b.clamp_min(1), |
| | torch.zeros_like(explain_logprob_b) |
| | ) |
| |
|
| | |
| | preview_by_batch = [[] for _ in range(B)] |
| | result_by_batch = [[] for _ in range(B)] |
| | for idx, tlist in enumerate(gen_token_ids): |
| | txt = tok.decode(tlist, skip_special_tokens=True) if tlist else "" |
| | b = owner_t[idx].item() |
| | result_of_eq = self.extract_equal_result(txt) |
| | preview_by_batch[b].append(txt) |
| | result_by_batch[b].append(result_of_eq) |
| |
|
| | for b in range(B): |
| | total_steps = len(result_by_batch[b]) |
| | if total_steps == 0: |
| | explain_coherence_b[b] = 1.0 |
| | continue |
| |
|
| | matched = 0 |
| | for i in range(len(result_by_batch[b]) - 1): |
| | for j in range(i + 1, len(preview_by_batch[b])): |
| | if result_by_batch[b][i] in preview_by_batch[b][j]: |
| | matched += 1 |
| | break |
| | last_result = result_by_batch[b][-1] |
| | if last_result in pred_strings[b]: |
| | matched += 1 |
| | explain_coherence_b[b] = matched / total_steps |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return explain_logprob_b, explain_coherence_b |
| |
|
| | @torch.no_grad() |
| | def rollout(self, questions, questions_mask, gt_answers, group_size, explainable_ids_list=None): |
| | rl_cfg = self.rl_config |
| | group_size = group_size |
| |
|
| | |
| | group_questions = questions.repeat_interleave(group_size, dim=0) |
| | group_masks = questions_mask.repeat_interleave(group_size, dim=0) |
| | if explainable_ids_list is not None: |
| | group_explain_ids = [copy.deepcopy(ids) for ids in explainable_ids_list for _ in range(group_size)] |
| | else: |
| | group_explain_ids = None |
| |
|
| | gen_out = self.generate_with_latent( |
| | questions=group_questions.to(self.device), |
| | questions_mask=group_masks.to(self.device), |
| | max_new_tokens=rl_cfg.max_new_tokens, |
| | temperature=rl_cfg.temperature, |
| | ) |
| | |
| | pred_strings = self.tokenizer.batch_decode(gen_out.pred_ids, skip_special_tokens=True) |
| | n_latent_forward = gen_out.latent_attention_mask.sum(dim=1) |
| |
|
| | explain_logprob, explain_coherence = None, None |
| | if self.expainable_llm is not None and group_explain_ids and rl_cfg.reason_reward_weight != 0: |
| | explain_logprob, explain_coherence = self.compute_explain_logprob( |
| | latent_inputs_embeds=gen_out.latent_inputs_embeds, |
| | n_latent_forward=n_latent_forward, |
| | pred_strings=pred_strings, |
| | ) |
| | |
| |
|
| | rewards_list, acc_list, adv_list = [], [], [] |
| | for i in range(len(questions)): |
| | start = i * group_size |
| | end = (i + 1) * group_size |
| | rewards, accuracies = self.get_group_rewards_and_acc( |
| | pred_answers=pred_strings[start:end], |
| | gt_answer=gt_answers[i], |
| | explain_coherence=None if explain_coherence is None else explain_coherence[start:end], |
| | n_latent_forward=n_latent_forward[start:end], |
| | explain_logprob_chunk=None if explain_logprob is None else explain_logprob[start:end], |
| | ) |
| | rewards_list.append(rewards) |
| | acc_list.append(accuracies) |
| | adv_list.append(grpo.group_advantages(rewards)) |
| |
|
| | rewards = torch.cat(rewards_list, dim=0) |
| | accuracies = torch.cat(acc_list, dim=0) |
| | advantages = torch.cat(adv_list, dim=0) |
| | |
| | |
| | |
| | |
| | experience = grpo.Experience( |
| | question_input_ids=gen_out.question_input_ids, |
| | question_attention_mask=gen_out.question_attention_mask, |
| | latent_inputs_embeds=gen_out.latent_inputs_embeds, |
| | latent_attention_mask=gen_out.latent_attention_mask, |
| | answer_input_ids=gen_out.pred_ids, |
| | answer_attention_mask=gen_out.pred_ids.ne(self.pad_token_id).long(), |
| | n_latent_forward=n_latent_forward.unsqueeze(1), |
| | rewards=rewards, |
| | accuracies=accuracies, |
| | advantages=advantages, |
| | ) |
| |
|
| | if explain_logprob is not None: |
| | experience.explain_logprob = explain_logprob |
| |
|
| | latent_logprobs, answer_logprobs = self.get_logprobs(experience) |
| | |
| | |
| | experience.latent_logprobs = latent_logprobs |
| | experience.answer_logprobs = answer_logprobs |
| | return experience |
| |
|
| | def get_logprobs(self, e: grpo.Experience): |
| | question_length = e.question_input_ids.shape[1] |
| | latent_length = e.latent_inputs_embeds.shape[1] |
| | answer_length = e.answer_input_ids.shape[1] |
| |
|
| | question_inputs_embeds = self.embedding(e.question_input_ids) |
| | answer_inputs_embeds = self.embedding(e.answer_input_ids) |
| |
|
| | all_inputs_embeds = torch.cat( |
| | [question_inputs_embeds, e.latent_inputs_embeds, answer_inputs_embeds], dim=1 |
| | ) |
| | all_attention_mask = torch.cat( |
| | [e.question_attention_mask, e.latent_attention_mask, e.answer_attention_mask], dim=1 |
| | ) |
| | all_position_ids = get_position_ids_from_attention_mask(all_attention_mask) |
| |
|
| | all_outputs = self.base_causallm.forward( |
| | inputs_embeds=all_inputs_embeds, |
| | attention_mask=all_attention_mask, |
| | position_ids=all_position_ids, |
| | output_hidden_states=True, |
| | ) |
| |
|
| | |
| | hidden = all_outputs.hidden_states[-1] |
| | vec = hidden[:, question_length - 1: question_length + latent_length - 1, :] |
| |
|
| | mu, logvar = self.latent_head.forward(vec) |
| | std = torch.exp(0.5 * logvar) |
| | dist = torch.distributions.Normal(mu, std) |
| | latent_samples = e.latent_inputs_embeds[:, :latent_length, :] |
| | pred_token_id = self.end_head(latent_samples) |
| | |
| | for pass_idx in range(e.latent_attention_mask.size(1)): |
| | if pass_idx % self.rl_config.c_thought: |
| | pred_token_id[:, pass_idx, 1] = float('-inf') |
| | pred_logprob = torch.log_softmax(pred_token_id, dim=-1) |
| | eos_place = e.latent_attention_mask.sum(-1)[:, None] |
| | |
| | log_prob_fill = torch.cat((pred_logprob[:, :, 1], torch.zeros_like(pred_logprob[:, 0:1, 1])), dim=1) |
| | latent_logprobs = dist.log_prob(latent_samples).sum(dim=-1) + pred_logprob[:, :, 0] |
| | |
| | logits = all_outputs.logits |
| | answer_logits = logits[:, question_length + latent_length - 1: -1, :] |
| | answer_logprobs = F.log_softmax(answer_logits, dim=-1) |
| | answer_logprobs = answer_logprobs.gather(dim=-1, index=e.answer_input_ids.unsqueeze(-1)).squeeze(-1) |
| | answer_logprobs[:, 0] = log_prob_fill.gather(dim=1, index=eos_place).squeeze(-1) |
| | |
| |
|
| | return latent_logprobs, answer_logprobs |
| |
|
| | def get_group_rewards_and_acc( |
| | self, |
| | pred_answers, |
| | gt_answer, |
| | explain_coherence, |
| | n_latent_forward, |
| | explain_logprob_chunk=None, |
| | ): |
| | rl_config = self.rl_config |
| | group_size = len(pred_answers) |
| | format = torch.zeros(group_size, 1, device=self.device) |
| | accuracies = torch.zeros(group_size, 1, device=self.device) |
| | for i, ans in enumerate(pred_answers): |
| | gt_answer[gt_answer == -100] = self.pad_token_id |
| | gt_string = self.tokenizer.decode(gt_answer, skip_special_tokens=True) |
| | marker = "<|end-latent|>###" |
| | gt_answer_string = gt_string[gt_string.rfind("###") + len("###"):] |
| | idx = ans.rfind(marker) |
| | if idx == -1: |
| | pred_string = "" |
| | else: |
| | pred_string = ans[idx + len(marker):] |
| | if len(pred_string.strip()) != 0: |
| | format[i] = 1 |
| | accuracies[i] = self.verify_answer(gt_answer=gt_answer_string, pred_answer=pred_string) |
| | print( |
| | f"Ground_truth: {gt_answer_string} Answer {ans} Accuracy: {accuracies[i]} Format: {format[i]}") |
| | if explain_logprob_chunk is not None: |
| | print(f"Explain_prob: {explain_logprob_chunk[i]}") |
| | if explain_coherence is not None: |
| | print(f"Explain_coherence: {explain_coherence[i]}") |
| |
|
| | rewards = rl_config.accuracy_reward_weight * accuracies.detach().clone() + rl_config.format_reward_weight * format.detach().clone() |
| |
|
| | |
| | |
| | if explain_coherence is not None: |
| | if rl_config.accuracy_reward_weight ==0: |
| | rewards = rewards + rl_config.reason_reward_weight * explain_coherence[:, None] |
| | else: |
| | rewards = rewards + rl_config.reason_reward_weight * explain_coherence[:, None] * (accuracies.detach().clone()) |
| |
|
| | return rewards, accuracies |
| |
|
| | def verify_answer(self, gt_answer: str, pred_answer: str) -> float: |
| | |
| | |
| | try: |
| | gt_answer = float(gt_answer) |
| | pred_answer = float(pred_answer) |
| | return float(gt_answer == pred_answer) |
| | except ValueError: |
| | return float(self.is_equiv(gt_answer, pred_answer)) |
| |
|
| | @torch.no_grad() |
| | def eval_generation(self, dataloader, val='val'): |
| | loop_sums = defaultdict(float) |
| | inner_steps = 0 |
| | self.is_training = False |
| | for step, batch in enumerate(dataloader): |
| | inner_steps += 1 |
| | questions = batch["input_ids"] |
| | questions_mask = batch["attention_mask"] |
| | answers = batch["labels"] |
| | explainable_ids_list = batch["explainable_ids_list"] |
| |
|
| | experience = self.rollout( |
| | questions=questions, |
| | questions_mask=questions_mask, |
| | gt_answers=answers, |
| | group_size=1, |
| | explainable_ids_list=explainable_ids_list, |
| | ) |
| |
|
| | loop_sums[f"{val}/rewards"] += experience.rewards.mean().item() |
| | loop_sums[f"{val}/accuracies"] += experience.accuracies.mean().item() |
| | loop_sums[f"{val}/n_latent_forward"] += experience.n_latent_forward.float().mean().item() |
| | if getattr(experience, "explain_logprob", None) is not None: |
| | loop_sums[f"{val}/explain_logprob"] += experience.explain_logprob.mean().item() |
| |
|
| | torch.cuda.empty_cache() |
| | print(f"Accuracy in Validation {loop_sums[f'{val}/accuracies'] * 100 / inner_steps}") |
| | log_dict = {} |
| | if inner_steps > 0: |
| | for key in loop_sums: |
| | log_dict[key] = loop_sums[key] / inner_steps |
| |
|
| | self.is_training = True |
| | return log_dict |
| |
|
| | |
| |
|
| | def is_equiv(self, str1, str2, verbose=False): |
| | if str1 is None and str2 is None: |
| | print("WARNING: Both None") |
| | return True |
| | if str1 is None or str2 is None: |
| | return False |
| |
|
| | try: |
| | ss1 = self.strip_string(str1) |
| | ss2 = self.strip_string(str2) |
| | if verbose: |
| | print(ss1, ss2) |
| | return ss1 == ss2 |
| | except Exception: |
| | return str1 == str2 |
| |
|
| | def remove_boxed(self, s): |
| | if "\\boxed " in s: |
| | left = "\\boxed " |
| | assert s[:len(left)] == left |
| | return s[len(left):] |
| |
|
| | left = "\\boxed{" |
| |
|
| | assert s[:len(left)] == left |
| | assert s[-1] == "}" |
| |
|
| | return s[len(left):-1] |
| |
|
| | def last_boxed_only_string(self, string): |
| | idx = string.rfind("\\boxed") |
| | if "\\boxed " in string: |
| | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] |
| | if idx < 0: |
| | idx = string.rfind("\\fbox") |
| | if idx < 0: |
| | return None |
| |
|
| | i = idx |
| | right_brace_idx = None |
| | num_left_braces_open = 0 |
| | while i < len(string): |
| | if string[i] == "{": |
| | num_left_braces_open += 1 |
| | if string[i] == "}": |
| | num_left_braces_open -= 1 |
| | if num_left_braces_open == 0: |
| | right_brace_idx = i |
| | break |
| | i += 1 |
| |
|
| | if right_brace_idx is None: |
| | retval = None |
| | else: |
| | retval = string[idx:right_brace_idx + 1] |
| |
|
| | return retval |
| |
|
| | def fix_fracs(self, string): |
| | substrs = string.split("\\frac") |
| | new_str = substrs[0] |
| | if len(substrs) > 1: |
| | substrs = substrs[1:] |
| | for substr in substrs: |
| | new_str += "\\frac" |
| | if substr[0] == "{": |
| | new_str += substr |
| | else: |
| | try: |
| | assert len(substr) >= 2 |
| | except AssertionError: |
| | return string |
| | a = substr[0] |
| | b = substr[1] |
| | if b != "{": |
| | if len(substr) > 2: |
| | post_substr = substr[2:] |
| | new_str += "{" + a + "}{" + b + "}" + post_substr |
| | else: |
| | new_str += "{" + a + "}{" + b + "}" |
| | else: |
| | if len(substr) > 2: |
| | post_substr = substr[2:] |
| | new_str += "{" + a + "}" + b + post_substr |
| | else: |
| | new_str += "{" + a + "}" + b |
| | string = new_str |
| | return string |
| |
|
| | def fix_a_slash_b(self, string): |
| | if len(string.split("/")) != 2: |
| | return string |
| | a = string.split("/")[0] |
| | b = string.split("/")[1] |
| | try: |
| | a = int(a) |
| | b = int(b) |
| | assert string == "{}/{}".format(a, b) |
| | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
| | return new_string |
| | except AssertionError: |
| | return string |
| |
|
| | def remove_right_units(self, string): |
| | |
| | if "\\text{ " in string: |
| | splits = string.split("\\text{ ") |
| | assert len(splits) == 2 |
| | return splits[0] |
| | else: |
| | return string |
| |
|
| | def fix_sqrt(self, string): |
| | if "\\sqrt" not in string: |
| | return string |
| | splits = string.split("\\sqrt") |
| | new_string = splits[0] |
| | for split in splits[1:]: |
| | if split[0] != "{": |
| | a = split[0] |
| | new_substr = "\\sqrt{" + a + "}" + split[1:] |
| | else: |
| | new_substr = "\\sqrt" + split |
| | new_string += new_substr |
| | return new_string |
| |
|
| | def strip_string(self, string): |
| | |
| | string = string.replace("\n", "") |
| |
|
| | |
| | string = string.replace("\\!", "") |
| |
|
| | |
| | string = string.replace("\\\\", "\\") |
| |
|
| | |
| | string = string.replace("tfrac", "frac") |
| | string = string.replace("dfrac", "frac") |
| |
|
| | |
| | string = string.replace("\\left", "") |
| | string = string.replace("\\right", "") |
| |
|
| | |
| | string = string.replace("^{\\circ}", "") |
| | string = string.replace("^\\circ", "") |
| |
|
| | |
| | string = string.replace("\\$", "") |
| |
|
| | |
| | string = self.remove_right_units(string) |
| |
|
| | |
| | string = string.replace("\\%", "") |
| | string = string.replace("\%", "") |
| |
|
| | |
| | string = string.replace(" .", " 0.") |
| | string = string.replace("{.", "{0.") |
| | |
| | if len(string) == 0: |
| | return string |
| | if string[0] == ".": |
| | string = "0" + string |
| |
|
| | |
| | if len(string.split("=")) == 2: |
| | if len(string.split("=")[0]) <= 2: |
| | string = string.split("=")[1] |
| |
|
| | |
| | string = self.fix_sqrt(string) |
| |
|
| | |
| | string = string.replace(" ", "") |
| |
|
| | |
| | string = self.fix_fracs(string) |
| |
|
| | |
| | if string == "0.5": |
| | string = "\\frac{1}{2}" |
| |
|
| | |
| | string = self.fix_a_slash_b(string) |
| |
|
| | return string |
| |
|