File size: 24,183 Bytes
feba2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
"""
Utilities for initializing components of the training process.

Here, we initialize all of the components that are part of the learning process. From logging,
and checkpointing to the optimizer to the dataset and the dataloader, this file contains the
logic for setting up the classes and functions that are used in the training loop.

As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the
more experimental stuff to you.
"""

import logging
import math
import os
import warnings
from dataclasses import fields, is_dataclass
from datetime import datetime
from typing import Dict, Optional, Union

import lightning as L
import torch
import yaml
from datasets import Dataset, DownloadConfig, load_dataset
from datasets import config as datasets_config
from huggingface_hub import add_collection_item, create_branch, create_repo
from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.utilities.rank_zero import rank_zero_only
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

import wandb
from src.config import (
    CheckpointingConfig,
    DataConfig,
    EvaluationConfig,
    ModelConfig,
    MonitoringConfig,
    TrainingConfig,
)
from src.model import PicoDecoder
from src.training.utils.io import use_backoff
from wandb.integration.lightning.fabric import WandbLogger

warnings.filterwarnings(
    "ignore",
    message=".*This integration is tested and supported for lightning Fabric.*",
)
warnings.filterwarnings(
    "ignore",
    message=".*Please report any issues to.*",
)

########################################################
#
# Basic Initialization
#
########################################################


def _apply_config_overrides(config, overrides: dict):
    """Recursively apply configuration overrides to a dataclass config object.

    Args:
        config: Base configuration object (must be a dataclass)
        overrides: Dictionary of override values matching config structure

    Returns:
        Modified config object with overrides to the config.
    """
    for field in fields(config):
        field_value = getattr(config, field.name)
        if is_dataclass(field_value):
            _apply_config_overrides(field_value, overrides.get(field.name, {}))
        else:
            if field.name in overrides:
                setattr(config, field.name, overrides[field.name])
    return config


def initialize_configuration(
    config_path: Optional[str] = None,
) -> Dict[
    str,
    Union[
        DataConfig,
        ModelConfig,
        TrainingConfig,
        EvaluationConfig,
        MonitoringConfig,
        CheckpointingConfig,
    ],
]:
    """Initialize configuration objects with optional overrides from a YAML file.

    This function initializes all of the configuration objects, and then applies
    any overrides from the config_path file. If no config_path is provided,
    the function will use the default configuration objects.

    Args:
        config_path: Path to a YAML file containing configuration overrides.

    Returns:
        A dictionary containing the initialized configuration objects.
    """
    data_config = DataConfig()
    model_config = ModelConfig()
    training_config = TrainingConfig()
    evaluation_config = EvaluationConfig()
    monitoring_config = MonitoringConfig()
    checkpointing_config = CheckpointingConfig()

    if config_path:
        overrides = yaml.safe_load(open(config_path, "r"))
        data_config = _apply_config_overrides(data_config, overrides.get("data", {}))
        model_config = _apply_config_overrides(model_config, overrides.get("model", {}))
        training_config = _apply_config_overrides(
            training_config, overrides.get("training", {})
        )
        evaluation_config = _apply_config_overrides(
            evaluation_config, overrides.get("evaluation", {})
        )
        monitoring_config = _apply_config_overrides(
            monitoring_config, overrides.get("monitoring", {})
        )
        checkpointing_config = _apply_config_overrides(
            checkpointing_config, overrides.get("checkpointing", {})
        )

    configs = {
        "data": data_config,
        "model": model_config,
        "training": training_config,
        "evaluation": evaluation_config,
        "monitoring": monitoring_config,
        "checkpointing": checkpointing_config,
    }

    return configs


def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str:
    """Initialize a directory for the current training run.

    Creates a unique directory for storing training, evaluation, and logging artifacts.
    If no run name is specified in the config, generates a timestamp-based name.

    Args:
        checkpointing_config: Configuration object containing run settings.
            NOTE: Must have a 'run_name' attribute that can be None, in which case
            a timestamp-based name will be generated.

    Returns:
        str: The path to the run directory.
    """
    run_name = checkpointing_config.run_name
    if run_name is None:
        run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        checkpointing_config.run_name = run_name

    run_dir = os.path.join(checkpointing_config.runs_dir, run_name)

    os.makedirs(run_dir, exist_ok=True)
    return run_dir


