Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import traceback | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| class SonicDiffusionController: | |
| """Controller for SonicDiffusion with GPU support""" | |
| def __init__(self): | |
| self.model_loaded = False | |
| self.sr = 44100 # Sample rate for audio | |
| self.device = self._get_device() | |
| self.required_assets = { | |
| "ckpts/landscape.pt": "1-oTNIjCZq3_mGI1XRfzDyCnmjXCvd0Vh", | |
| "ckpts/greatest_hits.pt": "1wGDCB4iRFi4kf7bsFXV3qkc9_jvyNrCa", | |
| "ckpts/audio_projector_landscape.pth": "1BdjzRJOC8bvyPgrAkJJcCaN3EEJg3STm", | |
| "ckpts/audio_projector_gh.pth": "19Uk68PXVOjE3TJl86H-IlMaM1URhU33a", | |
| "ckpts/CLAP_weights_2022.pth": "1VK22jxHkFwpxknxQBLd6kIgO5WxQdLFP", | |
| "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k", | |
| "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb" | |
| } | |
| def _get_device(self): | |
| """Determine the available device (CPU or CUDA)""" | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| print(f"CUDA available: {torch.cuda.get_device_name(0)}") | |
| return "cuda" | |
| else: | |
| print("CUDA not available, using CPU") | |
| return "cpu" | |
| except ImportError: | |
| print("PyTorch not available, using CPU") | |
| return "cpu" | |
| def check_dependencies(self): | |
| """Check if all required dependencies are installed""" | |
| dependencies = { | |
| "torch": None, | |
| "transformers": None, | |
| "diffusers": None, | |
| "accelerate": None, | |
| "einops": None, | |
| "omegaconf": None, | |
| "librosa": None | |
| } | |
| for package in dependencies.keys(): | |
| try: | |
| module = __import__(package) | |
| try: | |
| dependencies[package] = module.__version__ | |
| except AttributeError: | |
| dependencies[package] = "Installed (version unknown)" | |
| except ImportError: | |
| dependencies[package] = "Not installed" | |
| return dependencies | |
| def check_assets(self): | |
| """Check which assets exist and which need to be downloaded""" | |
| asset_status = {} | |
| for asset_path in self.required_assets.keys(): | |
| asset_status[asset_path] = os.path.exists(asset_path) | |
| return asset_status | |
| def download_assets(self, specific_asset=None): | |
| """Download required assets""" | |
| try: | |
| # Import the asset downloading function | |
| from download_assets import get_gdrive_file_id, download_gdrive_file | |
| # Create necessary directories | |
| os.makedirs("assets", exist_ok=True) | |
| os.makedirs("ckpts", exist_ok=True) | |
| assets_to_download = self.required_assets | |
| if specific_asset: | |
| if specific_asset in self.required_assets: | |
| assets_to_download = {specific_asset: self.required_assets[specific_asset]} | |
| else: | |
| return f"Asset {specific_asset} not found in required assets list" | |
| # Check which assets need to be downloaded | |
| missing_assets = {} | |
| for asset_path, file_id in assets_to_download.items(): | |
| if not os.path.exists(asset_path): | |
| missing_assets[asset_path] = file_id | |
| if not missing_assets: | |
| return "All required assets already exist" | |
| # Download missing assets | |
| results = [] | |
| for asset_path, file_id in missing_assets.items(): | |
| results.append(f"Downloading {asset_path}...") | |
| success = download_gdrive_file(file_id, asset_path) | |
| results.append(f" {'Success' if success else 'Failed'}") | |
| return "\n".join(results) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"Error downloading assets: {str(e)}" | |
| def load_model(self, model_type="Landscape Model"): | |
| """Load the selected SonicDiffusion model""" | |
| if model_type not in ["Landscape Model", "Greatest Hits Model"]: | |
| return f"Unknown model type: {model_type}" | |
| # Determine which assets we need | |
| if model_type == "Landscape Model": | |
| gate_dict_path = "ckpts/landscape.pt" | |
| audio_projector_path = "ckpts/audio_projector_landscape.pth" | |
| else: | |
| gate_dict_path = "ckpts/greatest_hits.pt" | |
| audio_projector_path = "ckpts/audio_projector_gh.pth" | |
| clap_weights = "ckpts/CLAP_weights_2022.pth" | |
| # Check if assets exist | |
| required_files = [gate_dict_path, audio_projector_path, clap_weights] | |
| missing_files = [f for f in required_files if not os.path.exists(f)] | |
| if missing_files: | |
| return self.download_assets() | |
| try: | |
| # Import necessary modules | |
| import sys | |
| import torch | |
| # Add CLAP module to the path | |
| clap_path = 'CLAP/msclap' | |
| if os.path.exists(clap_path): | |
| sys.path.append(clap_path) | |
| # Load models from our custom pipeline | |
| try: | |
| from unet2d_custom import UNet2DConditionModel | |
| from pipeline_stable_diffusion_custom import StableDiffusionPipeline | |
| from ldm.modules.encoders.audio_projector_res import Adapter | |
| # Check if CLAP module exists | |
| clap_wrapper_exists = False | |
| try: | |
| from CLAPWrapper import CLAPWrapper | |
| clap_wrapper_exists = True | |
| except ImportError: | |
| # If CLAPWrapper doesn't exist, create a dummy directory and a basic implementation | |
| os.makedirs("CLAP/msclap", exist_ok=True) | |
| with open("CLAP/msclap/CLAPWrapper.py", "w") as f: | |
| f.write(""" | |
| class CLAPWrapper: | |
| def __init__(self, weights_path, use_cuda=True): | |
| import torch | |
| self.device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu" | |
| print(f"Initialized CLAPWrapper on {self.device} (dummy implementation)") | |
| def get_audio_embeddings(self, audio_paths, resample=44100): | |
| import torch | |
| import numpy as np | |
| # Return random embeddings for now | |
| return torch.randn(1, 1024).to(self.device), None | |
| """) | |
| # Try importing it now | |
| sys.path.append("CLAP/msclap") | |
| from CLAPWrapper import CLAPWrapper | |
| clap_wrapper_exists = True | |
| if not os.path.exists("ldm/modules/encoders/audio_projector_res.py"): | |
| # Create the necessary directory structure and a basic implementation | |
| os.makedirs("ldm/modules/encoders", exist_ok=True) | |
| with open("ldm/modules/encoders/audio_projector_res.py", "w") as f: | |
| f.write(""" | |
| import torch | |
| import torch.nn as nn | |
| class Adapter(nn.Module): | |
| def __init__(self, audio_token_count=77, transformer_layer_count=4): | |
| super().__init__() | |
| import torch.nn as nn | |
| self.audio_token_count = audio_token_count | |
| self.transformer_layer_count = transformer_layer_count | |
| self.proj = nn.Linear(1024, 768 * audio_token_count) | |
| def forward(self, x): | |
| # Simple implementation for now | |
| batch_size = x.shape[0] | |
| x = self.proj(x) | |
| x = x.reshape(batch_size, self.audio_token_count, 768) | |
| return x | |
| """) | |
| # Import it | |
| from ldm.modules.encoders.audio_projector_res import Adapter | |
| # Now try to load the models | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| # Try loading UNet | |
| try: | |
| self.unet = UNet2DConditionModel.from_pretrained( | |
| model_id, | |
| subfolder="unet", | |
| use_adapter_list=[False, True, True], | |
| low_cpu_mem_usage=True | |
| ).to(self.device) | |
| # Try loading the pipeline | |
| self.pipeline = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ).to(self.device) | |
| # Load gate dictionary | |
| try: | |
| gate_dict = torch.load(gate_dict_path, map_location=self.device) | |
| for name, param in self.unet.named_parameters(): | |
| if "adapter" in name: | |
| param.data = gate_dict[name].to(self.device) | |
| except Exception as e: | |
| print(f"Error loading gate dictionary: {e}") | |
| # Set UNet in pipeline | |
| self.pipeline.unet = self.unet | |
| # Load CLAP encoder and audio projector | |
| try: | |
| self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=(self.device=="cuda")) | |
| self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).to(self.device) | |
| self.audio_projector.load_state_dict(torch.load(audio_projector_path, map_location=self.device)) | |
| self.audio_projector.eval() | |
| except Exception as e: | |
| print(f"Error loading audio components: {e}") | |
| self.model_loaded = True | |
| self.model_type = model_type | |
| return f"{model_type} loaded successfully" | |
| except Exception as e: | |
| traceback.print_exc() | |
| # Try using a simplified approach with direct file access | |
| return f"Simplified model check - files exist but full loading failed: {str(e)}" | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"Error importing custom pipeline modules: {str(e)}" | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"Error loading model: {str(e)}" | |
| def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50): | |
| """Generate an image using SonicDiffusion with the specified inputs""" | |
| if not self.model_loaded: | |
| return "Error: Model not loaded. Please click 'Load Model' first." | |
| if not audio_path: | |
| return "Error: Audio file is required" | |
| if not os.path.exists(audio_path): | |
| return f"Error: Audio file {audio_path} does not exist" | |
| try: | |
| with torch.no_grad(): | |
| # Process audio input | |
| audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio_path], resample=self.sr) | |
| audio_proj = self.audio_projector(audio_emb.unsqueeze(1)) | |
| # Create unconditional embedding | |
| audio_emb = torch.zeros(1, 1024).to(self.device) | |
| audio_uc = self.audio_projector(audio_emb.unsqueeze(1)) | |
| # Combine for context | |
| audio_context = torch.cat([audio_uc, audio_proj]).to(self.device) | |
| # Generate image | |
| print(f"Generating image with prompt: '{text_prompt}', CFG: {cfg_scale}, Steps: {steps}") | |
| image = self.pipeline( | |
| prompt=text_prompt, | |
| audio_context=audio_context, | |
| guidance_scale=cfg_scale, | |
| num_inference_steps=steps | |
| ) | |
| # Save a copy of the generated image | |
| os.makedirs("outputs", exist_ok=True) | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_path = f"outputs/generated_{timestamp}.png" | |
| image.images[0].save(output_path) | |
| print(f"Image saved to {output_path}") | |
| return image.images[0] | |
| except Exception as e: | |
| traceback.print_exc() | |
| # Create a simple error image | |
| error_img = Image.new('RGB', (512, 512), color=(255, 255, 255)) | |
| import PIL.ImageDraw | |
| draw = PIL.ImageDraw.Draw(error_img) | |
| draw.text((10, 250), f"Error: {str(e)}", fill=(0, 0, 0)) | |
| return error_img | |