garyuzair commited on
Commit
f60e24b
Β·
verified Β·
1 Parent(s): 367b87b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -437
app.py CHANGED
@@ -1,487 +1,381 @@
1
  import streamlit as st
2
- import torch
3
- import numpy as np
4
- from transformers import AutoProcessor, BlipForConditionalGeneration, MusicgenForConditionalGeneration
5
- import imageio
6
  from PIL import Image
7
- import soundfile as sf
 
 
8
  import os
9
  import tempfile
10
- import subprocess
11
- from pydub import AudioSegment, effects
12
- import moviepy.editor as mpy
13
- import time
14
- from concurrent.futures import ThreadPoolExecutor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Set page configuration
17
- st.set_page_config(page_title="🎬 AI SoundFX Studio", layout="wide", initial_sidebar_state="expanded")
 
 
 
18
 
19
- # Enhanced CSS for better UI
 
 
 
20
  st.markdown("""
21
- <style>
22
- .reportview-container {
23
- background: #1a1a1a;
24
- color: white;
25
- }
26
- .sidebar .sidebar-content {
27
- width: 350px;
28
- }
29
- .stProgress > div > div {
30
- background-color: #4CAF50;
31
- }
32
- .stButton>button {
33
- background-color: #4CAF50;
34
- color: white;
35
- }
36
- </style>
37
- """, unsafe_allow_html=True)
38
 
39
- # Optional scene detection
40
- scene_detect_available = True
41
- try:
42
- from scenedetect import VideoManager, SceneManager
43
- from scenedetect.detectors import ContentDetector
44
- except ImportError:
45
- scene_detect_available = False
 
46
 
47
- # Model Management
48
- @st.cache_resource
49
- def load_blip_model():
50
- """Load BLIP model with optimized settings"""
51
  try:
52
- processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
53
- model = BlipForConditionalGeneration.from_pretrained(
54
- "Salesforce/blip-image-captioning-base",
55
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
56
- low_mem=True
57
- ).to("cuda" if torch.cuda.is_available() else "cpu")
 
58
  return processor, model
59
  except Exception as e:
60
- st.error(f"BLIP model load error: {str(e)}")
 
61
  return None, None
62
 
63
- @st.cache_resource
64
- def load_musicgen_model(model_name="facebook/musicgen-medium"):
65
- """Load MusicGen model with optimized settings"""
66
  try:
67
- model = MusicgenForConditionalGeneration.from_pretrained(
68
- model_name,
69
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
70
- low_cpu_mem_usage=True
71
- ).to("cuda" if torch.cuda.is_available() else "cpu")
72
- processor = AutoProcessor.from_pretrained(model_name)
 
 
 
73
  return processor, model
74
  except Exception as e:
75
- st.error(f"MusicGen model load error: {str(e)}")
 
76
  return None, None
77
 
78
- def extract_frames(video_path, num_frames, method="uniform", segment_start=0, segment_end=None):
79
- """Optimized frame extraction with smart sampling"""
 
80
  try:
81
- video = imageio.get_reader(video_path, "ffmpeg")
82
- meta = video.get_meta_data()
83
- fps = meta['fps']
84
- total_frames = int(meta['duration'] * fps)
85
-
86
- # Optimize frame count based on duration
87
- actual_num_frames = min(num_frames, int(total_frames / 5) or 5)
88
-
89
- if segment_end is None:
90
- segment_end = total_frames / fps
91
-
92
- start_frame = int(segment_start * fps)
93
- end_frame = int(segment_end * fps)
94
- total_segment_frames = end_frame - start_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # Smart frame selection
97
- if method == "scene" and scene_detect_available:
98
  try:
