noblebarkrr's picture
Updated to Dzeta
4f175c5
import torch
from multiprocessing import cpu_count
class Config:
"""Конфигурация для VC"""
def __init__(self, device_str: str = "cuda" if torch.cuda.is_available() else "cpu") -> None:
"""
Инициализация конфигурации
Args:
device_str: Строка устройства
"""
self.device_str: str = device_str
self.device_ids = None
self.set_device(self.device_str)
self.is_half: bool = False
self.n_cpu: int = cpu_count()
self.gpu_name = None
self.gpu_mem = None
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
def set_device(self, device_str: str) -> None:
"""
Установить устройство
Args:
device_str: Строка устройства
"""
if "cuda" in device_str.lower():
if ":" in device_str:
device_spec = device_str.split(":")[1]
self.device_ids = [int(id) for id in device_spec.split(",") if id.isdigit()]
else:
self.device_ids = list(range(torch.cuda.device_count()))
self.device = torch.device("cuda" if not self.device_ids else f"cuda:{self.device_ids[0]}")
elif "mps" in device_str.lower():
self.device_ids = None
self.device = torch.device("mps")
else:
self.device_ids = None
self.device = torch.device("cpu")
def device_config(self):
"""
Настройка параметров для устройства
Returns:
Кортеж (x_pad, x_query, x_center, x_max)
"""
if self.device.type == "cuda":
if self.device_ids:
self.gpu_mem = self._configure_gpu(self.device_ids[0])
x_pad, x_query, x_center, x_max = (
(3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
)
if self.gpu_mem is not None and self.gpu_mem <= 4:
x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
return x_pad, x_query, x_center, x_max
def _configure_gpu(self, device_id: int) -> int:
"""
Настройка GPU
Args:
device_id: ID устройства
Returns:
Объем памяти GPU в GB
"""
self.gpu_name = torch.cuda.get_device_name(f"cuda:{device_id}")
low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
if (
any(gpu in self.gpu_name for gpu in low_end_gpus)
and "V100" not in self.gpu_name.upper()
):
self.is_half = False
return int(
torch.cuda.get_device_properties(self.device).total_memory
/ 1024
/ 1024
/ 1024
+ 0.4
)