File size: 1,213 Bytes
528efee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
# Time      :2025/3/29 10:27
# Author    :Hui Huang
from omegaconf import OmegaConf, DictConfig
import torch


def load_config(config_path: str) -> DictConfig:
    """Loads a configuration file and optionally merges it with a base configuration.

    Args:
    config_path (Path): Path to the configuration file.
    """
    # Load the initial configuration from the given path
    config = OmegaConf.load(config_path)

    # Check if there is a base configuration specified and merge if necessary
    if config.get("base_config", None) is not None:
        base_config = OmegaConf.load(config["base_config"])
        config = OmegaConf.merge(base_config, config)

    return config


def gpu_supports_fp16() -> bool:
    # 1. 确保 CUDA 可用
    if not torch.cuda.is_available():
        return False

    # 2. 获取设备的 compute capability
    major, minor = torch.cuda.get_device_capability()

    # 3. 判断是否 >= 5.3
    if major > 5 or (major == 5 and minor >= 3):
        return True
    else:
        return False


def get_dtype(device: str):
    if device.startswith('cuda') and gpu_supports_fp16():
        return torch.float16
    else:
        return torch.float32