99
- video_manager = VideoManager([video_path])
100
- scene_manager = SceneManager()
101
- scene_manager.add_detector(ContentDetector(threshold=30))
102
- video_manager.set_downscale_factor(2)
103
- video_manager.start()
104
- scene_manager.detect_scenes(frame_source=video_manager)
105
- scene_list = scene_manager.get_scene_list()
106
- segment_scenes = [scene for scene in scene_list
107
- if segment_start <= scene[0].get_seconds() < segment_end]
108
-
109
- frames = []
110
- for scene in segment_scenes[:actual_num_frames]:
111
- frame = video_manager.get_frame(scene[0].get_frames())
112
- if frame is not None:
113
- frames.append(Image.fromarray(frame))
114
- video_manager.release()
115
-
116
- # Fill remaining frames if needed
117
- if len(frames) < actual_num_frames and total_segment_frames > 0:
118
- remaining = actual_num_frames - len(frames)
119
- step = max(1, total_segment_frames // (remaining + 1))
120
- for i in range(1, remaining + 1):
121
- frame_idx = start_frame + i * step
122
- if frame_idx < end_frame:
123
- frames.append(Image.fromarray(video.get_data(frame_idx)))
124
- return frames[:actual_num_frames]
125
- except Exception as e:
126
- st.warning(f"Scene detection failed: {e}. Using uniform extraction.")
127
-
128
- # Uniform extraction with numpy optimization
129
- frame_indices = np.linspace(start_frame, end_frame, actual_num_frames, endpoint=False).astype(int)
130
- frames = []
131
- for idx in frame_indices:
132
- if idx < total_frames:
133
- frame = video.get_data(idx)
134
- frames.append(Image.fromarray(frame))
135
- return frames[:actual_num_frames]
136
-
137
  except Exception as e:
138
- st.error(f"Frame extraction error: {str(e)}")
 
139
  return []
 
 
 
140
 
141
- def generate_captions_parallel(frames, processor, model):
142
- """Parallel caption generation with error handling"""
143
- def process_frame(frame):
144
- try:
145
- inputs = processor(images=frame, return_tensors="pt").to(model.device)
146
- out = model.generate(**inputs, max_length=25, num_beams=3)
147
- return processor.decode(out[0], skip_special_tokens=True)
148
- except Exception as e:
149
- st.warning(f"Captioning error: {str(e)}")
150
- return ""
151
-
152
- with ThreadPoolExecutor() as executor:
153
- return list(executor.map(process_frame, frames))
154
-
155
- def enhance_prompt(descriptions, mood="default"):
156
- """Advanced prompt engineering with pattern recognition"""
157
- if not descriptions:
158
- return f"{mood} ambient sound with subtle effects"
159
 
160
- combined = ". ".join(set(desc.lower() for desc in descriptions))
 
 
161
 
162
- # Enhanced pattern matching dictionary
163
- pattern_map = {
164
- ('walk', 'run'): "crisp footsteps on varied surfaces with immersive movement sounds",
165
- ('car', 'drive', 'vehicle'): "roaring engine, tire screeches, and dynamic road noise with spatial positioning",
166
- ('talk', 'person', 'people', 'conversation'): "lively voices, crowd murmur, and spatial chatter with natural reverb",
167
- ('wind', 'tree', 'forest', 'nature'): "rustling leaves, gentle wind gusts, and natural ambiance with atmospheric depth",
168
- ('crash', 'fall', 'impact'): "intense crash impact, debris scattering, and sharp transient effects with dynamic range",
169
- ('water', 'ocean', 'sea'): "realistic water movement, wave dynamics, and aquatic ambiance",
170
- ('fire', 'explosion'): "realistic fire crackling, explosions, and heat distortion audio",
171
- ('space', 'sci-fi'): "futuristic ambient textures, synth effects, and spatial audio design"
172
- }
173
 
174
- # Advanced pattern matching logic
175
- matched_patterns = []
176
- for keywords, effect in pattern_map.items():
177
- if any(keyword in combined for keyword in keywords):
178
- matched_patterns.append(effect)
179
 
180
- if matched_patterns:
181
- return f"{mood} {combined}, {'; '.join(matched_patterns)}, cinematic sound design with spatial audio"
182
- return f"{mood} {combined}, rich ambient soundscape with professional effects, 4K audio resolution"
 
 
 
 
 
183
 
184
- def generate_audio(prompt, processor, model, duration, sample_rate=44100):
185
- """Optimized audio generation with smart parameters"""
186
  try:
187
- inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(model.device)
188
-
189
- audio_values = model.generate(
190
- **inputs,
191
- max_new_tokens=int(256 * (duration / 8)),
192
- num_beams=3,
193
- early_stopping=True,
194
- do_sample=True,
195
- temperature=0.85,
196
- guidance_scale=5.0,
197
- top_k=80,
198
- top_p=0.85
199
- )
200
 
201
- audio_array = audio_values[0].cpu().numpy()
202
- audio_array = np.tanh(audio_array) # Faster than clip + max normalization
203
- return audio_array
204
-
205
- except Exception as e:
206
- st.error(f"Audio generation error: {str(e)}")
207
- return np.zeros(int(duration * sample_rate))
208
 
209
- def apply_audio_effects(audio_path, settings):
210
- """Enhanced audio effects processing"""
211
- try:
212
- sound = AudioSegment.from_wav(audio_path)
213
-
214
- # Reverb
215
- if settings['reverb_ms'] > 0:
216
- sound = sound.overlay(sound - 15, position=settings['reverb_ms'])
217
-
218
- # Echo
219
- if settings['echo_ms'] > 0:
220
- echo = sound - 15
221
- sound = sound.overlay(echo, position=settings['echo_ms'])
222
 
223
- # Filters
224
- if settings['highpass'] > 0:
225
- sound = sound.high_pass_filter(settings['highpass'])
226
- if settings['lowpass'] < 20000:
227
- sound = sound.low_pass_filter(settings['lowpass'])
228
-
229
- # Dynamic processing
230
- if settings['compress']:
231
- sound = effects.compress_dynamic_range(sound)
232
-
233
- # Stereo imaging
234
- sound = sound.pan(settings['stereo_pan'])
235
- sound = effects.normalize(sound)
236
-
237
- processed_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
238
- sound.export(processed_path, format="wav")
239
- return processed_path
240
 
 
 
 
 
 
 
241
  except Exception as e:
242
- st.error(f"Audio effects error: {str(e)}")
243
- return audio_path
 
 
 
 
 
 
 
 
 
 
244
 
245
- def sync_audio_video(video_path, audio_path, output_path, mix_original=False,
246
- original_volume=0.5, generated_volume=0.5):
247
- """Enhanced video/audio synchronization"""
248
  try:
249
- if mix_original:
250
- video_clip = mpy.VideoFileClip(video_path)
251
- if video_clip.audio:
252
- original_audio_seg = AudioSegment.from_file(video_path, format="mp4")
253
- generated_audio_seg = AudioSegment.from_wav(audio_path)
254
-
255
- # Volume adjustment
256
- original_audio_seg = original_audio_seg - (20 * (1 - original_volume))
257
- generated_audio_seg = generated_audio_seg - (20 * (1 - generated_volume))
258
-
259
- mixed_audio = original_audio_seg.overlay(generated_audio_seg)
260
- mixed_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
261
- mixed_audio.export(mixed_path, format="wav")
262
- audio_path = mixed_path
263
- else:
264
- st.warning("No original audio found. Using generated audio only.")
265
-
266
- # FFmpeg command with hardware acceleration
267
- cmd = [
268
- 'ffmpeg',
269
- '-i', video_path,
270
- '-i', audio_path,
271
- '-c:v', 'copy',
272
- '-c:a', 'aac',
273
- '-map', '0:v:0',
274
- '-map', '1:a:0',
275
- '-shortest',
276
- '-y',
277
- '-preset', 'ultrafast',
278
- '-vsync', '2',
279
- output_path
280
- ]
281
- subprocess.run(cmd, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
282
 
283
- except Exception as e:
284
- st.error(f"Sync error: {str(e)}")
285
 
286
- def unload_model(model):
287
- """Memory management utility"""
288
- if torch.cuda.is_available():
289
- model.to("cpu")
290
- torch.cuda.empty_cache()
291
 
292
- def main():
293
- st.title("🎬 AI SoundFX Studio")
294
- st.markdown("### Create immersive soundscapes from video with optimized AI processing")
295
-
296
- # Initialize session state
297
- if 'processing_time' not in st.session_state:
298
- st.session_state.processing_time = 0
299
-
300
- # User Guide
301
- with st.expander("πŸ“– User Guide & Tips"):
302
- st.markdown("""
303
- **How to Use:**
304
- 1. Upload a video file (MP4, MOV, AVI)
305
- 2. Choose between Automatic (AI-generated) or Manual sound description
306
- 3. Adjust settings in the sidebar:
307
- - Model size (small/medium/large)
308
- - Frame analysis parameters
309
- - Audio effects customization
310
- 4. Click "Generate Sound Effects"
311
- 5. Download the enhanced video
312
-
313
- **Optimization Tips:**
314
- - Use "small" model for quick previews
315
- - Enable "Scene Detection" for better context
316
- - Adjust audio effects for custom sound design
317
- - Use "Mix with Original Audio" for balanced results
318
- """)
319
-
320
- # Sidebar Settings
321
- with st.sidebar:
322
- st.header("βš™οΈ Processing Settings")
323
-
324
- # Processing Mode
325
- prompt_mode = st.selectbox("Prompt Generation", ["Automatic", "Manual"])
326
-
327
- # Model Selection
328
- model_size = st.selectbox("Model Size", ["small", "medium", "large"], index=1,
329
- help="Larger models = better quality but slower processing")
330
-
331
- # Audio Mixing
332
- mix_original = st.checkbox("Mix with Original Audio", value=False)
333
- col1, col2 = st.columns(2)
334
- with col1:
335
- original_vol = st.slider("Original Volume", 0.0, 1.0, 0.5) if mix_original else 0.5
336
- with col2:
337
- generated_vol = st.slider("Generated Volume", 0.0, 1.0, 0.5) if mix_original else 0.5
338
-
339
- # Frame Analysis
340
- st.subheader("πŸŽ₯ Frame Analysis")
341
- num_frames = st.slider("Frames to Analyze", 3, 10, 5,
342
- help="More frames improve accuracy but increase processing time")
343
- frame_method = st.selectbox("Frame Extraction",
344
- ["Uniform", "Scene"] if scene_detect_available else ["Uniform"],
345
- help="Scene detection provides better contextual analysis")
346
 
347
- # Audio Effects
348
- st.subheader("πŸŽ›οΈ Audio Effects")
349
- effects_settings = {
350
- 'reverb_ms': st.slider("Reverb (ms)", 0, 500, 50),
351
- 'echo_ms': st.slider("Echo (ms)", 0, 1000, 100),
352
- 'highpass': st.slider("High-pass Filter (Hz)", 0, 3000, 50),
353
- 'lowpass': st.slider("Low-pass Filter (Hz)", 5000, 20000, 12000),
354
- 'compress': st.checkbox("Dynamic Compression", value=True),
355
- 'stereo_pan': st.slider("Stereo Pan (-1 left, 1 right)", -1.0, 1.0, 0.0)
356
- }
357
 
358
- # Performance Presets
359
- st.subheader("⚑ Performance")
360
- quality_preset = st.selectbox("Quality Preset", ["Fast", "Balanced", "High Quality"])
361
- presets = {
362
- "Fast": {"num_frames": 3, "model_size": "small"},
363
- "Balanced": {"num_frames": 5, "model_size": "medium"},
364
- "High Quality": {"num_frames": 8, "model_size": "large"}
365
- }
366
- if quality_preset != "Balanced":
367
- num_frames = presets[quality_preset]["num_frames"]
368
- model_size = presets[quality_preset]["model_size"]
369
 
370
- st.info("Processing time estimate: 2-5 minutes (varies by settings)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- # Main Content Area
373
- uploaded_file = st.file_uploader("Upload Video File", type=["mp4", "mov", "avi"])
 
374
 
375
- if uploaded_file:
376
- # Create temporary files
377
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp:
378
- tmp.write(uploaded_file.read())
379
- video_path = tmp.name
380
-
381
- # Video Preview
382
- st.video(video_path)
383
-
384
- # Get video duration
385
- video_clip = mpy.VideoFileClip(video_path)
386
- duration = video_clip.duration
387
- video_clip.close()
388
-
389
- # Prompt Generation
390
- if prompt_mode == "Automatic":
391
- with st.spinner("Analyzing video content..."):
392
- blip_processor, blip_model = load_blip_model()
393
- if not blip_processor or not blip_model:
394
- st.error("Failed to load BLIP model")
395
- return
396
-
397
- frames = extract_frames(video_path, num_frames, frame_method)
398
- if not frames:
399
- st.error("No frames extracted. Try a different video or settings.")
400
- return
401
-
402
- # Display analyzed frames
403
- cols = st.columns(len(frames))
404
- for col, frame in zip(cols, frames):
405
- with col:
406
- st.image(frame, use_column_width=True)
407
-
408
- descriptions = generate_captions_parallel(frames, blip_processor, blip_model)
409
- unload_model(blip_model)
410
-
411
- # Mood selection
412
- mood = st.selectbox("Sound Mood", [
413
- "default", "dramatic", "ambient", "action", "sci-fi", "horror", "comedy"
414
- ], help="Select the overall atmosphere for the sound design")
415
-
416
- # Enhanced prompt with AI suggestions
417
- text_prompt = enhance_prompt(descriptions, mood)
418
- st.subheader("Generated Prompt")
419
- text_prompt = st.text_area("Edit Prompt", text_prompt, height=150)
420
- st.markdown("*Suggested modifications: Add specific instrument types, intensity levels, or emotional cues*")
421
- else:
422
- st.subheader("Enter Sound Description")
423
- text_prompt = st.text_area("Describe the desired sound effects",
424
- "E.g., 'Cinematic trailer music with thunderous impacts and soaring strings'",
425
- height=150)
426
-
427
- # Generation Button
428
- if st.button("πŸ”Š Generate Sound Effects", key="generate", use_container_width=True):
429
- start_time = time.time()
430
-
431
- # Progress tracking
432
- progress_bar = st.progress(0)
433
- status_text = st.empty()
434
- status_text.text("Loading models...")
435
-
436
- # Load MusicGen model
437
- musicgen_processor, musicgen_model = load_musicgen_model(f"facebook/musicgen-{model_size}")
438
- if not musicgen_processor or not musicgen_model:
439
- st.error("Failed to load MusicGen model")
440
- return
441
- progress_bar.progress(20)
442
-
443
- # Audio Generation
444
- status_text.text("Generating audio...")
445
- audio_array = generate_audio(text_prompt, musicgen_processor, musicgen_model, duration)
446
- unload_model(musicgen_model)
447
-
448
- temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
449
- sf.write(temp_audio, audio_array, 44100)
450
- progress_bar.progress(50)
451
-
452
- # Apply Effects
453
- status_text.text("Applying audio effects...")
454
- processed_audio = apply_audio_effects(temp_audio, effects_settings)
455
- progress_bar.progress(75)
456
-
457
- # Sync with Video
458
- status_text.text("Syncing with video...")
459
- output_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
460
- sync_audio_video(video_path, processed_audio, output_video, mix_original, original_vol, generated_vol)
461
- progress_bar.progress(100)
462
-
463
- # Finalize
464
- status_text.text("Processing complete!")
465
- st.success("βœ… Sound effects applied successfully!")
466
-
467
- # Display result
468
- st.video(output_video)
469
 
470
- # Download button
471
- with open(output_video, "rb") as f:
472
- st.download_button("πŸ“₯ Download Enhanced Video",
473
- f, "enhanced_video.mp4", "video/mp4",
474
- use_container_width=True)
 
 
 
 
475
 
476
- # Timing info
477
- processing_time = time.time() - start_time
478
- st.session_state.processing_time = processing_time
479
- st.info(f"⏱️ Processing time: {processing_time:.1f} seconds")
 
 
 
 
 
480
 
481
- # Cleanup
482
- for file in [video_path, temp_audio, processed_audio, output_video]:
483
- if os.path.exists(file):
484
- os.remove(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
- if __name__ == "__main__":
487
- main()
 
1
  import streamlit as st
 
 
 
 
2
  from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ import gc
6
  import os
7
  import tempfile
8
+ import math
9
+ import imageio
10
+ import traceback
11
+ import scipy.io.wavfile # For saving WAV files
12
+
13
+ # --- Attempt to import moviepy for video processing ---
14
+ try:
15
+ import moviepy.editor as mpy
16
+ MOVIEPY_AVAILABLE = True
17
+ except Exception as e: # Catch any exception during moviepy import
18
+ MOVIEPY_AVAILABLE = False
19
+ st.warning(
20
+ f"MoviePy library could not be loaded (Error: {e}). "
21
+ "Video syncing features will be disabled. "
22
+ "If running locally, ensure MoviePy and its dependency ffmpeg are correctly installed."
23
+ )
24
+ print(f"MoviePy load error: {e}")
25
+
26
+
27
+ # --- Model Configuration ---
28
+ IMAGE_CAPTION_MODEL = "Salesforce/blip-image-captioning-base"
29
+ AUDIO_GEN_MODEL = "facebook/musicgen-small"
30
 
31
+ # --- Constants ---
32
+ DEFAULT_NUM_FRAMES = 2
33
+ DEFAULT_AUDIO_DURATION_S = 7 # Slightly increased default
34
+ MAX_FRAMES_TO_SHOW_UI = 2 # Reducing for smaller UI footprint
35
+ DEVICE = torch.device("cpu") # Explicitly use CPU
36
 
37
+ # --- Page Setup ---
38
+ st.set_page_config(page_title="AI Video Sound Designer (HF Space)", layout="wide", page_icon="🎬")
39
+
40
+ st.title("🎬 AI Video Sound Designer")
41
  st.markdown("""
42
+ Upload a short video (MP4, MOV, AVI). The tool will:
43
+ 1. Extract frames.
44
+ 2. Analyze frames with **BLIP** to generate sound ideas.
45
+ 3. Synthesize audio with **MusicGen** based on these ideas.
46
+ 4. Optionally, combine the new audio with your video.
47
+ ---
48
+ **Important:** This app runs on CPU. **Audio generation can be very slow (several minutes for a few seconds of audio).** Please be patient!
49
+ """)
 
 
 
 
 
 
 
 
 
50
 
51
+ # --- Utility Functions ---
52
+ def clear_memory(model_obj=None, processor_obj=None):
53
+ if model_obj:
54
+ del model_obj
55
+ if processor_obj:
56
+ del processor_obj
57
+ gc.collect()
58
+ print("Memory cleared.")
59
 
60
+ @st.cache_resource(show_spinner="Loading Image Analysis Model...")
61
+ def load_image_caption_model_and_processor():
 
 
62
  try:
63
+ from transformers import BlipProcessor, BlipForConditionalGeneration
64
+ print(f"Loading Image Captioning Model: {IMAGE_CAPTION_MODEL} to {DEVICE}")
65
+ processor = BlipProcessor.from_pretrained(IMAGE_CAPTION_MODEL)
66
+ # Standard loading for BLIP on CPU. No 'low_mem' or 'low_cpu_mem_usage'
67
+ model = BlipForConditionalGeneration.from_pretrained(IMAGE_CAPTION_MODEL).to(DEVICE)
68
+ model.eval() # Set to evaluation mode
69
+ st.toast("Image Analysis model (BLIP) loaded!", icon="πŸ–ΌοΈ")
70
  return processor, model
71
  except Exception as e:
72
+ st.error(f"Error loading BLIP model ({IMAGE_CAPTION_MODEL}): {e}")
73
+ st.error(traceback.format_exc())
74
  return None, None
75
 
76
+ @st.cache_resource(show_spinner="Loading Audio Generation Model (can be slow)...")
77
+ def load_audio_gen_model_and_processor():
 
78
  try:
79
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
80
+ print(f"Loading Audio Generation Model: {AUDIO_GEN_MODEL} to {DEVICE}")
81
+ processor = AutoProcessor.from_pretrained(AUDIO_GEN_MODEL)
82
+ # Standard loading for MusicGen on CPU.
83
+ # `low_cpu_mem_usage` could be used here if accelerate is properly configured,
84
+ # but for simplicity and robustness on free tier, direct .to(DEVICE) is safer.
85
+ model = MusicgenForConditionalGeneration.from_pretrained(AUDIO_GEN_MODEL).to(DEVICE)
86
+ model.eval() # Set to evaluation mode
87
+ st.toast("Audio Generation model (MusicGen) loaded! (CPU generation will be slow)", icon="🎢")
88
  return processor, model
89
  except Exception as e:
90
+ st.error(f"Error loading MusicGen model ({AUDIO_GEN_MODEL}): {e}")
91
+ st.error(traceback.format_exc())
92
  return None, None
93
 
94
+ def extract_frames_from_video(video_path, num_frames_to_extract):
95
+ frames = []
96
+ reader = None
97
  try:
98
+ reader = imageio.get_reader(video_path, "ffmpeg")
99
+ total_frames_in_video = 0
100
+ try: # Try to get frame count
101
+ total_frames_in_video = reader.count_frames()
102
+ except Exception:
103
+ meta_data = reader.get_meta_data()
104
+ duration = meta_data.get('duration')
105
+ fps = meta_data.get('fps', 25) # Default FPS if not found
106
+ if duration and fps:
107
+ total_frames_in_video = int(duration * fps)
108
+
109
+ if not total_frames_in_video or total_frames_in_video < 1:
110
+ # Fallback: try to read a few frames directly if length is unknown
111
+ print("Video length unknown or zero, attempting to read initial frames.")
112
+ temp_frames = []
113
+ for i, frame_data in enumerate(reader):
114
+ temp_frames.append(Image.fromarray(frame_data).convert("RGB"))
115
+ if len(temp_frames) >= num_frames_to_extract * 2: # Read a bit more
116
+ break
117
+ if not temp_frames:
118
+ st.error("Could not extract any frames. Video might be empty or corrupted.")
119
+ if reader: reader.close()
120
+ return []
121
+ # Select frames from what was read
122
+ indices = np.linspace(0, len(temp_frames) - 1, num_frames_to_extract, dtype=int, endpoint=True)
123
+ frames = [temp_frames[i] for i in indices]
124
+ if reader: reader.close()
125
+ return frames
126
+
127
+ # If frame count is known
128
+ num_to_sample = min(num_frames_to_extract, total_frames_in_video)
129
+ indices = np.linspace(0, total_frames_in_video - 1, num_to_sample, dtype=int, endpoint=True)
130
 
131
+ for i in indices:
 
132
  try:
133
+ frame_data = reader.get_data(i)
134
+ frames.append(Image.fromarray(frame_data).convert("RGB"))
135
+ except Exception as frame_e:
136
+ st.warning(f"Skipping problematic frame at index {i}: {frame_e}")
137
+ return frames
138
+ except (imageio.core.fetching.NeedDownloadError, OSError) as e_ffmpeg:
139
+ st.error(f"FFmpeg error during frame extraction: {e_ffmpeg}. Ensure ffmpeg is available.")
140
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
+ st.error(f"Could not extract frames: {e}")
143
+ st.error(traceback.format_exc())
144
  return []
145
+ finally:
146
+ if reader:
147
+ reader.close()
148
 
149
+ def generate_sound_prompt_from_frames(frames, caption_proc, caption_mod):
150
+ if not frames: return "ambient background noise"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ descriptions = []
153
+ # BLIP doesn't need a complex instruction, it captions directly.
154
+ # We ask for "sound-producing elements" in post-processing of the descriptions.
155
 
156
+ progress_bar = st.progress(0.0, text="Analyzing frames for sound ideas...")
157
+ for i, frame in enumerate(frames):
158
+ try:
159
+ inputs = caption_proc(images=frame, return_tensors="pt").to(DEVICE)
160
+ generated_ids = caption_mod.generate(**inputs, max_new_tokens=40) # Shorter captions
161
+ description = caption_proc.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
162
+ if description: descriptions.append(description)
163
+ progress_bar.progress((i + 1) / len(frames), text=f"Frame {i+1}/{len(frames)} analyzed.")
164
+ except Exception as e:
165
+ st.warning(f"Could not get description for a frame: {e}")
166
+ progress_bar.empty()
167
 
168
+ if not descriptions: return "general ambiance, subtle environmental sounds"
 
 
 
 
169
 
170
+ unique_descs = list(dict.fromkeys(descriptions)) # Remove duplicates
171
+ combined_prompt = ". ".join(unique_descs)
172
+ # Refine for MusicGen
173
+ final_prompt = f"Soundscape for a scene with: {combined_prompt}. Emphasize distinct sounds and overall mood."
174
+ # Limit prompt length for MusicGen
175
+ if len(final_prompt) > 300: # Arbitrary limit for conciseness
176
+ final_prompt = final_prompt[:300] + "..."
177
+ return final_prompt
178
 
179
+ def generate_audio_from_prompt(prompt, duration_s, audio_proc, audio_mod, guidance, temp):
 
180
  try:
181
+ inputs = audio_proc(text=[prompt], return_tensors="pt", padding=True).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ tokens_per_second = audio_mod.config.audio_encoder.token_per_second
184
+ max_new_tokens = min(int(duration_s * tokens_per_second), 1500) # Cap for stability
 
 
 
 
 
185
 
186
+ with st.spinner(f"Synthesizing {duration_s}s audio (CPU: This is the SLOW part!)... Please wait patiently."):
187
+ audio_values = audio_mod.generate(
188
+ **inputs,
189
+ max_new_tokens=max_new_tokens,
190
+ do_sample=True, # Sampling is important for diversity
191
+ guidance_scale=guidance,
192
+ temperature=temp,
193
+ )
 
 
 
 
 
194
 
195
+ audio_array = audio_values[0, 0].cpu().numpy()
196
+ sampling_rate = audio_mod.config.audio_encoder.sampling_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ peak = np.abs(audio_array).max()
199
+ if peak > 1e-5: # Avoid division by zero or near-zero
200
+ audio_array = (audio_array / peak) * 0.9 # Normalize with headroom
201
+ else:
202
+ audio_array = np.zeros_like(audio_array) # Output silence if generated audio is too quiet
203
+ return audio_array, sampling_rate
204
  except Exception as e:
205
+ st.error(f"Error generating audio: {e}")
206
+ st.error(traceback.format_exc())
207
+ return None, None
208
+
209
+ def combine_audio_video(video_path, audio_arr, sr, mix_orig):
210
+ if not MOVIEPY_AVAILABLE:
211
+ st.error("MoviePy is unavailable. Cannot combine audio and video.")
212
+ return None
213
+
214
+ out_vid_path = None
215
+ tmp_audio_path = None
216
+ vid_clip = gen_audio_clip = final_aud = comp_clip = None
217
 
 
 
 
218
  try:
219
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio_f:
220
+ scipy.io.wavfile.write(tmp_audio_f.name, sr, audio_arr.astype(np.float32)) # Ensure float32 for some moviepy versions
221
+ tmp_audio_path = tmp_audio_f.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ vid_clip = mpy.VideoFileClip(video_path)
224
+ gen_audio_clip = mpy.AudioFileClip(tmp_audio_path)
225
 
226
+ target_duration = vid_clip.duration
227
+ if gen_audio_clip.duration < target_duration:
228
+ gen_audio_clip = gen_audio_clip.fx(mpy.afx.audio_loop, duration=target_duration)
229
+ gen_audio_clip = gen_audio_clip.subclip(0, target_duration)
 
230
 
231
+ if mix_orig and vid_clip.audio:
232
+ orig_audio = vid_clip.audio.volumex(0.6) # Lower original audio
233
+ gen_audio_clip = gen_audio_clip.volumex(0.9) # Slightly boost generated
234
+ final_aud = mpy.CompositeAudioClip([orig_audio, gen_audio_clip]).set_duration(target_duration)
235
+ else:
236
+ final_aud = gen_audio_clip.set_duration(target_duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ comp_clip = vid_clip.set_audio(final_aud)
 
 
 
 
 
 
 
 
 
239
 
240
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_out_vid_f:
241
+ out_vid_path = tmp_out_vid_f.name
 
 
 
 
 
 
 
 
 
242
 
243
+ comp_clip.write_videofile(
244
+ out_vid_path, codec="libx264", audio_codec="aac",
245
+ threads=max(1, os.cpu_count() // 2), logger=None # Be less verbose
246
+ )
247
+ return out_vid_path
248
+ except Exception as e:
249
+ st.error(f"Error combining audio/video: {e}")
250
+ st.error(traceback.format_exc())
251
+ return None
252
+ finally:
253
+ # Close all clips
254
+ if vid_clip: vid_clip.close()
255
+ if gen_audio_clip: gen_audio_clip.close()
256
+ # final_aud is usually a derivative, not needing separate close if others are.
257
+ if comp_clip: comp_clip.close()
258
+ if tmp_audio_path and os.path.exists(tmp_audio_path): os.remove(tmp_audio_path)
259
+
260
+ # --- Sidebar for Settings ---
261
+ with st.sidebar:
262
+ st.header("βš™οΈ Settings")
263
+ num_frames_analysis = st.slider("Frames to Analyze", 1, 4, DEFAULT_NUM_FRAMES, 1,
264
+ help="Fewer frames = faster analysis.")
265
+ audio_duration = st.slider("Target Audio Duration (s)", 3, 15, DEFAULT_AUDIO_DURATION_S, 1, # Max 15s for CPU
266
+ help="Shorter = MUCH faster on CPU. MusicGen is slow.")
267
 
268
+ st.subheader("MusicGen Parameters")
269
+ guidance = st.slider("Guidance Scale", 1.0, 7.0, 3.0, 0.5)
270
+ temperature = st.slider("Temperature", 0.5, 1.5, 1.0, 0.1)
271
 
272
+ mix_audio = False
273
+ if MOVIEPY_AVAILABLE:
274
+ st.subheader("Video Output")
275
+ mix_audio = st.checkbox("Mix with original video audio", value=False)
276
+
277
+ # --- Main Application ---
278
+ uploaded_file = st.file_uploader("πŸ“€ Upload Video (MP4, MOV, AVI - short clips best):", type=["mp4", "mov", "avi"])
279
+
280
+ if 'generated_audio_path_sess' not in st.session_state:
281
+ st.session_state.generated_audio_path_sess = None
282
+ if 'output_video_path_sess' not in st.session_state:
283
+ st.session_state.output_video_path_sess = None
284
+
285
+ if uploaded_file is not None:
286
+ st.video(uploaded_file)
287
+
288
+ if st.button("✨ Generate Sound Design!", type="primary", use_container_width=True):
289
+ # --- Clear previous results from session state and disk ---
290
+ for key in ['generated_audio_path_sess', 'output_video_path_sess']:
291
+ if st.session_state.get(key) and os.path.exists(st.session_state[key]):
292
+ try: os.remove(st.session_state[key])
293
+ except Exception as e_rem: print(f"Error removing old temp file: {e_rem}")
294
+ st.session_state[key] = None
295
+ clear_memory() # General memory clear
296
+
297
+ temp_video_path = None
298
+ try:
299
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_v:
300
+ tmp_v.write(uploaded_file.read()) # Use read() for BytesIO from uploader
301
+ temp_video_path = tmp_v.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ # === Stage 1: Frame Extraction ===
304
+ st.markdown("--- \n### 1. Extracting Frames")
305
+ frames = extract_frames_from_video(temp_video_path, num_frames_analysis)
306
+ if not frames: st.error("Frame extraction failed. Cannot continue."); st.stop()
307
+ st.success(f"Extracted {len(frames)} frames.")
308
+ if frames:
309
+ cols = st.columns(min(len(frames), MAX_FRAMES_TO_SHOW_UI))
310
+ for i, frame_img in enumerate(frames[:len(cols)]):
311
+ cols[i].image(frame_img, caption=f"Frame {i+1}", use_column_width=True)
312
 
313
+ # === Stage 2: Image Captioning (Sound Prompt Generation) ===
314
+ st.markdown("--- \n### 2. Analyzing Frames for Sound Ideas (BLIP)")
315
+ cap_proc, cap_model = load_image_caption_model_and_processor()
316
+ sound_prompt = "ambient environmental sounds" # Default
317
+ if cap_proc and cap_model:
318
+ sound_prompt = generate_sound_prompt_from_frames(frames, cap_proc, cap_model)
319
+ clear_memory(cap_model, cap_proc) # Unload BLIP
320
+ else: st.error("BLIP model failed to load. Using default sound prompt.")
321
+ st.info(f"✍️ **Sound Prompt for MusicGen:** {sound_prompt}")
322
 
323
+ # === Stage 3: Audio Generation ===
324
+ st.markdown("--- \n### 3. Synthesizing Audio (MusicGen)")
325
+ st.warning("🎧 This step is very slow on CPU. Your patience is appreciated!")
326
+ aud_proc, aud_model = load_audio_gen_model_and_processor()
327
+ gen_aud_arr, s_r = None, None
328
+ if aud_proc and aud_model:
329
+ gen_aud_arr, s_r = generate_audio_from_prompt(sound_prompt, audio_duration, aud_proc, aud_model, guidance, temperature)
330
+ clear_memory(aud_model, aud_proc) # Unload MusicGen
331
+ else: st.error("MusicGen model failed to load. Cannot generate audio.")
332
+
333
+ if gen_aud_arr is not None and s_r is not None:
334
+ st.success("Audio successfully generated!")
335
+ st.audio(gen_aud_arr, sample_rate=s_r)
336
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_a_out:
337
+ scipy.io.wavfile.write(tmp_a_out.name, s_r, gen_aud_arr.astype(np.float32))
338
+ st.session_state.generated_audio_path_sess = tmp_a_out.name
339
+ with open(st.session_state.generated_audio_path_sess, "rb") as f_aud:
340
+ st.download_button("πŸ“₯ Download Audio Only (.wav)", f_aud, "generated_sound.wav", "audio/wav")
341
+
342
+ # === Stage 4: (Optional) Video and Audio Syncing ===
343
+ if MOVIEPY_AVAILABLE:
344
+ st.markdown("--- \n### 4. Combining Audio with Video")
345
+ with st.spinner("Processing video with new audio... (also can be slow)"):
346
+ out_vid_p = combine_audio_video(temp_video_path, gen_aud_arr, s_r, mix_audio)
347
+ if out_vid_p and os.path.exists(out_vid_p):
348
+ st.success("Video processing complete!")
349
+ st.video(out_vid_p)
350
+ st.session_state.output_video_path_sess = out_vid_p
351
+ with open(out_vid_p, "rb") as f_vid:
352
+ st.download_button("🎬 Download Video with New Sound (.mp4)", f_vid, "video_with_sound.mp4", "video/mp4")
353
+ elif MOVIEPY_AVAILABLE: st.error("Failed to combine audio and video.")
354
+ else: st.error("Audio generation failed. Video syncing skipped.")
355
+ except Exception as e_main:
356
+ st.error(f"An unexpected error occurred in main processing: {e_main}")
357
+ st.error(traceback.format_exc())
358
+ finally:
359
+ if temp_video_path and os.path.exists(temp_video_path): os.remove(temp_video_path)
360
+ clear_memory() # Final general clear
361
+
362
+ # Show download buttons for files from a previous successful run in the same session
363
+ elif st.session_state.generated_audio_path_sess and os.path.exists(st.session_state.generated_audio_path_sess):
364
+ st.markdown("---")
365
+ st.write("Previously generated audio available:")
366
+ st.audio(st.session_state.generated_audio_path_sess)
367
+ with open(st.session_state.generated_audio_path_sess, "rb") as f_aud_prev:
368
+ st.download_button("πŸ“₯ Download Previous Audio (.wav)", f_aud_prev, "generated_sound_prev.wav", "audio/wav", key="prev_aud_dl")
369
+
370
+ if st.session_state.output_video_path_sess and os.path.exists(st.session_state.output_video_path_sess) and MOVIEPY_AVAILABLE:
371
+ st.markdown("---") # This might appear even if audio only was generated, so careful with flow
372
+ st.write("Previously generated video available:")
373
+ st.video(st.session_state.output_video_path_sess)
374
+ with open(st.session_state.output_video_path_sess, "rb") as f_vid_prev:
375
+ st.download_button("🎬 Download Previous Video (.mp4)", f_vid_prev, "video_with_sound_prev.mp4", "video/mp4", key="prev_vid_dl")
376
+
377
+ else:
378
+ st.info("☝️ Upload a video to begin.")
379
 
380
+ st.markdown("---")
381
+ st.caption("Built for Hugging Face Spaces (CPU). Patience is key for generation times!")