manoskary commited on
Commit
5be99b2
·
1 Parent(s): 59f9e42

Add initial project files including .gitignore, README, app.py, and requirements.txt

Browse files
Files changed (4) hide show
  1. .gitignore +7 -0
  2. README.md +37 -5
  3. app.py +481 -0
  4. requirements.txt +11 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+
4
+ checkpoints/
5
+ outputs/
6
+ .gradio/
7
+ .codex
README.md CHANGED
@@ -1,15 +1,47 @@
1
  ---
2
  title: Woosh DFlow
3
- emoji: 💻
4
  colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 6.12.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
- short_description: 'Woosh: Sound Effect Generative Model '
 
 
 
 
 
 
 
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Woosh DFlow
3
+ emoji: 🔊
4
  colorFrom: red
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 6.12.0
8
+ python_version: '3.12.12'
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
+ fullWidth: true
13
+ startup_duration_timeout: 1h
14
+ short_description: 'Woosh-DFlow text-to-audio sound effect generation'
15
+ tags:
16
+ - audio-generation
17
+ - text-to-audio
18
+ - gradio
19
+ - zerogpu
20
  ---
21
 
22
+ # Woosh-DFlow
23
+
24
+ Text-to-audio sound effect generation with Sony AI's distilled Woosh-DFlow model.
25
+
26
+ The app downloads the official `Woosh-DFlow.zip` checkpoint from the
27
+ `SonyResearch/Woosh` v1.0.0 GitHub release when the Space starts, then loads the
28
+ model for ZeroGPU inference. The first build or cold start can take a while
29
+ because the checkpoint is about 1.2 GB.
30
+
31
+ ## Notes
32
+
33
+ - Inference is decorated with `@spaces.GPU`, so select ZeroGPU hardware in the
34
+ Space settings.
35
+ - The Woosh inference source is installed from the upstream GitHub repository at
36
+ a pinned commit.
37
+ - The upstream code is MIT/Apache-2.0. The released model weights are
38
+ CC-BY-NC, as stated by the upstream project.
39
+
40
+ ## Local Run
41
+
42
+ ```bash
43
+ python app.py
44
+ ```
45
+
46
+ Use `WOOSH_CHECKPOINT_DIR=/path/to/checkpoints/Woosh-DFlow` to point the app at
47
+ an existing checkpoint directory.
app.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Woosh-DFlow text-to-audio Space."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import logging
7
+ import os
8
+ import shutil
9
+ import threading
10
+ import time
11
+ import zipfile
12
+ from pathlib import Path
13
+ from typing import Callable
14
+
15
+ try:
16
+ import spaces
17
+ except ImportError: # Allows syntax checks and local CPU runs without ZeroGPU helpers.
18
+
19
+ class _SpacesFallback:
20
+ @staticmethod
21
+ def GPU(*args, **kwargs):
22
+ if args and callable(args[0]) and len(args) == 1 and not kwargs:
23
+ return args[0]
24
+
25
+ def decorator(fn: Callable):
26
+ return fn
27
+
28
+ return decorator
29
+
30
+ spaces = _SpacesFallback()
31
+
32
+ import gradio as gr
33
+ import requests
34
+ import torch
35
+
36
+ from woosh.components.base import LoadConfig
37
+ from woosh.inference.flowmap_sampler import sample_euler
38
+ from woosh.model.flowmap_from_pretrained import FlowMapFromPretrained
39
+
40
+
41
+ logging.basicConfig(level=logging.INFO)
42
+ log = logging.getLogger("woosh_space")
43
+
44
+ APP_DIR = Path(__file__).resolve().parent
45
+ CHECKPOINT_NAME = "Woosh-DFlow"
46
+ DEFAULT_CHECKPOINT_URL = (
47
+ "https://github.com/SonyResearch/Woosh/releases/download/v1.0.0/"
48
+ "Woosh-DFlow.zip"
49
+ )
50
+ CHECKPOINT_URL = os.getenv("WOOSH_CHECKPOINT_URL", DEFAULT_CHECKPOINT_URL)
51
+ SAMPLE_RATE = 48_000
52
+ LATENT_CHANNELS = 128
53
+ LATENT_FRAMES = 501
54
+ GENERATION_STEPS = 4
55
+ RENOISE_SCHEDULE = [0.0, 0.5, 0.5, 0.3]
56
+ MAX_VARIANTS = 2
57
+
58
+ _model = None
59
+ _device = None
60
+ _model_lock = threading.Lock()
61
+ _startup_error: str | None = None
62
+
63
+
64
+ def _resolve_app_path(value: str) -> Path:
65
+ path = Path(value).expanduser()
66
+ if path.is_absolute():
67
+ return path
68
+ return APP_DIR / path
69
+
70
+
71
+ CHECKPOINT_DIR = _resolve_app_path(
72
+ os.getenv("WOOSH_CHECKPOINT_DIR", f"checkpoints/{CHECKPOINT_NAME}")
73
+ )
74
+
75
+
76
+ def _checkpoint_ready(path: Path) -> bool:
77
+ return path.exists() and path.is_dir() and any(path.iterdir())
78
+
79
+
80
+ def _download_file(url: str, destination: Path) -> None:
81
+ destination.parent.mkdir(parents=True, exist_ok=True)
82
+ tmp_path = destination.with_suffix(destination.suffix + ".partial")
83
+
84
+ log.info("Downloading %s to %s", url, destination)
85
+ with requests.get(url, stream=True, timeout=(10, 120)) as response:
86
+ response.raise_for_status()
87
+ total = int(response.headers.get("content-length", 0))
88
+ downloaded = 0
89
+ last_log = time.perf_counter()
90
+
91
+ with tmp_path.open("wb") as handle:
92
+ for chunk in response.iter_content(chunk_size=8 * 1024 * 1024):
93
+ if not chunk:
94
+ continue
95
+ handle.write(chunk)
96
+ downloaded += len(chunk)
97
+
98
+ now = time.perf_counter()
99
+ if now - last_log > 10:
100
+ if total:
101
+ pct = downloaded / total * 100
102
+ log.info("Checkpoint download %.1f%%", pct)
103
+ else:
104
+ log.info(
105
+ "Checkpoint download %.1f MB",
106
+ downloaded / 1024 / 1024,
107
+ )
108
+ last_log = now
109
+
110
+ tmp_path.replace(destination)
111
+
112
+
113
+ def ensure_checkpoint(path: Path | None = None) -> Path:
114
+ if path is None:
115
+ path = CHECKPOINT_DIR
116
+ if _checkpoint_ready(path):
117
+ return path
118
+
119
+ archive_path = APP_DIR / "checkpoints" / ".downloads" / f"{CHECKPOINT_NAME}.zip"
120
+ if not archive_path.exists():
121
+ _download_file(CHECKPOINT_URL, archive_path)
122
+
123
+ log.info("Extracting %s", archive_path)
124
+ APP_DIR.mkdir(parents=True, exist_ok=True)
125
+ try:
126
+ with zipfile.ZipFile(archive_path) as archive:
127
+ archive.extractall(APP_DIR)
128
+ except zipfile.BadZipFile:
129
+ log.warning("Checkpoint archive was invalid; downloading it again.")
130
+ archive_path.unlink(missing_ok=True)
131
+ _download_file(CHECKPOINT_URL, archive_path)
132
+ with zipfile.ZipFile(archive_path) as archive:
133
+ archive.extractall(APP_DIR)
134
+
135
+ if _checkpoint_ready(path):
136
+ return path
137
+
138
+ candidates = [
139
+ candidate
140
+ for candidate in APP_DIR.rglob(CHECKPOINT_NAME)
141
+ if candidate.is_dir() and candidate != path
142
+ ]
143
+ if candidates:
144
+ path.parent.mkdir(parents=True, exist_ok=True)
145
+ shutil.move(str(candidates[0]), str(path))
146
+
147
+ if not _checkpoint_ready(path):
148
+ raise RuntimeError(
149
+ f"Could not find {CHECKPOINT_NAME} after extracting {archive_path}."
150
+ )
151
+
152
+ return path
153
+
154
+
155
+ def select_device() -> str:
156
+ if torch.cuda.is_available():
157
+ return "cuda"
158
+ mps = getattr(torch.backends, "mps", None)
159
+ if mps is not None and mps.is_available():
160
+ return "mps"
161
+ return "cpu"
162
+
163
+
164
+ def get_model():
165
+ global _device, _model
166
+
167
+ with _model_lock:
168
+ if _model is not None:
169
+ return _model, _device
170
+
171
+ checkpoint_path = ensure_checkpoint()
172
+ _device = select_device()
173
+ log.info("Loading %s on %s", checkpoint_path, _device)
174
+ model = FlowMapFromPretrained(LoadConfig(path=str(checkpoint_path)))
175
+ _model = model.eval().to(_device)
176
+ log.info("Model loaded")
177
+ return _model, _device
178
+
179
+
180
+ def _seed_everything(seed: int) -> int:
181
+ if seed < 0:
182
+ seed = int.from_bytes(os.urandom(4), "big") % (2**31)
183
+ torch.manual_seed(seed)
184
+ if torch.cuda.is_available():
185
+ torch.cuda.manual_seed_all(seed)
186
+ return seed
187
+
188
+
189
+ def _format_audio_batch(audio: torch.Tensor) -> list[tuple[int, object]]:
190
+ audio = audio.detach().cpu().float()
191
+ outputs = []
192
+ for sample in audio:
193
+ peak = sample.abs().max().clamp(min=1.0)
194
+ sample = (sample / peak).clamp(-1.0, 1.0)
195
+ if sample.ndim == 2 and sample.shape[0] == 1:
196
+ sample = sample.squeeze(0)
197
+ elif sample.ndim == 2:
198
+ sample = sample.transpose(0, 1)
199
+ outputs.append((SAMPLE_RATE, sample.numpy()))
200
+ return outputs
201
+
202
+
203
+ @spaces.GPU(duration=120)
204
+ @torch.inference_mode()
205
+ def generate(
206
+ prompt: str,
207
+ variants: int,
208
+ cfg_scale: float,
209
+ seed: int,
210
+ progress=gr.Progress(track_tqdm=False),
211
+ ):
212
+ prompt = (prompt or "").strip()
213
+ if not prompt:
214
+ raise gr.Error("Enter a short sound description.")
215
+
216
+ variants = max(1, min(int(variants), MAX_VARIANTS))
217
+ cfg_scale = float(cfg_scale)
218
+ seed = _seed_everything(int(seed))
219
+ try:
220
+ model, device = get_model()
221
+ except Exception as exc:
222
+ raise gr.Error(f"Could not load Woosh-DFlow: {exc}") from exc
223
+
224
+ progress(0.1, desc="Preparing text conditioning")
225
+ noise = torch.randn(variants, LATENT_CHANNELS, LATENT_FRAMES, device=device)
226
+ cond = model.get_cond(
227
+ {"audio": None, "description": [prompt] * variants},
228
+ no_dropout=True,
229
+ device=device,
230
+ )
231
+
232
+ progress(0.35, desc="Synthesizing latent audio")
233
+ start_time = time.perf_counter()
234
+ latents = sample_euler(
235
+ model=model,
236
+ noise=noise,
237
+ cond=cond,
238
+ num_steps=GENERATION_STEPS,
239
+ renoise=RENOISE_SCHEDULE,
240
+ cfg=cfg_scale,
241
+ )
242
+
243
+ progress(0.75, desc="Decoding waveform")
244
+ audio = model.autoencoder.inverse(latents)
245
+ elapsed = time.perf_counter() - start_time
246
+ outputs = _format_audio_batch(audio)
247
+
248
+ if device == "cuda":
249
+ torch.cuda.empty_cache()
250
+
251
+ audio_updates = [
252
+ gr.update(value=value, visible=True) for value in outputs[:MAX_VARIANTS]
253
+ ]
254
+ while len(audio_updates) < MAX_VARIANTS:
255
+ audio_updates.append(gr.update(value=None, visible=False))
256
+
257
+ details = (
258
+ f"Generated {variants} take{'s' if variants != 1 else ''} "
259
+ f"in {elapsed:.1f}s on {device}. "
260
+ f"Seed: `{seed}`. Steps: `{GENERATION_STEPS}`. "
261
+ f"Sample rate: `{SAMPLE_RATE} Hz`."
262
+ )
263
+ progress(1.0, desc="Done")
264
+ return [*audio_updates, details]
265
+
266
+
267
+ def build_ui() -> gr.Blocks:
268
+ css = """
269
+ .gradio-container {
270
+ max-width: 1180px !important;
271
+ }
272
+ #hero {
273
+ padding: 28px;
274
+ border: 1px solid #d8e3df;
275
+ border-radius: 8px;
276
+ background: linear-gradient(135deg, #ffffff 0%, #f1faf7 100%);
277
+ }
278
+ #hero h1 {
279
+ margin: 0 0 10px;
280
+ font-size: 2.35rem;
281
+ line-height: 1.05;
282
+ letter-spacing: 0;
283
+ color: #202124;
284
+ }
285
+ #hero p {
286
+ margin: 0;
287
+ color: #3a4140;
288
+ font-size: 1.02rem;
289
+ line-height: 1.55;
290
+ }
291
+ #hero .meta {
292
+ margin-top: 14px;
293
+ color: #007a7a;
294
+ font-weight: 650;
295
+ }
296
+ .primary-button {
297
+ min-height: 48px;
298
+ }
299
+ """
300
+
301
+ theme = gr.themes.Soft(
302
+ primary_hue="red",
303
+ secondary_hue="teal",
304
+ neutral_hue="gray",
305
+ radius_size="sm",
306
+ )
307
+
308
+ with gr.Blocks(
309
+ title="Woosh-DFlow",
310
+ theme=theme,
311
+ css=css,
312
+ analytics_enabled=False,
313
+ ) as demo:
314
+ gr.HTML(
315
+ """
316
+ <section id="hero">
317
+ <h1>Woosh-DFlow</h1>
318
+ <p>
319
+ Fast text-to-audio generation for sound effects, ambience,
320
+ impacts, machines, weather, Foley, and synthetic UI sounds.
321
+ </p>
322
+ <p class="meta">
323
+ Distilled Woosh model by Sony AI. Outputs are five-second,
324
+ 48 kHz audio clips.
325
+ </p>
326
+ </section>
327
+ """
328
+ )
329
+
330
+ with gr.Row(equal_height=False):
331
+ with gr.Column(scale=7):
332
+ prompt = gr.Textbox(
333
+ label="Sound prompt",
334
+ placeholder="A heavy metal door slams shut in a concrete hallway",
335
+ lines=4,
336
+ max_lines=6,
337
+ )
338
+ run_button = gr.Button(
339
+ "Generate sound",
340
+ variant="primary",
341
+ elem_classes=["primary-button"],
342
+ )
343
+
344
+ gr.Examples(
345
+ examples=[
346
+ "sportscar engine revving and driving away quickly",
347
+ "heavy rain on a tin roof with distant thunder",
348
+ "large wooden door creaking open in an empty hallway",
349
+ "arcade laser blast with a bright digital tail",
350
+ ],
351
+ inputs=prompt,
352
+ )
353
+
354
+ with gr.Column(scale=3):
355
+ variants = gr.Slider(
356
+ minimum=1,
357
+ maximum=MAX_VARIANTS,
358
+ step=1,
359
+ value=1,
360
+ label="Takes",
361
+ info="Generate one or two variations per request.",
362
+ )
363
+ cfg_scale = gr.Slider(
364
+ minimum=0.0,
365
+ maximum=9.0,
366
+ step=0.1,
367
+ value=4.5,
368
+ label="Prompt strength",
369
+ info="Higher values follow the prompt more tightly.",
370
+ )
371
+ seed = gr.Number(
372
+ value=-1,
373
+ label="Seed",
374
+ precision=0,
375
+ info="Use -1 for a random seed.",
376
+ )
377
+
378
+ with gr.Row():
379
+ audio_1 = gr.Audio(
380
+ label="Take 1",
381
+ type="numpy",
382
+ format="wav",
383
+ autoplay=True,
384
+ interactive=False,
385
+ )
386
+ audio_2 = gr.Audio(
387
+ label="Take 2",
388
+ type="numpy",
389
+ format="wav",
390
+ visible=False,
391
+ interactive=False,
392
+ )
393
+
394
+ initial_details = (
395
+ "The first request may wait while the official DFlow checkpoint "
396
+ "is downloaded and loaded."
397
+ )
398
+ if _startup_error is not None:
399
+ initial_details = (
400
+ "Model preload failed. Generation will retry the download and "
401
+ f"load step. Error: `{_startup_error}`"
402
+ )
403
+ run_details = gr.Markdown(value=initial_details)
404
+
405
+ inputs = [prompt, variants, cfg_scale, seed]
406
+ outputs = [audio_1, audio_2, run_details]
407
+ prompt.submit(
408
+ fn=generate,
409
+ inputs=inputs,
410
+ outputs=outputs,
411
+ api_name="generate",
412
+ show_progress="full",
413
+ )
414
+ run_button.click(
415
+ fn=generate,
416
+ inputs=inputs,
417
+ outputs=outputs,
418
+ api_name="generate_click",
419
+ show_progress="full",
420
+ )
421
+
422
+ gr.Markdown(
423
+ """
424
+ Model weights are downloaded from the official
425
+ `SonyResearch/Woosh` v1.0.0 release. The released weights are
426
+ licensed CC-BY-NC; the upstream inference code is MIT/Apache-2.0.
427
+ """
428
+ )
429
+
430
+ return demo
431
+
432
+
433
+ def eager_load_model() -> None:
434
+ global _startup_error
435
+ if os.getenv("WOOSH_EAGER_LOAD", "1").lower() in {"0", "false", "no"}:
436
+ return
437
+ try:
438
+ get_model()
439
+ except Exception as exc: # Keep the Space UI reachable with a clear error.
440
+ _startup_error = str(exc)
441
+ log.exception("Model preload failed")
442
+
443
+
444
+ def main() -> None:
445
+ global CHECKPOINT_DIR
446
+
447
+ parser = argparse.ArgumentParser(description="Woosh-DFlow Gradio Space")
448
+ parser.add_argument(
449
+ "--checkpoint",
450
+ type=str,
451
+ default=str(CHECKPOINT_DIR),
452
+ help="Path to the Woosh-DFlow checkpoint directory.",
453
+ )
454
+ parser.add_argument("--share", action="store_true", help="Create a public link.")
455
+ parser.add_argument(
456
+ "--server-name",
457
+ default=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
458
+ help="Server address to bind.",
459
+ )
460
+ parser.add_argument(
461
+ "--server-port",
462
+ type=int,
463
+ default=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
464
+ help="Server port.",
465
+ )
466
+ args = parser.parse_args()
467
+
468
+ CHECKPOINT_DIR = _resolve_app_path(args.checkpoint)
469
+
470
+ eager_load_model()
471
+ demo = build_ui()
472
+ demo.queue(default_concurrency_limit=1, max_size=12).launch(
473
+ show_error=True,
474
+ share=args.share,
475
+ server_name=args.server_name,
476
+ server_port=args.server_port,
477
+ )
478
+
479
+
480
+ if __name__ == "__main__":
481
+ main()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+
3
+ torch==2.8.0
4
+ torchaudio==2.8.0
5
+ torchvision==0.23.0
6
+ gradio==6.12.0
7
+ spaces==0.48.2
8
+ requests>=2.31.0
9
+ soundfile>=0.13.1
10
+
11
+ woosh @ git+https://github.com/SonyResearch/Woosh.git@88006c57774a85bede9f87733c019664410d6f4e