File size: 1,948 Bytes
9ba32f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
File: mllm/models/scalar_critic.py
Summary: Defines a scalar critic network and helper utilities.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

from mllm.models.adapter_training_wrapper import AdapterWrapper


class ScalarCritic(nn.Module):
    """
    A causal-LM critic_adapter + a scalar value head:
        V_φ(s) = wᵀ h_last + b
    Only LoRA adapters (inside critic_adapter) and the value head are trainable.
    """

    def __init__(self, critic_adapter: AdapterWrapper):
        super().__init__()
        self.critic_adapter = critic_adapter
        hidden_size = self.critic_adapter.shared_llm.config.hidden_size
        self.value_head = nn.Linear(hidden_size, 1).to(
            dtype=critic_adapter.dtype, device=critic_adapter.device
        )

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # AdapterWrapper activates its own adapter internally
        outputs = self.critic_adapter(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs,
        )
        h_last = outputs.hidden_states[-1]  # (B, S, H)
        values = self.value_head(h_last).squeeze(-1)  # (B, S)
        return values

    def parameters(self, recurse: bool = True):
        """Iterator over *trainable* parameters for this critic."""
        # 1) LoRA params for *this* adapter
        for p in self.critic_adapter.parameters():
            yield p
        # 2) scalar head
        yield from self.value_head.parameters()

    def gradient_checkpointing_enable(self, *args, **kwargs):
        self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs)

    @property
    def dtype(self):
        return self.critic_adapter.dtype

    @property
    def device(self):
        return self.critic_adapter.device