Spaces:
Sleeping
Sleeping
File size: 32,592 Bytes
c3d0544 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 | # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration dataclasses for the active learning driver.
This module provides structured configuration classes that separate different
concerns in the active learning workflow: optimization, training, strategies,
and driver orchestration.
"""
from __future__ import annotations
import math
import uuid
from collections import defaultdict
from dataclasses import dataclass, field
from json import dumps
from pathlib import Path
from typing import Any
from warnings import warn
import torch
from torch import distributed as dist
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from physicsnemo.active_learning import protocols as p
from physicsnemo.active_learning._registry import registry
from physicsnemo.active_learning.loop import DefaultTrainingLoop
from physicsnemo.distributed import DistributedManager
@dataclass
class OptimizerConfig:
"""
Configuration for optimizer and learning rate scheduler.
This encapsulates all training optimization parameters, keeping
them separate from the active learning orchestration logic.
Attributes
----------
optimizer_cls: type[Optimizer]
The optimizer class to use. Defaults to AdamW.
optimizer_kwargs: dict[str, Any]
Keyword arguments to pass to the optimizer constructor.
Defaults to {"lr": 1e-4}.
scheduler_cls: type[_LRScheduler] | None
The learning rate scheduler class to use. If None, no
scheduler will be configured.
scheduler_kwargs: dict[str, Any]
Keyword arguments to pass to the scheduler constructor.
"""
optimizer_cls: type[Optimizer] = AdamW
optimizer_kwargs: dict[str, Any] = field(default_factory=lambda: {"lr": 1e-4})
scheduler_cls: type[_LRScheduler] | None = None
scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate optimizer configuration."""
# Validate learning rate if present
if "lr" in self.optimizer_kwargs:
lr = self.optimizer_kwargs["lr"]
if not isinstance(lr, (int, float)) or lr <= 0:
raise ValueError(f"Learning rate must be positive, got {lr}")
# Validate that scheduler_kwargs is only set if scheduler_cls is provided
if self.scheduler_kwargs and self.scheduler_cls is None:
raise ValueError(
"scheduler_kwargs provided but scheduler_cls is None. "
"Provide a scheduler_cls or remove scheduler_kwargs."
)
def to_dict(self) -> dict[str, Any]:
"""
Returns a JSON-serializable dictionary representation of the OptimizerConfig.
For round-tripping, the registry is used to de-serialize the optimizer and scheduler
classes.
Returns
-------
dict[str, Any]
A dictionary that can be JSON serialized.
"""
opt = {
"__name__": self.optimizer_cls.__name__,
"__module__": self.optimizer_cls.__module__,
}
if self.scheduler_cls:
scheduler = {
"__name__": self.scheduler_cls.__name__,
"__module__": self.scheduler_cls.__module__,
}
else:
scheduler = None
return {
"optimizer_cls": opt,
"optimizer_kwargs": self.optimizer_kwargs,
"scheduler_cls": scheduler,
"scheduler_kwargs": self.scheduler_kwargs,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> OptimizerConfig:
"""
Creates an OptimizerConfig instance from a dictionary.
This method assumes that the optimizer and scheduler classes are
included in the ``physicsnemo.active_learning.registry``, or
a module path is specified to import the class from.
Parameters
----------
data: dict[str, Any]
A dictionary that was previously serialized using the ``to_dict`` method.
Returns
-------
OptimizerConfig
A new ``OptimizerConfig`` instance.
"""
optimizer_cls = registry.get_class(
data["optimizer_cls"]["__name__"], data["optimizer_cls"]["__module__"]
)
if (s := data.get("scheduler_cls")) is not None:
scheduler_cls = registry.get_class(s["__name__"], s["__module__"])
else:
scheduler_cls = None
return cls(
optimizer_cls=optimizer_cls,
optimizer_kwargs=data["optimizer_kwargs"],
scheduler_cls=scheduler_cls,
scheduler_kwargs=data["scheduler_kwargs"],
)
@dataclass
class TrainingConfig:
"""
Configuration for the training phase of active learning.
This groups all training-related components together, making it
clear when training is or isn't being used in the AL workflow.
Attributes
----------
train_datapool: p.DataPool
The pool of labeled data to use for training.
max_training_epochs: int
The maximum number of epochs to train for. If ``max_fine_tuning_epochs``
isn't specified, this value is used for all active learning steps.
val_datapool: p.DataPool | None
Optional pool of data to use for validation during training.
optimizer_config: OptimizerConfig
Configuration for the optimizer and scheduler. Defaults to
AdamW with lr=1e-4, no scheduler.
max_fine_tuning_epochs: int | None
The maximum number of epochs used during fine-tuning steps, i.e. after
the first active learning step. If ``None``, then the fine-tuning will
be performed for the duration of the active learning loop.
train_loop_fn: p.TrainingLoop
The training loop function that orchestrates the training process.
This defaults to a concrete implementation, ``DefaultTrainingLoop``,
which provides a very typical loop that includes the use of static
capture, etc.
"""
train_datapool: p.DataPool
max_training_epochs: int
val_datapool: p.DataPool | None = None
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
max_fine_tuning_epochs: int | None = None
train_loop_fn: p.TrainingLoop = field(default_factory=DefaultTrainingLoop)
def __post_init__(self) -> None:
"""Validate training configuration."""
# Validate datapools have consistent interface
if not hasattr(self.train_datapool, "__len__"):
raise ValueError("train_datapool must implement __len__")
if self.val_datapool is not None and not hasattr(self.val_datapool, "__len__"):
raise ValueError("val_datapool must implement __len__")
# Validate training loop is callable
if not callable(self.train_loop_fn):
raise ValueError("train_loop_fn must be callable")
# set the same value for fine tuning epochs if not provided
if self.max_fine_tuning_epochs is None:
self.max_fine_tuning_epochs = self.max_training_epochs
def to_dict(self) -> dict[str, Any]:
"""
Returns a JSON-serializable dictionary representation of the TrainingConfig.
For round-tripping, the registry is used to de-serialize the training loop
and optimizer configuration. Note that datapools (train_datapool and val_datapool)
are NOT serialized as they typically contain large datasets, file handles, or other
non-serializable state.
Returns
-------
dict[str, Any]
A dictionary that can be JSON serialized. Excludes datapools.
Warnings
--------
This method will issue a warning about the exclusion of datapools.
"""
# Warn about datapool exclusion
warn(
"The `train_datapool` and `val_datapool` attributes are not supported for "
"serialization and will be excluded from the ``TrainingConfig`` dictionary. "
"You must re-provide these datapools when deserializing."
)
# Serialize optimizer config
optimizer_dict = self.optimizer_config.to_dict()
# Serialize training loop function
if not hasattr(self.train_loop_fn, "_args"):
raise ValueError(
f"Training loop {self.train_loop_fn} does not have an `_args` attribute "
"which is required for serialization. Make sure your training loop "
"either subclasses `ActiveLearningProtocol` or implements the `__new__` "
"method to capture object arguments."
)
train_loop_dict = self.train_loop_fn._args
return {
"max_training_epochs": self.max_training_epochs,
"max_fine_tuning_epochs": self.max_fine_tuning_epochs,
"optimizer_config": optimizer_dict,
"train_loop_fn": train_loop_dict,
}
@classmethod
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> TrainingConfig:
"""
Creates a TrainingConfig instance from a dictionary.
This method assumes that the training loop class is included in the
``physicsnemo.active_learning.registry``, or a module path is specified
to import the class from. Note that datapools must be provided via
kwargs as they are not serialized.
Parameters
----------
data: dict[str, Any]
A dictionary that was previously serialized using the ``to_dict`` method.
**kwargs: Any
Additional keyword arguments to pass to the constructor. This is where
you must provide ``train_datapool`` and optionally ``val_datapool``.
Returns
-------
TrainingConfig
A new ``TrainingConfig`` instance.
Raises
------
ValueError
If required datapools are not provided in kwargs, if the data contains
unexpected keys, or if object construction fails.
"""
# Ensure required datapools are provided
if "train_datapool" not in kwargs:
raise ValueError(
"``train_datapool`` must be provided in kwargs when deserializing "
"TrainingConfig, as datapools are not serialized."
)
# Reconstruct optimizer config
optimizer_config = OptimizerConfig.from_dict(data["optimizer_config"])
# Reconstruct training loop function
train_loop_data = data["train_loop_fn"]
train_loop_fn = registry.construct(
train_loop_data["__name__"],
module_path=train_loop_data["__module__"],
**train_loop_data["__args__"],
)
# Build the config
try:
config = cls(
max_training_epochs=data["max_training_epochs"],
max_fine_tuning_epochs=data.get("max_fine_tuning_epochs"),
optimizer_config=optimizer_config,
train_loop_fn=train_loop_fn,
**kwargs,
)
except Exception as e:
raise ValueError(
"Failed to construct ``TrainingConfig`` from dictionary."
) from e
return config
@dataclass
class StrategiesConfig:
"""
Configuration for active learning strategies and data acquisition.
This encapsulates the query-label-metrology cycle that is at the
heart of active learning: strategies for selecting data, labeling it,
and measuring model uncertainty/performance.
Attributes
----------
query_strategies: list[p.QueryStrategy]
The query strategies to use for selecting data to label.
queue_cls: type[p.AbstractQueue]
The queue implementation to use for passing data between
query and labeling phases.
label_strategy: p.LabelStrategy | None
The strategy to use for labeling queried data. If None,
labeling will be skipped.
metrology_strategies: list[p.MetrologyStrategy] | None
Strategies for measuring model performance and uncertainty.
If None, metrology will be skipped.
unlabeled_datapool: p.DataPool | None
Pool of unlabeled data that query strategies can sample from.
Not all strategies require this (some may generate synthetic data).
"""
query_strategies: list[p.QueryStrategy]
queue_cls: type[p.AbstractQueue]
label_strategy: p.LabelStrategy | None = None
metrology_strategies: list[p.MetrologyStrategy] | None = None
unlabeled_datapool: p.DataPool | None = None
def __post_init__(self) -> None:
"""Validate strategies configuration."""
# Must have at least one query strategy
if not self.query_strategies:
raise ValueError(
"At least one query strategy must be provided. "
"Active learning requires a mechanism to select data."
)
# All query strategies must be callable
for strategy in self.query_strategies:
if not callable(strategy):
raise ValueError(f"Query strategy {strategy} must be callable")
# Label strategy must be callable if provided
if self.label_strategy is not None and not callable(self.label_strategy):
raise ValueError("label_strategy must be callable")
# Metrology strategies must be callable if provided
if self.metrology_strategies is not None:
if not self.metrology_strategies:
raise ValueError(
"metrology_strategies is an empty list. "
"Either provide strategies or set to None to skip metrology."
)
for strategy in self.metrology_strategies:
if not callable(strategy):
raise ValueError(f"Metrology strategy {strategy} must be callable")
# Validate queue class has basic queue interface
if not hasattr(self.queue_cls, "__call__"):
raise ValueError("queue_cls must be a callable class")
def to_dict(self) -> dict[str, Any]:
"""
Method that converts the present ``StrategiesConfig`` instance into a dictionary
that can be JSON serialized.
This method, for the most part, assumes that strategies are subclasses of
``ActiveLearningProtocol`` and/or they have an ``_args`` attribute that
captures the arguments to the constructor.
One issue is the inability to reliably serialize the ``unlabeled_datapool``,
which for the most part, likely does not need serialization as a dataset.
Regardless, this method will trigger a warning if ``unlabeled_datapool`` is
not None.
Returns
-------
dict[str, Any]
A dictionary that can be JSON serialized.
"""
output = defaultdict(list)
for strategy in self.query_strategies:
if not hasattr(strategy, "_args"):
raise ValueError(
f"Query strategy {strategy} does not have an `_args` attribute"
" which is required for serialization. Make sure your strategy"
" either subclasses `ActiveLearningProtocol` or implements"
" the `__new__` method to capture object arguments."
)
output["query_strategies"].append(strategy._args)
if self.label_strategy is not None:
if not hasattr(self.label_strategy, "_args"):
raise ValueError(
f"Label strategy {self.label_strategy} does not have an `_args` attribute"
" which is required for serialization. Make sure your strategy"
" either subclasses `ActiveLearningProtocol` or implements"
" the `__new__` method to capture object arguments."
)
output["label_strategy"] = self.label_strategy._args
output["queue_cls"] = {
"__name__": self.queue_cls.__name__,
"__module__": self.queue_cls.__module__,
}
if self.metrology_strategies is not None:
for strategy in self.metrology_strategies:
if not hasattr(strategy, "_args"):
raise ValueError(
f"Metrology strategy {strategy} does not have an `_args` attribute"
" which is required for serialization. Make sure your strategy"
" either subclasses `ActiveLearningProtocol` or implements"
" the `__new__` method to capture object arguments."
)
output["metrology_strategies"].append(strategy._args)
if self.unlabeled_datapool is not None:
warn(
"The `unlabeled_datapool` attribute is not supported for serialization"
" and will be excluded from the ``StrategiesConfig`` dictionary."
)
return output
@classmethod
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> StrategiesConfig:
"""
Create a ``StrategiesConfig`` instance from a dictionary.
This method heavily relies on classes being added to the
``physicsnemo.active_learning.registry``, which is used to instantiate
all strategies and custom types used in active learning. As a fall
back, the `registry.construct` method will try and import the class
from the module path if it is not found in the registry.
Parameters
----------
data: dict[str, Any]
A dictionary that was previously serialized using the ``to_dict`` method.
**kwargs: Any
Additional keyword arguments to pass to the constructor.
Returns
-------
StrategiesConfig
A new ``StrategiesConfig`` instance.
Raises
------
ValueError:
If the data contains unexpected keys or if the object construction fails.
NameError:
If a class is not found in the registry and no module path is provided.
ModuleNotFoundError:
If a module is not found with the specified module path.
"""
# ensure that the data contains no unexpected keys
data_keys = set(data.keys())
expected_keys = set(cls.__dataclass_fields__.keys())
extra_keys = data_keys - expected_keys
if extra_keys:
raise ValueError(
f"Unexpected keys in data: {extra_keys}. Expected keys are {expected_keys}."
)
# instantiate objects from the serialized data; general strategy is to
# use `registry.construct` that will try and resolve the class within
# the registry first, and if not found, then it will try and import the
# class from the module path.
output_dict = defaultdict(list)
for entry in data["query_strategies"]:
output_dict["query_strategies"].append(
registry.construct(
entry["__name__"],
module_path=entry["__module__"],
**entry["__args__"],
)
)
if "metrology_strategies" in data:
for entry in data["metrology_strategies"]:
output_dict["metrology_strategies"].append(
registry.construct(
entry["__name__"],
module_path=entry["__module__"],
**entry["__args__"],
)
)
if "label_strategy" in data:
output_dict["label_strategy"] = registry.construct(
data["label_strategy"]["__name__"],
module_path=data["label_strategy"]["__module__"],
**data["label_strategy"]["__args__"],
)
output_dict["queue_cls"] = registry.get_class(
data["queue_cls"]["__name__"], data["queue_cls"]["__module__"]
)
# potentially override with keyword arguments
output_dict.update(kwargs)
try:
config = cls(**output_dict)
except Exception as e:
raise ValueError(
"Failed to construct ``StrategiesConfig`` from dictionary."
) from e
return config
@dataclass
class DriverConfig:
"""
Configuration for driver orchestration and infrastructure.
This contains parameters that control the overall loop execution,
logging, checkpointing, and distributed training setup - orthogonal
to the specific AL or training logic.
Attributes
----------
batch_size: int
The batch size to use for data loaders.
max_active_learning_steps: int | None, default None
Maximum number of AL iterations to perform. None means infinite.
run_id: str, default auto-generated UUID
Unique identifier for this run. Auto-generated if not provided.
fine_tuning_lr: float | None, default None
Learning rate to switch to after the first AL step for fine-tuning.
reset_optim_states: bool, default True
Whether to reset optimizer states between AL steps.
skip_training: bool, default False
If True, skip the training phase entirely.
skip_metrology: bool, default False
If True, skip the metrology phase entirely.
skip_labeling: bool, default False
If True, skip the labeling phase entirely.
checkpoint_interval: int, default 1
Save model checkpoint every N AL steps. 0 disables checkpointing.
checkpoint_on_training: bool, default False
If True, save checkpoint at the start of the training phase.
checkpoint_on_metrology: bool, default False
If True, save checkpoint at the start of the metrology phase.
checkpoint_on_query: bool, default False
If True, save checkpoint at the start of the query phase.
checkpoint_on_labeling: bool, default True
If True, save checkpoint at the start of the labeling phase.
model_checkpoint_frequency: int, default 0
Save model weights every N epochs during training. 0 means only save
between active learning phases. Useful for mid-training restarts.
root_log_dir: str | Path, default Path.cwd() / "active_learning_logs"
Directory to save logs and checkpoints to. Defaults to
an 'active_learning_logs' directory in the current working directory.
dist_manager: DistributedManager | None, default None
Manager for distributed training configuration.
collate_fn: callable | None, default None
Custom collate function for batching data.
num_dataloader_workers: int, default 0
Number of worker processes for data loading.
device: str | torch.device | None, default None
Device to use for model and data. This is intended for single process
workflows; for distributed workflows, the device should be set in
``DistributedManager`` instead. If not specified, then the device
will default to ``torch.get_default_device()``.
dtype: torch.dtype | None, default None
The dtype to use for model and data, and AMP contexts. If not provided,
then the dtype will default to ``torch.get_default_dtype()``.
"""
batch_size: int
max_active_learning_steps: int | None = None
run_id: str = field(default_factory=lambda: str(uuid.uuid4()))
fine_tuning_lr: float | None = None # TODO: move to TrainingConfig
reset_optim_states: bool = True
skip_training: bool = False
skip_metrology: bool = False
skip_labeling: bool = False
checkpoint_interval: int = 1
checkpoint_on_training: bool = False
checkpoint_on_metrology: bool = False
checkpoint_on_query: bool = False
checkpoint_on_labeling: bool = True
model_checkpoint_frequency: int = 0
root_log_dir: str | Path = field(default=Path.cwd() / "active_learning_logs")
dist_manager: DistributedManager | None = None
collate_fn: callable | None = None
num_dataloader_workers: int = 0
device: str | torch.device | None = None
dtype: torch.dtype | None = None
def __post_init__(self) -> None:
"""Validate driver configuration."""
if self.max_active_learning_steps is None:
self.max_active_learning_steps = float("inf")
if (
self.max_active_learning_steps is not None
and self.max_active_learning_steps <= 0
):
raise ValueError(
"`max_active_learning_steps` must be a positive integer or None."
)
if not math.isfinite(self.batch_size) or self.batch_size <= 0:
raise ValueError("`batch_size` must be a positive integer.")
if not math.isfinite(self.checkpoint_interval) or self.checkpoint_interval < 0:
raise ValueError(
"`checkpoint_interval` must be a non-negative integer. "
"Use 0 to disable checkpointing."
)
if self.fine_tuning_lr is not None and self.fine_tuning_lr <= 0:
raise ValueError("`fine_tuning_lr` must be positive if provided.")
if self.num_dataloader_workers < 0:
raise ValueError("`num_dataloader_workers` must be non-negative.")
if self.model_checkpoint_frequency < 0:
raise ValueError("`model_checkpoint_frequency` must be non-negative.")
if isinstance(self.root_log_dir, str):
self.root_log_dir = Path(self.root_log_dir)
# Validate collate_fn if provided
if self.collate_fn is not None and not callable(self.collate_fn):
raise ValueError("`collate_fn` must be callable if provided.")
# device and dtype setup when not using DistributedManager
if self.device is None and not self.dist_manager:
self.device = torch.get_default_device()
if self.dtype is None:
self.dtype = torch.get_default_dtype()
def to_json(self) -> str:
"""
Returns a JSON string representation of the ``DriverConfig``.
Note that certain fields are not serialized and must be provided when
deserializing: ``dist_manager``, ``collate_fn``.
Returns
-------
str
A JSON string representation of the config.
"""
# base dict representation skips Python objects
dict_repr = {
key: self.__dict__[key]
for key in self.__dict__
if key
not in ["dist_manager", "collate_fn", "root_log_dir", "device", "dtype"]
}
# Note: checkpoint flags are included in dict_repr automatically
dict_repr["default_dtype"] = str(torch.get_default_dtype())
dict_repr["log_dir"] = str(self.root_log_dir)
# Convert dtype to string for JSON serialization
if self.dtype is not None:
dict_repr["dtype"] = str(self.dtype)
else:
dict_repr["dtype"] = None
if self.dist_manager is not None:
dict_repr["world_size"] = self.dist_manager.world_size
dict_repr["device"] = self.dist_manager.device.type
dict_repr["dist_manager_init_method"] = (
self.dist_manager._initialization_method
)
else:
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
dict_repr["world_size"] = world_size
if self.device is not None:
dict_repr["device"] = (
str(self.device)
if hasattr(self.device, "type")
else str(self.device)
)
else:
dict_repr["device"] = torch.get_default_device().type
dict_repr["dist_manager_init_method"] = None
if self.collate_fn is not None:
dict_repr["collate_fn"] = self.collate_fn.__name__
else:
dict_repr["collate_fn"] = None
return dumps(dict_repr, indent=2)
@classmethod
def from_json(cls, json_str: str, **kwargs: Any) -> DriverConfig:
"""
Creates a DriverConfig instance from a JSON string.
This method reconstructs a DriverConfig from JSON. Note that certain
fields cannot be serialized and must be provided via kwargs:
- ``dist_manager``: DistributedManager instance (optional)
- ``collate_fn``: Custom collate function (optional)
Parameters
----------
json_str: str
A JSON string that was previously serialized using ``to_json()``.
**kwargs: Any
Additional keyword arguments to override or provide non-serializable
fields like ``dist_manager`` and ``collate_fn``.
Returns
-------
DriverConfig
A new ``DriverConfig`` instance.
Raises
------
ValueError
If the JSON cannot be parsed or required fields are missing.
Notes
-----
The device and dtype fields are reconstructed from their string
representations. The ``log_dir`` field in JSON is mapped to
``root_log_dir`` in the config.
"""
import json
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON string: {e}") from e
# Define fields that are not actual DriverConfig constructor parameters
metadata_fields = [
"default_dtype",
"world_size",
"dist_manager_init_method",
"log_dir", # handled separately as root_log_dir
]
non_serializable_fields = [
"dist_manager",
"collate_fn",
"root_log_dir",
"device",
"dtype",
]
# Extract serializable fields that map directly
config_fields = {
key: value
for key, value in data.items()
if key not in metadata_fields + non_serializable_fields
}
# Handle root_log_dir (stored as "log_dir" in JSON)
if "log_dir" in data:
config_fields["root_log_dir"] = Path(data["log_dir"])
# Handle device reconstruction from string
if "device" in data and data["device"] is not None:
device_str = data["device"]
# Parse device strings like "cuda:0", "cpu", "cuda", etc.
config_fields["device"] = torch.device(device_str)
# Handle dtype reconstruction from string
if "dtype" in data and data["dtype"] is not None:
dtype_str = data["dtype"]
# Map string representations to torch dtypes
dtype_map = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
"torch.int8": torch.int8,
"torch.uint8": torch.uint8,
}
if dtype_str in dtype_map:
config_fields["dtype"] = dtype_map[dtype_str]
else:
warn(
f"Unknown dtype string '{dtype_str}' in JSON. "
"Using default dtype instead."
)
# Merge with provided kwargs (allows overriding and adding non-serializable fields)
config_fields.update(kwargs)
# Create the config
try:
config = cls(**config_fields)
except Exception as e:
raise ValueError(
"Failed to construct ``DriverConfig`` from JSON string."
) from e
return config
|