alpercagann commited on
Commit
8c0dbae
·
1 Parent(s): 32002e9

Create more complete SonicDiffusion controller

Browse files
Files changed (1) hide show
  1. controller.py +140 -70
controller.py CHANGED
@@ -1,14 +1,22 @@
1
  import os
2
  import sys
 
3
 
4
- class SimpleSonicDiffusionController:
5
- """A simplified version of the controller with minimal dependencies"""
6
 
7
  def __init__(self):
8
  self.model_loaded = False
9
- self.tokenizer_loaded = False
10
- self.pipe_loaded = False
11
  self.device = self._get_device()
 
 
 
 
 
 
 
 
 
12
 
13
  def _get_device(self):
14
  """Determine the available device (CPU or CUDA)"""
@@ -24,86 +32,148 @@ class SimpleSonicDiffusionController:
24
  print("PyTorch not available, using CPU")
25
  return "cpu"
26
 
27
- def load_model(self):
28
- """Load a simple model to verify libraries are working"""
29
- status_messages = []
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
- # Test PyTorch
33
- import torch
34
- self.test_tensor = torch.rand(3, 3)
35
- status_messages.append("✓ PyTorch loaded successfully")
36
 
37
- # Try loading a simple tokenizer from transformers
38
- try:
39
- from transformers import AutoTokenizer
40
- self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
41
- self.tokenizer_loaded = True
42
- status_messages.append("✓ Transformers tokenizer loaded")
43
- except Exception as e:
44
- status_messages.append(f"✗ Transformers error: {str(e)}")
45
 
46
- # Try loading a simple pipeline from diffusers
47
- try:
48
- from diffusers import DiffusionPipeline
49
- # Just check if the class exists, don't actually load a model
50
- self.pipe_class = DiffusionPipeline
51
- self.pipe_loaded = True
52
- status_messages.append("✓ Diffusers available")
53
- except Exception as e:
54
- status_messages.append(f"✗ Diffusers error: {str(e)}")
 
 
 
55
 
56
- self.model_loaded = True
57
- return "\n".join(status_messages)
58
 
 
 
 
 
 
 
 
 
 
59
  except Exception as e:
60
- return f"Error loading model: {str(e)}"
 
61
 
62
- def generate(self, text_prompt, audio_path=None):
63
- """Generate text using available libraries"""
64
- if not self.model_loaded:
65
- return "Error: Model not loaded. Please click 'Load Model' first."
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  try:
70
- # Use tokenizer if available
71
- if self.tokenizer_loaded:
72
- tokens = self.tokenizer(text_prompt, return_tensors="pt")
73
- token_count = len(tokens['input_ids'][0])
74
- results.append(f"Transformers: Tokenized into {token_count} tokens")
75
 
76
- # Check diffusers status
77
- if self.pipe_loaded:
78
- results.append("Diffusers is available for pipeline creation")
79
- else:
80
- results.append("Diffusers is not available")
 
 
81
 
82
- return "\n".join(results)
 
83
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
- return f"Error during generation: {str(e)}"
86
-
87
- def get_asset_status(self):
88
- """Check if required directories and files exist"""
89
- asset_status = {}
90
-
91
- # Check directories
92
- for dir_name in ["assets", "ckpts", "outputs"]:
93
- asset_status[dir_name] = "✓" if os.path.exists(dir_name) else "✗"
94
 
95
- # Check library availability
96
- asset_status["PyTorch"] = "✓" if self._check_import("torch") else "✗"
97
- asset_status["Transformers"] = "✓" if self._check_import("transformers") else "✗"
98
- asset_status["Diffusers"] = "✓" if self._check_import("diffusers") else "✗"
99
- asset_status["Accelerate"] = "✓" if self._check_import("accelerate") else "✗"
100
 
101
- return asset_status
102
-
103
- def _check_import(self, module_name):
104
- """Check if a module can be imported"""
105
- try:
106
- __import__(module_name)
107
- return True
108
- except ImportError:
109
- return False
 
1
  import os
2
  import sys
3
+ import traceback
4
 
5
+ class SonicDiffusionController:
6
+ """Controller for SonicDiffusion with asset downloading support"""
7
 
8
  def __init__(self):
9
  self.model_loaded = False
 
 
10
  self.device = self._get_device()
11
+ self.required_assets = {
12
+ "ckpts/landscape.pt": "1-oTNIjCZq3_mGI1XRfzDyCnmjXCvd0Vh",
13
+ "ckpts/greatest_hits.pt": "1wGDCB4iRFi4kf7bsFXV3qkc9_jvyNrCa",
14
+ "ckpts/audio_projector_landscape.pth": "1BdjzRJOC8bvyPgrAkJJcCaN3EEJg3STm",
15
+ "ckpts/audio_projector_gh.pth": "19Uk68PXVOjE3TJl86H-IlMaM1URhU33a",
16
+ "ckpts/CLAP_weights_2022.pth": "1VK22jxHkFwpxknxQBLd6kIgO5WxQdLFP",
17
+ "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
18
+ "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
19
+ }
20
 
21
  def _get_device(self):
22
  """Determine the available device (CPU or CUDA)"""
 
32
  print("PyTorch not available, using CPU")
33
  return "cpu"
34
 
35
+ def check_dependencies(self):
36
+ """Check if all required dependencies are installed"""
37
+ dependencies = {
38
+ "torch": None,
39
+ "transformers": None,
40
+ "diffusers": None,
41
+ "accelerate": None,
42
+ "einops": None,
43
+ "omegaconf": None,
44
+ "librosa": None
45
+ }
46
 
