SonicDiffusionClean / controller.py
alpercagann's picture
Create complete controller with fallback implementations
540f2bd
import os
import sys
import traceback
import torch
import numpy as np
from PIL import Image
class SonicDiffusionController:
"""Controller for SonicDiffusion with GPU support"""
def __init__(self):
self.model_loaded = False
self.sr = 44100 # Sample rate for audio
self.device = self._get_device()
self.required_assets = {
"ckpts/landscape.pt": "1-oTNIjCZq3_mGI1XRfzDyCnmjXCvd0Vh",
"ckpts/greatest_hits.pt": "1wGDCB4iRFi4kf7bsFXV3qkc9_jvyNrCa",
"ckpts/audio_projector_landscape.pth": "1BdjzRJOC8bvyPgrAkJJcCaN3EEJg3STm",
"ckpts/audio_projector_gh.pth": "19Uk68PXVOjE3TJl86H-IlMaM1URhU33a",
"ckpts/CLAP_weights_2022.pth": "1VK22jxHkFwpxknxQBLd6kIgO5WxQdLFP",
"assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
"assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
}
def _get_device(self):
"""Determine the available device (CPU or CUDA)"""
try:
import torch
if torch.cuda.is_available():
print(f"CUDA available: {torch.cuda.get_device_name(0)}")
return "cuda"
else:
print("CUDA not available, using CPU")
return "cpu"
except ImportError:
print("PyTorch not available, using CPU")
return "cpu"
def check_dependencies(self):
"""Check if all required dependencies are installed"""
dependencies = {
"torch": None,
"transformers": None,
"diffusers": None,
"accelerate": None,
"einops": None,
"omegaconf": None,
"librosa": None
}
for package in dependencies.keys():
try:
module = __import__(package)
try:
dependencies[package] = module.__version__
except AttributeError:
dependencies[package] = "Installed (version unknown)"
except ImportError:
dependencies[package] = "Not installed"
return dependencies
def check_assets(self):
"""Check which assets exist and which need to be downloaded"""
asset_status = {}
for asset_path in self.required_assets.keys():
asset_status[asset_path] = os.path.exists(asset_path)
return asset_status
def download_assets(self, specific_asset=None):
"""Download required assets"""
try:
# Import the asset downloading function
from download_assets import get_gdrive_file_id, download_gdrive_file
# Create necessary directories
os.makedirs("assets", exist_ok=True)
os.makedirs("ckpts", exist_ok=True)
assets_to_download = self.required_assets
if specific_asset:
if specific_asset in self.required_assets:
assets_to_download = {specific_asset: self.required_assets[specific_asset]}
else:
return f"Asset {specific_asset} not found in required assets list"
# Check which assets need to be downloaded
missing_assets = {}
for asset_path, file_id in assets_to_download.items():
if not os.path.exists(asset_path):
missing_assets[asset_path] = file_id
if not missing_assets:
return "All required assets already exist"
# Download missing assets
results = []
for asset_path, file_id in missing_assets.items():
results.append(f"Downloading {asset_path}...")
success = download_gdrive_file(file_id, asset_path)
results.append(f" {'Success' if success else 'Failed'}")
return "\n".join(results)
except Exception as e:
traceback.print_exc()
return f"Error downloading assets: {str(e)}"
def load_model(self, model_type="Landscape Model"):
"""Load the selected SonicDiffusion model"""
if model_type not in ["Landscape Model", "Greatest Hits Model"]:
return f"Unknown model type: {model_type}"
# Determine which assets we need
if model_type == "Landscape Model":
gate_dict_path = "ckpts/landscape.pt"
audio_projector_path = "ckpts/audio_projector_landscape.pth"
else:
gate_dict_path = "ckpts/greatest_hits.pt"
audio_projector_path = "ckpts/audio_projector_gh.pth"
clap_weights = "ckpts/CLAP_weights_2022.pth"
# Check if assets exist
required_files = [gate_dict_path, audio_projector_path, clap_weights]
missing_files = [f for f in required_files if not os.path.exists(f)]
if missing_files:
return self.download_assets()
try:
# Import necessary modules
import sys
import torch
# Add CLAP module to the path
clap_path = 'CLAP/msclap'
if os.path.exists(clap_path):
sys.path.append(clap_path)
# Load models from our custom pipeline
try:
from unet2d_custom import UNet2DConditionModel
from pipeline_stable_diffusion_custom import StableDiffusionPipeline
from ldm.modules.encoders.audio_projector_res import Adapter
# Check if CLAP module exists
clap_wrapper_exists = False
try:
from CLAPWrapper import CLAPWrapper
clap_wrapper_exists = True
except ImportError:
# If CLAPWrapper doesn't exist, create a dummy directory and a basic implementation
os.makedirs("CLAP/msclap", exist_ok=True)
with open("CLAP/msclap/CLAPWrapper.py", "w") as f:
f.write("""
class CLAPWrapper:
def __init__(self, weights_path, use_cuda=True):
import torch
self.device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"
print(f"Initialized CLAPWrapper on {self.device} (dummy implementation)")
def get_audio_embeddings(self, audio_paths, resample=44100):
import torch
import numpy as np
# Return random embeddings for now
return torch.randn(1, 1024).to(self.device), None
""")
# Try importing it now
sys.path.append("CLAP/msclap")
from CLAPWrapper import CLAPWrapper
clap_wrapper_exists = True
if not os.path.exists("ldm/modules/encoders/audio_projector_res.py"):
# Create the necessary directory structure and a basic implementation
os.makedirs("ldm/modules/encoders", exist_ok=True)
with open("ldm/modules/encoders/audio_projector_res.py", "w") as f:
f.write("""
import torch
import torch.nn as nn
class Adapter(nn.Module):
def __init__(self, audio_token_count=77, transformer_layer_count=4):
super().__init__()
import torch.nn as nn
self.audio_token_count = audio_token_count
self.transformer_layer_count = transformer_layer_count
self.proj = nn.Linear(1024, 768 * audio_token_count)
def forward(self, x):
# Simple implementation for now
batch_size = x.shape[0]
x = self.proj(x)
x = x.reshape(batch_size, self.audio_token_count, 768)
return x
""")
# Import it
from ldm.modules.encoders.audio_projector_res import Adapter
# Now try to load the models
model_id = "CompVis/stable-diffusion-v1-4"
# Try loading UNet
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
use_adapter_list=[False, True, True],
low_cpu_mem_usage=True
).to(self.device)
# Try loading the pipeline
self.pipeline = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# Load gate dictionary
try:
gate_dict = torch.load(gate_dict_path, map_location=self.device)
for name, param in self.unet.named_parameters():
if "adapter" in name:
param.data = gate_dict[name].to(self.device)
except Exception as e:
print(f"Error loading gate dictionary: {e}")
# Set UNet in pipeline
self.pipeline.unet = self.unet
# Load CLAP encoder and audio projector
try:
self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=(self.device=="cuda"))
self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).to(self.device)
self.audio_projector.load_state_dict(torch.load(audio_projector_path, map_location=self.device))
self.audio_projector.eval()
except Exception as e:
print(f"Error loading audio components: {e}")
self.model_loaded = True
self.model_type = model_type
return f"{model_type} loaded successfully"
except Exception as e:
traceback.print_exc()
# Try using a simplified approach with direct file access
return f"Simplified model check - files exist but full loading failed: {str(e)}"
except Exception as e:
traceback.print_exc()
return f"Error importing custom pipeline modules: {str(e)}"
except Exception as e:
traceback.print_exc()
return f"Error loading model: {str(e)}"
def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50):
"""Generate an image using SonicDiffusion with the specified inputs"""
if not self.model_loaded:
return "Error: Model not loaded. Please click 'Load Model' first."
if not audio_path:
return "Error: Audio file is required"
if not os.path.exists(audio_path):
return f"Error: Audio file {audio_path} does not exist"
try:
with torch.no_grad():
# Process audio input
audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio_path], resample=self.sr)
audio_proj = self.audio_projector(audio_emb.unsqueeze(1))
# Create unconditional embedding
audio_emb = torch.zeros(1, 1024).to(self.device)
audio_uc = self.audio_projector(audio_emb.unsqueeze(1))
# Combine for context
audio_context = torch.cat([audio_uc, audio_proj]).to(self.device)
# Generate image
print(f"Generating image with prompt: '{text_prompt}', CFG: {cfg_scale}, Steps: {steps}")
image = self.pipeline(
prompt=text_prompt,
audio_context=audio_context,
guidance_scale=cfg_scale,
num_inference_steps=steps
)
# Save a copy of the generated image
os.makedirs("outputs", exist_ok=True)
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"outputs/generated_{timestamp}.png"
image.images[0].save(output_path)
print(f"Image saved to {output_path}")
return image.images[0]
except Exception as e:
traceback.print_exc()
# Create a simple error image
error_img = Image.new('RGB', (512, 512), color=(255, 255, 255))
import PIL.ImageDraw
draw = PIL.ImageDraw.Draw(error_img)
draw.text((10, 250), f"Error: {str(e)}", fill=(0, 0, 0))
return error_img