diff --git a/ThinkSound/__init__.py b/PrismAudio/__init__.py similarity index 100% rename from ThinkSound/__init__.py rename to PrismAudio/__init__.py diff --git a/ThinkSound/configs/model_configs/prismaudio.json b/PrismAudio/configs/model_configs/prismaudio.json similarity index 100% rename from ThinkSound/configs/model_configs/prismaudio.json rename to PrismAudio/configs/model_configs/prismaudio.json diff --git a/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json b/PrismAudio/configs/model_configs/stable_audio_2_0_vae.json similarity index 100% rename from ThinkSound/configs/model_configs/stable_audio_2_0_vae.json rename to PrismAudio/configs/model_configs/stable_audio_2_0_vae.json diff --git a/ThinkSound/configs/model_configs/thinksound.json b/PrismAudio/configs/model_configs/thinksound.json similarity index 100% rename from ThinkSound/configs/model_configs/thinksound.json rename to PrismAudio/configs/model_configs/thinksound.json diff --git a/ThinkSound/configs/multimodal_dataset_demo.json b/PrismAudio/configs/multimodal_dataset_demo.json similarity index 100% rename from ThinkSound/configs/multimodal_dataset_demo.json rename to PrismAudio/configs/multimodal_dataset_demo.json diff --git a/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json b/PrismAudio/configs/multimodal_dataset_demo_prismaudio.json similarity index 59% rename from ThinkSound/configs/multimodal_dataset_demo_prismaudio.json rename to PrismAudio/configs/multimodal_dataset_demo_prismaudio.json index 8a3e2320175dd91bd176d4bded1847a46e865e62..8ebd638d2e9bfe2b7934d7fe1f2cf0d49e9e0a7e 100644 --- a/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json +++ b/PrismAudio/configs/multimodal_dataset_demo_prismaudio.json @@ -3,22 +3,22 @@ "datasets": [ { "id": "vggsound", - "path": "test", - "split_path": "test/test.txt" + "path": "data/train", + "split_path": "split/train.txt" } ], "val_datasets": [ { "id": "vggsound", - "path": "test", - "split_path": "test/test.txt" + "path": "data/test", + "split_path": "split/test.txt" } ], "test_datasets": [ { "id": "vggsound", - "path": "test", - "split_path": "test/test.txt" + "path": "data/test", + "split_path": "split/test.txt" } ], "random_crop": false, diff --git a/ThinkSound/data/__init__.py b/PrismAudio/data/__init__.py similarity index 100% rename from ThinkSound/data/__init__.py rename to PrismAudio/data/__init__.py diff --git a/ThinkSound/data/datamodule.py b/PrismAudio/data/datamodule.py similarity index 100% rename from ThinkSound/data/datamodule.py rename to PrismAudio/data/datamodule.py diff --git a/ThinkSound/data/dataset.py b/PrismAudio/data/dataset.py similarity index 100% rename from ThinkSound/data/dataset.py rename to PrismAudio/data/dataset.py diff --git a/ThinkSound/data/utils.py b/PrismAudio/data/utils.py similarity index 100% rename from ThinkSound/data/utils.py rename to PrismAudio/data/utils.py diff --git a/ThinkSound/inference/__init__.py b/PrismAudio/inference/__init__.py similarity index 100% rename from ThinkSound/inference/__init__.py rename to PrismAudio/inference/__init__.py diff --git a/ThinkSound/inference/generation.py b/PrismAudio/inference/generation.py similarity index 100% rename from ThinkSound/inference/generation.py rename to PrismAudio/inference/generation.py diff --git a/ThinkSound/inference/sampling.py b/PrismAudio/inference/sampling.py similarity index 100% rename from ThinkSound/inference/sampling.py rename to PrismAudio/inference/sampling.py diff --git a/ThinkSound/inference/utils.py b/PrismAudio/inference/utils.py similarity index 100% rename from ThinkSound/inference/utils.py rename to PrismAudio/inference/utils.py diff --git a/ThinkSound/interface/__init__.py b/PrismAudio/interface/__init__.py similarity index 100% rename from ThinkSound/interface/__init__.py rename to PrismAudio/interface/__init__.py diff --git a/ThinkSound/interface/aeiou.py b/PrismAudio/interface/aeiou.py similarity index 100% rename from ThinkSound/interface/aeiou.py rename to PrismAudio/interface/aeiou.py diff --git a/ThinkSound/interface/gradio.py b/PrismAudio/interface/gradio.py similarity index 100% rename from ThinkSound/interface/gradio.py rename to PrismAudio/interface/gradio.py diff --git a/ThinkSound/models/__init__.py b/PrismAudio/models/__init__.py similarity index 100% rename from ThinkSound/models/__init__.py rename to PrismAudio/models/__init__.py diff --git a/ThinkSound/models/adp.py b/PrismAudio/models/adp.py similarity index 100% rename from ThinkSound/models/adp.py rename to PrismAudio/models/adp.py diff --git a/ThinkSound/models/autoencoders.py b/PrismAudio/models/autoencoders.py similarity index 100% rename from ThinkSound/models/autoencoders.py rename to PrismAudio/models/autoencoders.py diff --git a/ThinkSound/models/blocks.py b/PrismAudio/models/blocks.py similarity index 100% rename from ThinkSound/models/blocks.py rename to PrismAudio/models/blocks.py diff --git a/ThinkSound/models/bottleneck.py b/PrismAudio/models/bottleneck.py similarity index 100% rename from ThinkSound/models/bottleneck.py rename to PrismAudio/models/bottleneck.py diff --git a/ThinkSound/models/codebook_patterns.py b/PrismAudio/models/codebook_patterns.py similarity index 100% rename from ThinkSound/models/codebook_patterns.py rename to PrismAudio/models/codebook_patterns.py diff --git a/ThinkSound/models/conditioners.py b/PrismAudio/models/conditioners.py similarity index 100% rename from ThinkSound/models/conditioners.py rename to PrismAudio/models/conditioners.py diff --git a/ThinkSound/models/diffusion.py b/PrismAudio/models/diffusion.py similarity index 100% rename from ThinkSound/models/diffusion.py rename to PrismAudio/models/diffusion.py diff --git a/ThinkSound/models/diffusion_prior.py b/PrismAudio/models/diffusion_prior.py similarity index 100% rename from ThinkSound/models/diffusion_prior.py rename to PrismAudio/models/diffusion_prior.py diff --git a/ThinkSound/models/discriminators.py b/PrismAudio/models/discriminators.py similarity index 100% rename from ThinkSound/models/discriminators.py rename to PrismAudio/models/discriminators.py diff --git a/ThinkSound/models/dit (1).py b/PrismAudio/models/dit (1).py similarity index 100% rename from ThinkSound/models/dit (1).py rename to PrismAudio/models/dit (1).py diff --git a/ThinkSound/models/dit.py b/PrismAudio/models/dit.py similarity index 100% rename from ThinkSound/models/dit.py rename to PrismAudio/models/dit.py diff --git a/ThinkSound/models/factory.py b/PrismAudio/models/factory.py similarity index 100% rename from ThinkSound/models/factory.py rename to PrismAudio/models/factory.py diff --git a/ThinkSound/models/lm.py b/PrismAudio/models/lm.py similarity index 100% rename from ThinkSound/models/lm.py rename to PrismAudio/models/lm.py diff --git a/ThinkSound/models/lm_backbone.py b/PrismAudio/models/lm_backbone.py similarity index 100% rename from ThinkSound/models/lm_backbone.py rename to PrismAudio/models/lm_backbone.py diff --git a/ThinkSound/models/lm_continuous.py b/PrismAudio/models/lm_continuous.py similarity index 100% rename from ThinkSound/models/lm_continuous.py rename to PrismAudio/models/lm_continuous.py diff --git a/ThinkSound/models/local_attention.py b/PrismAudio/models/local_attention.py similarity index 100% rename from ThinkSound/models/local_attention.py rename to PrismAudio/models/local_attention.py diff --git a/ThinkSound/models/meta_queries/__init__.py b/PrismAudio/models/meta_queries/__init__.py similarity index 100% rename from ThinkSound/models/meta_queries/__init__.py rename to PrismAudio/models/meta_queries/__init__.py diff --git a/ThinkSound/models/meta_queries/metaquery.py b/PrismAudio/models/meta_queries/metaquery.py similarity index 100% rename from ThinkSound/models/meta_queries/metaquery.py rename to PrismAudio/models/meta_queries/metaquery.py diff --git a/ThinkSound/models/meta_queries/model.py b/PrismAudio/models/meta_queries/model.py similarity index 100% rename from ThinkSound/models/meta_queries/model.py rename to PrismAudio/models/meta_queries/model.py diff --git a/ThinkSound/models/meta_queries/models/__init__.py b/PrismAudio/models/meta_queries/models/__init__.py similarity index 100% rename from ThinkSound/models/meta_queries/models/__init__.py rename to PrismAudio/models/meta_queries/models/__init__.py diff --git a/ThinkSound/models/meta_queries/models/process_audio_info.py b/PrismAudio/models/meta_queries/models/process_audio_info.py similarity index 100% rename from ThinkSound/models/meta_queries/models/process_audio_info.py rename to PrismAudio/models/meta_queries/models/process_audio_info.py diff --git a/ThinkSound/models/meta_queries/models/qwen25VL.py b/PrismAudio/models/meta_queries/models/qwen25VL.py similarity index 100% rename from ThinkSound/models/meta_queries/models/qwen25VL.py rename to PrismAudio/models/meta_queries/models/qwen25VL.py diff --git a/ThinkSound/models/meta_queries/models/qwen25omni.py b/PrismAudio/models/meta_queries/models/qwen25omni.py similarity index 100% rename from ThinkSound/models/meta_queries/models/qwen25omni.py rename to PrismAudio/models/meta_queries/models/qwen25omni.py diff --git a/ThinkSound/models/meta_queries/transformer_encoder.py b/PrismAudio/models/meta_queries/transformer_encoder.py similarity index 100% rename from ThinkSound/models/meta_queries/transformer_encoder.py rename to PrismAudio/models/meta_queries/transformer_encoder.py diff --git a/ThinkSound/models/mmdit.py b/PrismAudio/models/mmdit.py similarity index 100% rename from ThinkSound/models/mmdit.py rename to PrismAudio/models/mmdit.py diff --git a/ThinkSound/models/mmmodules/__init__.py b/PrismAudio/models/mmmodules/__init__.py similarity index 100% rename from ThinkSound/models/mmmodules/__init__.py rename to PrismAudio/models/mmmodules/__init__.py diff --git a/ThinkSound/models/mmmodules/ext/__init__.py b/PrismAudio/models/mmmodules/ext/__init__.py similarity index 100% rename from ThinkSound/models/mmmodules/ext/__init__.py rename to PrismAudio/models/mmmodules/ext/__init__.py diff --git a/ThinkSound/models/mmmodules/ext/rotary_embeddings.py b/PrismAudio/models/mmmodules/ext/rotary_embeddings.py similarity index 100% rename from ThinkSound/models/mmmodules/ext/rotary_embeddings.py rename to PrismAudio/models/mmmodules/ext/rotary_embeddings.py diff --git a/ThinkSound/models/mmmodules/ext/stft_converter.py b/PrismAudio/models/mmmodules/ext/stft_converter.py similarity index 100% rename from ThinkSound/models/mmmodules/ext/stft_converter.py rename to PrismAudio/models/mmmodules/ext/stft_converter.py diff --git a/ThinkSound/models/mmmodules/ext/stft_converter_mel.py b/PrismAudio/models/mmmodules/ext/stft_converter_mel.py similarity index 100% rename from ThinkSound/models/mmmodules/ext/stft_converter_mel.py rename to PrismAudio/models/mmmodules/ext/stft_converter_mel.py diff --git a/ThinkSound/models/mmmodules/model/__init__.py b/PrismAudio/models/mmmodules/model/__init__.py similarity index 100% rename from ThinkSound/models/mmmodules/model/__init__.py rename to PrismAudio/models/mmmodules/model/__init__.py diff --git a/ThinkSound/models/mmmodules/model/embeddings.py b/PrismAudio/models/mmmodules/model/embeddings.py similarity index 100% rename from ThinkSound/models/mmmodules/model/embeddings.py rename to PrismAudio/models/mmmodules/model/embeddings.py diff --git a/ThinkSound/models/mmmodules/model/flow_matching.py b/PrismAudio/models/mmmodules/model/flow_matching.py similarity index 100% rename from ThinkSound/models/mmmodules/model/flow_matching.py rename to PrismAudio/models/mmmodules/model/flow_matching.py diff --git a/ThinkSound/models/mmmodules/model/low_level.py b/PrismAudio/models/mmmodules/model/low_level.py similarity index 100% rename from ThinkSound/models/mmmodules/model/low_level.py rename to PrismAudio/models/mmmodules/model/low_level.py diff --git a/ThinkSound/models/mmmodules/model/networks.py b/PrismAudio/models/mmmodules/model/networks.py similarity index 100% rename from ThinkSound/models/mmmodules/model/networks.py rename to PrismAudio/models/mmmodules/model/networks.py diff --git a/ThinkSound/models/mmmodules/model/sequence_config.py b/PrismAudio/models/mmmodules/model/sequence_config.py similarity index 100% rename from ThinkSound/models/mmmodules/model/sequence_config.py rename to PrismAudio/models/mmmodules/model/sequence_config.py diff --git a/ThinkSound/models/mmmodules/model/transformer_layers.py b/PrismAudio/models/mmmodules/model/transformer_layers.py similarity index 100% rename from ThinkSound/models/mmmodules/model/transformer_layers.py rename to PrismAudio/models/mmmodules/model/transformer_layers.py diff --git a/ThinkSound/models/mmmodules/runner.py b/PrismAudio/models/mmmodules/runner.py similarity index 100% rename from ThinkSound/models/mmmodules/runner.py rename to PrismAudio/models/mmmodules/runner.py diff --git a/ThinkSound/models/mmmodules/sample.py b/PrismAudio/models/mmmodules/sample.py similarity index 100% rename from ThinkSound/models/mmmodules/sample.py rename to PrismAudio/models/mmmodules/sample.py diff --git a/ThinkSound/models/pqmf.py b/PrismAudio/models/pqmf.py similarity index 100% rename from ThinkSound/models/pqmf.py rename to PrismAudio/models/pqmf.py diff --git a/ThinkSound/models/pretrained.py b/PrismAudio/models/pretrained.py similarity index 100% rename from ThinkSound/models/pretrained.py rename to PrismAudio/models/pretrained.py diff --git a/ThinkSound/models/pretransforms.py b/PrismAudio/models/pretransforms.py similarity index 100% rename from ThinkSound/models/pretransforms.py rename to PrismAudio/models/pretransforms.py diff --git a/ThinkSound/models/transformer (1).py b/PrismAudio/models/transformer (1).py similarity index 100% rename from ThinkSound/models/transformer (1).py rename to PrismAudio/models/transformer (1).py diff --git a/ThinkSound/models/transformer.py b/PrismAudio/models/transformer.py similarity index 100% rename from ThinkSound/models/transformer.py rename to PrismAudio/models/transformer.py diff --git a/ThinkSound/models/utils.py b/PrismAudio/models/utils.py similarity index 100% rename from ThinkSound/models/utils.py rename to PrismAudio/models/utils.py diff --git a/ThinkSound/models/wavelets.py b/PrismAudio/models/wavelets.py similarity index 100% rename from ThinkSound/models/wavelets.py rename to PrismAudio/models/wavelets.py diff --git a/ThinkSound/training/__init__.py b/PrismAudio/training/__init__.py similarity index 100% rename from ThinkSound/training/__init__.py rename to PrismAudio/training/__init__.py diff --git a/ThinkSound/training/autoencoders.py b/PrismAudio/training/autoencoders.py similarity index 100% rename from ThinkSound/training/autoencoders.py rename to PrismAudio/training/autoencoders.py diff --git a/ThinkSound/training/autoencoders_1.py b/PrismAudio/training/autoencoders_1.py similarity index 100% rename from ThinkSound/training/autoencoders_1.py rename to PrismAudio/training/autoencoders_1.py diff --git a/ThinkSound/training/diffusion.py b/PrismAudio/training/diffusion.py similarity index 100% rename from ThinkSound/training/diffusion.py rename to PrismAudio/training/diffusion.py diff --git a/ThinkSound/training/factory.py b/PrismAudio/training/factory.py similarity index 100% rename from ThinkSound/training/factory.py rename to PrismAudio/training/factory.py diff --git a/ThinkSound/training/lm.py b/PrismAudio/training/lm.py similarity index 100% rename from ThinkSound/training/lm.py rename to PrismAudio/training/lm.py diff --git a/ThinkSound/training/lm_continuous.py b/PrismAudio/training/lm_continuous.py similarity index 100% rename from ThinkSound/training/lm_continuous.py rename to PrismAudio/training/lm_continuous.py diff --git a/ThinkSound/training/losses/__init__.py b/PrismAudio/training/losses/__init__.py similarity index 100% rename from ThinkSound/training/losses/__init__.py rename to PrismAudio/training/losses/__init__.py diff --git a/ThinkSound/training/losses/auraloss.py b/PrismAudio/training/losses/auraloss.py similarity index 100% rename from ThinkSound/training/losses/auraloss.py rename to PrismAudio/training/losses/auraloss.py diff --git a/ThinkSound/training/losses/losses.py b/PrismAudio/training/losses/losses.py similarity index 100% rename from ThinkSound/training/losses/losses.py rename to PrismAudio/training/losses/losses.py diff --git a/ThinkSound/training/utils.py b/PrismAudio/training/utils.py similarity index 100% rename from ThinkSound/training/utils.py rename to PrismAudio/training/utils.py diff --git a/app.py b/app.py index 81b9011b066a970b993bc713d45565e3a18ee503..0212220cc4097872ec9dc036110e57b3bd4f6ad9 100644 --- a/app.py +++ b/app.py @@ -4,9 +4,6 @@ import subprocess import sys subprocess.run(["bash", "setup.sh"], check=True) - -os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".gradio_tmp") -os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True) os.environ["JAX_PLATFORMS"] = "cpu" import gradio as gr import logging @@ -52,10 +49,10 @@ SAMPLE_RATE = 44100 from huggingface_hub import snapshot_download snapshot_download(repo_id="FunAudioLLM/PrismAudio", local_dir="./ckpts") -MODEL_CONFIG_PATH = "ThinkSound/configs/model_configs/prismaudio.json" +MODEL_CONFIG_PATH = "PrismAudio/configs/model_configs/prismaudio.json" CKPT_PATH = "ckpts/prismaudio.ckpt" VAE_CKPT_PATH = "ckpts/vae.ckpt" -VAE_CONFIG_PATH = "ThinkSound/configs/model_configs/stable_audio_2_0_vae.json" +VAE_CONFIG_PATH = "PrismAudio/configs/model_configs/stable_audio_2_0_vae.json" SYNCHFORMER_CKPT_PATH = "ckpts/synchformer_state_dict.pth" DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' @@ -99,8 +96,8 @@ def load_all_models(): log.info("βœ… FeaturesUtils loaded") # ---- 3. Diffusion model ---- - from ThinkSound.models import create_model_from_config - from ThinkSound.models.utils import load_ckpt_state_dict + from PrismAudio.models import create_model_from_config + from PrismAudio.models.utils import load_ckpt_state_dict with open(MODEL_CONFIG_PATH) as f: model_config = json.load(f) @@ -288,7 +285,7 @@ def build_meta(info: dict, duration: float, caption: str): def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor: """Reuses globally loaded diffusion model β€” no reload per call.""" - from ThinkSound.inference.sampling import sample, sample_discrete_euler + from PrismAudio.inference.sampling import sample, sample_discrete_euler import time diffusion = _MODELS["diffusion"] @@ -374,7 +371,7 @@ def generate_audio(video_file, caption: str): return "\n".join(logs) # ---- Working directory (auto-cleaned on exit) ---- - work_dir = tempfile.mkdtemp(dir=os.environ["GRADIO_TEMP_DIR"], prefix="thinksound_") + work_dir = tempfile.mkdtemp(dir=os.environ["GRADIO_TEMP_DIR"], prefix="PrismAudio_") try: # ---- Step 1: Convert / copy to mp4 ---- @@ -476,7 +473,7 @@ def generate_audio(video_file, caption: str): def build_ui() -> gr.Blocks: with gr.Blocks( - title="ThinkSound - Video to Audio Generation", + title="PrismAudio - Video to Audio Generation", theme=gr.themes.Soft(), css=""" .title { text-align:center; font-size:2em; font-weight:bold; margin-bottom:.2em; } @@ -485,7 +482,7 @@ def build_ui() -> gr.Blocks: """, ) as demo: - gr.HTML('
🎡 ThinkSound
') + gr.HTML('
🎡 PrismAudio
') gr.HTML( '
' 'Upload a video and a text prompt β€” ' @@ -543,13 +540,8 @@ def build_ui() -> gr.Blocks: height=400, ) - # ====================================================== - # Example prompts - # ====================================================== - gr.Markdown("---") - gr.Markdown("### πŸ’‘ Example Prompts (click to fill)") - - # ====================================================== + + # ====================================================== # Instructions # ====================================================== with gr.Accordion("πŸ“– Instructions", open=False): @@ -602,7 +594,7 @@ SYNCHFORMER_CKPT_PATH = {SYNCHFORMER_CKPT_PATH} if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="ThinkSound Gradio App") + parser = argparse.ArgumentParser(description="PrismAudio Gradio App") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Gradio server host") parser.add_argument("--server_port", type=int, default=7860, diff --git a/data_utils/extract_training_audio.py b/data_utils/extract_training_audio.py index ed6e4a23fe64b844b5022bda921621e88fd91a66..a74c047f3aec7f76965001726c9de70e0c829959 100644 --- a/data_utils/extract_training_audio.py +++ b/data_utils/extract_training_audio.py @@ -124,7 +124,7 @@ if __name__ == '__main__': parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds') parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples') parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') - parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') parser.add_argument('--start-row', type=int, default=0, help='start row') parser.add_argument('--end-row', type=int, default=None, help='end row') diff --git a/data_utils/extract_training_video.py b/data_utils/extract_training_video.py index 92ea01c61e08d49bd4572d28ce5d30d4e5283b19..53086afcab5f34e9ef446a36707a744b2d99038c 100644 --- a/data_utils/extract_training_video.py +++ b/data_utils/extract_training_video.py @@ -129,7 +129,7 @@ if __name__ == '__main__': parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds') parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples') parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') - parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') parser.add_argument('--start-row', type=int, default=0, help='start row') parser.add_argument('--end-row', type=int, default=None, help='end row') diff --git a/data_utils/prismaudio_data_process.py b/data_utils/prismaudio_data_process.py index c4491e3620bac5801883049f34ccacb43a028ab3..f40797efd3a540024d7320ce99541cad9d0ec27f 100644 --- a/data_utils/prismaudio_data_process.py +++ b/data_utils/prismaudio_data_process.py @@ -137,7 +137,7 @@ if __name__ == '__main__': parser.add_argument('--save-dir', default='results') parser.add_argument('--sample_rate', type=int, default=44100, help='Audio sample rate') parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') - parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') parser.add_argument('--start-row', '-s', type=int, default=0, help='Start row index') parser.add_argument('--end-row', '-e', type=int, default=None, help='End row index') diff --git a/data_utils/v2a_utils/feature_utils_224.py b/data_utils/v2a_utils/feature_utils_224.py index 3bcb012470a9ebbdbd34e8e37a920aa359d44b40..8f84a6cd54a5fee8701eaa8ba203e3c281817c03 100644 --- a/data_utils/v2a_utils/feature_utils_224.py +++ b/data_utils/v2a_utils/feature_utils_224.py @@ -7,9 +7,9 @@ import torch.nn.functional as F from einops import rearrange from open_clip import create_model_from_pretrained from torchvision.transforms import Normalize -from ThinkSound.models.factory import create_model_from_config -from ThinkSound.models.utils import load_ckpt_state_dict -from ThinkSound.models.utils import copy_state_dict +from PrismAudio.models.factory import create_model_from_config +from PrismAudio.models.utils import load_ckpt_state_dict +from PrismAudio.models.utils import copy_state_dict from transformers import AutoModel from transformers import AutoProcessor from transformers import T5EncoderModel, AutoTokenizer diff --git a/data_utils/v2a_utils/feature_utils_224_audio.py b/data_utils/v2a_utils/feature_utils_224_audio.py index b869edf4620f2ae6c1b928d3bf844292b3dcba45..ffda01ff723731629bae66973daf558a3f44efb6 100644 --- a/data_utils/v2a_utils/feature_utils_224_audio.py +++ b/data_utils/v2a_utils/feature_utils_224_audio.py @@ -7,9 +7,9 @@ import torch.nn.functional as F from einops import rearrange # from open_clip import create_model_from_pretrained from torchvision.transforms import Normalize -from ThinkSound.models.factory import create_model_from_config -from ThinkSound.models.utils import load_ckpt_state_dict -from ThinkSound.training.utils import copy_state_dict +from PrismAudio.models.factory import create_model_from_config +from PrismAudio.models.utils import load_ckpt_state_dict +from PrismAudio.training.utils import copy_state_dict from transformers import AutoModel from transformers import AutoProcessor from transformers import T5EncoderModel, AutoTokenizer diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py index 9b12e224087cc0afad7dccae78bc9c96fbcf5260..239099f99463e7c55fe7185b763c4b0d3bd80f7b 100644 --- a/data_utils/v2a_utils/feature_utils_288.py +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -5,8 +5,8 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torchvision.transforms import Normalize -from ThinkSound.models.factory import create_model_from_config -from ThinkSound.models.utils import load_ckpt_state_dict +from PrismAudio.models.factory import create_model_from_config +from PrismAudio.models.utils import load_ckpt_state_dict import einshape import sys import os diff --git a/eval_batch.py b/eval_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..047a147e41fba8c5eb6a3d61848c420caff79a22 --- /dev/null +++ b/eval_batch.py @@ -0,0 +1,136 @@ +from prefigure.prefigure import get_all_args, push_wandb_config +import json +import os +import re +import torch +import torchaudio +from lightning.pytorch import seed_everything +import random +from datetime import datetime +import numpy as np + +from PrismAudio.data.datamodule import DataModule +from PrismAudio.models import create_model_from_config +from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from PrismAudio.inference.sampling import sample, sample_discrete_euler +from pathlib import Path +from tqdm import tqdm + + +def predict_step(diffusion, batch, diffusion_objective, device='cuda:0'): + diffusion = diffusion.to(device) + + reals, metadata = batch + ids = [item['id'] for item in metadata] + batch_size, length = reals.shape[0], reals.shape[2] + with torch.amp.autocast('cuda'): + conditioning = diffusion.conditioner(metadata, device) + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + conditioning['metaclip_features'][~video_exist] = diffusion.model.model.empty_clip_feat + conditioning['sync_features'][~video_exist] = diffusion.model.model.empty_sync_feat + + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + if batch_size > 1: + noise_list = [] + for _ in range(batch_size): + noise_1 = torch.randn([1, diffusion.io_channels, length]).to(device) # ζ―ζ¬‘η”ŸζˆζŽ¨θΏ›RNGηŠΆζ€ + noise_list.append(noise_1) + noise = torch.cat(noise_list, dim=0) + else: + noise = torch.randn([batch_size, diffusion.io_channels, length]).to(device) + + with torch.amp.autocast('cuda'): + + model = diffusion.model + if diffusion_objective == "v": + fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) + elif diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + if diffusion.pretransform is not None: + fakes = diffusion.pretransform.decode(fakes) + + audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + return audios + + +def main(): + args = get_all_args() + + if args.save_dir == '': + args.save_dir = args.results_dir + + + seed = args.seed + if os.environ.get("SLURM_PROCID") is not None: + seed += int(os.environ.get("SLURM_PROCID")) + seed_everything(seed, workers=True) + + # Load config + if args.model_config == '': + args.model_config = "PrismAudio/configs/model_configs/thinksound.json" + with open(args.model_config) as f: + model_config = json.load(f) + + duration = float(args.duration_sec) + sample_rate = model_config["sample_rate"] + latent_length = round(44100 / 64 / 32 * duration) + + model_config["sample_size"] = duration * sample_rate + model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration) + model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration) + model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length + + model = create_model_from_config(model_config) + if args.compile: + model = torch.compile(model) + + model.load_state_dict(torch.load(args.ckpt_dir)) + vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.') + model.pretransform.load_state_dict(vae_state) + + + if args.dataset_config == '': + args.dataset_config = "PrismAudio/configs/multimodal_dataset_demo.json" + with open(args.dataset_config) as f: + dataset_config = json.load(f) + + for td in dataset_config["test_datasets"]: + td["path"] = args.results_dir + + dm = DataModule( + dataset_config, + batch_size=args.batch_size, + test_batch_size=args.test_batch_size, + num_workers=args.num_workers, + sample_rate=model_config["sample_rate"], + sample_size=(float)(args.duration_sec) * model_config["sample_rate"], + audio_channels=model_config.get("audio_channels", 2), + latent_length=round(44100/64/32*duration), + ) + dm.setup('predict') + dl = dm.predict_dataloader() + + current_date = datetime.now() + formatted_date = current_date.strftime('%m%d') + + audio_dir = os.path.join(args.save_dir,f'{formatted_date}_batch_size'+str(args.test_batch_size)) + os.makedirs(audio_dir,exist_ok=True) + + for batch in tqdm(dl, desc="Predicting"): + audio = predict_step( + model, + batch=batch, + diffusion_objective=model_config["model"]["diffusion"]["diffusion_objective"], + device='cuda:0' + ) + + _, metadata = batch + ids = [item['id'] for item in metadata] + + for i in range(audio.size(0)): + id_str = ids[i] if i < len(ids) else f"unknown_{i}" + torchaudio.save(os.path.join(audio_dir, f"{id_str}.wav"), audio[i], 44100) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/predict.py b/predict.py index 44fbe7276fd289c69f403bab1270adbdad0d1443..7062b5a92c969d1ef6824972b944dec767f5df48 100644 --- a/predict.py +++ b/predict.py @@ -8,9 +8,9 @@ from lightning.pytorch import seed_everything import random from datetime import datetime import numpy as np -from ThinkSound.models import create_model_from_config -from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model -from ThinkSound.inference.sampling import sample, sample_discrete_euler +from PrismAudio.models import create_model_from_config +from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from PrismAudio.inference.sampling import sample, sample_discrete_euler from pathlib import Path @@ -114,7 +114,7 @@ def main(): #Get JSON config from args.model_config if args.model_config == '': - args.model_config = "ThinkSound/configs/model_configs/thinksound.json" + args.model_config = "PrismAudio/configs/model_configs/thinksound.json" with open(args.model_config) as f: model_config = json.load(f) diff --git a/scripts/PrismAudio/grpo_1node4gpus.sh b/scripts/PrismAudio/grpo_1node4gpus.sh index e360e3c449bd3c413bd04ad54d411378731a282d..7f07c22df35e0ba88fda8b144fdc7782ce99ee88 100644 --- a/scripts/PrismAudio/grpo_1node4gpus.sh +++ b/scripts/PrismAudio/grpo_1node4gpus.sh @@ -1,4 +1,5 @@ + export NCCL_IB_DISABLE=1 export NCCL_IB_HCA=mlx5 export NCCL_DEBUG=WARN diff --git a/set_up.sh b/scripts/PrismAudio/setup/build_env.sh similarity index 100% rename from set_up.sh rename to scripts/PrismAudio/setup/build_env.sh diff --git a/scripts/ThinkSound/demo.sh b/scripts/ThinkSound/demo.sh index cbc681b98a866d69d34db5c227c406dbd80b721f..96b4f724da1c6ee0bed1ab1ba47c185580a56674 100755 --- a/scripts/ThinkSound/demo.sh +++ b/scripts/ThinkSound/demo.sh @@ -66,7 +66,7 @@ fi echo "⏳ Running model inference..." python predict.py \ --model-config "$model_config" \ - --duration-sec "$DURATION_SEC" \ + --duration-sec "$DURATION" \ --results-dir "results"\ if [ $? -ne 0 ]; then diff --git a/setup.sh b/setup.sh index d328f1a822a81e305ba1647909a0774c93ee3118..0500c615734a4e68c7c60407f363393bcc92fa4f 100644 --- a/setup.sh +++ b/setup.sh @@ -5,5 +5,3 @@ cd .. pip install -r scripts/PrismAudio/setup/requirements.txt pip install tensorflow-cpu==2.15.0 pip install facenet_pytorch==2.6.0 --no-deps - -conda install -y -c conda-forge 'ffmpeg<7' diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..74f90e5546c286347c0d9f2346c1ae61b485cf5e --- /dev/null +++ b/train.py @@ -0,0 +1,193 @@ +from prefigure.prefigure import get_all_args, push_wandb_config +import json +import os +import torch +import torchaudio +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.tuner import Tuner +from lightning.pytorch import seed_everything +import random +from datetime import datetime +# from PrismAudio.data.dataset import create_dataloader_from_config +from PrismAudio.data.datamodule import DataModule +from PrismAudio.models import create_model_from_config +from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from PrismAudio.training import create_training_wrapper_from_config, create_demo_callback_from_config +from PrismAudio.training.utils import copy_state_dict + +class ExceptionCallback(Callback): + def on_exception(self, trainer, module, err): + print(f'{type(err).__name__}: {err}') + +class ModelConfigEmbedderCallback(Callback): + def __init__(self, model_config): + self.model_config = model_config + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint["model_config"] = self.model_config + +class CustomWriter(BasePredictionWriter): + + def __init__(self, output_dir, write_interval='batch'): + super().__init__(write_interval) + self.output_dir = output_dir + + def write_on_batch_end(self, trainer, pl_module, predictions, batch_indices, batch, batch_idx, dataloader_idx): + + audios = predictions + ids = [item['id'] for item in batch[1]] + # θŽ·ε–ε½“ε‰ζ—₯期 + current_date = datetime.now() + + # ζ ΌεΌεŒ–ζ—₯期为 'MMDD' 归式 + formatted_date = current_date.strftime('%m%d') + if trainer.ckpt_path is None: + global_step = pl_module.global_step // 1000 + else: + global_step = int(trainer.ckpt_path.split("-step=")[-1].split(".")[0]) // 1000 + os.makedirs(os.path.join(self.output_dir, f'{formatted_date}_step{global_step}k'),exist_ok=True) + for audio, id in zip(audios, ids): + save_path = os.path.join(self.output_dir, f'{formatted_date}_step{global_step}k', f'{id}.wav') + torchaudio.save(save_path, audio, 44100) + +def main(): + + args = get_all_args() + + seed = args.seed + + # Set a different seed for each process if using SLURM + if os.environ.get("SLURM_PROCID") is not None: + seed += int(os.environ.get("SLURM_PROCID")) + + # random.seed(seed) + # torch.manual_seed(seed) + seed_everything(seed, workers=True) + print('########################') + print(f'precision is {args.precision}') + print('########################') + #Get JSON config from args.model_config + with open(args.model_config) as f: + model_config = json.load(f) + + with open(args.dataset_config) as f: + dataset_config = json.load(f) + + # train_dl = create_dataloader_from_config( + # dataset_config, + # batch_size=args.batch_size, + # num_workers=args.num_workers, + # sample_rate=model_config["sample_rate"], + # sample_size=model_config["sample_size"], + # audio_channels=model_config.get("audio_channels", 2), + # ) + dm = DataModule( + dataset_config, + batch_size=args.batch_size, + test_batch_size=args.test_batch_size, + num_workers=args.num_workers, + sample_rate=model_config["sample_rate"], + sample_size=model_config["sample_size"], + audio_channels=model_config.get("audio_channels", 2), + repeat_num=args.repeat_num + ) + + model = create_model_from_config(model_config) + + ## speed by torch.compile + if args.compile: + model = torch.compile(model) + + if args.pretrained_ckpt_path: + copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion. + + if args.remove_pretransform_weight_norm == "pre_load": + remove_weight_norm_from_model(model.pretransform) + # import ipdb + # ipdb.set_trace() + if args.pretransform_ckpt_path: + load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.') + # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")} + model.pretransform.load_state_dict(load_vae_state) + + # Remove weight_norm from the pretransform if specified + if args.remove_pretransform_weight_norm == "post_load": + remove_weight_norm_from_model(model.pretransform) + + training_wrapper = create_training_wrapper_from_config(model_config, model) + + wandb_logger = L.pytorch.loggers.WandbLogger(project=args.name) + wandb_logger.watch(training_wrapper) + + exc_callback = ExceptionCallback() + + if args.save_dir and isinstance(wandb_logger.experiment.id, str): + checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints") + else: + checkpoint_dir = None + + # ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='val_loss', mode='min', save_top_k=14) + ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='epoch', mode='max', save_top_k=14) + save_model_config_callback = ModelConfigEmbedderCallback(model_config) + # audio_dir = os.path.join(args.save_dir, args.name, "audios") + # pred_writer = CustomWriter(output_dir=audio_dir, write_interval="batch") + timer = Timer(duration="00:16:00:00") + demo_callback = create_demo_callback_from_config(model_config, demo_dl=dm) + + #Combine args and config dicts + args_dict = vars(args) + args_dict.update({"model_config": model_config}) + args_dict.update({"dataset_config": dataset_config}) + push_wandb_config(wandb_logger, args_dict) + + #Set multi-GPU strategy if specified + if args.strategy: + if args.strategy == "deepspeed": + from pytorch_lightning.strategies import DeepSpeedStrategy + strategy = DeepSpeedStrategy(stage=2, + contiguous_gradients=True, + overlap_comm=True, + reduce_scatter=True, + reduce_bucket_size=5e8, + allgather_bucket_size=5e8, + load_full_weights=True + ) + else: + strategy = args.strategy + else: + strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" + + trainer = L.Trainer( + devices=args.num_gpus, + accelerator="gpu", + num_nodes = args.num_nodes, + strategy=strategy, + precision=args.precision, + accumulate_grad_batches=args.accum_batches, + callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback, timer], + logger=wandb_logger, + log_every_n_steps=1, + max_epochs=90, + default_root_dir=args.save_dir, + gradient_clip_val=args.gradient_clip_val, + reload_dataloaders_every_n_epochs = 0, + check_val_every_n_epoch=2, + ) + + # query training/validation/test time (in seconds) + # timer.time_elapsed("train") + # timer.start_time("validate") + # tuner = Tuner(trainer) + # Auto-scale batch size by growing it exponentially (default) + # tuner.scale_batch_size(training_wrapper, mode="power") + # tuner.lr_find(training_wrapper) + # trainer.tune(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None) + # trainer.validate(training_wrapper, dm) + trainer.fit(training_wrapper, dm, ckpt_path=args.ckpt_path if args.ckpt_path else None) + # trainer.predict(training_wrapper, dm, return_predictions=False) + +if __name__ == '__main__': + main() \ No newline at end of file