| | from dataclasses import dataclass |
| | from typing import Optional, Tuple |
| | import torch |
| | import torch.nn as nn |
| | from transformers import LlamaPreTrainedModel, LlamaModel |
| | from transformers.utils import ModelOutput |
| |
|
| | @dataclass |
| | class MultiAspectRewardOutput(ModelOutput): |
| | """ |
| | Custom output class to return multi-aspect predictions plus final reward. |
| | |
| | Args: |
| | aspect_scores (torch.FloatTensor): shape (batch, 5) |
| | final_reward (torch.FloatTensor): shape (batch,) |
| | logits (torch.FloatTensor): shape (batch,) same as final_reward |
| | loss (torch.FloatTensor): optional scalar |
| | hidden_states (tuple(torch.FloatTensor)): optional hidden states |
| | attentions (tuple(torch.FloatTensor)): optional attentions |
| | """ |
| | aspect_scores: torch.FloatTensor = None |
| | final_reward: torch.FloatTensor = None |
| | logits: torch.FloatTensor = None |
| | loss: Optional[torch.FloatTensor] = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
| |
|
| | class LlamaFixedWeightReward(LlamaPreTrainedModel): |
| | """ |
| | A single final class that: |
| | 1) Optionally takes a pretrained Llama backbone (base_llama), |
| | 2) Predicts 5 aspect scores, computing MSE if 5-dim labels are provided, |
| | 3) Aggregates the 5 aspect scores via fixed weights -> 1 scalar reward, |
| | 4) Returns MultiAspectRewardOutput with shape [batch] in 'final_reward' and 'logits'. |
| | """ |
| | def __init__(self, config, base_llama=None, rule_weights=None): |
| | """ |
| | Args: |
| | config: LlamaConfig with num_labels=5 for multi-aspect predictions. |
| | base_llama: (optional) an already loaded LlamaModel |
| | rule_weights: (optional) A list or torch.Tensor of shape (5,) for aggregation. |
| | If None, defaults to [0.2, 0.2, 0.2, 0.2, 0.2]. |
| | """ |
| | super().__init__(config) |
| | |
| | |
| | if base_llama is not None: |
| | self.llama = base_llama |
| | else: |
| | self.llama = LlamaModel(config) |
| |
|
| | |
| | |
| | self.aspect_head = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | |
| | if rule_weights is not None: |
| | w = torch.tensor(rule_weights, dtype=torch.float) |
| | else: |
| | weights = [1/config.num_labels] * config.num_labels |
| | |
| | w = torch.tensor(weights, dtype=torch.float) |
| | self.register_buffer("rule_weights", w.view(1, -1), persistent=True) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | labels=None, |
| | **kwargs |
| | ): |
| | |
| | outputs = self.llama( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **kwargs |
| | ) |
| | |
| | last_hidden = outputs.last_hidden_state |
| |
|
| | |
| | pooled = last_hidden[:, -1, :] |
| |
|
| | |
| | aspect_scores = self.aspect_head(pooled) |
| | |
| | aspect_scores = torch.sigmoid(aspect_scores) |
| |
|
| | |
| | loss = None |
| | if labels is not None: |
| | mse_fn = nn.MSELoss() |
| | loss = mse_fn(aspect_scores, labels.float()) |
| |
|
| | |
| | reward = (aspect_scores * self.rule_weights).sum(dim=-1) |
| |
|
| | |
| | return MultiAspectRewardOutput( |
| | loss=loss, |
| | aspect_scores=aspect_scores, |
| | final_reward=reward, |
| | logits=reward, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions |
| | ) |