Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import logging | |
| from transformers import AutoTokenizer, AutoModel | |
| from diffusers import DiffusionPipeline | |
| import soundfile as sf | |
| import numpy as np | |
| # Set up logging to debug startup issues | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| try: | |
| # Load text tokenizer and embedding model (umt5-base) | |
| def load_text_processor(): | |
| logger.info("Loading text processor (umt5-base)...") | |
| tokenizer = AutoTokenizer.from_pretrained("./umt5-base") | |
| text_model = AutoModel.from_pretrained( | |
| "./umt5-base", | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| logger.info("Text processor loaded successfully.") | |
| return tokenizer, text_model | |
| # Load the transformer backbone (phantomstep_transformer) | |
| def load_transformer(): | |
| logger.info("Loading transformer (phantomstep_transformer)...") | |
| transformer = DiffusionPipeline.from_pretrained( | |
| "./phantomstep_transformer", | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| logger.info("Transformer loaded successfully.") | |
| return transformer | |
| # Load the DCAE for audio encoding/decoding (phantomstep_dcae) | |
| def load_dcae(): | |
| logger.info("Loading DCAE (phantomstep_dcae)...") | |
| dcae = DiffusionPipeline.from_pretrained( | |
| "./phantomstep_dcae", | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| logger.info("DCAE loaded successfully.") | |
| return dcae | |
| # Load the vocoder for audio synthesis (phantomstep_vocoder) | |
| def load_vocoder(): | |
| logger.info("Loading vocoder (phantomstep_vocoder)...") | |
| vocoder = DiffusionPipeline.from_pretrained( | |
| "./phantomstep_vocoder", | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| logger.info("Vocoder loaded successfully.") | |
| return vocoder | |
| # Generate music from a text prompt | |
| def generate_music(prompt, duration=20, seed=42): | |
| logger.info(f"Generating music with prompt: {prompt}, duration: {duration}, seed: {seed}") | |
| torch.manual_seed(seed) | |
| # Load all components | |
| tokenizer, text_model = load_text_processor() | |
| transformer = load_transformer() | |
| dcae = load_dcae() | |
| vocoder = load_vocoder() | |
| # Step 1: Process text prompt to embeddings | |
| logger.info("Processing text prompt to embeddings...") | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(text_model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| embeddings = text_model(**inputs).last_hidden_state.mean(dim=1) | |
| # Step 2: Pass embeddings through transformer | |
| logger.info("Generating with transformer...") | |
| transformer_output = transformer( | |
| embeddings, | |
| num_inference_steps=50, | |
| audio_length_in_s=duration | |
| ).audios[0] | |
| # Step 3: Decode audio features with DCAE | |
| logger.info("Decoding with DCAE...") | |
| dcae_output = dcae( | |
| transformer_output, | |
| num_inference_steps=50, | |
| audio_length_in_s=duration | |
| ).audios[0] | |
| # Step 4: Synthesize final audio with vocoder | |
| logger.info("Synthesizing with vocoder...") | |
| audio = vocoder( | |
| dcae_output, | |
| num_inference_steps=50, | |
| audio_length_in_s=duration | |
| ).audios[0] | |
| # Save audio to a file | |
| output_path = "output.wav" | |
| sf.write(output_path, audio, 22050) # 22kHz sample rate | |
| logger.info("Music generation complete.") | |
| return output_path | |
| # Gradio interface | |
| logger.info("Setting up Gradio interface...") | |
| with gr.Blocks(title="PhantomStep: Text-to-Music Generation ๐ต") as demo: | |
| gr.Markdown("# PhantomStep by GhostAI ๐") | |
| gr.Markdown("Enter a text prompt to generate music! ๐ถ") | |
| prompt_input = gr.Textbox(label="Text Prompt", placeholder="A jazzy piano melody with a fast tempo") | |
| duration_input = gr.Slider(label="Duration (seconds)", minimum=10, maximum=60, value=20, step=1) | |
| seed_input = gr.Number(label="Random Seed", value=42, precision=0) | |
| generate_button = gr.Button("Generate Music") | |
| audio_output = gr.Audio(label="Generated Music") | |
| generate_button.click( | |
| fn=generate_music, | |
| inputs=[prompt_input, duration_input, seed_input], | |
| outputs=audio_output | |
| ) | |
| logger.info("Launching Gradio app...") | |
| demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) | |
| except Exception as e: | |
| logger.error(f"Failed to start the application: {str(e)}") | |
| raise |