diff --git a/ThinkSound/__init__.py b/PrismAudio/__init__.py
similarity index 100%
rename from ThinkSound/__init__.py
rename to PrismAudio/__init__.py
diff --git a/ThinkSound/configs/model_configs/prismaudio.json b/PrismAudio/configs/model_configs/prismaudio.json
similarity index 100%
rename from ThinkSound/configs/model_configs/prismaudio.json
rename to PrismAudio/configs/model_configs/prismaudio.json
diff --git a/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json b/PrismAudio/configs/model_configs/stable_audio_2_0_vae.json
similarity index 100%
rename from ThinkSound/configs/model_configs/stable_audio_2_0_vae.json
rename to PrismAudio/configs/model_configs/stable_audio_2_0_vae.json
diff --git a/ThinkSound/configs/model_configs/thinksound.json b/PrismAudio/configs/model_configs/thinksound.json
similarity index 100%
rename from ThinkSound/configs/model_configs/thinksound.json
rename to PrismAudio/configs/model_configs/thinksound.json
diff --git a/ThinkSound/configs/multimodal_dataset_demo.json b/PrismAudio/configs/multimodal_dataset_demo.json
similarity index 100%
rename from ThinkSound/configs/multimodal_dataset_demo.json
rename to PrismAudio/configs/multimodal_dataset_demo.json
diff --git a/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json b/PrismAudio/configs/multimodal_dataset_demo_prismaudio.json
similarity index 59%
rename from ThinkSound/configs/multimodal_dataset_demo_prismaudio.json
rename to PrismAudio/configs/multimodal_dataset_demo_prismaudio.json
index 8a3e2320175dd91bd176d4bded1847a46e865e62..8ebd638d2e9bfe2b7934d7fe1f2cf0d49e9e0a7e 100644
--- a/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json
+++ b/PrismAudio/configs/multimodal_dataset_demo_prismaudio.json
@@ -3,22 +3,22 @@
"datasets": [
{
"id": "vggsound",
- "path": "test",
- "split_path": "test/test.txt"
+ "path": "data/train",
+ "split_path": "split/train.txt"
}
],
"val_datasets": [
{
"id": "vggsound",
- "path": "test",
- "split_path": "test/test.txt"
+ "path": "data/test",
+ "split_path": "split/test.txt"
}
],
"test_datasets": [
{
"id": "vggsound",
- "path": "test",
- "split_path": "test/test.txt"
+ "path": "data/test",
+ "split_path": "split/test.txt"
}
],
"random_crop": false,
diff --git a/ThinkSound/data/__init__.py b/PrismAudio/data/__init__.py
similarity index 100%
rename from ThinkSound/data/__init__.py
rename to PrismAudio/data/__init__.py
diff --git a/ThinkSound/data/datamodule.py b/PrismAudio/data/datamodule.py
similarity index 100%
rename from ThinkSound/data/datamodule.py
rename to PrismAudio/data/datamodule.py
diff --git a/ThinkSound/data/dataset.py b/PrismAudio/data/dataset.py
similarity index 100%
rename from ThinkSound/data/dataset.py
rename to PrismAudio/data/dataset.py
diff --git a/ThinkSound/data/utils.py b/PrismAudio/data/utils.py
similarity index 100%
rename from ThinkSound/data/utils.py
rename to PrismAudio/data/utils.py
diff --git a/ThinkSound/inference/__init__.py b/PrismAudio/inference/__init__.py
similarity index 100%
rename from ThinkSound/inference/__init__.py
rename to PrismAudio/inference/__init__.py
diff --git a/ThinkSound/inference/generation.py b/PrismAudio/inference/generation.py
similarity index 100%
rename from ThinkSound/inference/generation.py
rename to PrismAudio/inference/generation.py
diff --git a/ThinkSound/inference/sampling.py b/PrismAudio/inference/sampling.py
similarity index 100%
rename from ThinkSound/inference/sampling.py
rename to PrismAudio/inference/sampling.py
diff --git a/ThinkSound/inference/utils.py b/PrismAudio/inference/utils.py
similarity index 100%
rename from ThinkSound/inference/utils.py
rename to PrismAudio/inference/utils.py
diff --git a/ThinkSound/interface/__init__.py b/PrismAudio/interface/__init__.py
similarity index 100%
rename from ThinkSound/interface/__init__.py
rename to PrismAudio/interface/__init__.py
diff --git a/ThinkSound/interface/aeiou.py b/PrismAudio/interface/aeiou.py
similarity index 100%
rename from ThinkSound/interface/aeiou.py
rename to PrismAudio/interface/aeiou.py
diff --git a/ThinkSound/interface/gradio.py b/PrismAudio/interface/gradio.py
similarity index 100%
rename from ThinkSound/interface/gradio.py
rename to PrismAudio/interface/gradio.py
diff --git a/ThinkSound/models/__init__.py b/PrismAudio/models/__init__.py
similarity index 100%
rename from ThinkSound/models/__init__.py
rename to PrismAudio/models/__init__.py
diff --git a/ThinkSound/models/adp.py b/PrismAudio/models/adp.py
similarity index 100%
rename from ThinkSound/models/adp.py
rename to PrismAudio/models/adp.py
diff --git a/ThinkSound/models/autoencoders.py b/PrismAudio/models/autoencoders.py
similarity index 100%
rename from ThinkSound/models/autoencoders.py
rename to PrismAudio/models/autoencoders.py
diff --git a/ThinkSound/models/blocks.py b/PrismAudio/models/blocks.py
similarity index 100%
rename from ThinkSound/models/blocks.py
rename to PrismAudio/models/blocks.py
diff --git a/ThinkSound/models/bottleneck.py b/PrismAudio/models/bottleneck.py
similarity index 100%
rename from ThinkSound/models/bottleneck.py
rename to PrismAudio/models/bottleneck.py
diff --git a/ThinkSound/models/codebook_patterns.py b/PrismAudio/models/codebook_patterns.py
similarity index 100%
rename from ThinkSound/models/codebook_patterns.py
rename to PrismAudio/models/codebook_patterns.py
diff --git a/ThinkSound/models/conditioners.py b/PrismAudio/models/conditioners.py
similarity index 100%
rename from ThinkSound/models/conditioners.py
rename to PrismAudio/models/conditioners.py
diff --git a/ThinkSound/models/diffusion.py b/PrismAudio/models/diffusion.py
similarity index 100%
rename from ThinkSound/models/diffusion.py
rename to PrismAudio/models/diffusion.py
diff --git a/ThinkSound/models/diffusion_prior.py b/PrismAudio/models/diffusion_prior.py
similarity index 100%
rename from ThinkSound/models/diffusion_prior.py
rename to PrismAudio/models/diffusion_prior.py
diff --git a/ThinkSound/models/discriminators.py b/PrismAudio/models/discriminators.py
similarity index 100%
rename from ThinkSound/models/discriminators.py
rename to PrismAudio/models/discriminators.py
diff --git a/ThinkSound/models/dit (1).py b/PrismAudio/models/dit (1).py
similarity index 100%
rename from ThinkSound/models/dit (1).py
rename to PrismAudio/models/dit (1).py
diff --git a/ThinkSound/models/dit.py b/PrismAudio/models/dit.py
similarity index 100%
rename from ThinkSound/models/dit.py
rename to PrismAudio/models/dit.py
diff --git a/ThinkSound/models/factory.py b/PrismAudio/models/factory.py
similarity index 100%
rename from ThinkSound/models/factory.py
rename to PrismAudio/models/factory.py
diff --git a/ThinkSound/models/lm.py b/PrismAudio/models/lm.py
similarity index 100%
rename from ThinkSound/models/lm.py
rename to PrismAudio/models/lm.py
diff --git a/ThinkSound/models/lm_backbone.py b/PrismAudio/models/lm_backbone.py
similarity index 100%
rename from ThinkSound/models/lm_backbone.py
rename to PrismAudio/models/lm_backbone.py
diff --git a/ThinkSound/models/lm_continuous.py b/PrismAudio/models/lm_continuous.py
similarity index 100%
rename from ThinkSound/models/lm_continuous.py
rename to PrismAudio/models/lm_continuous.py
diff --git a/ThinkSound/models/local_attention.py b/PrismAudio/models/local_attention.py
similarity index 100%
rename from ThinkSound/models/local_attention.py
rename to PrismAudio/models/local_attention.py
diff --git a/ThinkSound/models/meta_queries/__init__.py b/PrismAudio/models/meta_queries/__init__.py
similarity index 100%
rename from ThinkSound/models/meta_queries/__init__.py
rename to PrismAudio/models/meta_queries/__init__.py
diff --git a/ThinkSound/models/meta_queries/metaquery.py b/PrismAudio/models/meta_queries/metaquery.py
similarity index 100%
rename from ThinkSound/models/meta_queries/metaquery.py
rename to PrismAudio/models/meta_queries/metaquery.py
diff --git a/ThinkSound/models/meta_queries/model.py b/PrismAudio/models/meta_queries/model.py
similarity index 100%
rename from ThinkSound/models/meta_queries/model.py
rename to PrismAudio/models/meta_queries/model.py
diff --git a/ThinkSound/models/meta_queries/models/__init__.py b/PrismAudio/models/meta_queries/models/__init__.py
similarity index 100%
rename from ThinkSound/models/meta_queries/models/__init__.py
rename to PrismAudio/models/meta_queries/models/__init__.py
diff --git a/ThinkSound/models/meta_queries/models/process_audio_info.py b/PrismAudio/models/meta_queries/models/process_audio_info.py
similarity index 100%
rename from ThinkSound/models/meta_queries/models/process_audio_info.py
rename to PrismAudio/models/meta_queries/models/process_audio_info.py
diff --git a/ThinkSound/models/meta_queries/models/qwen25VL.py b/PrismAudio/models/meta_queries/models/qwen25VL.py
similarity index 100%
rename from ThinkSound/models/meta_queries/models/qwen25VL.py
rename to PrismAudio/models/meta_queries/models/qwen25VL.py
diff --git a/ThinkSound/models/meta_queries/models/qwen25omni.py b/PrismAudio/models/meta_queries/models/qwen25omni.py
similarity index 100%
rename from ThinkSound/models/meta_queries/models/qwen25omni.py
rename to PrismAudio/models/meta_queries/models/qwen25omni.py
diff --git a/ThinkSound/models/meta_queries/transformer_encoder.py b/PrismAudio/models/meta_queries/transformer_encoder.py
similarity index 100%
rename from ThinkSound/models/meta_queries/transformer_encoder.py
rename to PrismAudio/models/meta_queries/transformer_encoder.py
diff --git a/ThinkSound/models/mmdit.py b/PrismAudio/models/mmdit.py
similarity index 100%
rename from ThinkSound/models/mmdit.py
rename to PrismAudio/models/mmdit.py
diff --git a/ThinkSound/models/mmmodules/__init__.py b/PrismAudio/models/mmmodules/__init__.py
similarity index 100%
rename from ThinkSound/models/mmmodules/__init__.py
rename to PrismAudio/models/mmmodules/__init__.py
diff --git a/ThinkSound/models/mmmodules/ext/__init__.py b/PrismAudio/models/mmmodules/ext/__init__.py
similarity index 100%
rename from ThinkSound/models/mmmodules/ext/__init__.py
rename to PrismAudio/models/mmmodules/ext/__init__.py
diff --git a/ThinkSound/models/mmmodules/ext/rotary_embeddings.py b/PrismAudio/models/mmmodules/ext/rotary_embeddings.py
similarity index 100%
rename from ThinkSound/models/mmmodules/ext/rotary_embeddings.py
rename to PrismAudio/models/mmmodules/ext/rotary_embeddings.py
diff --git a/ThinkSound/models/mmmodules/ext/stft_converter.py b/PrismAudio/models/mmmodules/ext/stft_converter.py
similarity index 100%
rename from ThinkSound/models/mmmodules/ext/stft_converter.py
rename to PrismAudio/models/mmmodules/ext/stft_converter.py
diff --git a/ThinkSound/models/mmmodules/ext/stft_converter_mel.py b/PrismAudio/models/mmmodules/ext/stft_converter_mel.py
similarity index 100%
rename from ThinkSound/models/mmmodules/ext/stft_converter_mel.py
rename to PrismAudio/models/mmmodules/ext/stft_converter_mel.py
diff --git a/ThinkSound/models/mmmodules/model/__init__.py b/PrismAudio/models/mmmodules/model/__init__.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/__init__.py
rename to PrismAudio/models/mmmodules/model/__init__.py
diff --git a/ThinkSound/models/mmmodules/model/embeddings.py b/PrismAudio/models/mmmodules/model/embeddings.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/embeddings.py
rename to PrismAudio/models/mmmodules/model/embeddings.py
diff --git a/ThinkSound/models/mmmodules/model/flow_matching.py b/PrismAudio/models/mmmodules/model/flow_matching.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/flow_matching.py
rename to PrismAudio/models/mmmodules/model/flow_matching.py
diff --git a/ThinkSound/models/mmmodules/model/low_level.py b/PrismAudio/models/mmmodules/model/low_level.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/low_level.py
rename to PrismAudio/models/mmmodules/model/low_level.py
diff --git a/ThinkSound/models/mmmodules/model/networks.py b/PrismAudio/models/mmmodules/model/networks.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/networks.py
rename to PrismAudio/models/mmmodules/model/networks.py
diff --git a/ThinkSound/models/mmmodules/model/sequence_config.py b/PrismAudio/models/mmmodules/model/sequence_config.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/sequence_config.py
rename to PrismAudio/models/mmmodules/model/sequence_config.py
diff --git a/ThinkSound/models/mmmodules/model/transformer_layers.py b/PrismAudio/models/mmmodules/model/transformer_layers.py
similarity index 100%
rename from ThinkSound/models/mmmodules/model/transformer_layers.py
rename to PrismAudio/models/mmmodules/model/transformer_layers.py
diff --git a/ThinkSound/models/mmmodules/runner.py b/PrismAudio/models/mmmodules/runner.py
similarity index 100%
rename from ThinkSound/models/mmmodules/runner.py
rename to PrismAudio/models/mmmodules/runner.py
diff --git a/ThinkSound/models/mmmodules/sample.py b/PrismAudio/models/mmmodules/sample.py
similarity index 100%
rename from ThinkSound/models/mmmodules/sample.py
rename to PrismAudio/models/mmmodules/sample.py
diff --git a/ThinkSound/models/pqmf.py b/PrismAudio/models/pqmf.py
similarity index 100%
rename from ThinkSound/models/pqmf.py
rename to PrismAudio/models/pqmf.py
diff --git a/ThinkSound/models/pretrained.py b/PrismAudio/models/pretrained.py
similarity index 100%
rename from ThinkSound/models/pretrained.py
rename to PrismAudio/models/pretrained.py
diff --git a/ThinkSound/models/pretransforms.py b/PrismAudio/models/pretransforms.py
similarity index 100%
rename from ThinkSound/models/pretransforms.py
rename to PrismAudio/models/pretransforms.py
diff --git a/ThinkSound/models/transformer (1).py b/PrismAudio/models/transformer (1).py
similarity index 100%
rename from ThinkSound/models/transformer (1).py
rename to PrismAudio/models/transformer (1).py
diff --git a/ThinkSound/models/transformer.py b/PrismAudio/models/transformer.py
similarity index 100%
rename from ThinkSound/models/transformer.py
rename to PrismAudio/models/transformer.py
diff --git a/ThinkSound/models/utils.py b/PrismAudio/models/utils.py
similarity index 100%
rename from ThinkSound/models/utils.py
rename to PrismAudio/models/utils.py
diff --git a/ThinkSound/models/wavelets.py b/PrismAudio/models/wavelets.py
similarity index 100%
rename from ThinkSound/models/wavelets.py
rename to PrismAudio/models/wavelets.py
diff --git a/ThinkSound/training/__init__.py b/PrismAudio/training/__init__.py
similarity index 100%
rename from ThinkSound/training/__init__.py
rename to PrismAudio/training/__init__.py
diff --git a/ThinkSound/training/autoencoders.py b/PrismAudio/training/autoencoders.py
similarity index 100%
rename from ThinkSound/training/autoencoders.py
rename to PrismAudio/training/autoencoders.py
diff --git a/ThinkSound/training/autoencoders_1.py b/PrismAudio/training/autoencoders_1.py
similarity index 100%
rename from ThinkSound/training/autoencoders_1.py
rename to PrismAudio/training/autoencoders_1.py
diff --git a/ThinkSound/training/diffusion.py b/PrismAudio/training/diffusion.py
similarity index 100%
rename from ThinkSound/training/diffusion.py
rename to PrismAudio/training/diffusion.py
diff --git a/ThinkSound/training/factory.py b/PrismAudio/training/factory.py
similarity index 100%
rename from ThinkSound/training/factory.py
rename to PrismAudio/training/factory.py
diff --git a/ThinkSound/training/lm.py b/PrismAudio/training/lm.py
similarity index 100%
rename from ThinkSound/training/lm.py
rename to PrismAudio/training/lm.py
diff --git a/ThinkSound/training/lm_continuous.py b/PrismAudio/training/lm_continuous.py
similarity index 100%
rename from ThinkSound/training/lm_continuous.py
rename to PrismAudio/training/lm_continuous.py
diff --git a/ThinkSound/training/losses/__init__.py b/PrismAudio/training/losses/__init__.py
similarity index 100%
rename from ThinkSound/training/losses/__init__.py
rename to PrismAudio/training/losses/__init__.py
diff --git a/ThinkSound/training/losses/auraloss.py b/PrismAudio/training/losses/auraloss.py
similarity index 100%
rename from ThinkSound/training/losses/auraloss.py
rename to PrismAudio/training/losses/auraloss.py
diff --git a/ThinkSound/training/losses/losses.py b/PrismAudio/training/losses/losses.py
similarity index 100%
rename from ThinkSound/training/losses/losses.py
rename to PrismAudio/training/losses/losses.py
diff --git a/ThinkSound/training/utils.py b/PrismAudio/training/utils.py
similarity index 100%
rename from ThinkSound/training/utils.py
rename to PrismAudio/training/utils.py
diff --git a/app.py b/app.py
index 81b9011b066a970b993bc713d45565e3a18ee503..0212220cc4097872ec9dc036110e57b3bd4f6ad9 100644
--- a/app.py
+++ b/app.py
@@ -4,9 +4,6 @@ import subprocess
import sys
subprocess.run(["bash", "setup.sh"], check=True)
-
-os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".gradio_tmp")
-os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
os.environ["JAX_PLATFORMS"] = "cpu"
import gradio as gr
import logging
@@ -52,10 +49,10 @@ SAMPLE_RATE = 44100
from huggingface_hub import snapshot_download
snapshot_download(repo_id="FunAudioLLM/PrismAudio", local_dir="./ckpts")
-MODEL_CONFIG_PATH = "ThinkSound/configs/model_configs/prismaudio.json"
+MODEL_CONFIG_PATH = "PrismAudio/configs/model_configs/prismaudio.json"
CKPT_PATH = "ckpts/prismaudio.ckpt"
VAE_CKPT_PATH = "ckpts/vae.ckpt"
-VAE_CONFIG_PATH = "ThinkSound/configs/model_configs/stable_audio_2_0_vae.json"
+VAE_CONFIG_PATH = "PrismAudio/configs/model_configs/stable_audio_2_0_vae.json"
SYNCHFORMER_CKPT_PATH = "ckpts/synchformer_state_dict.pth"
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
@@ -99,8 +96,8 @@ def load_all_models():
log.info("β
FeaturesUtils loaded")
# ---- 3. Diffusion model ----
- from ThinkSound.models import create_model_from_config
- from ThinkSound.models.utils import load_ckpt_state_dict
+ from PrismAudio.models import create_model_from_config
+ from PrismAudio.models.utils import load_ckpt_state_dict
with open(MODEL_CONFIG_PATH) as f:
model_config = json.load(f)
@@ -288,7 +285,7 @@ def build_meta(info: dict, duration: float, caption: str):
def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor:
"""Reuses globally loaded diffusion model β no reload per call."""
- from ThinkSound.inference.sampling import sample, sample_discrete_euler
+ from PrismAudio.inference.sampling import sample, sample_discrete_euler
import time
diffusion = _MODELS["diffusion"]
@@ -374,7 +371,7 @@ def generate_audio(video_file, caption: str):
return "\n".join(logs)
# ---- Working directory (auto-cleaned on exit) ----
- work_dir = tempfile.mkdtemp(dir=os.environ["GRADIO_TEMP_DIR"], prefix="thinksound_")
+ work_dir = tempfile.mkdtemp(dir=os.environ["GRADIO_TEMP_DIR"], prefix="PrismAudio_")
try:
# ---- Step 1: Convert / copy to mp4 ----
@@ -476,7 +473,7 @@ def generate_audio(video_file, caption: str):
def build_ui() -> gr.Blocks:
with gr.Blocks(
- title="ThinkSound - Video to Audio Generation",
+ title="PrismAudio - Video to Audio Generation",
theme=gr.themes.Soft(),
css="""
.title { text-align:center; font-size:2em; font-weight:bold; margin-bottom:.2em; }
@@ -485,7 +482,7 @@ def build_ui() -> gr.Blocks:
""",
) as demo:
- gr.HTML('
π΅ ThinkSound
')
+ gr.HTML('π΅ PrismAudio
')
gr.HTML(
''
'Upload a video and a text prompt β '
@@ -543,13 +540,8 @@ def build_ui() -> gr.Blocks:
height=400,
)
- # ======================================================
- # Example prompts
- # ======================================================
- gr.Markdown("---")
- gr.Markdown("### π‘ Example Prompts (click to fill)")
-
- # ======================================================
+
+ # ======================================================
# Instructions
# ======================================================
with gr.Accordion("π Instructions", open=False):
@@ -602,7 +594,7 @@ SYNCHFORMER_CKPT_PATH = {SYNCHFORMER_CKPT_PATH}
if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser(description="ThinkSound Gradio App")
+ parser = argparse.ArgumentParser(description="PrismAudio Gradio App")
parser.add_argument("--server_name", type=str, default="0.0.0.0",
help="Gradio server host")
parser.add_argument("--server_port", type=int, default=7860,
diff --git a/data_utils/extract_training_audio.py b/data_utils/extract_training_audio.py
index ed6e4a23fe64b844b5022bda921621e88fd91a66..a74c047f3aec7f76965001726c9de70e0c829959 100644
--- a/data_utils/extract_training_audio.py
+++ b/data_utils/extract_training_audio.py
@@ -124,7 +124,7 @@ if __name__ == '__main__':
parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
- parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
+ parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', type=int, default=0, help='start row')
parser.add_argument('--end-row', type=int, default=None, help='end row')
diff --git a/data_utils/extract_training_video.py b/data_utils/extract_training_video.py
index 92ea01c61e08d49bd4572d28ce5d30d4e5283b19..53086afcab5f34e9ef446a36707a744b2d99038c 100644
--- a/data_utils/extract_training_video.py
+++ b/data_utils/extract_training_video.py
@@ -129,7 +129,7 @@ if __name__ == '__main__':
parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
- parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
+ parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', type=int, default=0, help='start row')
parser.add_argument('--end-row', type=int, default=None, help='end row')
diff --git a/data_utils/prismaudio_data_process.py b/data_utils/prismaudio_data_process.py
index c4491e3620bac5801883049f34ccacb43a028ab3..f40797efd3a540024d7320ce99541cad9d0ec27f 100644
--- a/data_utils/prismaudio_data_process.py
+++ b/data_utils/prismaudio_data_process.py
@@ -137,7 +137,7 @@ if __name__ == '__main__':
parser.add_argument('--save-dir', default='results')
parser.add_argument('--sample_rate', type=int, default=44100, help='Audio sample rate')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
- parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
+ parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', '-s', type=int, default=0, help='Start row index')
parser.add_argument('--end-row', '-e', type=int, default=None, help='End row index')
diff --git a/data_utils/v2a_utils/feature_utils_224.py b/data_utils/v2a_utils/feature_utils_224.py
index 3bcb012470a9ebbdbd34e8e37a920aa359d44b40..8f84a6cd54a5fee8701eaa8ba203e3c281817c03 100644
--- a/data_utils/v2a_utils/feature_utils_224.py
+++ b/data_utils/v2a_utils/feature_utils_224.py
@@ -7,9 +7,9 @@ import torch.nn.functional as F
from einops import rearrange
from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
-from ThinkSound.models.factory import create_model_from_config
-from ThinkSound.models.utils import load_ckpt_state_dict
-from ThinkSound.models.utils import copy_state_dict
+from PrismAudio.models.factory import create_model_from_config
+from PrismAudio.models.utils import load_ckpt_state_dict
+from PrismAudio.models.utils import copy_state_dict
from transformers import AutoModel
from transformers import AutoProcessor
from transformers import T5EncoderModel, AutoTokenizer
diff --git a/data_utils/v2a_utils/feature_utils_224_audio.py b/data_utils/v2a_utils/feature_utils_224_audio.py
index b869edf4620f2ae6c1b928d3bf844292b3dcba45..ffda01ff723731629bae66973daf558a3f44efb6 100644
--- a/data_utils/v2a_utils/feature_utils_224_audio.py
+++ b/data_utils/v2a_utils/feature_utils_224_audio.py
@@ -7,9 +7,9 @@ import torch.nn.functional as F
from einops import rearrange
# from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
-from ThinkSound.models.factory import create_model_from_config
-from ThinkSound.models.utils import load_ckpt_state_dict
-from ThinkSound.training.utils import copy_state_dict
+from PrismAudio.models.factory import create_model_from_config
+from PrismAudio.models.utils import load_ckpt_state_dict
+from PrismAudio.training.utils import copy_state_dict
from transformers import AutoModel
from transformers import AutoProcessor
from transformers import T5EncoderModel, AutoTokenizer
diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py
index 9b12e224087cc0afad7dccae78bc9c96fbcf5260..239099f99463e7c55fe7185b763c4b0d3bd80f7b 100644
--- a/data_utils/v2a_utils/feature_utils_288.py
+++ b/data_utils/v2a_utils/feature_utils_288.py
@@ -5,8 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchvision.transforms import Normalize
-from ThinkSound.models.factory import create_model_from_config
-from ThinkSound.models.utils import load_ckpt_state_dict
+from PrismAudio.models.factory import create_model_from_config
+from PrismAudio.models.utils import load_ckpt_state_dict
import einshape
import sys
import os
diff --git a/eval_batch.py b/eval_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..047a147e41fba8c5eb6a3d61848c420caff79a22
--- /dev/null
+++ b/eval_batch.py
@@ -0,0 +1,136 @@
+from prefigure.prefigure import get_all_args, push_wandb_config
+import json
+import os
+import re
+import torch
+import torchaudio
+from lightning.pytorch import seed_everything
+import random
+from datetime import datetime
+import numpy as np
+
+from PrismAudio.data.datamodule import DataModule
+from PrismAudio.models import create_model_from_config
+from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
+from PrismAudio.inference.sampling import sample, sample_discrete_euler
+from pathlib import Path
+from tqdm import tqdm
+
+
+def predict_step(diffusion, batch, diffusion_objective, device='cuda:0'):
+ diffusion = diffusion.to(device)
+
+ reals, metadata = batch
+ ids = [item['id'] for item in metadata]
+ batch_size, length = reals.shape[0], reals.shape[2]
+ with torch.amp.autocast('cuda'):
+ conditioning = diffusion.conditioner(metadata, device)
+
+ video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
+ conditioning['metaclip_features'][~video_exist] = diffusion.model.model.empty_clip_feat
+ conditioning['sync_features'][~video_exist] = diffusion.model.model.empty_sync_feat
+
+ cond_inputs = diffusion.get_conditioning_inputs(conditioning)
+ if batch_size > 1:
+ noise_list = []
+ for _ in range(batch_size):
+ noise_1 = torch.randn([1, diffusion.io_channels, length]).to(device) # ζ―欑ηζζ¨θΏRNGηΆζ
+ noise_list.append(noise_1)
+ noise = torch.cat(noise_list, dim=0)
+ else:
+ noise = torch.randn([batch_size, diffusion.io_channels, length]).to(device)
+
+ with torch.amp.autocast('cuda'):
+
+ model = diffusion.model
+ if diffusion_objective == "v":
+ fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
+ elif diffusion_objective == "rectified_flow":
+ fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
+ if diffusion.pretransform is not None:
+ fakes = diffusion.pretransform.decode(fakes)
+
+ audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+ return audios
+
+
+def main():
+ args = get_all_args()
+
+ if args.save_dir == '':
+ args.save_dir = args.results_dir
+
+
+ seed = args.seed
+ if os.environ.get("SLURM_PROCID") is not None:
+ seed += int(os.environ.get("SLURM_PROCID"))
+ seed_everything(seed, workers=True)
+
+ # Load config
+ if args.model_config == '':
+ args.model_config = "PrismAudio/configs/model_configs/thinksound.json"
+ with open(args.model_config) as f:
+ model_config = json.load(f)
+
+ duration = float(args.duration_sec)
+ sample_rate = model_config["sample_rate"]
+ latent_length = round(44100 / 64 / 32 * duration)
+
+ model_config["sample_size"] = duration * sample_rate
+ model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration)
+ model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration)
+ model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length
+
+ model = create_model_from_config(model_config)
+ if args.compile:
+ model = torch.compile(model)
+
+ model.load_state_dict(torch.load(args.ckpt_dir))
+ vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.')
+ model.pretransform.load_state_dict(vae_state)
+
+
+ if args.dataset_config == '':
+ args.dataset_config = "PrismAudio/configs/multimodal_dataset_demo.json"
+ with open(args.dataset_config) as f:
+ dataset_config = json.load(f)
+
+ for td in dataset_config["test_datasets"]:
+ td["path"] = args.results_dir
+
+ dm = DataModule(
+ dataset_config,
+ batch_size=args.batch_size,
+ test_batch_size=args.test_batch_size,
+ num_workers=args.num_workers,
+ sample_rate=model_config["sample_rate"],
+ sample_size=(float)(args.duration_sec) * model_config["sample_rate"],
+ audio_channels=model_config.get("audio_channels", 2),
+ latent_length=round(44100/64/32*duration),
+ )
+ dm.setup('predict')
+ dl = dm.predict_dataloader()
+
+ current_date = datetime.now()
+ formatted_date = current_date.strftime('%m%d')
+
+ audio_dir = os.path.join(args.save_dir,f'{formatted_date}_batch_size'+str(args.test_batch_size))
+ os.makedirs(audio_dir,exist_ok=True)
+
+ for batch in tqdm(dl, desc="Predicting"):
+ audio = predict_step(
+ model,
+ batch=batch,
+ diffusion_objective=model_config["model"]["diffusion"]["diffusion_objective"],
+ device='cuda:0'
+ )
+
+ _, metadata = batch
+ ids = [item['id'] for item in metadata]
+
+ for i in range(audio.size(0)):
+ id_str = ids[i] if i < len(ids) else f"unknown_{i}"
+ torchaudio.save(os.path.join(audio_dir, f"{id_str}.wav"), audio[i], 44100)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/predict.py b/predict.py
index 44fbe7276fd289c69f403bab1270adbdad0d1443..7062b5a92c969d1ef6824972b944dec767f5df48 100644
--- a/predict.py
+++ b/predict.py
@@ -8,9 +8,9 @@ from lightning.pytorch import seed_everything
import random
from datetime import datetime
import numpy as np
-from ThinkSound.models import create_model_from_config
-from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
-from ThinkSound.inference.sampling import sample, sample_discrete_euler
+from PrismAudio.models import create_model_from_config
+from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
+from PrismAudio.inference.sampling import sample, sample_discrete_euler
from pathlib import Path
@@ -114,7 +114,7 @@ def main():
#Get JSON config from args.model_config
if args.model_config == '':
- args.model_config = "ThinkSound/configs/model_configs/thinksound.json"
+ args.model_config = "PrismAudio/configs/model_configs/thinksound.json"
with open(args.model_config) as f:
model_config = json.load(f)
diff --git a/scripts/PrismAudio/grpo_1node4gpus.sh b/scripts/PrismAudio/grpo_1node4gpus.sh
index e360e3c449bd3c413bd04ad54d411378731a282d..7f07c22df35e0ba88fda8b144fdc7782ce99ee88 100644
--- a/scripts/PrismAudio/grpo_1node4gpus.sh
+++ b/scripts/PrismAudio/grpo_1node4gpus.sh
@@ -1,4 +1,5 @@
+
export NCCL_IB_DISABLE=1
export NCCL_IB_HCA=mlx5
export NCCL_DEBUG=WARN
diff --git a/set_up.sh b/scripts/PrismAudio/setup/build_env.sh
similarity index 100%
rename from set_up.sh
rename to scripts/PrismAudio/setup/build_env.sh
diff --git a/scripts/ThinkSound/demo.sh b/scripts/ThinkSound/demo.sh
index cbc681b98a866d69d34db5c227c406dbd80b721f..96b4f724da1c6ee0bed1ab1ba47c185580a56674 100755
--- a/scripts/ThinkSound/demo.sh
+++ b/scripts/ThinkSound/demo.sh
@@ -66,7 +66,7 @@ fi
echo "β³ Running model inference..."
python predict.py \
--model-config "$model_config" \
- --duration-sec "$DURATION_SEC" \
+ --duration-sec "$DURATION" \
--results-dir "results"\
if [ $? -ne 0 ]; then
diff --git a/setup.sh b/setup.sh
index d328f1a822a81e305ba1647909a0774c93ee3118..0500c615734a4e68c7c60407f363393bcc92fa4f 100644
--- a/setup.sh
+++ b/setup.sh
@@ -5,5 +5,3 @@ cd ..
pip install -r scripts/PrismAudio/setup/requirements.txt
pip install tensorflow-cpu==2.15.0
pip install facenet_pytorch==2.6.0 --no-deps
-
-conda install -y -c conda-forge 'ffmpeg<7'
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..74f90e5546c286347c0d9f2346c1ae61b485cf5e
--- /dev/null
+++ b/train.py
@@ -0,0 +1,193 @@
+from prefigure.prefigure import get_all_args, push_wandb_config
+import json
+import os
+import torch
+import torchaudio
+# import pytorch_lightning as pl
+import lightning as L
+from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
+from lightning.pytorch.callbacks import Callback
+from lightning.pytorch.tuner import Tuner
+from lightning.pytorch import seed_everything
+import random
+from datetime import datetime
+# from PrismAudio.data.dataset import create_dataloader_from_config
+from PrismAudio.data.datamodule import DataModule
+from PrismAudio.models import create_model_from_config
+from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
+from PrismAudio.training import create_training_wrapper_from_config, create_demo_callback_from_config
+from PrismAudio.training.utils import copy_state_dict
+
+class ExceptionCallback(Callback):
+ def on_exception(self, trainer, module, err):
+ print(f'{type(err).__name__}: {err}')
+
+class ModelConfigEmbedderCallback(Callback):
+ def __init__(self, model_config):
+ self.model_config = model_config
+
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+ checkpoint["model_config"] = self.model_config
+
+class CustomWriter(BasePredictionWriter):
+
+ def __init__(self, output_dir, write_interval='batch'):
+ super().__init__(write_interval)
+ self.output_dir = output_dir
+
+ def write_on_batch_end(self, trainer, pl_module, predictions, batch_indices, batch, batch_idx, dataloader_idx):
+
+ audios = predictions
+ ids = [item['id'] for item in batch[1]]
+ # θ·εε½εζ₯ζ
+ current_date = datetime.now()
+
+ # ζ ΌεΌεζ₯ζδΈΊ 'MMDD' ε½’εΌ
+ formatted_date = current_date.strftime('%m%d')
+ if trainer.ckpt_path is None:
+ global_step = pl_module.global_step // 1000
+ else:
+ global_step = int(trainer.ckpt_path.split("-step=")[-1].split(".")[0]) // 1000
+ os.makedirs(os.path.join(self.output_dir, f'{formatted_date}_step{global_step}k'),exist_ok=True)
+ for audio, id in zip(audios, ids):
+ save_path = os.path.join(self.output_dir, f'{formatted_date}_step{global_step}k', f'{id}.wav')
+ torchaudio.save(save_path, audio, 44100)
+
+def main():
+
+ args = get_all_args()
+
+ seed = args.seed
+
+ # Set a different seed for each process if using SLURM
+ if os.environ.get("SLURM_PROCID") is not None:
+ seed += int(os.environ.get("SLURM_PROCID"))
+
+ # random.seed(seed)
+ # torch.manual_seed(seed)
+ seed_everything(seed, workers=True)
+ print('########################')
+ print(f'precision is {args.precision}')
+ print('########################')
+ #Get JSON config from args.model_config
+ with open(args.model_config) as f:
+ model_config = json.load(f)
+
+ with open(args.dataset_config) as f:
+ dataset_config = json.load(f)
+
+ # train_dl = create_dataloader_from_config(
+ # dataset_config,
+ # batch_size=args.batch_size,
+ # num_workers=args.num_workers,
+ # sample_rate=model_config["sample_rate"],
+ # sample_size=model_config["sample_size"],
+ # audio_channels=model_config.get("audio_channels", 2),
+ # )
+ dm = DataModule(
+ dataset_config,
+ batch_size=args.batch_size,
+ test_batch_size=args.test_batch_size,
+ num_workers=args.num_workers,
+ sample_rate=model_config["sample_rate"],
+ sample_size=model_config["sample_size"],
+ audio_channels=model_config.get("audio_channels", 2),
+ repeat_num=args.repeat_num
+ )
+
+ model = create_model_from_config(model_config)
+
+ ## speed by torch.compile
+ if args.compile:
+ model = torch.compile(model)
+
+ if args.pretrained_ckpt_path:
+ copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
+
+ if args.remove_pretransform_weight_norm == "pre_load":
+ remove_weight_norm_from_model(model.pretransform)
+ # import ipdb
+ # ipdb.set_trace()
+ if args.pretransform_ckpt_path:
+ load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.')
+ # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
+ model.pretransform.load_state_dict(load_vae_state)
+
+ # Remove weight_norm from the pretransform if specified
+ if args.remove_pretransform_weight_norm == "post_load":
+ remove_weight_norm_from_model(model.pretransform)
+
+ training_wrapper = create_training_wrapper_from_config(model_config, model)
+
+ wandb_logger = L.pytorch.loggers.WandbLogger(project=args.name)
+ wandb_logger.watch(training_wrapper)
+
+ exc_callback = ExceptionCallback()
+
+ if args.save_dir and isinstance(wandb_logger.experiment.id, str):
+ checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints")
+ else:
+ checkpoint_dir = None
+
+ # ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='val_loss', mode='min', save_top_k=14)
+ ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='epoch', mode='max', save_top_k=14)
+ save_model_config_callback = ModelConfigEmbedderCallback(model_config)
+ # audio_dir = os.path.join(args.save_dir, args.name, "audios")
+ # pred_writer = CustomWriter(output_dir=audio_dir, write_interval="batch")
+ timer = Timer(duration="00:16:00:00")
+ demo_callback = create_demo_callback_from_config(model_config, demo_dl=dm)
+
+ #Combine args and config dicts
+ args_dict = vars(args)
+ args_dict.update({"model_config": model_config})
+ args_dict.update({"dataset_config": dataset_config})
+ push_wandb_config(wandb_logger, args_dict)
+
+ #Set multi-GPU strategy if specified
+ if args.strategy:
+ if args.strategy == "deepspeed":
+ from pytorch_lightning.strategies import DeepSpeedStrategy
+ strategy = DeepSpeedStrategy(stage=2,
+ contiguous_gradients=True,
+ overlap_comm=True,
+ reduce_scatter=True,
+ reduce_bucket_size=5e8,
+ allgather_bucket_size=5e8,
+ load_full_weights=True
+ )
+ else:
+ strategy = args.strategy
+ else:
+ strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto"
+
+ trainer = L.Trainer(
+ devices=args.num_gpus,
+ accelerator="gpu",
+ num_nodes = args.num_nodes,
+ strategy=strategy,
+ precision=args.precision,
+ accumulate_grad_batches=args.accum_batches,
+ callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback, timer],
+ logger=wandb_logger,
+ log_every_n_steps=1,
+ max_epochs=90,
+ default_root_dir=args.save_dir,
+ gradient_clip_val=args.gradient_clip_val,
+ reload_dataloaders_every_n_epochs = 0,
+ check_val_every_n_epoch=2,
+ )
+
+ # query training/validation/test time (in seconds)
+ # timer.time_elapsed("train")
+ # timer.start_time("validate")
+ # tuner = Tuner(trainer)
+ # Auto-scale batch size by growing it exponentially (default)
+ # tuner.scale_batch_size(training_wrapper, mode="power")
+ # tuner.lr_find(training_wrapper)
+ # trainer.tune(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None)
+ # trainer.validate(training_wrapper, dm)
+ trainer.fit(training_wrapper, dm, ckpt_path=args.ckpt_path if args.ckpt_path else None)
+ # trainer.predict(training_wrapper, dm, return_predictions=False)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file