alpercagann commited on
Commit
c9ef435
·
1 Parent(s): 517b2f4

Add simplified controller and update app

Browse files
Files changed (2) hide show
  1. app.py +42 -56
  2. controller.py +45 -129
app.py CHANGED
@@ -7,14 +7,15 @@ print(f"Python version: {sys.version}")
7
  print(f"Working directory: {os.getcwd()}")
8
  print(f"Directory contents: {os.listdir('.')}")
9
 
10
- # Import gradio first - this is our most essential dependency
11
- import gradio as gr
 
 
12
 
13
- # Try importing other packages (but don't fail if they're missing)
14
- torch_available = False
15
- transformers_available = False
16
- diffusers_available = False
17
 
 
18
  try:
19
  import torch
20
  print(f"PyTorch version: {torch.__version__}")
@@ -24,63 +25,48 @@ try:
24
  torch_available = True
25
  except ImportError as e:
26
  print(f"PyTorch import error: {e}")
 
27
 
28
- try:
29
- import transformers
30
- print(f"Transformers version: {transformers.__version__}")
31
- transformers_available = True
32
- except ImportError as e:
33
- print(f"Transformers import error: {e}")
34
 
35
- try:
36
- import diffusers
37
- print(f"Diffusers version: {diffusers.__version__}")
38
- diffusers_available = True
39
- except ImportError as e:
40
- print(f"Diffusers import error: {e}")
41
 
42
- # Simple demo interface
43
- def hello(name):
44
- if not name:
45
- name = "World"
46
 
47
- status = []
48
- if torch_available:
49
- status.append("PyTorch ✓")
50
- else:
51
- status.append("PyTorch ✗")
52
 
53
- if transformers_available:
54
- status.append("Transformers ")
55
- else:
56
- status.append("Transformers ✗")
57
 
58
- if diffusers_available:
59
- status.append("Diffusers ✓")
60
- else:
61
- status.append("Diffusers ✗")
 
 
 
 
 
 
 
 
 
 
62
 
63
- return f"Hello, {name}!\n\nPackage Status:\n" + "\n".join(status)
64
-
65
- # Create the Gradio interface
66
- demo = gr.Interface(
67
- fn=hello,
68
- inputs="text",
69
- outputs="text",
70
- title="SonicDiffusion - Setup Status",
71
- description="This app shows which packages are successfully installed."
72
- )
73
 
74
  if __name__ == "__main__":
75
- # Try installing packages at runtime if they're not available
76
- if not torch_available:
77
- print("Attempting to install PyTorch...")
78
- try:
79
- import subprocess
80
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.1"])
81
- print("PyTorch installed successfully!")
82
- except Exception as e:
83
- print(f"Error installing PyTorch: {e}")
84
-
85
- # Launch the demo
86
  demo.launch()
 
7
  print(f"Working directory: {os.getcwd()}")
8
  print(f"Directory contents: {os.listdir('.')}")
9
 
10
+ # Create necessary directories
11
+ os.makedirs("assets", exist_ok=True)
12
+ os.makedirs("ckpts", exist_ok=True)
13
+ os.makedirs("outputs", exist_ok=True)
14
 
15
+ # Import required packages
16
+ import gradio as gr
 
 
17
 
18
+ # Try importing torch
19
  try:
20
  import torch
21
  print(f"PyTorch version: {torch.__version__}")
 
25
  torch_available = True
26
  except ImportError as e:
27
  print(f"PyTorch import error: {e}")
28
+ torch_available = False
29
 
30
+ # Import our controller
31
+ from controller import SimpleSonicDiffusionController
 
 
 
 
32
 
33
+ # Initialize controller
34
+ controller = SimpleSonicDiffusionController()
 
 
 
 
35
 
36
+ # Create the Gradio interface
37
+ with gr.Blocks(title="SonicDiffusion - Progressive Setup") as demo:
38
+ gr.Markdown("# SonicDiffusion - Simplified Version")
 
39
 
40
+ status_output = gr.Textbox(label="Status", value="System initialized. Click 'Check System' to verify setup.")
 
 
 
 
41
 
42
+ with gr.Tab("System Check"):
43
+ check_btn = gr.Button("Check System")
 
 
44
 