47
+ for package in dependencies.keys():
48
+ try:
49
+ module = __import__(package)
50
+ try:
51
+ dependencies[package] = module.__version__
52
+ except AttributeError:
53
+ dependencies[package] = "Installed (version unknown)"
54
+ except ImportError:
55
+ dependencies[package] = "Not installed"
56
+
57
+ return dependencies
58
+
59
+ def check_assets(self):
60
+ """Check which assets exist and which need to be downloaded"""
61
+ asset_status = {}
62
+
63
+ for asset_path in self.required_assets.keys():
64
+ asset_status[asset_path] = os.path.exists(asset_path)
65
+
66
+ return asset_status
67
+
68
+ def download_assets(self, specific_asset=None):
69
+ """Download required assets"""
70
  try:
71
+ # Import the asset downloading function
72
+ from download_assets import get_gdrive_file_id, download_gdrive_file
 
 
73
 
74
+ # Create necessary directories
75
+ os.makedirs("assets", exist_ok=True)
76
+ os.makedirs("ckpts", exist_ok=True)
 
 
 
 
 
77
 
78
+ assets_to_download = self.required_assets
79
+ if specific_asset:
80
+ if specific_asset in self.required_assets:
81
+ assets_to_download = {specific_asset: self.required_assets[specific_asset]}
82
+ else:
83
+ return f"Asset {specific_asset} not found in required assets list"
84
+
85
+ # Check which assets need to be downloaded
86
+ missing_assets = {}
87
+ for asset_path, file_id in assets_to_download.items():
88
+ if not os.path.exists(asset_path):
89
+ missing_assets[asset_path] = file_id
90
 
91
+ if not missing_assets:
92
+ return "All required assets already exist"
93
 
94
+ # Download missing assets
95
+ results = []
96
+ for asset_path, file_id in missing_assets.items():
97
+ results.append(f"Downloading {asset_path}...")
98
+ success = download_gdrive_file(file_id, asset_path)
99
+ results.append(f" {'Success' if success else 'Failed'}")
100
+
101
+ return "\n".join(results)
102
+
103
  except Exception as e:
104
+ traceback.print_exc()
105
+ return f"Error downloading assets: {str(e)}"
106
 
107
+ def load_model(self, model_type="Landscape Model"):
108
+ """Load the selected SonicDiffusion model"""
109
+ if model_type not in ["Landscape Model", "Greatest Hits Model"]:
110
+ return f"Unknown model type: {model_type}"
111
+
112
+ # Determine which assets we need
113
+ if model_type == "Landscape Model":
114
+ gate_dict_path = "ckpts/landscape.pt"
115
+ audio_projector_path = "ckpts/audio_projector_landscape.pth"
116
+ else:
117
+ gate_dict_path = "ckpts/greatest_hits.pt"
118
+ audio_projector_path = "ckpts/audio_projector_gh.pth"
119
+
120
+ clap_path = "CLAP/msclap"
121
+ clap_weights = "ckpts/CLAP_weights_2022.pth"
122
 
123
+ # Check if assets exist
124
+ required_files = [gate_dict_path, audio_projector_path, clap_weights]
125
+ missing_files = [f for f in required_files if not os.path.exists(f)]
126
+
127
+ if missing_files:
128
+ # Download missing files
129
+ for file_path in missing_files:
130
+ if file_path in self.required_assets:
131
+ try:
132
+ from download_assets import download_gdrive_file
133
+ download_gdrive_file(self.required_assets[file_path], file_path)
134
+ except Exception as e:
135
+ return f"Failed to download {file_path}: {str(e)}"
136
+ else:
137
+ return f"Missing required file {file_path} and no download source available"
138
 
139
  try:
140
+ # Simple test of loading the model components
141
+ import torch
 
 
 
142
 
143
+ # Load a small test tensor to verify PyTorch works
144
+ self.test_tensor = torch.rand(3, 3).to(self.device)
145
+
146
+ # Just check if we can access the file
147
+ with open(gate_dict_path, 'rb') as f:
148
+ # Just read a small part to verify the file exists and is readable
149
+ f.read(10)
150
 
151
+ with open(audio_projector_path, 'rb') as f:
152
+ f.read(10)
153
 
154
+ with open(clap_weights, 'rb') as f:
155
+ f.read(10)
156
+
157
+ # For now, just mark as loaded - we'll implement real loading later
158
+ self.model_loaded = True
159
+ self.model_type = model_type
160
+
161
+ return f"{model_type} files verified and accessible"
162
+
163
  except Exception as e:
164
+ traceback.print_exc()
165
+ return f"Error loading model: {str(e)}"
166
+
167
+ def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50):
168
+ """Generate an image using SonicDiffusion with the specified inputs"""
169
+ if not self.model_loaded:
170
+ return "Error: Model not loaded. Please click 'Load Model' first."
 
 
171
 
172
+ if not audio_path:
173
+ return "Error: Audio file is required"
174
+
175
+ if not os.path.exists(audio_path):
176
+ return f"Error: Audio file {audio_path} does not exist"
177
 
178
+ # Return info about what would be generated
179
+ return f"Would generate image with:\nModel: {self.model_type}\nPrompt: {text_prompt}\nAudio: {audio_path}\nCFG Scale: {cfg_scale}\nSteps: {steps}\n\nFull implementation coming soon!"