|
|
""" |
|
|
Actor and Critic models for offline RL with QLoRA. |
|
|
|
|
|
This file contains the Actor and Critic model implementations using QLoRA |
|
|
for efficient finetuning of LLMs. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
from peft import LoraConfig, get_peft_model, PeftModel, TaskType |
|
|
import platform |
|
|
import re |
|
|
|
|
|
|
|
|
def get_target_modules_for_model(model_id, model): |
|
|
""" |
|
|
Get the appropriate target modules for LoRA based on the model architecture. |
|
|
|
|
|
Args: |
|
|
model_id: The model identifier string |
|
|
model: The loaded model |
|
|
|
|
|
Returns: |
|
|
List of target module names |
|
|
""" |
|
|
|
|
|
if "llama" in model_id.lower() or "mistral" in model_id.lower(): |
|
|
|
|
|
return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
|
|
elif "gpt-j" in model_id.lower(): |
|
|
|
|
|
return ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out"] |
|
|
elif "gpt-neox" in model_id.lower(): |
|
|
|
|
|
return ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"] |
|
|
elif "gpt2" in model_id.lower(): |
|
|
|
|
|
return ["c_attn", "c_proj", "c_fc", "c_proj"] |
|
|
elif hasattr(model, "config") and hasattr(model.config, "architectures"): |
|
|
|
|
|
arch = model.config.architectures[0] if model.config.architectures else "" |
|
|
if "GPT2" in arch: |
|
|
return ["c_attn", "c_proj", "c_fc", "c_proj"] |
|
|
elif "Llama" in arch or "Mistral" in arch: |
|
|
return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
|
|
elif "GPTNeoX" in arch: |
|
|
return ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"] |
|
|
|
|
|
|
|
|
module_names = [] |
|
|
for name, _ in model.named_modules(): |
|
|
if any(substr in name for substr in ["attn", "mlp", "attention"]): |
|
|
parts = name.split(".") |
|
|
if len(parts) > 1: |
|
|
module_names.append(parts[-1]) |
|
|
|
|
|
|
|
|
if module_names: |
|
|
|
|
|
attn_patterns = ["attn", "attention", "self", "q", "k", "v", "query", "key", "value"] |
|
|
mlp_patterns = ["mlp", "feed_forward", "fc", "dense", "linear", "ffn"] |
|
|
|
|
|
attn_modules = [name for name in module_names if any(p in name.lower() for p in attn_patterns)] |
|
|
mlp_modules = [name for name in module_names if any(p in name.lower() for p in mlp_patterns)] |
|
|
|
|
|
if attn_modules or mlp_modules: |
|
|
return list(set(attn_modules + mlp_modules)) |
|
|
|
|
|
|
|
|
print(f"Warning: Could not determine target modules for {model_id}. Using default modules.") |
|
|
return ["query", "key", "value", "dense"] |
|
|
|
|
|
|
|
|
class LLMActorLora: |
|
|
"""Actor model with QLoRA for LLMs.""" |
|
|
|
|
|
def __init__(self, device, model_id="meta-llama/Llama-3-8B-Instruct", lora_r=8, disable_quantization=False): |
|
|
""" |
|
|
Initialize the actor model with QLoRA. |
|
|
|
|
|
Args: |
|
|
device: Device to run the model on |
|
|
model_id: HuggingFace model ID |
|
|
lora_r: LoRA rank parameter |
|
|
disable_quantization: If True, disable 4-bit quantization (useful for Mac/CPU) |
|
|
""" |
|
|
self.device = device |
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
is_mac = platform.system() == 'Darwin' |
|
|
running_on_cpu = 'cpu' in str(device).lower() |
|
|
|
|
|
|
|
|
if disable_quantization or (is_mac and running_on_cpu): |
|
|
if disable_quantization: |
|
|
print(f"4-bit quantization disabled by user request") |
|
|
else: |
|
|
print(f"4-bit quantization automatically disabled for Mac/CPU") |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float32 if running_on_cpu else torch.bfloat16, |
|
|
).to(device) |
|
|
else: |
|
|
|
|
|
self.bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Loading {model_id} with QLoRA (4-bit quantization + LoRA)") |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
quantization_config=self.bnb_config, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
target_modules = get_target_modules_for_model(model_id, self.model) |
|
|
print(f"Using target modules for LoRA: {target_modules}") |
|
|
|
|
|
|
|
|
self.lora_config = LoraConfig( |
|
|
r=lora_r, |
|
|
lora_alpha=2 * lora_r, |
|
|
target_modules=target_modules, |
|
|
bias="none", |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
lora_dropout=0.05, |
|
|
) |
|
|
|
|
|
|
|
|
self.model = get_peft_model(self.model, self.lora_config) |
|
|
self.model.print_trainable_parameters() |
|
|
|
|
|
def parameters(self): |
|
|
"""Return the model parameters.""" |
|
|
return self.model.parameters() |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
""" |
|
|
Forward pass through the model. |
|
|
|
|
|
Args: |
|
|
input_ids: Tokenized input IDs |
|
|
attention_mask: Attention mask |
|
|
|
|
|
Returns: |
|
|
Model outputs |
|
|
""" |
|
|
return self.model(input_ids, attention_mask=attention_mask) |
|
|
|
|
|
def get_log_probs(self, input_ids, action_ids, attention_mask=None): |
|
|
""" |
|
|
Calculate log probabilities for given actions. |
|
|
|
|
|
Args: |
|
|
input_ids: Tokenized input IDs [batch_size, seq_len] |
|
|
action_ids: Tokenized action IDs [batch_size, act_len] |
|
|
attention_mask: Attention mask [batch_size, seq_len] |
|
|
|
|
|
Returns: |
|
|
log_probs: Log probabilities of actions [batch_size] |
|
|
entropy: Entropy of the policy [batch_size] |
|
|
""" |
|
|
outputs = self.model(input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
last_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(last_token_logits, dim=-1) |
|
|
first_action_tokens = action_ids[:, 0] |
|
|
selected_log_probs = log_probs.gather(1, first_action_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
|
|
|
probs = torch.exp(log_probs) |
|
|
entropy = -(probs * log_probs).sum(dim=-1) |
|
|
|
|
|
return selected_log_probs, entropy |
|
|
|
|
|
def generate(self, input_ids, attention_mask=None, **kwargs): |
|
|
""" |
|
|
Generate text from the model. |
|
|
|
|
|
Args: |
|
|
input_ids: Tokenized input IDs |
|
|
attention_mask: Attention mask |
|
|
kwargs: Additional generation arguments |
|
|
|
|
|
Returns: |
|
|
Generated token IDs |
|
|
""" |
|
|
return self.model.generate(input_ids, attention_mask=attention_mask, **kwargs) |
|
|
|
|
|
def save_pretrained(self, path): |
|
|
""" |
|
|
Save the model to the given path. |
|
|
|
|
|
Args: |
|
|
path: Path to save the model to |
|
|
""" |
|
|
self.model.save_pretrained(path) |
|
|
|
|
|
def load_pretrained(self, path): |
|
|
""" |
|
|
Load the model from the given path. |
|
|
|
|
|
Args: |
|
|
path: Path to load the model from |
|
|
""" |
|
|
self.model = PeftModel.from_pretrained(self.model, path) |
|
|
|
|
|
|
|
|
class LLMCriticLora: |
|
|
"""Critic (value function) model with QLoRA for LLMs.""" |
|
|
|
|
|
def __init__(self, device, model_id="meta-llama/Llama-3-8B-Instruct", lora_r=8, disable_quantization=False): |
|
|
""" |
|
|
Initialize the critic model with QLoRA. |
|
|
|
|
|
Args: |
|
|
device: Device to run the model on |
|
|
model_id: HuggingFace model ID |
|
|
lora_r: LoRA rank parameter |
|
|
disable_quantization: If True, disable 4-bit quantization (useful for Mac/CPU) |
|
|
""" |
|
|
self.device = device |
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
is_mac = platform.system() == 'Darwin' |
|
|
running_on_cpu = 'cpu' in str(device).lower() |
|
|
|
|
|
|
|
|
if disable_quantization or (is_mac and running_on_cpu): |
|
|
if disable_quantization: |
|
|
print(f"Critic: 4-bit quantization disabled by user request") |
|
|
else: |
|
|
print(f"Critic: 4-bit quantization automatically disabled for Mac/CPU") |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float32 if running_on_cpu else torch.bfloat16, |
|
|
).to(device) |
|
|
else: |
|
|
|
|
|
self.bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Loading critic {model_id} with QLoRA (4-bit quantization + LoRA)") |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
quantization_config=self.bnb_config, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
target_modules = get_target_modules_for_model(model_id, self.model) |
|
|
print(f"Critic: Using target modules for LoRA: {target_modules}") |
|
|
|
|
|
|
|
|
self.lora_config = LoraConfig( |
|
|
r=lora_r, |
|
|
lora_alpha=2 * lora_r, |
|
|
target_modules=target_modules, |
|
|
bias="none", |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
lora_dropout=0.05, |
|
|
) |
|
|
|
|
|
|
|
|
self.model = get_peft_model(self.model, self.lora_config) |
|
|
|
|
|
|
|
|
hidden_size = self.model.config.hidden_size |
|
|
self.value_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
).to(device) |
|
|
|
|
|
self.model.print_trainable_parameters() |
|
|
|
|
|
def parameters(self): |
|
|
"""Return all trainable parameters.""" |
|
|
return list(self.model.parameters()) + list(self.value_head.parameters()) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
""" |
|
|
Forward pass to compute the value function. |
|
|
|
|
|
Args: |
|
|
input_ids: Tokenized input IDs |
|
|
attention_mask: Attention mask |
|
|
|
|
|
Returns: |
|
|
Value predictions [batch_size] |
|
|
""" |
|
|
|
|
|
outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
if attention_mask is not None: |
|
|
|
|
|
last_token_positions = attention_mask.sum(dim=1) - 1 |
|
|
last_token_hidden = hidden_states[torch.arange(batch_size), last_token_positions] |
|
|
else: |
|
|
|
|
|
last_token_hidden = hidden_states[:, -1] |
|
|
|
|
|
|
|
|
|
|
|
if last_token_hidden.dtype != next(self.value_head.parameters()).dtype: |
|
|
last_token_hidden = last_token_hidden.to(next(self.value_head.parameters()).dtype) |
|
|
values = self.value_head(last_token_hidden) |
|
|
|
|
|
return values |
|
|
|
|
|
def save_pretrained(self, path): |
|
|
""" |
|
|
Save the model to the given path. |
|
|
|
|
|
Args: |
|
|
path: Path to save the model to |
|
|
""" |
|
|
self.model.save_pretrained(f"{path}/lora") |
|
|
torch.save(self.value_head.state_dict(), f"{path}/value_head.pt") |
|
|
|
|
|
def load_pretrained(self, path): |
|
|
""" |
|
|
Load the model from the given path. |
|
|
|
|
|
Args: |
|
|
path: Path to load the model from |
|
|
""" |
|
|
self.model = PeftModel.from_pretrained(self.model, f"{path}/lora") |
|
|
self.value_head.load_state_dict(torch.load(f"{path}/value_head.pt")) |