alpercagann commited on
Commit
ba78f27
·
1 Parent(s): 3e50399

Update app.py to use the new SonicDiffusion controller

Browse files
Files changed (1) hide show
  1. app.py +120 -35
app.py CHANGED
@@ -15,7 +15,7 @@ os.makedirs("outputs", exist_ok=True)
15
  # Import required packages
16
  import gradio as gr
17
 
18
- # Try importing key packages
19
  packages = {
20
  "torch": None,
21
  "transformers": None,
@@ -26,54 +26,139 @@ packages = {
26
  for package in packages.keys():
27
  try:
28
  module = __import__(package)
29
- packages[package] = module.__version__
30
- print(f"{package} version: {module.__version__}")
 
 
 
 
31
  except ImportError as e:
32
  print(f"{package} import error: {e}")
33
 
34
  # Import our controller
35
- from controller import SimpleSonicDiffusionController
36
 
37
  # Initialize controller
38
- controller = SimpleSonicDiffusionController()
39
 
40
  # Create the Gradio interface
41
- with gr.Blocks(title="SonicDiffusion - Progressive Setup") as demo:
42
- gr.Markdown("# SonicDiffusion - Building Up")
 
43
 
44
- status_output = gr.Textbox(label="Status", value="System initialized. Click 'Check System' to verify setup.")
45
 
46
- with gr.Tab("System Check"):
47
- check_btn = gr.Button("Check System")
48
-
49
- def check_system():
50
- status = ["Package Status:"]
51
-
52
- # Check package availability
53
- for package, version in packages.items():
54
- status.append(f"{package}: {version if version else 'Not Available'}")
 
 
 
55
 
56
- # Check directories
57
- status.append("\nDirectory Status:")
58
- asset_status = controller.get_asset_status()
59
- for dir_name, dir_status in asset_status.items():
60
- if dir_name in ["assets", "ckpts", "outputs"]:
61
- status.append(f"Directory '{dir_name}': {dir_status}")
62
 
63
- return "\n".join(status)
64
-
65
- check_btn.click(fn=check_system, outputs=status_output)
66
-
67
- with gr.Tab("Model"):
68
- load_model_btn = gr.Button("Load Model")
69
- load_model_btn.click(fn=controller.load_model, outputs=status_output)
 
 
 
70
 
71
- with gr.Tab("Generate"):
72
- text_input = gr.Textbox(label="Enter a prompt", value="Hello, SonicDiffusion!")
73
- gen_btn = gr.Button("Process Text")
74
- gen_output = gr.Textbox(label="Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- gen_btn.click(fn=controller.generate, inputs=[text_input], outputs=gen_output)
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  demo.launch()
 
15
  # Import required packages
16
  import gradio as gr
17
 
18
+ # Try importing key packages and print versions
19
  packages = {
20
  "torch": None,
21
  "transformers": None,
 
26
  for package in packages.keys():
27
  try:
28
  module = __import__(package)
29
+ try:
30
+ packages[package] = module.__version__
31
+ print(f"{package} version: {module.__version__}")
32
+ except AttributeError:
33
+ packages[package] = "Installed (version unknown)"
34
+ print(f"{package} installed (version unknown)")
35
  except ImportError as e:
36
  print(f"{package} import error: {e}")
37
 
38
  # Import our controller
39
+ from controller import SonicDiffusionController
40
 
41
  # Initialize controller
42
+ controller = SonicDiffusionController()
43
 
44
  # Create the Gradio interface
45
+ with gr.Blocks(title="SonicDiffusion") as demo:
46
+ gr.Markdown("# SonicDiffusion - Audio-to-Image Generation")
47
+ gr.Markdown("Generate images conditioned on audio inputs using Stable Diffusion")
48
 
49
+ status_output = gr.Textbox(label="Status", value="System initialized. Check dependencies and download assets first.")
50
 
51
+ with gr.Tab("1. Setup"):
52
+ with gr.Row():
53
+ with gr.Column():
54
+ check_deps_btn = gr.Button("Check Dependencies")
55
+
56
+ def format_deps(deps):
57
+ return "\n".join([f"{pkg}: {vers}" for pkg, vers in deps.items()])
58
+
59
+ check_deps_btn.click(
60
+ fn=lambda: format_deps(controller.check_dependencies()),
61
+ outputs=status_output
62
+ )
63
 
64
+ with gr.Column():
65
+ check_assets_btn = gr.Button("Check Assets")
 
 
 
 
66
 
67
+ def format_assets(assets):
68
+ return "\n".join([f"{path}: {'✓' if exists else '✗'}" for path, exists in assets.items()])
69
+
70
+ check_assets_btn.click(
71
+ fn=lambda: format_assets(controller.check_assets()),
72
+ outputs=status_output
73
+ )
74
+
75
+ download_assets_btn = gr.Button("Download Required Assets")
76
+ download_assets_btn.click(fn=controller.download_assets, outputs=status_output)
77
 
78
+ with gr.Tab("2. Generate"):
79
+ with gr.Row():
80
+ with gr.Column():
81
+ model_dropdown = gr.Dropdown(
82
+ label="Select Model",
83
+ choices=["Landscape Model", "Greatest Hits Model"],
84
+ value="Landscape Model"
85
+ )
86
+
87
+ load_model_btn = gr.Button("Load Selected Model")
88
+ load_model_btn.click(
89
+ fn=controller.load_model,
90
+ inputs=[model_dropdown],
91
+ outputs=status_output
92
+ )
93
+
94
+ prompt_input = gr.Textbox(
95
+ label="Prompt",
96
+ placeholder="Enter a descriptive prompt...",
97
+ value="a high quality photograph of a fantasy landscape"
98
+ )
99
+
100
+ audio_input = gr.Audio(
101
+ label="Upload Audio",
102
+ type="filepath",
103
+ sources=["upload"]
104
+ )
105
+
106
+ with gr.Row():
107
+ cfg_scale = gr.Slider(
108
+ label="CFG Scale",
109
+ minimum=1.0,
110
+ maximum=20.0,
111
+ value=7.5,
112
+ step=0.5
113
+ )
114
+
115
+ steps = gr.Slider(
116
+ label="Steps",
117
+ minimum=20,
118
+ maximum=100,
119
+ value=50,
120
+ step=5
121
+ )
122
+
123
+ generate_btn = gr.Button("Generate", variant="primary")
124
+
125
+ with gr.Column():
126
+ output_image = gr.Image(label="Generated Image")
127
+
128
+ generate_btn.click(
129
+ fn=controller.generate,
130
+ inputs=[prompt_input, audio_input, cfg_scale, steps],
131
+ outputs=output_image
132
+ )
133
 
134
+ with gr.Row():
135
+ gr.Markdown("### Example Audio Files")
136
+ examples = [
137
+ ["a serene landscape with mountains and a lake", "assets/fire_crackling.wav", 7.5, 50],
138
+ ["a mysterious forest at night with glowing elements", "assets/plastic_bag.wav", 7.5, 50]
139
+ ]
140
+ gr.Examples(
141
+ examples=examples,
142
+ inputs=[prompt_input, audio_input, cfg_scale, steps],
143
+ outputs=output_image,
144
+ fn=controller.generate
145
+ )
146
 
147
  if __name__ == "__main__":
148
+ # Attempt to download example audio files if they don't exist
149
+ if not os.path.exists("assets/fire_crackling.wav") or not os.path.exists("assets/plastic_bag.wav"):
150
+ try:
151
+ from download_assets import download_gdrive_file
152
+
153
+ assets = {
154
+ "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
155
+ "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
156
+ }
157
+
158
+ for path, file_id in assets.items():
159
+ if not os.path.exists(path):
160
+ download_gdrive_file(file_id, path)
161
+ except Exception as e:
162
+ print(f"Error downloading example audio files: {e}")
163
+
164
  demo.launch()