ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration dataclasses for the active learning driver.
This module provides structured configuration classes that separate different
concerns in the active learning workflow: optimization, training, strategies,
and driver orchestration.
"""
from __future__ import annotations
import math
import uuid
from collections import defaultdict
from dataclasses import dataclass, field
from json import dumps
from pathlib import Path
from typing import Any
from warnings import warn
import torch
from torch import distributed as dist
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from physicsnemo.active_learning import protocols as p
from physicsnemo.active_learning._registry import registry
from physicsnemo.active_learning.loop import DefaultTrainingLoop
from physicsnemo.distributed import DistributedManager
@dataclass
class OptimizerConfig:
"""
Configuration for optimizer and learning rate scheduler.
This encapsulates all training optimization parameters, keeping
them separate from the active learning orchestration logic.
Attributes
----------
optimizer_cls: type[Optimizer]
The optimizer class to use. Defaults to AdamW.
optimizer_kwargs: dict[str, Any]
Keyword arguments to pass to the optimizer constructor.
Defaults to {"lr": 1e-4}.
scheduler_cls: type[_LRScheduler] | None
The learning rate scheduler class to use. If None, no
scheduler will be configured.
scheduler_kwargs: dict[str, Any]
Keyword arguments to pass to the scheduler constructor.
"""
optimizer_cls: type[Optimizer] = AdamW
optimizer_kwargs: dict[str, Any] = field(default_factory=lambda: {"lr": 1e-4})
scheduler_cls: type[_LRScheduler] | None = None
scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate optimizer configuration."""
# Validate learning rate if present
if "lr" in self.optimizer_kwargs:
lr = self.optimizer_kwargs["lr"]
if not isinstance(lr, (int, float)) or lr <= 0:
raise ValueError(f"Learning rate must be positive, got {lr}")
# Validate that scheduler_kwargs is only set if scheduler_cls is provided
if self.scheduler_kwargs and self.scheduler_cls is None:
raise ValueError(
"scheduler_kwargs provided but scheduler_cls is None. "
"Provide a scheduler_cls or remove scheduler_kwargs."
)
def to_dict(self) -> dict[str, Any]:
"""
Returns a JSON-serializable dictionary representation of the OptimizerConfig.
For round-tripping, the registry is used to de-serialize the optimizer and scheduler
classes.
Returns
-------
dict[str, Any]
A dictionary that can be JSON serialized.
"""
opt = {
"__name__": self.optimizer_cls.__name__,
"__module__": self.optimizer_cls.__module__,
}
if self.scheduler_cls:
scheduler = {
"__name__": self.scheduler_cls.__name__,
"__module__": self.scheduler_cls.__module__,
}
else:
scheduler = None
return {
"optimizer_cls": opt,
"optimizer_kwargs": self.optimizer_kwargs,
"scheduler_cls": scheduler,
"scheduler_kwargs": self.scheduler_kwargs,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> OptimizerConfig:
"""
Creates an OptimizerConfig instance from a dictionary.
This method assumes that the optimizer and scheduler classes are
included in the ``physicsnemo.active_learning.registry``, or
a module path is specified to import the class from.
Parameters
----------
data: dict[str, Any]
A dictionary that was previously serialized using the ``to_dict`` method.
Returns
-------
OptimizerConfig
A new ``OptimizerConfig`` instance.
"""
optimizer_cls = registry.get_class(
data["optimizer_cls"]["__name__"], data["optimizer_cls"]["__module__"]
)
if (s := data.get("scheduler_cls")) is not None:
scheduler_cls = registry.get_class(s["__name__"], s["__module__"])
else:
scheduler_cls = None
return cls(
optimizer_cls=optimizer_cls,
optimizer_kwargs=data["optimizer_kwargs"],
scheduler_cls=scheduler_cls,
scheduler_kwargs=data["scheduler_kwargs"],
)
@dataclass
class TrainingConfig:
"""
Configuration for the training phase of active learning.
This groups all training-related components together, making it
clear when training is or isn't being used in the AL workflow.
Attributes
----------
train_datapool: p.DataPool
The pool of labeled data to use for training.
max_training_epochs: int
The maximum number of epochs to train for. If ``max_fine_tuning_epochs``
isn't specified, this value is used for all active learning steps.
val_datapool: p.DataPool | None
Optional pool of data to use for validation during training.
optimizer_config: OptimizerConfig
Configuration for the optimizer and scheduler. Defaults to
AdamW with lr=1e-4, no scheduler.
max_fine_tuning_epochs: int | None
The maximum number of epochs used during fine-tuning steps, i.e. after
the first active learning step. If ``None``, then the fine-tuning will
be performed for the duration of the active learning loop.
train_loop_fn: p.TrainingLoop
The training loop function that orchestrates the training process.
This defaults to a concrete implementation, ``DefaultTrainingLoop``,
which provides a very typical loop that includes the use of static
capture, etc.
"""
train_datapool: p.DataPool
max_training_epochs: int
val_datapool: p.DataPool | None = None
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
max_fine_tuning_epochs: int | None = None
train_loop_fn: p.TrainingLoop = field(default_factory=DefaultTrainingLoop)
def __post_init__(self) -> None:
"""Validate training configuration."""
# Validate datapools have consistent interface
if not hasattr(self.train_datapool, "__len__"):
raise ValueError("train_datapool must implement __len__")
if self.val_datapool is not None and not hasattr(self.val_datapool, "__len__"):
raise ValueError("val_datapool must implement __len__")
# Validate training loop is callable
if not callable(self.train_loop_fn):
raise ValueError("train_loop_fn must be callable")
# set the same value for fine tuning epochs if not provided
if self.max_fine_tuning_epochs is None:
self.max_fine_tuning_epochs = self.max_training_epochs
def to_dict(self) -> dict[str, Any]:
"""
Returns a JSON-serializable dictionary representation of the TrainingConfig.
For round-tripping, the registry is used to de-serialize the training loop
and optimizer configuration. Note that datapools (train_datapool and val_datapool)
are NOT serialized as they typically contain large datasets, file handles, or other
non-serializable state.
Returns
-------
dict[str, Any]
A dictionary that can be JSON serialized. Excludes datapools.
Warnings
--------
This method will issue a warning about the exclusion of datapools.
"""
# Warn about datapool exclusion
warn(
"The `train_datapool` and `val_datapool` attributes are not supported for "
"serialization and will be excluded from the ``TrainingConfig`` dictionary. "
"You must re-provide these datapools when deserializing."
)
# Serialize optimizer config
optimizer_dict = self.optimizer_config.to_dict()
# Serialize training loop function
if not hasattr(self.train_loop_fn, "_args"):
raise ValueError(
f"Training loop {self.train_loop_fn} does not have an `_args` attribute "
"which is required for serialization. Make sure your training loop "
"either subclasses `ActiveLearningProtocol` or implements the `__new__` "
"method to capture object arguments."
)
train_loop_dict = self.train_loop_fn._args
return {
"max_training_epochs": self.max_training_epochs,
"max_fine_tuning_epochs": self.max_fine_tuning_epochs,
"optimizer_config": optimizer_dict,
"train_loop_fn": train_loop_dict,
}
@classmethod
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> TrainingConfig:
"""
Creates a TrainingConfig instance from a dictionary.
This method assumes that the training loop class is included in the
``physicsnemo.active_learning.registry``, or a module path is specified
to import the class from. Note that datapools must be provided via
kwargs as they are not serialized.
Parameters
----------
data: dict[str, Any]
A dictionary that was previously serialized using the ``to_dict`` method.
**kwargs: Any
Additional keyword arguments to pass to the constructor. This is where
you must provide ``train_datapool`` and optionally ``val_datapool``.
Returns
-------
TrainingConfig
A new ``TrainingConfig`` instance.
Raises
------
ValueError
If required datapools are not provided in kwargs, if the data contains
unexpected keys, or if object construction fails.
"""
# Ensure required datapools are provided
if "train_datapool" not in kwargs:
raise ValueError(
"``train_datapool`` must be provided in kwargs when deserializing "
"TrainingConfig, as datapools are not serialized."
)
# Reconstruct optimizer config
optimizer_config = OptimizerConfig.from_dict(data["optimizer_config"])
# Reconstruct training loop function
train_loop_data = data["train_loop_fn"]
train_loop_fn = registry.construct(
train_loop_data["__name__"],
module_path=train_loop_data["__module__"],
**train_loop_data["__args__"],
)
# Build the config
try:
config = cls(
max_training_epochs=data["max_training_epochs"],
max_fine_tuning_epochs=data.get("max_fine_tuning_epochs"),
optimizer_config=optimizer_config,
train_loop_fn=train_loop_fn,
**kwargs,
)
except Exception as e:
raise ValueError(
"Failed to construct ``TrainingConfig`` from dictionary."
) from e
return config
@dataclass
class StrategiesConfig:
"""
Configuration for active learning strategies and data acquisition.
This encapsulates the query-label-metrology cycle that is at the
heart of active learning: strategies for selecting data, labeling it,
and measuring model uncertainty/performance.
Attributes
----------
query_strategies: list[p.QueryStrategy]
The query strategies to use for selecting data to label.
queue_cls: type[p.AbstractQueue]
The queue implementation to use for passing data between
query and labeling phases.
label_strategy: p.LabelStrategy | None
The strategy to use for labeling queried data. If None,
labeling will be skipped.
metrology_strategies: list[p.MetrologyStrategy] | None
Strategies for measuring model performance and uncertainty.
If None, metrology will be skipped.
unlabeled_datapool: p.DataPool | None
Pool of unlabeled data that query strategies can sample from.
Not all strategies require this (some may generate synthetic data).
"""
query_strategies: list[p.QueryStrategy]
queue_cls: type[p.AbstractQueue]
label_strategy: p.LabelStrategy | None = None
metrology_strategies: list[p.MetrologyStrategy] | None = None
unlabeled_datapool: p.DataPool | None = None
def __post_init__(self) -> None:
"""Validate strategies configuration."""
# Must have at least one query strategy
if not self.query_strategies:
raise ValueError(
"At least one query strategy must be provided. "
"Active learning requires a mechanism to select data."
)
# All query strategies must be callable
for strategy in self.query_strategies:
if not callable(strategy):
raise ValueError(f"Query strategy {strategy} must be callable")
# Label strategy must be callable if provided
if self.label_strategy is not None and not callable(self.label_strategy):
raise ValueError("label_strategy must be callable")
# Metrology strategies must be callable if provided
if self.metrology_strategies is not None:
if not self.metrology_strategies:
raise ValueError(
"metrology_strategies is an empty list. "
"Either provide strategies or set to None to skip metrology."
)
for strategy in self.metrology_strategies:
if not callable(strategy):
raise ValueError(f"Metrology strategy {strategy} must be callable")
# Validate queue class has basic queue interface
if not hasattr(self.queue_cls, "__call__"):
raise ValueError("queue_cls must be a callable class")
def to_dict(self) -> dict[str, Any]:
"""
Method that converts the present ``StrategiesConfig`` instance into a dictionary
that can be JSON serialized.
This method, for the most part, assumes that strategies are subclasses of
``ActiveLearningProtocol`` and/or they have an ``_args`` attribute that
captures the arguments to the constructor.
One issue is the inability to reliably serialize the ``unlabeled_datapool``,
which for the most part, likely does not need serialization as a dataset.
Regardless, this method will trigger a warning if ``unlabeled_datapool`` is
not None.
Returns
-------
dict[str, Any]
A dictionary that can be JSON serialized.
"""
output = defaultdict(list)
for strategy in self.query_strategies:
if not hasattr(strategy, "_args"):
raise ValueError(
f"Query strategy {strategy} does not have an `_args` attribute"
" which is required for serialization. Make sure your strategy"
" either subclasses `ActiveLearningProtocol` or implements"
" the `__new__` method to capture object arguments."
)
output["query_strategies"].append(strategy._args)
if self.label_strategy is not None:
if not hasattr(self.label_strategy, "_args"):
raise ValueError(
f"Label strategy {self.label_strategy} does not have an `_args` attribute"
" which is required for serialization. Make sure your strategy"
" either subclasses `ActiveLearningProtocol` or implements"
" the `__new__` method to capture object arguments."
)
output["label_strategy"] = self.label_strategy._args
output["queue_cls"] = {
"__name__": self.queue_cls.__name__,
"__module__": self.queue_cls.__module__,
}
if self.metrology_strategies is not None:
for strategy in self.metrology_strategies:
if not hasattr(strategy, "_args"):
raise ValueError(
f"Metrology strategy {strategy} does not have an `_args` attribute"
" which is required for serialization. Make sure your strategy"
" either subclasses `ActiveLearningProtocol` or implements"
" the `__new__` method to capture object arguments."
)
output["metrology_strategies"].append(strategy._args)
if self.unlabeled_datapool is not None:
warn(
"The `unlabeled_datapool` attribute is not supported for serialization"
" and will be excluded from the ``StrategiesConfig`` dictionary."
)
return output
@classmethod
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> StrategiesConfig:
"""
Create a ``StrategiesConfig`` instance from a dictionary.
This method heavily relies on classes being added to the
``physicsnemo.active_learning.registry``, which is used to instantiate
all strategies and custom types used in active learning. As a fall
back, the `registry.construct` method will try and import the class
from the module path if it is not found in the registry.
Parameters
----------
data: dict[str, Any]
A dictionary that was previously serialized using the ``to_dict`` method.
**kwargs: Any
Additional keyword arguments to pass to the constructor.
Returns
-------
StrategiesConfig
A new ``StrategiesConfig`` instance.
Raises
------
ValueError:
If the data contains unexpected keys or if the object construction fails.
NameError:
If a class is not found in the registry and no module path is provided.
ModuleNotFoundError:
If a module is not found with the specified module path.
"""
# ensure that the data contains no unexpected keys
data_keys = set(data.keys())
expected_keys = set(cls.__dataclass_fields__.keys())
extra_keys = data_keys - expected_keys
if extra_keys:
raise ValueError(
f"Unexpected keys in data: {extra_keys}. Expected keys are {expected_keys}."
)
# instantiate objects from the serialized data; general strategy is to
# use `registry.construct` that will try and resolve the class within
# the registry first, and if not found, then it will try and import the
# class from the module path.
output_dict = defaultdict(list)
for entry in data["query_strategies"]:
output_dict["query_strategies"].append(
registry.construct(
entry["__name__"],
module_path=entry["__module__"],
**entry["__args__"],
)
)
if "metrology_strategies" in data:
for entry in data["metrology_strategies"]:
output_dict["metrology_strategies"].append(
registry.construct(
entry["__name__"],
module_path=entry["__module__"],
**entry["__args__"],
)
)
if "label_strategy" in data:
output_dict["label_strategy"] = registry.construct(
data["label_strategy"]["__name__"],
module_path=data["label_strategy"]["__module__"],
**data["label_strategy"]["__args__"],
)
output_dict["queue_cls"] = registry.get_class(
data["queue_cls"]["__name__"], data["queue_cls"]["__module__"]
)
# potentially override with keyword arguments
output_dict.update(kwargs)
try:
config = cls(**output_dict)
except Exception as e:
raise ValueError(
"Failed to construct ``StrategiesConfig`` from dictionary."
) from e
return config
@dataclass
class DriverConfig:
"""
Configuration for driver orchestration and infrastructure.
This contains parameters that control the overall loop execution,
logging, checkpointing, and distributed training setup - orthogonal
to the specific AL or training logic.
Attributes
----------
batch_size: int
The batch size to use for data loaders.
max_active_learning_steps: int | None, default None
Maximum number of AL iterations to perform. None means infinite.
run_id: str, default auto-generated UUID
Unique identifier for this run. Auto-generated if not provided.
fine_tuning_lr: float | None, default None
Learning rate to switch to after the first AL step for fine-tuning.
reset_optim_states: bool, default True
Whether to reset optimizer states between AL steps.
skip_training: bool, default False
If True, skip the training phase entirely.
skip_metrology: bool, default False
If True, skip the metrology phase entirely.
skip_labeling: bool, default False
If True, skip the labeling phase entirely.
checkpoint_interval: int, default 1
Save model checkpoint every N AL steps. 0 disables checkpointing.
checkpoint_on_training: bool, default False
If True, save checkpoint at the start of the training phase.
checkpoint_on_metrology: bool, default False
If True, save checkpoint at the start of the metrology phase.
checkpoint_on_query: bool, default False
If True, save checkpoint at the start of the query phase.
checkpoint_on_labeling: bool, default True
If True, save checkpoint at the start of the labeling phase.
model_checkpoint_frequency: int, default 0
Save model weights every N epochs during training. 0 means only save
between active learning phases. Useful for mid-training restarts.
root_log_dir: str | Path, default Path.cwd() / "active_learning_logs"
Directory to save logs and checkpoints to. Defaults to
an 'active_learning_logs' directory in the current working directory.
dist_manager: DistributedManager | None, default None
Manager for distributed training configuration.
collate_fn: callable | None, default None
Custom collate function for batching data.
num_dataloader_workers: int, default 0
Number of worker processes for data loading.
device: str | torch.device | None, default None
Device to use for model and data. This is intended for single process
workflows; for distributed workflows, the device should be set in
``DistributedManager`` instead. If not specified, then the device
will default to ``torch.get_default_device()``.
dtype: torch.dtype | None, default None
The dtype to use for model and data, and AMP contexts. If not provided,
then the dtype will default to ``torch.get_default_dtype()``.
"""
batch_size: int
max_active_learning_steps: int | None = None
run_id: str = field(default_factory=lambda: str(uuid.uuid4()))
fine_tuning_lr: float | None = None # TODO: move to TrainingConfig
reset_optim_states: bool = True
skip_training: bool = False
skip_metrology: bool = False
skip_labeling: bool = False
checkpoint_interval: int = 1
checkpoint_on_training: bool = False
checkpoint_on_metrology: bool = False
checkpoint_on_query: bool = False
checkpoint_on_labeling: bool = True
model_checkpoint_frequency: int = 0
root_log_dir: str | Path = field(default=Path.cwd() / "active_learning_logs")
dist_manager: DistributedManager | None = None
collate_fn: callable | None = None
num_dataloader_workers: int = 0
device: str | torch.device | None = None
dtype: torch.dtype | None = None
def __post_init__(self) -> None:
"""Validate driver configuration."""
if self.max_active_learning_steps is None:
self.max_active_learning_steps = float("inf")
if (
self.max_active_learning_steps is not None
and self.max_active_learning_steps <= 0
):
raise ValueError(
"`max_active_learning_steps` must be a positive integer or None."
)
if not math.isfinite(self.batch_size) or self.batch_size <= 0:
raise ValueError("`batch_size` must be a positive integer.")
if not math.isfinite(self.checkpoint_interval) or self.checkpoint_interval < 0:
raise ValueError(
"`checkpoint_interval` must be a non-negative integer. "
"Use 0 to disable checkpointing."
)
if self.fine_tuning_lr is not None and self.fine_tuning_lr <= 0:
raise ValueError("`fine_tuning_lr` must be positive if provided.")
if self.num_dataloader_workers < 0:
raise ValueError("`num_dataloader_workers` must be non-negative.")
if self.model_checkpoint_frequency < 0:
raise ValueError("`model_checkpoint_frequency` must be non-negative.")
if isinstance(self.root_log_dir, str):
self.root_log_dir = Path(self.root_log_dir)
# Validate collate_fn if provided
if self.collate_fn is not None and not callable(self.collate_fn):
raise ValueError("`collate_fn` must be callable if provided.")
# device and dtype setup when not using DistributedManager
if self.device is None and not self.dist_manager:
self.device = torch.get_default_device()
if self.dtype is None:
self.dtype = torch.get_default_dtype()
def to_json(self) -> str:
"""
Returns a JSON string representation of the ``DriverConfig``.
Note that certain fields are not serialized and must be provided when
deserializing: ``dist_manager``, ``collate_fn``.
Returns
-------
str
A JSON string representation of the config.
"""
# base dict representation skips Python objects
dict_repr = {
key: self.__dict__[key]
for key in self.__dict__
if key
not in ["dist_manager", "collate_fn", "root_log_dir", "device", "dtype"]
}
# Note: checkpoint flags are included in dict_repr automatically
dict_repr["default_dtype"] = str(torch.get_default_dtype())
dict_repr["log_dir"] = str(self.root_log_dir)
# Convert dtype to string for JSON serialization
if self.dtype is not None:
dict_repr["dtype"] = str(self.dtype)
else:
dict_repr["dtype"] = None
if self.dist_manager is not None:
dict_repr["world_size"] = self.dist_manager.world_size
dict_repr["device"] = self.dist_manager.device.type
dict_repr["dist_manager_init_method"] = (
self.dist_manager._initialization_method
)
else:
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
dict_repr["world_size"] = world_size
if self.device is not None:
dict_repr["device"] = (
str(self.device)
if hasattr(self.device, "type")
else str(self.device)
)
else:
dict_repr["device"] = torch.get_default_device().type
dict_repr["dist_manager_init_method"] = None
if self.collate_fn is not None:
dict_repr["collate_fn"] = self.collate_fn.__name__
else:
dict_repr["collate_fn"] = None
return dumps(dict_repr, indent=2)
@classmethod
def from_json(cls, json_str: str, **kwargs: Any) -> DriverConfig:
"""
Creates a DriverConfig instance from a JSON string.
This method reconstructs a DriverConfig from JSON. Note that certain
fields cannot be serialized and must be provided via kwargs:
- ``dist_manager``: DistributedManager instance (optional)
- ``collate_fn``: Custom collate function (optional)
Parameters
----------
json_str: str
A JSON string that was previously serialized using ``to_json()``.
**kwargs: Any
Additional keyword arguments to override or provide non-serializable
fields like ``dist_manager`` and ``collate_fn``.
Returns
-------
DriverConfig
A new ``DriverConfig`` instance.
Raises
------
ValueError
If the JSON cannot be parsed or required fields are missing.
Notes
-----
The device and dtype fields are reconstructed from their string
representations. The ``log_dir`` field in JSON is mapped to
``root_log_dir`` in the config.
"""
import json
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON string: {e}") from e
# Define fields that are not actual DriverConfig constructor parameters
metadata_fields = [
"default_dtype",
"world_size",
"dist_manager_init_method",
"log_dir", # handled separately as root_log_dir
]
non_serializable_fields = [
"dist_manager",
"collate_fn",
"root_log_dir",
"device",
"dtype",
]
# Extract serializable fields that map directly
config_fields = {
key: value
for key, value in data.items()
if key not in metadata_fields + non_serializable_fields
}
# Handle root_log_dir (stored as "log_dir" in JSON)
if "log_dir" in data:
config_fields["root_log_dir"] = Path(data["log_dir"])
# Handle device reconstruction from string
if "device" in data and data["device"] is not None:
device_str = data["device"]
# Parse device strings like "cuda:0", "cpu", "cuda", etc.
config_fields["device"] = torch.device(device_str)
# Handle dtype reconstruction from string
if "dtype" in data and data["dtype"] is not None:
dtype_str = data["dtype"]
# Map string representations to torch dtypes
dtype_map = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
"torch.int8": torch.int8,
"torch.uint8": torch.uint8,
}
if dtype_str in dtype_map:
config_fields["dtype"] = dtype_map[dtype_str]
else:
warn(
f"Unknown dtype string '{dtype_str}' in JSON. "
"Using default dtype instead."
)
# Merge with provided kwargs (allows overriding and adding non-serializable fields)
config_fields.update(kwargs)
# Create the config
try:
config = cls(**config_fields)
except Exception as e:
raise ValueError(
"Failed to construct ``DriverConfig`` from JSON string."
) from e
return config