| | import torch |
| | import torch.nn as nn |
| | from diffusers import AutoencoderOobleck |
| | from diffusers import FluxTransformer2DModel |
| | from tangoflux import TangoFluxInference |
| | from tangoflux.model import DurationEmbedder, TangoFlux |
| |
|
| | def export_vae_encoder(vae, save_path, batch_size=1, audio_length=441000): |
| | """导出VAE编码器到ONNX格式 |
| | |
| | Args: |
| | vae: AutoencoderOobleck实例 |
| | save_path: 保存路径 |
| | batch_size: batch大小 |
| | audio_length: 音频长度(默认10秒,44100Hz采样率) |
| | """ |
| | vae.eval() |
| | |
| | |
| | dummy_input = torch.randn(batch_size, 2, audio_length) |
| | |
| | |
| | class VAEEncoderWrapper(nn.Module): |
| | def __init__(self, vae): |
| | super().__init__() |
| | self.vae = vae |
| | |
| | def forward(self, audio): |
| | return self.vae.encode(audio).latent_dist.sample() |
| | |
| | wrapper = VAEEncoderWrapper(vae) |
| | |
| | |
| | torch.onnx.export( |
| | wrapper, |
| | dummy_input, |
| | save_path, |
| | input_names=['audio'], |
| | output_names=['latent'], |
| | dynamic_axes={ |
| | 'audio': {0: 'batch_size', 2: 'audio_length'}, |
| | 'latent': {0: 'batch_size', 2: 'latent_length'} |
| | }, |
| | opset_version=17 |
| | ) |
| |
|
| | def export_vae_decoder(vae, save_path, batch_size=1, latent_length=645): |
| | """导出VAE解码器到ONNX格式 |
| | |
| | Args: |
| | vae: AutoencoderOobleck实例 |
| | save_path: 保存路径 |
| | batch_size: batch大小 |
| | latent_length: 潜在向量长度 |
| | """ |
| | vae.eval() |
| | |
| | |
| | dummy_input = torch.randn(batch_size, 64, latent_length) |
| | |
| | |
| | class VAEDecoderWrapper(nn.Module): |
| | def __init__(self, vae): |
| | super().__init__() |
| | self.vae = vae |
| | |
| | def forward(self, latent): |
| | return self.vae.decode(latent).sample |
| | |
| | wrapper = VAEDecoderWrapper(vae) |
| | |
| | |
| | torch.onnx.export( |
| | wrapper, |
| | dummy_input, |
| | save_path, |
| | input_names=['latent'], |
| | output_names=['audio'], |
| | dynamic_axes={ |
| | 'latent': {0: 'batch_size', 2: 'latent_length'}, |
| | 'audio': {0: 'batch_size', 2: 'audio_length'} |
| | }, |
| | opset_version=17 |
| | ) |
| |
|
| | def export_duration_embedder(duration_embedder, save_path, batch_size=1): |
| | """导出Duration Embedder到ONNX格式 |
| | |
| | Args: |
| | duration_embedder: DurationEmbedder实例 |
| | save_path: 保存路径 |
| | batch_size: batch大小 |
| | """ |
| | duration_embedder.eval() |
| | |
| | |
| | dummy_input = torch.tensor([[10.0]], dtype=torch.float32) |
| | |
| | |
| | torch.onnx.export( |
| | duration_embedder, |
| | dummy_input, |
| | save_path, |
| | input_names=['duration'], |
| | output_names=['embedding'], |
| | dynamic_axes={ |
| | 'duration': {0: 'batch_size'}, |
| | 'embedding': {0: 'batch_size'} |
| | }, |
| | opset_version=17 |
| | ) |
| |
|
| | def export_flux_transformer(transformer, save_path, batch_size=1, seq_length=645): |
| | """导出FluxTransformer2D到ONNX格式 |
| | |
| | Args: |
| | transformer: FluxTransformer2DModel实例 |
| | save_path: 保存路径 |
| | batch_size: batch大小 |
| | seq_length: 序列长度 |
| | """ |
| | transformer.eval() |
| | |
| | |
| | hidden_states = torch.randn(batch_size, seq_length, 64) |
| | timestep = torch.tensor([0.5]) |
| | pooled_text = torch.randn(batch_size, 1024) |
| | encoder_hidden_states = torch.randn(batch_size, 64, 1024) |
| | txt_ids = torch.zeros(batch_size, 64, 3).to(torch.int64) |
| | img_ids = torch.arange(seq_length).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 3).to(torch.int64) |
| | |
| | |
| | class TransformerWrapper(nn.Module): |
| | def __init__(self, transformer): |
| | super().__init__() |
| | self.transformer = transformer |
| | |
| | def forward(self, hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids): |
| | return self.transformer( |
| | hidden_states=hidden_states, |
| | timestep=timestep, |
| | guidance=None, |
| | pooled_projections=pooled_text, |
| | encoder_hidden_states=encoder_hidden_states, |
| | txt_ids=txt_ids, |
| | img_ids=img_ids, |
| | return_dict=False |
| | )[0] |
| | |
| | wrapper = TransformerWrapper(transformer) |
| | |
| | |
| | torch.onnx.export( |
| | wrapper, |
| | (hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids), |
| | save_path, |
| | input_names=['hidden_states', 'timestep', 'pooled_text', 'encoder_hidden_states', 'txt_ids', 'img_ids'], |
| | output_names=['output'], |
| | dynamic_axes={ |
| | 'hidden_states': {0: 'batch_size', 1: 'sequence_length'}, |
| | 'pooled_text': {0: 'batch_size'}, |
| | 'encoder_hidden_states': {0: 'batch_size', 1: 'text_length'}, |
| | 'txt_ids': {0: 'batch_size', 1: 'text_length'}, |
| | 'img_ids': {0: 'batch_size', 1: 'sequence_length'} |
| | }, |
| | opset_version=17 |
| | ) |
| |
|
| | def export_proj_layer(proj_layer, save_path, batch_size=1): |
| | """导出projection层到ONNX格式 |
| | |
| | Args: |
| | proj_layer: 投影层(fc层)实例 |
| | save_path: 保存路径 |
| | batch_size: batch大小 |
| | """ |
| | proj_layer.eval() |
| | |
| | |
| | dummy_input = torch.randn(batch_size, 1024) |
| | |
| | |
| | torch.onnx.export( |
| | proj_layer, |
| | dummy_input, |
| | save_path, |
| | input_names=['text_embedding'], |
| | output_names=['projected'], |
| | dynamic_axes={ |
| | 'text_embedding': {0: 'batch_size'}, |
| | 'projected': {0: 'batch_size'} |
| | }, |
| | opset_version=17 |
| | ) |
| |
|
| | def export_all(model_path, output_dir): |
| | """导出所有组件到ONNX格式 |
| | |
| | Args: |
| | model_path: TangoFlux模型路径 |
| | output_dir: 输出目录 |
| | """ |
| | import os |
| | |
| | |
| | model = TangoFluxInference(name=model_path, device="cpu") |
| | |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | export_vae_encoder(model.vae, f"{output_dir}/vae_encoder.onnx") |
| | export_vae_decoder(model.vae, f"{output_dir}/vae_decoder.onnx") |
| | |
| | |
| | export_duration_embedder(model.model.duration_emebdder, f"{output_dir}/duration_embedder.onnx") |
| | |
| | |
| | export_flux_transformer(model.model.transformer, f"{output_dir}/transformer.onnx") |
| | |
| | |
| | export_proj_layer(model.model.fc, f"{output_dir}/proj.onnx") |
| | |
| | print(f"所有模型已导出到: {output_dir}") |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="导出TangoFlux模型到ONNX格式") |
| | parser.add_argument("--model_path", type=str, required=True, help="TangoFlux模型路径") |
| | parser.add_argument("--output_dir", type=str, required=True, help="输出目录") |
| | |
| | args = parser.parse_args() |
| | export_all(args.model_path, args.output_dir) |