def initialize_fabric(
    training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None
):
    """Initialize Lightning Fabric for distributed training.

    Sets up a Lightning Fabric instance with the specified configuration for
    handling distributed training, mixed precision, and logging.

    Args:
        training_config: Configuration object containing fabric settings
            (accelerator, precision, devices, etc.).
        wandb_logger: Optional weights and biases logger instance for experiment tracking

    Returns:
        L.Fabric: Initialized Lightning Fabric instance.

    Example:
        >>> fabric = initialize_fabric(training_config, wandb_logger)
    """

    total_devices = (
        training_config.fabric.num_devices * training_config.fabric.num_nodes
    )

    if total_devices > 1:
        strategy = "deepspeed_stage_2"
    else:
        strategy = "auto"  # Sets up SingleDevice Strategy by default

    # NOTE: The strategy is set to use either DeepSpeed (Zero Stage 2) on multi-GPU,
    # or SingleDevice Strategy on single-GPU set ups. If you'd like to use a different strategy,
    # you can change the strategy flag in the fabric initialization, but be aware that this might
    # cause issues with checkpointing, evaluation, etc.

    fabric = L.Fabric(
        accelerator=training_config.fabric.accelerator,
        precision=training_config.fabric.precision,
        devices=training_config.fabric.num_devices,
        num_nodes=training_config.fabric.num_nodes,
        loggers=[wandb_logger] if wandb_logger is not None else None,
        strategy=strategy,
    )

    fabric.launch()

    return fabric


########################################################
#
# Dataset and Tokenization Initialization
#
########################################################


@use_backoff(max_retries=20)
def initialize_dataset(
    data_config: DataConfig,
    fabric: L.Fabric,
    initial_batch_step: Optional[int] = 0,
    return_fast_forward_steps: bool = False,
):
    """Initialize dataset based on the given config.

    This function will return a dataset object, and optionally a fast_forward_steps value.

    The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by,
    so that we can continue from a ertain batch of data we would have seen had training not previously
    stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be
    different from the initial_batch_step value.

    NOTE: This functionality is primarily useful for streaming datasets (which for large
    datasets is most of the time).

    Args:
        data_config: Configuration object containing dataset settings.
        fabric: A Lightning Fabric instance.
        initial_batch_step: The initial batch step to fast-forward to.
        return_fast_forward_steps: Whether to return the fast-forward steps value.

    Returns:
        Dataset: Initialized dataset object.
        Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True.
    """

    datasets_config.STREAMING_READ_MAX_RETRIES = 40  # default is 20
    datasets_config.STREAMING_READ_RETRY_INTERVAL = 10  # default is 5
    download_config = DownloadConfig(
        max_retries=20,  # default is 1 and can lead to pre-mature HTTPS errors
    )

    fast_forward_steps = 0

    if data_config.dataset.name == "pico-lm/pretokenized-dolma":
        # NOTE: We know that the dataset is sharded into 10,000 shards, so we can easily compute
        # the data file that we need to load in that contains the batch of data at
        # initial_batch_step.

        if initial_batch_step is not None:
            examples_per_shard = 20_480
            total_shards = 10_000
            batches_per_shard = examples_per_shard // data_config.dataloader.batch_size
            shard_idx = initial_batch_step // batches_per_shard

            data_files = [
                f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet"
                for _shard_idx in range(shard_idx, total_shards)
            ]

            fast_forward_steps = initial_batch_step % batches_per_shard
        else:
            data_files = None

        base_dataset = load_dataset(
            data_config.dataset.name,
            split="train",
            streaming=True,
            data_files=data_files,
            download_config=download_config,
        )
    else:
        # NOTE: For other datasets, you might want to add some custom loading logic, especially
        # to help with loading or fast-forwarding to the correct batch.

        base_dataset = load_dataset(
            data_config.dataset.name,
            split="train",
            streaming=True,
            download_config=download_config,
        )

    if data_config.dataset.name == "pico-lm/pretokenized-dolma":
        from .data import ShardedIterableDataset

        # NOTE: We wrap the dataset in a ShardedIterableDataset, which is a custom class that
        # allows us to shard an iterable dataset across multiple processes. This is useful for
        # distributed training, where we want data-parallelism.
        dataset = ShardedIterableDataset(
            base_dataset, fabric.global_rank, fabric.world_size
        )
    else:
        dataset = base_dataset

    if return_fast_forward_steps:
        return dataset, fast_forward_steps
    else:
        return dataset


def initialize_tokenizer(data_config: DataConfig):
    """Initialize the tokenizer for text processing.

    This function can be extended to include custom tokenization logic.

    Args:
        data_config: Configuration object containing tokenizer settings.

    Returns:
        AutoTokenizer: A HuggingFace tokenizer instance.
    """

    return AutoTokenizer.from_pretrained(data_config.tokenizer.name)


