kappaTune / selective_fine_tuning.py
oswaldoludwig's picture
Upload 2 files
6a08381 verified
# The class in this script implements a selective fine-tuning method based on the condition number
# Author: Oswaldo Ludwig (now with AI support)
# Date: 03/07/2025
# In case of publication using this script or ideas in this script, cite:
# Ludwig, Oswaldo. "The Condition Number as a Scale-Invariant Proxy for Information Encoding in Neural Units." arXiv preprint arXiv:2506.16289 (2025).
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
import numpy as np
import logging
from typing import Type, Dict, Any, Set, List
# Configure logging (ensure this is at the top level or configured once)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class SelectiveFineTuningOptimizer:
"""
A custom optimizer wrapper that selectively fine-tunes a PyTorch model
based on the condition numbers of its parameters. Parameters with lower
condition numbers are prioritized for fine-tuning.
"""
def __init__(self, model: nn.Module, base_optimizer_cls: Type[optim.Optimizer], optimizer_args: Dict[str, Any],
condition_file: str = 'condition_numbers.json',
num_tensors_to_finetune: int = 100,
recompute: bool = False,
max_dim_size_to_analyze: int = None): # New parameter for filtering
"""
Initializes the SelectiveFineTuningOptimizer.
Args:
model (nn.Module): The PyTorch model to be fine-tuned.
base_optimizer_cls (Type[optim.Optimizer]): The class of the base optimizer (e.g., torch.optim.Adam).
optimizer_args (Dict[str, Any]): A dictionary of arguments to pass to the base optimizer constructor.
condition_file (str): Path to the JSON file for storing/loading condition numbers.
num_tensors_to_finetune (int): The number of top tensors (based on condition number) to fine-tune.
recompute (bool): If True, recompute condition numbers even if the file exists.
max_dim_size_to_analyze (int, optional): If provided, any parameter tensor with at least one dimension
larger than this value will be skipped from analysis.
Useful for ignoring very large embedding matrices etc.
"""
self.model = model
self.condition_file = condition_file
self.num_tensors_to_finetune = num_tensors_to_finetune
self.recompute = recompute
self.max_dim_size_to_analyze = max_dim_size_to_analyze # Store the new parameter
self.condition_numbers: Dict[str, float] = {}
if not os.path.exists(condition_file) or recompute:
self.condition_numbers = self._analyze_model()
self._save_condition_numbers()
else:
self.condition_numbers = self._load_condition_numbers()
self.trainable_param_names: Set[str] = self._select_trainable_parameters()
self._unfreeze_selected_parameters()
# Initialize the base optimizer with selected parameters
params_to_optimize = [p for n, p in model.named_parameters() if n in self.trainable_param_names]
if not params_to_optimize:
logger.warning("No parameters selected for fine-tuning based on the criteria. Optimizer will have no parameters.")
self.optimizer = base_optimizer_cls(params_to_optimize, **optimizer_args)
logger.info(f"Optimizer initialized with {len(params_to_optimize)} trainable parameters.")
def _analyze_model(self) -> Dict[str, float]:
"""
Analyzes the singular values of model parameters to compute their condition numbers.
Parameters with less than 2 dimensions or having any dimension
larger than `max_dim_size_to_analyze` are ignored.
SVD is performed on the GPU if the tensor is on CUDA, otherwise on CPU.
Returns:
Dict[str, float]: A dictionary mapping parameter names to their condition numbers.
"""
condition_numbers = {}
logger.info("Analyzing the model tensors...")
initial_requires_grad_state = {}
for name, param in self.model.named_parameters():
initial_requires_grad_state[name] = param.requires_grad
param.requires_grad = False # Temporarily disable for analysis
analyzed_count = 0
skipped_ndim_count = 0
skipped_dim_size_count = 0 # New counter
skipped_svd_error_count = 0
total_params_in_model = 0
try:
for name, param in self.model.named_parameters():
total_params_in_model += 1
# Filter 1: Skip by number of dimensions
if param.ndim < 2:
logger.debug(f"Skipping {name} due to less than 2 dimensions (ndim={param.ndim}).")
skipped_ndim_count += 1
continue
# Filter 2: Skip by any dimension size exceeding threshold
if self.max_dim_size_to_analyze is not None:
if any(dim_size > self.max_dim_size_to_analyze for dim_size in param.shape):
logger.debug(f"Skipping {name} due to a dimension larger than {self.max_dim_size_to_analyze} (shape={param.shape}).")
skipped_dim_size_count += 1
continue
try:
data = param.detach() # Keep on GPU if already there
if data.is_cuda:
# Perform SVD on GPU
u, s, v = torch.linalg.svd(data, full_matrices=False)
else:
# Fallback to CPU if not on CUDA
u, s, v = torch.linalg.svd(data.cpu(), full_matrices=False)
cond_number = (s[0] / s[-1]).item() if s[-1] > 0 else float('inf')
condition_numbers[name] = cond_number
analyzed_count += 1
logger.debug(f"Analyzed {name}: condition_number={cond_number:.4f}")
except torch.linalg.LinAlgError as e:
logger.warning(f"Skipping {name} due to SVD Linear Algebra error: {e}")
skipped_svd_error_count += 1
except Exception as e:
logger.error(f"Skipping {name} due to unexpected error during SVD: {e}")
skipped_svd_error_count += 1
finally:
# Restore initial requires_grad state (though _unfreeze_selected_parameters will override this)
for name, param in self.model.named_parameters():
param.requires_grad = initial_requires_grad_state[name]
logger.info(f"Done analyzing model tensors. Total parameters in model: {total_params_in_model}")
logger.info(f"Parameters analyzed for condition numbers: {analyzed_count}")
logger.info(f"Skipped due to ndim < 2: {skipped_ndim_count}")
logger.info(f"Skipped due to dimension size > {self.max_dim_size_to_analyze}: {skipped_dim_size_count}") # New log
logger.info(f"Skipped due to SVD errors: {skipped_svd_error_count}")
return condition_numbers
def _save_condition_numbers(self):
"""
Saves the computed condition numbers to a JSON file.
"""
try:
with open(self.condition_file, 'w') as f:
json.dump(self.condition_numbers, f, indent=2)
logger.info(f"Condition numbers saved to {self.condition_file}")
except IOError as e:
logger.error(f"Failed to save condition numbers to {self.condition_file}: {e}")
def _load_condition_numbers(self) -> Dict[str, float]:
"""
Loads condition numbers from a JSON file. If the file is corrupted,
it triggers a recomputation.
Returns:
Dict[str, float]: The loaded condition numbers.
"""
try:
with open(self.condition_file, 'r') as f:
data = json.load(f)
logger.info(f"Condition numbers loaded from {self.condition_file}")
return data
except json.JSONDecodeError as e:
logger.warning(f"Condition file '{self.condition_file}' is corrupted or invalid. Error: {e}. Recomputing.")
if os.path.exists(self.condition_file):
try:
os.remove(self.condition_file) # Remove corrupted file
logger.info(f"Removed corrupted condition file: {self.condition_file}")
except OSError as err:
logger.error(f"Error removing corrupted file {self.condition_file}: {err}")
return self._analyze_model() # Recompute if loading fails
except IOError as e:
logger.error(f"Failed to load condition numbers from {self.condition_file}: {e}. Recomputing.")
return self._analyze_model() # Recompute if file not found or other IO error
def _select_trainable_parameters(self) -> Set[str]:
"""
Selects the top `num_tensors_to_finetune` parameters based on their condition numbers
(lower condition number is better).
Returns:
Set[str]: A set of names of the parameters chosen for fine-tuning.
"""
if not self.condition_numbers:
logger.warning("No condition numbers available to select trainable parameters.")
return set()
sorted_params = sorted(self.condition_numbers.items(), key=lambda x: x[1])
selected = [name for name, _ in sorted_params[:self.num_tensors_to_finetune]]
logger.info(f"Selected {len(selected)} parameters for fine-tuning out of {len(self.condition_numbers)} analyzed.")
logger.debug(f"Selected parameters: {selected}")
return set(selected)
def _unfreeze_selected_parameters(self):
"""
Sets `requires_grad=True` for the selected trainable parameters
and `requires_grad=False` for all other parameters in the model.
"""
total_params = 0
frozen_params_count = 0
unfrozen_params_count = 0
for name, param in self.model.named_parameters():
total_params += 1
if name in self.trainable_param_names:
if not param.requires_grad: # Only change if it's different
param.requires_grad = True
unfrozen_params_count += 1
logger.debug(f"Parameter '{name}' set to requires_grad=True.")
else:
if param.requires_grad: # Only change if it's different
param.requires_grad = False
frozen_params_count += 1
logger.debug(f"Parameter '{name}' set to requires_grad=False.")
logger.info(f"Model parameters configured: {unfrozen_params_count} unfrozen, {frozen_params_count} frozen (out of {total_params} total).")
def step(self):
"""
Performs a single optimization step (parameter update).
Delegates to the base optimizer's step method.
"""
self.optimizer.step()
def zero_grad(self):
"""
Clears the gradients of all optimized parameters.
Delegates to the base optimizer's zero_grad method.
"""
self.optimizer.zero_grad()
def state_dict(self) -> Dict[str, Any]:
"""
Returns a serializable dictionary containing the current state of the optimizer.
Delegates to the base optimizer's state_dict method.
"""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]):
"""
Loads the optimizer's state from a state_dict.
Delegates to the base optimizer's load_state_dict method.
Args:
state_dict (Dict[str, Any]): A dictionary containing the optimizer's state.
"""
self.optimizer.load_state_dict(state_dict)