Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import json | |
| import random | |
| import torch | |
| import torchaudio | |
| import re | |
| from diffusers import AutoencoderOobleck, FluxTransformer2DModel | |
| from huggingface_hub import snapshot_download | |
| from comfy.utils import load_torch_file, ProgressBar | |
| import folder_paths | |
| from tangoflux.model import TangoFlux | |
| from .teacache import teacache_forward | |
| log = logging.getLogger("TangoFlux") | |
| TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux") | |
| if "tangoflux" not in folder_paths.folder_names_and_paths: | |
| current_paths = [TANGOFLUX_DIR] | |
| else: | |
| current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"] | |
| folder_paths.folder_names_and_paths["tangoflux"] = ( | |
| current_paths, | |
| folder_paths.supported_pt_extensions, | |
| ) | |
| TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders") | |
| class TangoFluxLoader: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "enable_teacache": ("BOOLEAN", {"default": False}), | |
| "rel_l1_thresh": ( | |
| "FLOAT", | |
| {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01}, | |
| ), | |
| }, | |
| } | |
| RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE") | |
| RETURN_NAMES = ("model", "vae") | |
| OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae") | |
| CATEGORY = "TangoFlux" | |
| FUNCTION = "load_tangoflux" | |
| DESCRIPTION = "Load TangoFlux model" | |
| def __init__(self): | |
| self.model = None | |
| self.vae = None | |
| self.enable_teacache = False | |
| self.rel_l1_thresh = 0.25 | |
| self.original_forward = FluxTransformer2DModel.forward | |
| def load_tangoflux( | |
| self, | |
| enable_teacache=False, | |
| rel_l1_thresh=0.25, | |
| tangoflux_path=TANGOFLUX_DIR, | |
| text_encoder_path=TEXT_ENCODER_DIR, | |
| device="cuda", | |
| ): | |
| if self.model is None or self.enable_teacache != enable_teacache: | |
| pbar = ProgressBar(6) | |
| snapshot_download( | |
| repo_id="declare-lab/TangoFlux", | |
| allow_patterns=["*.json", "*.safetensors"], | |
| local_dir=tangoflux_path, | |
| local_dir_use_symlinks=False, | |
| ) | |
| pbar.update(1) | |
| log.info("Loading config") | |
| with open(os.path.join(tangoflux_path, "config.json"), "r") as f: | |
| config = json.load(f) | |
| pbar.update(1) | |
| text_encoder = re.sub( | |
| r'[<>:"/\\|?*]', | |
| "-", | |
| config.get("text_encoder_name", "google/flan-t5-large"), | |
| ) | |
| text_encoder_path = os.path.join(text_encoder_path, text_encoder) | |
| snapshot_download( | |
| repo_id=config.get("text_encoder_name", "google/flan-t5-large"), | |
| allow_patterns=["*.json", "*.safetensors", "*.model"], | |
| local_dir=text_encoder_path, | |
| local_dir_use_symlinks=False, | |
| ) | |
| pbar.update(1) | |
| log.info("Loading TangoFlux models") | |
| del self.model | |
| self.model = None | |
| model_weights = load_torch_file( | |
| os.path.join(tangoflux_path, "tangoflux.safetensors"), | |
| device=torch.device(device), | |
| ) | |
| pbar.update(1) | |
| if enable_teacache: | |
| log.info("Enabling TeaCache") | |
| FluxTransformer2DModel.forward = teacache_forward | |
| else: | |
| log.info("Disabling TeaCache") | |
| FluxTransformer2DModel.forward = self.original_forward | |
| model = TangoFlux(config=config, text_encoder_dir=text_encoder_path) | |
| model.load_state_dict(model_weights, strict=False) | |
| model.to(device) | |
| if enable_teacache: | |
| model.transformer.__class__.enable_teacache = True | |
| model.transformer.__class__.cnt = 0 | |
| model.transformer.__class__.rel_l1_thresh = rel_l1_thresh | |
| model.transformer.__class__.accumulated_rel_l1_distance = 0 | |
| model.transformer.__class__.previous_modulated_input = None | |
| model.transformer.__class__.previous_residual = None | |
| pbar.update(1) | |
| self.model = model | |
| del model | |
| self.enable_teacache = enable_teacache | |
| self.rel_l1_thresh = rel_l1_thresh | |
| if self.vae is None: | |
| log.info("Loading TangoFlux VAE") | |
| vae_weights = load_torch_file( | |
| os.path.join(tangoflux_path, "vae.safetensors") | |
| ) | |
| self.vae = AutoencoderOobleck() | |
| self.vae.load_state_dict(vae_weights) | |
| self.vae.to(device) | |
| pbar.update(1) | |
| if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh: | |
| self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh | |
| self.rel_l1_thresh = rel_l1_thresh | |
| return (self.model, self.vae) | |
| class TangoFluxSampler: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "model": ("TANGOFLUX_MODEL",), | |
| "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), | |
| "steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}), | |
| "guidance_scale": ( | |
| "FLOAT", | |
| {"default": 3, "min": 1, "max": 100, "step": 1}, | |
| ), | |
| "duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}), | |
| "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), | |
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), | |
| }, | |
| } | |
| RETURN_TYPES = ("TANGOFLUX_LATENTS",) | |
| RETURN_NAMES = ("latents",) | |
| OUTPUT_TOOLTIPS = "TangoFlux Sample" | |
| CATEGORY = "TangoFlux" | |
| FUNCTION = "sample" | |
| DESCRIPTION = "Sampler for TangoFlux" | |
| def sample( | |
| self, | |
| model, | |
| prompt, | |
| steps=50, | |
| guidance_scale=3, | |
| duration=10, | |
| seed=0, | |
| batch_size=1, | |
| device="cuda", | |
| ): | |
| pbar = ProgressBar(steps) | |
| with torch.no_grad(): | |
| model.to(device) | |
| try: | |
| if model.transformer.__class__.enable_teacache: | |
| model.transformer.__class__.num_steps = steps | |
| except: | |
| pass | |
| log.info("Generating latents with TangoFlux") | |
| latents = model.inference_flow( | |
| prompt, | |
| duration=duration, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| num_samples_per_prompt=batch_size, | |
| callback_on_step_end=lambda: pbar.update(1), | |
| ) | |
| return ({"latents": latents, "duration": duration},) | |
| class TangoFluxVAEDecodeAndPlay: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "vae": ("TANGOFLUX_VAE",), | |
| "latents": ("TANGOFLUX_LATENTS",), | |
| "filename_prefix": ("STRING", {"default": "TangoFlux"}), | |
| "format": ( | |
| ["wav", "mp3", "flac", "aac", "wma"], | |
| {"default": "wav"}, | |
| ), | |
| "save_output": ("BOOLEAN", {"default": True}), | |
| }, | |
| } | |
| RETURN_TYPES = () | |
| OUTPUT_NODE = True | |
| CATEGORY = "TangoFlux" | |
| FUNCTION = "play" | |
| DESCRIPTION = "Decoder and Player for TangoFlux" | |
| def decode(self, vae, latents): | |
| results = [] | |
| for latent in latents: | |
| decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu() | |
| results.append(decoded) | |
| results = torch.cat(results, dim=0) | |
| return results | |
| def play( | |
| self, | |
| vae, | |
| latents, | |
| filename_prefix="TangoFlux", | |
| format="wav", | |
| save_output=True, | |
| device="cuda", | |
| ): | |
| audios = [] | |
| pbar = ProgressBar(len(latents) + 2) | |
| if save_output: | |
| output_dir = folder_paths.get_output_directory() | |
| prefix_append = "" | |
| type = "output" | |
| else: | |
| output_dir = folder_paths.get_temp_directory() | |
| prefix_append = "_temp_" + "".join( | |
| random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5) | |
| ) | |
| type = "temp" | |
| filename_prefix += prefix_append | |
| full_output_folder, filename, counter, subfolder, _ = ( | |
| folder_paths.get_save_image_path(filename_prefix, output_dir) | |
| ) | |
| os.makedirs(full_output_folder, exist_ok=True) | |
| pbar.update(1) | |
| duration = latents["duration"] | |
| latents = latents["latents"] | |
| vae.to(device) | |
| log.info("Decoding Tangoflux latents") | |
| waves = self.decode(vae, latents) | |
| pbar.update(1) | |
| for wave in waves: | |
| waveform_end = int(duration * vae.config.sampling_rate) | |
| wave = wave[:, :waveform_end] | |
| file = f"{filename}_{counter:05}_.{format}" | |
| torchaudio.save( | |
| os.path.join(full_output_folder, file), wave, sample_rate=44100 | |
| ) | |
| counter += 1 | |
| audios.append({"filename": file, "subfolder": subfolder, "type": type}) | |
| pbar.update(1) | |
| return { | |
| "ui": {"audios": audios}, | |
| } | |
| NODE_CLASS_MAPPINGS = { | |
| "TangoFluxLoader": TangoFluxLoader, | |
| "TangoFluxSampler": TangoFluxSampler, | |
| "TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay, | |
| } | |