def initialize_dataloader(
    data_config: DataConfig,
    training_config: TrainingConfig,
    fabric: L.Fabric,
    dataset: Dataset,
):
    """Initialize the DataLoader for efficient batch processing.

    Creates a PyTorch DataLoader that handles batching and data loading for training.
    Configured specifically for streaming tokenized text datasets.

    You might also want to extend this function to add a sampler, or some sort of custom
    collate function. For the default dataset, we don't need any of this, because the data are
    pre-shuffled, and pre-tokenized.

    Args:
        data_config: Configuration object containing dataloader settings.
        training_config: Configuration object containing training settings.
        fabric: A Lightning Fabric instance.
        dataset: A HuggingFace Dataset object containing tokenized text data.
            Expected to have 'input_ids' field in its items.

    Returns:
        DataLoader: PyTorch DataLoader instance configured for the dataset.
    """

    def _collate_fn(batch):
        return {"input_ids": [entry["input_ids"] for entry in batch]}

    sub_batch_size = data_config.dataloader.batch_size // (
        fabric.world_size * training_config.optimization.gradient_accumulation_steps
    )

    # NOTE: We use the sub-batch size for the dataloader, which is the full batch size
    # divided by the gradient accumulation steps. This ensures that the effective batch size
    # is correct.

    return DataLoader(
        dataset,
        batch_size=sub_batch_size,
        shuffle=False,  # Keep sequential for streaming datasets
        pin_memory=True,  # Speeds up transfer to GPU
        collate_fn=_collate_fn,
    )


########################################################
#
# Model Initialization
#
########################################################


def initialize_model(model_config: ModelConfig):
    """Initialize the model for training.

    Loads in a given model implemented in the `src.model` package and returns it.

    NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer
    language model). If you'd like to implement your own model, you can do so by adding a new
    model class in the `src.model` package, and then adding a new entry here.

    Args:
        model_config: Configuration object containing model settings.

    Returns:
        PyTorch model instance.

    """
    if model_config.model_type == "pico_decoder":
        return PicoDecoder(model_config)
    else:
        raise ValueError(f"Invalid model type: {model_config.model_type}")


########################################################
#
# Optimizer and Scheduler
#
########################################################


def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module):
    """Initialize the optimizer for model training.

    Creates an optimizer instance based on the configuration settings.

    Add whatever other optimizers you want here.

    Args:
        training_config: Configuration object containing optimizer settings.
            Must have:
            - optimization.optimizer (str): Name of the optimizer ("adamw")
            - optimization.lr (float): Learning rate for the optimizer
        model: PyTorch model whose parameters will be optimized.

    Returns:
        torch.optim.Optimizer: Configured optimizer instance.

    """

    if training_config.optimization.optimizer == "adamw":
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=training_config.optimization.lr
        )
    else:
        raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}")

    return optimizer


def initialize_lr_scheduler(
    training_config: TrainingConfig, optimizer: torch.optim.Optimizer
):
    """Initialize a learning rate scheduler with warmup and decay.

    The default is a learning rate scheduler that implements a linear warmup followed by
    linear decay. The learning rate increases linearly from 0 to the initial lr
    during warmup, then decreases linearly to 0 during the remaining steps.

    Add other types of learning rate schedulers here.

    Args:
        training_config: Configuration object containing optimizer and scheduler settings.
        optimizer: PyTorch optimizer whose learning rate will be scheduled.

    Returns:
        torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance.
    """

    if training_config.optimization.lr_scheduler == "linear_with_warmup":
        # Credit where credit is due:
        # https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/optimization.py#L102
        def _lr_lambda(curr_step, num_warmup_steps, max_steps):
            if curr_step < num_warmup_steps:
                return float(curr_step) / float(max(1, num_warmup_steps))
            else:
                return max(
                    0.0,
                    float(max_steps - curr_step)
                    / float(max(1, max_steps - num_warmup_steps)),
                )

        lr_lambda = lambda step: _lr_lambda(  # noqa: E731
            step,
            training_config.optimization.lr_warmup_steps,
            training_config.max_steps,
        )
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda,
        )
    elif training_config.optimization.lr_scheduler == "cosine":
        # Cosine decay with warmup: linear warmup followed by cosine decay
        # This provides sustained learning over long training runs
        def _cosine_lr_lambda(curr_step, num_warmup_steps, max_steps):
            if curr_step < num_warmup_steps:
                # Linear warmup
                return float(curr_step) / float(max(1, num_warmup_steps))
            else:
                # Cosine decay to 0.1 * initial_lr (not to 0)
                progress = float(curr_step - num_warmup_steps) / float(
                    max(1, max_steps - num_warmup_steps)
                )
                return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))

        lr_lambda = lambda step: _cosine_lr_lambda(  # noqa: E731
            step,
            training_config.optimization.lr_warmup_steps,
            training_config.max_steps,
        )
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda,
        )
    else:
        raise ValueError(
            f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}"
        )

    return lr_scheduler


########################################################
#
# Experiment Monitoring (Logging, Experiment Tracking, etc.)
#
########################################################


