File size: 3,470 Bytes
b74998d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

"""

Miscellaneous utility functions.

"""

import logging
import os
import random

import numpy as np
import torch


class StreamToLogger:
    """

    A class that redirects stream writes to a logger.



    This class can be used to redirect stdout or stderr to a logger

    by implementing a file-like interface with write and flush methods.



    Parameters:

    - logger: A logger instance that will receive the log messages

    - log_level: The logging level to use (default: logging.INFO)

    """

    def __init__(self, logger, log_level=logging.INFO):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""

    def write(self, buf):
        """

        Write the buffer content to the logger.



        Parameters:

        - buf: The string buffer to write

        """
        for line in buf.rstrip().splitlines():
            self.logger.log(self.log_level, line.rstrip())

    def flush(self):
        """

        Flush method to comply with file-like object interface.

        This method is required but does nothing in this implementation.

        """
        pass


def seed_everything(seed: int = 42):
    """

    Set the `seed` value for torch and numpy seeds. Also turns on

    deterministic execution for cudnn.



    Parameters:

    - seed: A hashable seed value

    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Seed set to: {seed}")


def invalid_to_nans(arr, valid_mask, ndim=999):
    """

    Replace invalid values in an array with NaN values based on a validity mask.



    Parameters:

    - arr: Input array (typically a PyTorch tensor)

    - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False)

    - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim



    Returns:

    - Modified array with invalid values replaced by NaN

    """
    if valid_mask is not None:
        arr = arr.clone()
        arr[~valid_mask] = float("nan")
    if arr.ndim > ndim:
        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
    return arr


def invalid_to_zeros(arr, valid_mask, ndim=999):
    """

    Replace invalid values in an array with zeros based on a validity mask.



    Parameters:

    - arr: Input array (typically a PyTorch tensor)

    - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False)

    - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim



    Returns:

    - Tuple containing:

      - Modified array with invalid values replaced by zeros

      - nnz: Number of non-zero (valid) elements per sample in the batch

    """
    if valid_mask is not None:
        arr = arr.clone()
        arr[~valid_mask] = 0
        nnz = valid_mask.view(len(valid_mask), -1).sum(1)
    else:
        nnz = (
            arr[..., 0].numel() // len(arr) if len(arr) else 0
        )  # Number of pixels per image
    if arr.ndim > ndim:
        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
    return arr, nnz