|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import logging |
|
|
from typing import Type, Dict, Any, Set, List |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SelectiveFineTuningOptimizer: |
|
|
""" |
|
|
A custom optimizer wrapper that selectively fine-tunes a PyTorch model |
|
|
based on the condition numbers of its parameters. Parameters with lower |
|
|
condition numbers are prioritized for fine-tuning. |
|
|
""" |
|
|
def __init__(self, model: nn.Module, base_optimizer_cls: Type[optim.Optimizer], optimizer_args: Dict[str, Any], |
|
|
condition_file: str = 'condition_numbers.json', |
|
|
num_tensors_to_finetune: int = 100, |
|
|
recompute: bool = False, |
|
|
max_dim_size_to_analyze: int = None): |
|
|
""" |
|
|
Initializes the SelectiveFineTuningOptimizer. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The PyTorch model to be fine-tuned. |
|
|
base_optimizer_cls (Type[optim.Optimizer]): The class of the base optimizer (e.g., torch.optim.Adam). |
|
|
optimizer_args (Dict[str, Any]): A dictionary of arguments to pass to the base optimizer constructor. |
|
|
condition_file (str): Path to the JSON file for storing/loading condition numbers. |
|
|
num_tensors_to_finetune (int): The number of top tensors (based on condition number) to fine-tune. |
|
|
recompute (bool): If True, recompute condition numbers even if the file exists. |
|
|
max_dim_size_to_analyze (int, optional): If provided, any parameter tensor with at least one dimension |
|
|
larger than this value will be skipped from analysis. |
|
|
Useful for ignoring very large embedding matrices etc. |
|
|
""" |
|
|
self.model = model |
|
|
self.condition_file = condition_file |
|
|
self.num_tensors_to_finetune = num_tensors_to_finetune |
|
|
self.recompute = recompute |
|
|
self.max_dim_size_to_analyze = max_dim_size_to_analyze |
|
|
|
|
|
self.condition_numbers: Dict[str, float] = {} |
|
|
|
|
|
if not os.path.exists(condition_file) or recompute: |
|
|
self.condition_numbers = self._analyze_model() |
|
|
self._save_condition_numbers() |
|
|
else: |
|
|
self.condition_numbers = self._load_condition_numbers() |
|
|
|
|
|
self.trainable_param_names: Set[str] = self._select_trainable_parameters() |
|
|
self._unfreeze_selected_parameters() |
|
|
|
|
|
|
|
|
params_to_optimize = [p for n, p in model.named_parameters() if n in self.trainable_param_names] |
|
|
if not params_to_optimize: |
|
|
logger.warning("No parameters selected for fine-tuning based on the criteria. Optimizer will have no parameters.") |
|
|
self.optimizer = base_optimizer_cls(params_to_optimize, **optimizer_args) |
|
|
logger.info(f"Optimizer initialized with {len(params_to_optimize)} trainable parameters.") |
|
|
|
|
|
|
|
|
def _analyze_model(self) -> Dict[str, float]: |
|
|
""" |
|
|
Analyzes the singular values of model parameters to compute their condition numbers. |
|
|
Parameters with less than 2 dimensions or having any dimension |
|
|
larger than `max_dim_size_to_analyze` are ignored. |
|
|
SVD is performed on the GPU if the tensor is on CUDA, otherwise on CPU. |
|
|
|
|
|
Returns: |
|
|
Dict[str, float]: A dictionary mapping parameter names to their condition numbers. |
|
|
""" |
|
|
condition_numbers = {} |
|
|
logger.info("Analyzing the model tensors...") |
|
|
|
|
|
initial_requires_grad_state = {} |
|
|
for name, param in self.model.named_parameters(): |
|
|
initial_requires_grad_state[name] = param.requires_grad |
|
|
param.requires_grad = False |
|
|
|
|
|
analyzed_count = 0 |
|
|
skipped_ndim_count = 0 |
|
|
skipped_dim_size_count = 0 |
|
|
skipped_svd_error_count = 0 |
|
|
total_params_in_model = 0 |
|
|
|
|
|
try: |
|
|
for name, param in self.model.named_parameters(): |
|
|
total_params_in_model += 1 |
|
|
|
|
|
if param.ndim < 2: |
|
|
logger.debug(f"Skipping {name} due to less than 2 dimensions (ndim={param.ndim}).") |
|
|
skipped_ndim_count += 1 |
|
|
continue |
|
|
|
|
|
if self.max_dim_size_to_analyze is not None: |
|
|
if any(dim_size > self.max_dim_size_to_analyze for dim_size in param.shape): |
|
|
logger.debug(f"Skipping {name} due to a dimension larger than {self.max_dim_size_to_analyze} (shape={param.shape}).") |
|
|
skipped_dim_size_count += 1 |
|
|
continue |
|
|
|
|
|
try: |
|
|
data = param.detach() |
|
|
if data.is_cuda: |
|
|
|
|
|
u, s, v = torch.linalg.svd(data, full_matrices=False) |
|
|
else: |
|
|
|
|
|
u, s, v = torch.linalg.svd(data.cpu(), full_matrices=False) |
|
|
|
|
|
cond_number = (s[0] / s[-1]).item() if s[-1] > 0 else float('inf') |
|
|
condition_numbers[name] = cond_number |
|
|
analyzed_count += 1 |
|
|
logger.debug(f"Analyzed {name}: condition_number={cond_number:.4f}") |
|
|
except torch.linalg.LinAlgError as e: |
|
|
logger.warning(f"Skipping {name} due to SVD Linear Algebra error: {e}") |
|
|
skipped_svd_error_count += 1 |
|
|
except Exception as e: |
|
|
logger.error(f"Skipping {name} due to unexpected error during SVD: {e}") |
|
|
skipped_svd_error_count += 1 |
|
|
finally: |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
param.requires_grad = initial_requires_grad_state[name] |
|
|
|
|
|
|
|
|
logger.info(f"Done analyzing model tensors. Total parameters in model: {total_params_in_model}") |
|
|
logger.info(f"Parameters analyzed for condition numbers: {analyzed_count}") |
|
|
logger.info(f"Skipped due to ndim < 2: {skipped_ndim_count}") |
|
|
logger.info(f"Skipped due to dimension size > {self.max_dim_size_to_analyze}: {skipped_dim_size_count}") |
|
|
logger.info(f"Skipped due to SVD errors: {skipped_svd_error_count}") |
|
|
return condition_numbers |
|
|
|
|
|
def _save_condition_numbers(self): |
|
|
""" |
|
|
Saves the computed condition numbers to a JSON file. |
|
|
""" |
|
|
try: |
|
|
with open(self.condition_file, 'w') as f: |
|
|
json.dump(self.condition_numbers, f, indent=2) |
|
|
logger.info(f"Condition numbers saved to {self.condition_file}") |
|
|
except IOError as e: |
|
|
logger.error(f"Failed to save condition numbers to {self.condition_file}: {e}") |
|
|
|
|
|
def _load_condition_numbers(self) -> Dict[str, float]: |
|
|
""" |
|
|
Loads condition numbers from a JSON file. If the file is corrupted, |
|
|
it triggers a recomputation. |
|
|
|
|
|
Returns: |
|
|
Dict[str, float]: The loaded condition numbers. |
|
|
""" |
|
|
try: |
|
|
with open(self.condition_file, 'r') as f: |
|
|
data = json.load(f) |
|
|
logger.info(f"Condition numbers loaded from {self.condition_file}") |
|
|
return data |
|
|
except json.JSONDecodeError as e: |
|
|
logger.warning(f"Condition file '{self.condition_file}' is corrupted or invalid. Error: {e}. Recomputing.") |
|
|
if os.path.exists(self.condition_file): |
|
|
try: |
|
|
os.remove(self.condition_file) |
|
|
logger.info(f"Removed corrupted condition file: {self.condition_file}") |
|
|
except OSError as err: |
|
|
logger.error(f"Error removing corrupted file {self.condition_file}: {err}") |
|
|
return self._analyze_model() |
|
|
except IOError as e: |
|
|
logger.error(f"Failed to load condition numbers from {self.condition_file}: {e}. Recomputing.") |
|
|
return self._analyze_model() |
|
|
|
|
|
def _select_trainable_parameters(self) -> Set[str]: |
|
|
""" |
|
|
Selects the top `num_tensors_to_finetune` parameters based on their condition numbers |
|
|
(lower condition number is better). |
|
|
|
|
|
Returns: |
|
|
Set[str]: A set of names of the parameters chosen for fine-tuning. |
|
|
""" |
|
|
if not self.condition_numbers: |
|
|
logger.warning("No condition numbers available to select trainable parameters.") |
|
|
return set() |
|
|
|
|
|
sorted_params = sorted(self.condition_numbers.items(), key=lambda x: x[1]) |
|
|
selected = [name for name, _ in sorted_params[:self.num_tensors_to_finetune]] |
|
|
logger.info(f"Selected {len(selected)} parameters for fine-tuning out of {len(self.condition_numbers)} analyzed.") |
|
|
logger.debug(f"Selected parameters: {selected}") |
|
|
return set(selected) |
|
|
|
|
|
def _unfreeze_selected_parameters(self): |
|
|
""" |
|
|
Sets `requires_grad=True` for the selected trainable parameters |
|
|
and `requires_grad=False` for all other parameters in the model. |
|
|
""" |
|
|
total_params = 0 |
|
|
frozen_params_count = 0 |
|
|
unfrozen_params_count = 0 |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
total_params += 1 |
|
|
if name in self.trainable_param_names: |
|
|
if not param.requires_grad: |
|
|
param.requires_grad = True |
|
|
unfrozen_params_count += 1 |
|
|
logger.debug(f"Parameter '{name}' set to requires_grad=True.") |
|
|
else: |
|
|
if param.requires_grad: |
|
|
param.requires_grad = False |
|
|
frozen_params_count += 1 |
|
|
logger.debug(f"Parameter '{name}' set to requires_grad=False.") |
|
|
|
|
|
logger.info(f"Model parameters configured: {unfrozen_params_count} unfrozen, {frozen_params_count} frozen (out of {total_params} total).") |
|
|
|
|
|
|
|
|
def step(self): |
|
|
""" |
|
|
Performs a single optimization step (parameter update). |
|
|
Delegates to the base optimizer's step method. |
|
|
""" |
|
|
self.optimizer.step() |
|
|
|
|
|
def zero_grad(self): |
|
|
""" |
|
|
Clears the gradients of all optimized parameters. |
|
|
Delegates to the base optimizer's zero_grad method. |
|
|
""" |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
def state_dict(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Returns a serializable dictionary containing the current state of the optimizer. |
|
|
Delegates to the base optimizer's state_dict method. |
|
|
""" |
|
|
return self.optimizer.state_dict() |
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]): |
|
|
""" |
|
|
Loads the optimizer's state from a state_dict. |
|
|
Delegates to the base optimizer's load_state_dict method. |
|
|
|
|
|
Args: |
|
|
state_dict (Dict[str, Any]): A dictionary containing the optimizer's state. |
|
|
""" |
|
|
self.optimizer.load_state_dict(state_dict) |
|
|
|
|
|
|