|
|
|
|
|
import torch |
|
|
from typing import List, Dict |
|
|
|
|
|
class GPUManager: |
|
|
""" |
|
|
Classe singleton para gerenciar e alocar recursos de GPU de forma centralizada. |
|
|
|
|
|
Esta classe detecta as GPUs disponíveis e as aloca para diferentes |
|
|
tipos de tarefas com base em uma estratégia predefinida. Ela foi projetada |
|
|
para ser extensível a outros modelos e tarefas no futuro. |
|
|
|
|
|
Estratégia Padrão: |
|
|
- LTX_PRIMARY: Tarefas pesadas e sequenciais (Transformer, Text Encoder). |
|
|
Alocado para a primeira GPU (cuda:0) para maximizar o desempenho. |
|
|
- VAE_POOL: Tarefas mais leves e paralelizáveis (decodificação VAE). |
|
|
Alocado para as GPUs restantes (cuda:1, cuda:2, ...) em um pool de workers. |
|
|
|
|
|
Em caso de GPU única, todos os modelos compartilham o mesmo dispositivo. |
|
|
Em caso de ausência de GPUs, opera em modo CPU. |
|
|
""" |
|
|
def __init__(self): |
|
|
self.num_gpus = 0 |
|
|
self.devices: List[str] = [] |
|
|
self._allocations: Dict[str, List[str]] = {} |
|
|
self._worker_indices: Dict[str, int] = {} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.num_gpus = torch.cuda.device_count() |
|
|
|
|
|
self._initialize_allocations() |
|
|
|
|
|
def _initialize_allocations(self): |
|
|
"""Define a estratégia de alocação de dispositivos com base nas GPUs disponíveis.""" |
|
|
print("[GPUManager] Inicializando alocação de dispositivos...") |
|
|
|
|
|
if self.num_gpus == 0: |
|
|
print("[GPUManager] Nenhuma GPU CUDA detectada. Operando em modo CPU.") |
|
|
self.devices = ["cpu"] |
|
|
self._allocations['LTX_PRIMARY'] = ["cpu"] |
|
|
self._allocations['VAE_POOL'] = ["cpu"] |
|
|
elif self.num_gpus == 1: |
|
|
print("[GPUManager] Detectada 1 GPU. Todos os modelos compartilharão cuda:0.") |
|
|
self.devices = ["cuda:0"] |
|
|
self._allocations['LTX_PRIMARY'] = ["cuda:0"] |
|
|
self._allocations['VAE_POOL'] = ["cuda:0"] |
|
|
else: |
|
|
print(f"[GPUManager] Detectadas {self.num_gpus} GPUs. Ativando modo Multi-GPU.") |
|
|
self.devices = [f"cuda:{i}" for i in range(self.num_gpus)] |
|
|
|
|
|
|
|
|
primary_device = self.devices[0] |
|
|
self._allocations['LTX_PRIMARY'] = [primary_device] |
|
|
print(f" - Tarefas LTX_PRIMARY (Transformer/TextEncoder) alocadas para: {primary_device}") |
|
|
|
|
|
|
|
|
worker_devices = self.devices[1:] |
|
|
self._allocations['VAE_POOL'] = worker_devices |
|
|
self._worker_indices['VAE_POOL'] = 0 |
|
|
print(f" - Tarefas VAE_POOL (VAE Decode) alocadas para: {worker_devices}") |
|
|
|
|
|
print("[GPUManager] Alocação concluída.") |
|
|
|
|
|
def get_device_for(self, task_key: str) -> str: |
|
|
""" |
|
|
Retorna o dispositivo principal alocado para uma chave de tarefa. |
|
|
Use para modelos que residem em um único dispositivo. |
|
|
|
|
|
Args: |
|
|
task_key (str): A chave da tarefa (ex: 'LTX_PRIMARY'). |
|
|
|
|
|
Returns: |
|
|
str: O nome do dispositivo (ex: 'cuda:0'). |
|
|
""" |
|
|
if task_key not in self._allocations: |
|
|
raise ValueError(f"Chave de tarefa '{task_key}' não registrada no GPUManager.") |
|
|
|
|
|
|
|
|
return self._allocations[task_key][0] |
|
|
|
|
|
def get_next_worker_for(self, pool_key: str) -> str: |
|
|
""" |
|
|
Retorna o próximo dispositivo disponível de um pool de workers em rodízio. |
|
|
Use para tarefas que podem ser paralelizadas em várias GPUs. |
|
|
|
|
|
Args: |
|
|
pool_key (str): A chave do pool de workers (ex: 'VAE_POOL'). |
|
|
|
|
|
Returns: |
|
|
str: O nome do próximo dispositivo worker (ex: 'cuda:1'). |
|
|
""" |
|
|
if pool_key not in self._allocations or pool_key not in self._worker_indices: |
|
|
raise ValueError(f"Pool de workers '{pool_key}' não registrado no GPUManager.") |
|
|
|
|
|
worker_pool = self._allocations[pool_key] |
|
|
if not worker_pool: |
|
|
raise RuntimeError(f"O pool de workers '{pool_key}' está vazio.") |
|
|
|
|
|
|
|
|
current_idx = self._worker_indices[pool_key] |
|
|
device = worker_pool[current_idx] |
|
|
|
|
|
|
|
|
self._worker_indices[pool_key] = (current_idx + 1) % len(worker_pool) |
|
|
|
|
|
return device |
|
|
|
|
|
|
|
|
|
|
|
gpu_manager = GPUManager() |