import os import sys import traceback import torch import numpy as np from PIL import Image class SonicDiffusionController: """Controller for SonicDiffusion with GPU support""" def __init__(self): self.model_loaded = False self.sr = 44100 # Sample rate for audio self.device = self._get_device() self.required_assets = { "ckpts/landscape.pt": "1-oTNIjCZq3_mGI1XRfzDyCnmjXCvd0Vh", "ckpts/greatest_hits.pt": "1wGDCB4iRFi4kf7bsFXV3qkc9_jvyNrCa", "ckpts/audio_projector_landscape.pth": "1BdjzRJOC8bvyPgrAkJJcCaN3EEJg3STm", "ckpts/audio_projector_gh.pth": "19Uk68PXVOjE3TJl86H-IlMaM1URhU33a", "ckpts/CLAP_weights_2022.pth": "1VK22jxHkFwpxknxQBLd6kIgO5WxQdLFP", "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k", "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb" } def _get_device(self): """Determine the available device (CPU or CUDA)""" try: import torch if torch.cuda.is_available(): print(f"CUDA available: {torch.cuda.get_device_name(0)}") return "cuda" else: print("CUDA not available, using CPU") return "cpu" except ImportError: print("PyTorch not available, using CPU") return "cpu" def check_dependencies(self): """Check if all required dependencies are installed""" dependencies = { "torch": None, "transformers": None, "diffusers": None, "accelerate": None, "einops": None, "omegaconf": None, "librosa": None } for package in dependencies.keys(): try: module = __import__(package) try: dependencies[package] = module.__version__ except AttributeError: dependencies[package] = "Installed (version unknown)" except ImportError: dependencies[package] = "Not installed" return dependencies def check_assets(self): """Check which assets exist and which need to be downloaded""" asset_status = {} for asset_path in self.required_assets.keys(): asset_status[asset_path] = os.path.exists(asset_path) return asset_status def download_assets(self, specific_asset=None): """Download required assets""" try: # Import the asset downloading function from download_assets import get_gdrive_file_id, download_gdrive_file # Create necessary directories os.makedirs("assets", exist_ok=True) os.makedirs("ckpts", exist_ok=True) assets_to_download = self.required_assets if specific_asset: if specific_asset in self.required_assets: assets_to_download = {specific_asset: self.required_assets[specific_asset]} else: return f"Asset {specific_asset} not found in required assets list" # Check which assets need to be downloaded missing_assets = {} for asset_path, file_id in assets_to_download.items(): if not os.path.exists(asset_path): missing_assets[asset_path] = file_id if not missing_assets: return "All required assets already exist" # Download missing assets results = [] for asset_path, file_id in missing_assets.items(): results.append(f"Downloading {asset_path}...") success = download_gdrive_file(file_id, asset_path) results.append(f" {'Success' if success else 'Failed'}") return "\n".join(results) except Exception as e: traceback.print_exc() return f"Error downloading assets: {str(e)}" def load_model(self, model_type="Landscape Model"): """Load the selected SonicDiffusion model""" if model_type not in ["Landscape Model", "Greatest Hits Model"]: return f"Unknown model type: {model_type}" # Determine which assets we need if model_type == "Landscape Model": gate_dict_path = "ckpts/landscape.pt" audio_projector_path = "ckpts/audio_projector_landscape.pth" else: gate_dict_path = "ckpts/greatest_hits.pt" audio_projector_path = "ckpts/audio_projector_gh.pth" clap_weights = "ckpts/CLAP_weights_2022.pth" # Check if assets exist required_files = [gate_dict_path, audio_projector_path, clap_weights] missing_files = [f for f in required_files if not os.path.exists(f)] if missing_files: return self.download_assets() try: # Import necessary modules import sys import torch # Add CLAP module to the path clap_path = 'CLAP/msclap' if os.path.exists(clap_path): sys.path.append(clap_path) # Load models from our custom pipeline try: from unet2d_custom import UNet2DConditionModel from pipeline_stable_diffusion_custom import StableDiffusionPipeline from ldm.modules.encoders.audio_projector_res import Adapter # Check if CLAP module exists clap_wrapper_exists = False try: from CLAPWrapper import CLAPWrapper clap_wrapper_exists = True except ImportError: # If CLAPWrapper doesn't exist, create a dummy directory and a basic implementation os.makedirs("CLAP/msclap", exist_ok=True) with open("CLAP/msclap/CLAPWrapper.py", "w") as f: f.write(""" class CLAPWrapper: def __init__(self, weights_path, use_cuda=True): import torch self.device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu" print(f"Initialized CLAPWrapper on {self.device} (dummy implementation)") def get_audio_embeddings(self, audio_paths, resample=44100): import torch import numpy as np # Return random embeddings for now return torch.randn(1, 1024).to(self.device), None """) # Try importing it now sys.path.append("CLAP/msclap") from CLAPWrapper import CLAPWrapper clap_wrapper_exists = True if not os.path.exists("ldm/modules/encoders/audio_projector_res.py"): # Create the necessary directory structure and a basic implementation os.makedirs("ldm/modules/encoders", exist_ok=True) with open("ldm/modules/encoders/audio_projector_res.py", "w") as f: f.write(""" import torch import torch.nn as nn class Adapter(nn.Module): def __init__(self, audio_token_count=77, transformer_layer_count=4): super().__init__() import torch.nn as nn self.audio_token_count = audio_token_count self.transformer_layer_count = transformer_layer_count self.proj = nn.Linear(1024, 768 * audio_token_count) def forward(self, x): # Simple implementation for now batch_size = x.shape[0] x = self.proj(x) x = x.reshape(batch_size, self.audio_token_count, 768) return x """) # Import it from ldm.modules.encoders.audio_projector_res import Adapter # Now try to load the models model_id = "CompVis/stable-diffusion-v1-4" # Try loading UNet try: self.unet = UNet2DConditionModel.from_pretrained( model_id, subfolder="unet", use_adapter_list=[False, True, True], low_cpu_mem_usage=True ).to(self.device) # Try loading the pipeline self.pipeline = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # Load gate dictionary try: gate_dict = torch.load(gate_dict_path, map_location=self.device) for name, param in self.unet.named_parameters(): if "adapter" in name: param.data = gate_dict[name].to(self.device) except Exception as e: print(f"Error loading gate dictionary: {e}") # Set UNet in pipeline self.pipeline.unet = self.unet # Load CLAP encoder and audio projector try: self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=(self.device=="cuda")) self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).to(self.device) self.audio_projector.load_state_dict(torch.load(audio_projector_path, map_location=self.device)) self.audio_projector.eval() except Exception as e: print(f"Error loading audio components: {e}") self.model_loaded = True self.model_type = model_type return f"{model_type} loaded successfully" except Exception as e: traceback.print_exc() # Try using a simplified approach with direct file access return f"Simplified model check - files exist but full loading failed: {str(e)}" except Exception as e: traceback.print_exc() return f"Error importing custom pipeline modules: {str(e)}" except Exception as e: traceback.print_exc() return f"Error loading model: {str(e)}" def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50): """Generate an image using SonicDiffusion with the specified inputs""" if not self.model_loaded: return "Error: Model not loaded. Please click 'Load Model' first." if not audio_path: return "Error: Audio file is required" if not os.path.exists(audio_path): return f"Error: Audio file {audio_path} does not exist" try: with torch.no_grad(): # Process audio input audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio_path], resample=self.sr) audio_proj = self.audio_projector(audio_emb.unsqueeze(1)) # Create unconditional embedding audio_emb = torch.zeros(1, 1024).to(self.device) audio_uc = self.audio_projector(audio_emb.unsqueeze(1)) # Combine for context audio_context = torch.cat([audio_uc, audio_proj]).to(self.device) # Generate image print(f"Generating image with prompt: '{text_prompt}', CFG: {cfg_scale}, Steps: {steps}") image = self.pipeline( prompt=text_prompt, audio_context=audio_context, guidance_scale=cfg_scale, num_inference_steps=steps ) # Save a copy of the generated image os.makedirs("outputs", exist_ok=True) from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = f"outputs/generated_{timestamp}.png" image.images[0].save(output_path) print(f"Image saved to {output_path}") return image.images[0] except Exception as e: traceback.print_exc() # Create a simple error image error_img = Image.new('RGB', (512, 512), color=(255, 255, 255)) import PIL.ImageDraw draw = PIL.ImageDraw.Draw(error_img) draw.text((10, 250), f"Error: {str(e)}", fill=(0, 0, 0)) return error_img