def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str:
    """Create and initialize a timestamped log file in the run's log directory.

    Sets up a log file with a unique timestamp in the run's logging directory.
    Creates the necessary directory structure if it doesn't exist.

    Directory Structure:
        {checkpointing_config.runs_dir}/
        └── {checkpointing_config.run_name}/
            └── {checkpointing_config.logs_dir}/
                └── log_YYYYMMDD_HHMMSS.txt

    Args:
        checkpointing_config: Configuration object containing checkpointing settings.

    Returns:
        str: Absolute path to the created log file.

    """

    run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
    logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir)
    os.makedirs(logs_dir, exist_ok=True)

    # datetime stamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file_name = f"log_{timestamp}.log"
    log_file_path = os.path.join(logs_dir, log_file_name)

    open(log_file_path, "w").close()  # Create an empty log file

    return log_file_path


@use_backoff()
def initialize_wandb(
    monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig
):
    """Initialize Weights and Biases.

    This function initializes Weights and Biases based on the configuration settings.

    Args:
        monitoring_config: Configuration object containing monitoring settings.
        checkpointing_config: Configuration object containing checkpointing settings.

    Returns:
        Optional[WandbLogger]: An experiment tracker instance.
    """

    assert (
        monitoring_config.wandb.project is not None
        and monitoring_config.wandb.project != ""
    ), "Wandb project must be provided if wandb is to be used."
    assert (
        monitoring_config.wandb.entity is not None
        and monitoring_config.wandb.entity != ""
    ), "Wandb entity must be provided if wandb is to be used."

    _run_id = None
    if checkpointing_config.training.auto_resume:
        # If we are loading a checkpoint, we can try to find the run id of the previous run
        previous_runs = wandb.Api().runs(
            path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}",
            filters={"display_name": checkpointing_config.run_name},
        )
        try:
            if len(previous_runs) == 1:
                _run_id = previous_runs[0].id
        except ValueError:
            pass

    wandb_logger = WandbLogger(
        project=monitoring_config.wandb.project,
        entity=monitoring_config.wandb.entity,
        id=_run_id,
        name=checkpointing_config.run_name,
    )

    return wandb_logger


@rank_zero_only
def initialize_logging(
    monitoring_config: MonitoringConfig,
    checkpointing_config: CheckpointingConfig,
    fabric: L.Fabric,
):
    """Initialize logging system with default logging, to file and console.

    The default logging system uses a file handler and a stream handler.

    NOTE: this function is only called on rank 0.

    Args:
        monitoring_config: Configuration object containing monitoring settings.
        checkpointing_config: Configuration object containing checkpointing settings.

    Returns:
        logger: Standard Python logger configured for file and console output
    """

    # ---- Standard Local Logger ---- #
    logger = logging.getLogger("pico-train")
    logger.setLevel(logging.INFO)

    # Create file handler
    log_file_path = _initialize_log_file(checkpointing_config)
    file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
    file_handler.setLevel(monitoring_config.logging.log_level)

    # Create formatter and add it to the handler
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(formatter)

    # Add the handler to the logger
    logger.addHandler(file_handler)

    # Add a stream handler for console output
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(monitoring_config.logging.log_level)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    return logger


########################################################
#
# HuggingFace/Remote Checkpointing
#
########################################################


@rank_zero_only
@use_backoff()
def initialize_hf_checkpointing(
    checkpointing_config: CheckpointingConfig, fabric: L.Fabric
):
    """Initialize HuggingFace Checkpointing.

    Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run.

    NOTE: this function is only called on rank 0.

    Args:
        checkpointing_config: Configuration object containing checkpointing settings; must have
            a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and
            collection slug (if applicable) to save the checkpoint to.

    Raises:
        RuntimeError: If unable to create HuggingFace repository after multiple attempts.
    """

    huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id
    assert (
        huggingface_repo_id is not None and huggingface_repo_id != ""
    ), "hf_checkpoint.repo_id must be provided."

    repo = create_repo(huggingface_repo_id, exist_ok=True)

    # can create a repo without a specified namespace (will default to username)
    # however the rest of the HF calls need the fully qualified name
    # this is returned by create repo, so we update the config for later calls
    checkpointing_config.hf_checkpoint.repo_id = repo.repo_id
    huggingface_repo_id = repo.repo_id

    if checkpointing_config.hf_checkpoint.collection_slug:
        add_collection_item(
            checkpointing_config.hf_checkpoint.collection_slug,
            huggingface_repo_id,
            repo.repo_type,
            exists_ok=True,
        )

    create_branch(
        repo_id=huggingface_repo_id,
        branch=checkpointing_config.run_name,
        exist_ok=True,
    )