| |
| |
|
|
| import logging |
| import torch |
| import gc |
| from typing import Any |
|
|
| |
| logger = logging.getLogger("VibeVoice") |
|
|
| class VibeVoiceFreeMemoryNode: |
| """Node to explicitly free VibeVoice model memory""" |
| |
| |
| _single_speaker_instances = [] |
| _multi_speaker_instances = [] |
| |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "audio": ("AUDIO", {"tooltip": "Audio input that triggers memory cleanup and gets passed through"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("AUDIO",) |
| RETURN_NAMES = ("audio",) |
| FUNCTION = "free_vibevoice_memory" |
| CATEGORY = "VibeVoiceWrapper" |
| DESCRIPTION = "Free all loaded VibeVoice models from memory when audio passes through" |
| |
| @classmethod |
| def register_single_speaker(cls, node_instance): |
| """Register a single speaker node instance""" |
| if node_instance not in cls._single_speaker_instances: |
| cls._single_speaker_instances.append(node_instance) |
| |
| @classmethod |
| def register_multi_speaker(cls, node_instance): |
| """Register a multi speaker node instance""" |
| if node_instance not in cls._multi_speaker_instances: |
| cls._multi_speaker_instances.append(node_instance) |
| |
| def free_vibevoice_memory(self, audio): |
| """Free memory from all VibeVoice nodes and pass through the audio""" |
| |
| try: |
| freed_count = 0 |
| |
| |
| |
| try: |
| import sys |
| from .base_vibevoice import BaseVibeVoiceNode |
| |
| |
| for module_name, module in sys.modules.items(): |
| if module and 'vibevoice' in module_name.lower(): |
| for attr_name in dir(module): |
| if not attr_name.startswith('_'): |
| try: |
| attr = getattr(module, attr_name) |
| if isinstance(attr, type) and issubclass(attr, BaseVibeVoiceNode): |
| |
| for instance_attr in dir(attr): |
| instance = getattr(attr, instance_attr) |
| if isinstance(instance, BaseVibeVoiceNode) and hasattr(instance, 'free_memory'): |
| instance.free_memory() |
| freed_count += 1 |
| except: |
| pass |
| except: |
| pass |
| |
| |
| for node in self._single_speaker_instances: |
| if hasattr(node, 'free_memory'): |
| node.free_memory() |
| freed_count += 1 |
| |
| |
| for node in self._multi_speaker_instances: |
| if hasattr(node, 'free_memory'): |
| node.free_memory() |
| freed_count += 1 |
| |
| |
| gc.collect() |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| logger.info(f"Freed VibeVoice memory from {freed_count} nodes and cleared CUDA cache") |
| else: |
| logger.info(f"Freed VibeVoice memory from {freed_count} nodes") |
| |
| |
| return (audio,) |
| |
| except Exception as e: |
| logger.error(f"Error freeing VibeVoice memory: {str(e)}") |
| |
| return (audio,) |
| |
| @classmethod |
| def IS_CHANGED(cls, **kwargs): |
| """Always execute this node""" |
| return float("nan") |