Muqeeth's picture
Add files using upload-large-folder tool
9ba32f5 verified
"""
File: mllm/models/scalar_critic.py
Summary: Defines a scalar critic network and helper utilities.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from mllm.models.adapter_training_wrapper import AdapterWrapper
class ScalarCritic(nn.Module):
"""
A causal-LM critic_adapter + a scalar value head:
V_φ(s) = wᵀ h_last + b
Only LoRA adapters (inside critic_adapter) and the value head are trainable.
"""
def __init__(self, critic_adapter: AdapterWrapper):
super().__init__()
self.critic_adapter = critic_adapter
hidden_size = self.critic_adapter.shared_llm.config.hidden_size
self.value_head = nn.Linear(hidden_size, 1).to(
dtype=critic_adapter.dtype, device=critic_adapter.device
)
def forward(self, input_ids, attention_mask=None, **kwargs):
# AdapterWrapper activates its own adapter internally
outputs = self.critic_adapter(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h_last = outputs.hidden_states[-1] # (B, S, H)
values = self.value_head(h_last).squeeze(-1) # (B, S)
return values
def parameters(self, recurse: bool = True):
"""Iterator over *trainable* parameters for this critic."""
# 1) LoRA params for *this* adapter
for p in self.critic_adapter.parameters():
yield p
# 2) scalar head
yield from self.value_head.parameters()
def gradient_checkpointing_enable(self, *args, **kwargs):
self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs)
@property
def dtype(self):
return self.critic_adapter.dtype
@property
def device(self):
return self.critic_adapter.device