45
+ def check_system():
46
+ status = []
47
+
48
+ # Check PyTorch
49
+ status.append(f"PyTorch: {'Available' if torch_available else 'Not Available'}")
50
+
51
+ # Check directories
52
+ asset_status = controller.get_asset_status()
53
+ for dir_name, dir_status in asset_status.items():
54
+ status.append(f"Directory '{dir_name}': {dir_status}")
55
+
56
+ return "\n".join(status)
57
+
58
+ check_btn.click(fn=check_system, outputs=status_output)
59
 
60
+ with gr.Tab("Model"):
61
+ load_model_btn = gr.Button("Load Model")
62
+ load_model_btn.click(fn=controller.load_model, outputs=status_output)
63
+
64
+ with gr.Tab("Generate"):
65
+ text_input = gr.Textbox(label="Prompt")
66
+ gen_btn = gr.Button("Generate")
67
+ gen_output = gr.Textbox(label="Output")
68
+
69
+ gen_btn.click(fn=controller.generate, inputs=[text_input], outputs=gen_output)
70
 
71
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
72
  demo.launch()
controller.py CHANGED
@@ -1,140 +1,56 @@
1
  import os
2
- import torch
3
- from unet2d_custom import UNet2DConditionModel
4
- from pipeline_stable_diffusion_custom import StableDiffusionPipeline
5
- from ldm.modules.encoders.audio_projector_res import Adapter
6
 
7
- class SonicDiffusionController:
8
- def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
9
- self.device = device
10
- print(f"Using device: {self.device}")
11
- self.sr = 44100
12
  self.model_loaded = False
 
13
 
14
- def load_model(self,
15
- gate_dict_path="ckpts/landscape.pt",
16
- clap_path="CLAP/msclap",
17
- clap_weights="ckpts/CLAP_weights_2022.pth",
18
- adapter_ckpt_path="ckpts/audio_projector_landscape.pth"):
19
- """Load the model conditionally based on environment and availability"""
20
  try:
21
- # First, check if the required files exist
22
- for path in [gate_dict_path, adapter_ckpt_path]:
23
- if not os.path.exists(path):
24
- print(f"Warning: {path} not found, trying to download...")
25
- # You could add auto-download here
26
-
27
- print("Loading models - this may take a moment...")
28
-
29
- # Try to load the model with appropriate settings for your hardware
30
- model_id = "CompVis/stable-diffusion-v1-4"
31
- self.unet = UNet2DConditionModel.from_pretrained(
32
- model_id,
33
- subfolder="unet",
34
- use_adapter_list=[False, True, True],
35
- low_cpu_mem_usage=True,
36
- device_map="auto" # Let PyTorch decide the mapping
37
- )
38
-
39
- self.pipeline = StableDiffusionPipeline.from_pretrained(
40
- model_id,
41
- use_safetensors=True,
42
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
43
- )
44
-
45
- # Move models to the appropriate device
46
- self.unet = self.unet.to(self.device)
47
- self.pipeline = self.pipeline.to(self.device)
48
-
49
- # Load gate dictionary
50
- gate_dict = torch.load(gate_dict_path, map_location=self.device)
51
- for name, param in self.unet.named_parameters():
52
- if "adapter" in name:
53
- param.data = gate_dict[name].to(self.device)
54
-
55
- # Set pipeline's UNet
56
- self.pipeline.unet = self.unet
57
-
58
- # Import and load audio encoder
59
- import sys
60
- sys.path.append(clap_path)
61
- from CLAPWrapper import CLAPWrapper
62
-
63
- self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=(self.device=="cuda"))
64
- self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).to(self.device)
65
- self.audio_projector.load_state_dict(torch.load(adapter_ckpt_path, map_location=self.device))
66
- self.audio_projector.eval()
67
-
68
- self.model_loaded = True
69
- print("Model loaded successfully!")
70
- return True
71
-
72
- except Exception as e:
73
- print(f"Failed to load model: {e}")
74
- import traceback
75
- traceback.print_exc()
76
- return False
77
 
78
- def generate(self, file=None, audio=None, prompt=None, cfg_scale=5, num_inference_steps=50):
79
- """Generate an image from audio input"""
80
- if not self.model_loaded:
81
- raise ValueError("Model not loaded. Call load_model() first.")
82
-
83
  try:
84
- with torch.no_grad():
85
- # Process audio input
86
- audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio], resample=self.sr)
87
- audio_proj = self.audio_projector(audio_emb.unsqueeze(1))
88
-
89
- # Create unconditional embedding
90
- audio_emb = torch.zeros(1, 1024).to(self.device)
91
- audio_uc = self.audio_projector(audio_emb.unsqueeze(1))
92
-
93
- # Combine for context
94
- audio_context = torch.cat([audio_uc, audio_proj]).to(self.device)
95
-
96
- # Generate image
97
- image = self.pipeline(
98
- prompt=prompt,
99
- audio_context=audio_context,
100
- guidance_scale=cfg_scale,
101
- num_inference_steps=num_inference_steps
102
- )
103
-
104
- return image.images[0]
105
-
106
  except Exception as e:
107
- print(f"Error in generation: {e}")
108
- import traceback
109
- traceback.print_exc()
110
-
111
- # Return a blank error image
112
- from PIL import Image, ImageDraw
113
- img = Image.new('RGB', (512, 512), color=(255, 255, 255))
114
- d = ImageDraw.Draw(img)
115
- d.text((10, 250), f"Error: {str(e)}", fill=(0, 0, 0))
116
- return img
117
 
118
- def update_audio_model(self, audio_model_update):
119
- """Update audio model based on selection"""
 
 
 
120
  try:
121
- if audio_model_update == "Landscape Model":
122
- audio_projector_path = "ckpts/audio_projector_landscape.pth"
123
- gate_dict_path = "ckpts/landscape.pt"
124
- else:
125
- audio_projector_path = "ckpts/audio_projector_gh.pth"
126
- gate_dict_path = "ckpts/greatest_hits.pt"
127
-
128
- # Load gate dictionary and update parameters
129
- gate_dict = torch.load(gate_dict_path, map_location=self.device)
130
- for name, param in self.pipeline.unet.named_parameters():
131
- if "adapter" in name:
132
- param.data = gate_dict[name].to(self.device)
133
-
134
- # Load audio projector state
135
- self.audio_projector.load_state_dict(torch.load(audio_projector_path, map_location=self.device))
136
-
137
- return "Model updated successfully"
138
  except Exception as e:
139
- print(f"Error updating audio model: {e}")
140
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
 
 
 
3
 
4
+ class SimpleSonicDiffusionController:
5
+ """A simplified version of the controller with minimal dependencies"""
6
+
7
+ def __init__(self):
 
8
  self.model_loaded = False
9
+ self.device = self._get_device()
10
 
11
+ def _get_device(self):
12
+ """Determine the available device (CPU or CUDA)"""
 
 
 
 
13
  try:
14
+ import torch
15
+ if torch.cuda.is_available():
16
+ print(f"CUDA available: {torch.cuda.get_device_name(0)}")
17
+ return "cuda"
18
+ else:
19
+ print("CUDA not available, using CPU")
20
+ return "cpu"
21
+ except ImportError:
22
+ print("PyTorch not available, using CPU")
23
+ return "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def load_model(self):
26
+ """Simulated model loading"""
 
 
 
27
  try:
28
+ import torch
29
+ # Just create a simple tensor to verify PyTorch is working
30
+ self.test_tensor = torch.rand(3, 3)
31
+ self.model_loaded = True
32
+ return "Model loading simulation successful!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
+ return f"Error loading model: {str(e)}"
 
 
 
 
 
 
 
 
 
35
 
36
+ def generate(self, text_prompt, audio_path=None):
37
+ """Simulated generation process"""
38
+ if not self.model_loaded:
39
+ return "Error: Model not loaded. Please click 'Load Model' first."
40
+
41
  try:
42
+ import torch
43
+ # Just a placeholder - we'll implement real generation later
44
+ return f"Generated output for prompt: '{text_prompt}'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  except Exception as e:
46
+ return f"Error during generation: {str(e)}"
47
+
48
+ def get_asset_status(self):
49
+ """Check if required directories and files exist"""
50
+ asset_status = {}
51
+
52
+ # Check directories
53
+ for dir_name in ["assets", "ckpts", "outputs"]:
54
+ asset_status[dir_name] = "✓" if os.path.exists(dir_name) else "✗"
55
+
56
+ return asset_status