| """ |
| File: mllm/models/scalar_critic.py |
| Summary: Defines a scalar critic network and helper utilities. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from peft import LoraConfig, get_peft_model |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from mllm.models.adapter_training_wrapper import AdapterWrapper |
|
|
|
|
| class ScalarCritic(nn.Module): |
| """ |
| A causal-LM critic_adapter + a scalar value head: |
| V_φ(s) = wᵀ h_last + b |
| Only LoRA adapters (inside critic_adapter) and the value head are trainable. |
| """ |
|
|
| def __init__(self, critic_adapter: AdapterWrapper): |
| super().__init__() |
| self.critic_adapter = critic_adapter |
| hidden_size = self.critic_adapter.shared_llm.config.hidden_size |
| self.value_head = nn.Linear(hidden_size, 1).to( |
| dtype=critic_adapter.dtype, device=critic_adapter.device |
| ) |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| |
| outputs = self.critic_adapter( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| **kwargs, |
| ) |
| h_last = outputs.hidden_states[-1] |
| values = self.value_head(h_last).squeeze(-1) |
| return values |
|
|
| def parameters(self, recurse: bool = True): |
| """Iterator over *trainable* parameters for this critic.""" |
| |
| for p in self.critic_adapter.parameters(): |
| yield p |
| |
| yield from self.value_head.parameters() |
|
|
| def gradient_checkpointing_enable(self, *args, **kwargs): |
| self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs) |
|
|
| @property |
| def dtype(self): |
| return self.critic_adapter.dtype |
|
|
| @property |
| def device(self): |
| return self.critic_adapter.device |
|
|