VLM-Lens / src /probe /main.py
marstin's picture
[martin-dev] add demo v1 test
d425e71
"""Probe classes for information analysis in models.
Example command: python -m src.probe.probe -c configs/probe/qwen/clevr-boolean-l13-example.yaml
"""
import argparse
import io
import itertools
import json
import logging
import os
import random
import sqlite3
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from sklearn.model_selection import KFold, train_test_split
from statsmodels.stats.proportion import proportions_ztest
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
class ProbeConfig:
"""Configuration class for the probe."""
def __init__(self) -> None:
"""Initialize the configuration.
Raises:
ValueError: If the configuration file is not found.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'-c', '--config', type=str, help='Path to the probe configuration file'
)
parser.add_argument(
'--debug',
default=False,
action='store_true',
help='Flag to print out debug statements',
)
parser.add_argument(
'-d',
'--device',
type=str,
default='cuda' if torch.cuda.is_available() else 'cpu',
help='The device to send the model and tensors to',
)
args = parser.parse_args()
assert args.config is not None, 'Config file must be provided.'
with open(args.config, 'r') as file:
data = yaml.safe_load(file)
for key in data.keys():
setattr(self, key, data[key])
# Set debug mode based on config
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
# Load model device
if 'cuda' in args.device and not torch.cuda.is_available():
raise ValueError('No GPU found on this machine')
self.device = args.device
logging.debug(self.device)
# Load data mapping
assert (
hasattr(self, 'data')
), 'The `data` field must be specified in the config, with an input database path.'
data_mapping = {}
for mapping in self.data:
data_mapping = {**data_mapping, **mapping}
# Check if specific layer in specified for the database
data_mapping.setdefault('input_layer', None)
# Set default database name if not specified
if 'db_name' not in data_mapping:
logging.debug(
'Input database name attribute `db_name` not specified, setting to default `tensors`.')
data_mapping.setdefault('db_name', 'tensors')
self.data = data_mapping
# Load model mapping
model_mapping = {}
if hasattr(self, 'model'):
for mapping in self.model:
model_mapping = {**model_mapping, **mapping}
# Set default model config if not provided
# input_size and output_size will be set when the data is loaded
model_mapping.update({k: v for k, v in {
'activation': 'ReLU',
'hidden_size': 256,
'num_layers': 2,
}.items() if k not in model_mapping})
logging.debug(model_mapping)
self.model = model_mapping
# Load training mapping
train_mapping = {}
if hasattr(self, 'training'):
for mapping in self.training:
train_mapping = {**train_mapping, **mapping}
logging.debug(train_mapping)
# Set default training config if not provided
train_mapping.update({k: v for k, v in {
'optimizer': 'AdamW',
'learning_rate': 1e-3,
'loss': 'CrossEntropyLoss',
'num_epochs': 10,
'batch_size': 32
}.items() if k not in train_mapping})
self.training = train_mapping
# Load test mapping
test_mapping = {}
if hasattr(self, 'test'):
for mapping in self.test:
test_mapping = {**test_mapping, **mapping}
# Set default test config if not provided
test_mapping.update({k: v for k, v in {
'optimizer': 'AdamW',
'learning_rate': 1e-3,
'loss': 'CrossEntropyLoss',
'num_epochs': 10,
'batch_size': 32
}.items() if k not in test_mapping})
self.test = test_mapping
class Probe(nn.Module):
"""Probe class for extracting information from models."""
def __init__(self, config: Dict[str, Any]) -> None:
"""Intialize the probe with the given configuration.
Args:
config (Dict[str, Any]): Configuration dictionary for the probe.
"""
super(Probe, self).__init__()
self.config = config
# Load input data to parse model input_size and output_size
self.data = self.load_data()
# Intialize the model
self.build_model()
def build_model(self) -> None:
"""Builds the probe model from scratch."""
# Intialize probe model
layers = list()
layers.append(
nn.Linear(self.config.model['input_size'],
self.config.model['hidden_size'])
)
layers.append(getattr(nn, self.config.model['activation'])())
# Intialize intermediate layers based on config
for _ in range(self.config.model['num_layers'] - 2):
layers.append(
nn.Linear(self.config.model['hidden_size'],
self.config.model['hidden_size'])
)
layers.append(getattr(nn, self.config.model['activation'])())
# Final layer to output the desired size
layers.append(
nn.Linear(self.config.model['hidden_size'],
self.config.model['output_size'])
)
# Combine all layers to construct the model
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the probe model.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
logging.debug('Forward pass with input: %s', x.shape)
return self.model(x)
def load_data(self, shuffle: bool = False) -> TensorDataset:
"""Load tensors from the database.
Args:
shuffle (bool): Whether to shuffle the data.
Returns:
TensorDataset: A dataset containing the loaded tensors.
"""
logging.debug('Loading tensors from the database...')
# Connect to database
connection = sqlite3.connect(self.config.data['input_db'])
cursor = connection.cursor()
# Build query and fetch results
cursor.execute(
f"SELECT layer, tensor, label FROM {self.config.data['db_name']}"
)
results = cursor.fetchall()
# Close the connection
connection.close()
# Gather unique class labels
all_labels = set([result[2] for result in results])
self.config.model.setdefault('output_size', len(all_labels))
assert (
'output_size' in self.config.model and len(
all_labels) == self.config.model['output_size']
), 'Input attribute `output_size` does not match number of classes in dataset. Leave blank to assign automatically.'
# Label to index mapping
label_to_idx = {label: i for i, label in enumerate(all_labels)}
features, targets = [], []
probe_layer = self.config.data.get('input_layer', None)
if not probe_layer:
logging.debug(
'No `input_layer` attribute provided for database loading, extracting all tensors...')
input_size = self.config.data.get('input_size', None)
for layer, tensor_bytes, label in results:
if (probe_layer and layer == probe_layer) or (not probe_layer):
tensor = torch.load(io.BytesIO(tensor_bytes),
map_location=self.config.device)
if tensor.ndim > 2:
# Apply mean pooling if tensor is not already pooled
tensor = tensor.mean(dim=1)
# Squeeze to shape (hidden_dim)
tensor = tensor.squeeze()
if not input_size:
# Set model config input_size once
input_size = tensor.shape[0] # pooled tensor
self.config.model.setdefault('input_size', input_size)
assert (
'input_size' in self.config.model and input_size == self.config.model[
'input_size']
), 'Input attribute `input_size` does not match input tensor dimension. Leave blank to assign automatically.'
features.append(tensor)
targets.append(label_to_idx[label])
if shuffle:
random.shuffle(targets)
# Stack lists into batched tensors
X = torch.stack(features)
Y = torch.tensor(targets)
logging.debug(f'Features shape {X.shape}, Targets shape {Y.shape}')
# Move tensors to same device as model
X, Y = X.to(self.config.device), Y.to(self.config.device)
return TensorDataset(X, Y)
def cross_validate(self, config: dict, data: Dataset, nfolds: Optional[int] = 5) -> float:
"""Trains the model using the config hyperparameters across k folds.
Args:
config (dict): The configuration dictionary.
data (Dataset): The dataset to train on.
nfolds (Optional[int]): The number of folds for cross-validation.
Returns:
float: The average validation loss across all folds.
"""
kf = KFold(n_splits=nfolds, shuffle=True, random_state=42)
val_losses = []
for fold, (train_idx, val_idx) in enumerate(kf.split(range(len(data)))):
logging.debug(f'===Starting fold {fold}/{nfolds}===')
train_set, val_set = Subset(data, train_idx), Subset(data, val_idx)
# Reinitialize model after each fold to prevent contamination
self.build_model()
result = self.train(config, train_set, val_set)
val_losses.append(result['val_loss'] * len(val_set))
# Return the mean validation loss across all folds
return sum(val_losses) / len(data)
def train(self, train_config: dict, train_set: Dataset, val_set: Optional[Dataset] = None) -> dict:
"""Train the probe model.
Args:
train_config (dict): The training configuration.
train_set (Dataset): The training dataset.
val_set (Dataset, optional): The validation dataset.
Returns:
dict: The training results, including validation loss and accuracy.
"""
logging.debug(
f'Training the probe model with config {train_config}...')
# Set the device
device = torch.device(self.config.device)
self.model.to(device)
# Initialize the optimizer
optimizer_class = getattr(optim, train_config['optimizer'])
optimizer = optimizer_class(
self.parameters(), lr=train_config['learning_rate'])
# Intialize the loss function
loss_fn = getattr(nn, train_config['loss'])()
train_loader = DataLoader(
train_set, batch_size=train_config['batch_size'], shuffle=True)
for epoch in range(train_config['num_epochs']):
# Set the model to training mode
self.model.train()
total_loss = 0
for X, Y in train_loader:
optimizer.zero_grad()
outputs = self.model(X.float())
loss = loss_fn(outputs, Y)
loss.backward()
optimizer.step()
total_loss += loss.item() * X.size(0)
mean_train_loss = total_loss / len(train_set)
logging.debug(
f"--Epoch {epoch + 1}/{train_config['num_epochs']}: Train loss: {mean_train_loss:.4f}")
if val_set:
val_loader = DataLoader(
val_set, batch_size=train_config['batch_size'])
# Set model to eval mode and calculate validation loss
self.model.eval()
val_loss = 0
preds, labels = [], []
with torch.no_grad():
for X_val, Y_val in val_loader:
outputs = self.model(X_val.float())
loss = loss_fn(outputs, Y_val)
val_loss += loss.item() * X_val.size(0)
preds.append(outputs)
labels.append(Y_val)
preds = torch.cat(preds, dim=0)
labels = torch.cat(labels, dim=0)
val_loss = val_loss / len(val_set)
val_acc = (preds.argmax(dim=1) == labels).float().mean().item()
logging.debug(
f'Validation accuracy: {val_acc}, Validation mean loss: {val_loss}')
return {'preds': preds, 'labels': labels, 'val_loss': val_loss, 'val_acc': val_acc}
# TODO: Return train details here
return {}
def evaluate(self, test_set: Dataset) -> dict:
"""Evaluate the probe model on the input test set.
Args:
test_set (Dataset): The test dataset.
Returns:
dict: The evaluation results, including loss and accuracy.
"""
self.model.eval()
device = torch.device(self.config.device)
self.model.to(device)
test_config = self.config.test
test_loader = DataLoader(
test_set, batch_size=test_config['batch_size'])
loss_fn = getattr(nn, test_config['loss'])()
total_loss = 0.0
num_correct, num_samples = 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for X, Y in test_loader:
outputs = self.model(X.float())
loss = loss_fn(outputs, Y)
total_loss += loss.item() * X.size(0) # to account for incomplete batches
preds = outputs.argmax(dim=1)
num_correct += (preds == Y).sum()
num_samples += Y.size(0)
all_preds.append(preds)
all_labels.append(Y)
mean_loss = float(total_loss / len(test_set))
accuracy = float(num_correct / num_samples)
all_preds = torch.cat(all_preds, dim=0).cpu().numpy()
all_labels = torch.cat(all_labels, dim=0).cpu().numpy()
logging.debug(
f'Test accuracy: {accuracy}, Test mean loss: {mean_loss}')
return {'accuracy': accuracy,
'loss': mean_loss,
'labels': all_labels,
'preds': all_preds}
def save_model(self, metadata: Optional[dict] = None) -> None:
"""Saves the trained model to a user-specified path.
Args:
metadata (Optional[dict]): Metadata to save alongside the model.
"""
save_dir = self.config.model.get('save_dir') or 'probe_output'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'probe.pth')
try:
torch.save(self.model.state_dict(), save_path)
logging.debug(f'Model saved to {save_path}')
except Exception as e:
logging.error(f'Failed to save probe model: {e}')
if metadata:
try:
data_path = os.path.join(save_dir, 'probe_data.json')
with open(data_path, 'w') as f:
f.write(json.dumps(metadata, indent=2))
logging.debug(f'Probe metadata saved to {data_path}')
except Exception as e:
logging.error(f'Failed to save metadata: {e}')
def main() -> None:
"""Main function to run the probe."""
config = ProbeConfig()
probe = Probe(config)
# Load data and split into train/val and test
data = probe.data
indices = list(range(len(data)))
train_idx, test_idx = train_test_split(
indices, test_size=0.2, random_state=42)
train_set, test_set = Subset(data, train_idx), Subset(data, test_idx)
# Load all combinations of hyperparameters
train_keys = list(config.training.keys())
train_configs = list(itertools.product(
*[[config.training[k]] if not isinstance(config.training[k], list) else config.training[k] for k in train_keys]))
logging.debug(
f'Hyperparamer tuning using {len(train_configs)} config combinations...')
# Train using k-fold cross validation on all configs and store the lowest validation losses
val_losses = []
for config in train_configs:
val_loss = probe.cross_validate(
dict(zip(train_keys, config)), train_set)
val_losses.append(val_loss)
# Finally, train the model on the whole train_set using best config
min_idx = val_losses.index(min(val_losses))
final_config = dict(zip(train_keys, train_configs[min_idx]))
logging.debug(
f'Model config results after hyperparameter tuning: {final_config}')
# Shuffle the data and train the model again to test generalization
shffl_data = probe.load_data(shuffle=True)
shuffl_train, shuffl_test = Subset(
shffl_data, train_idx), Subset(shffl_data, test_idx)
probe.build_model()
probe.train(final_config, shuffl_train)
shffl_results = probe.evaluate(shuffl_test)
# Reinitialize model to finally train with best config
probe.build_model()
probe.train(final_config, train_set)
test_results = probe.evaluate(test_set)
# Calculate p-value using proportions z-test
shffl_correct = (shffl_results['preds'] == shffl_results['labels']).sum()
test_correct = (test_results['preds'] == test_results['labels']).sum()
pvalue = proportions_ztest([test_correct, shffl_correct],
[len(test_results['preds']), len(shffl_results['preds'])])[1]
# Save results to file with non-shuffled model to file
probe.save_model({'train_config': final_config,
'shuffle_accuracy': shffl_results['accuracy'],
'shuffle_loss': shffl_results['loss'],
'shuffle_preds': shffl_results['preds'].tolist(),
'shuffle_labels': shffl_results['labels'].tolist(),
'test_accuracy': test_results['accuracy'],
'test_loss': test_results['loss'],
'test_preds': test_results['preds'].tolist(),
'test_labels': test_results['labels'].tolist(),
'pvalue': pvalue})
# TODO: implement a demo
if __name__ == '__main__':
main()