KU_SW_Academy / audio_processing /effect_chain.py
heybaeheef's picture
Fix: Add DiffVox parameter conversion (sigmoid/minmax)
3bfa04b
raw
history blame
10.8 kB
"""
Effect Chain - DiffVox ํŒŒ๋ผ๋ฏธํ„ฐ ํ˜ธํ™˜ ๋ฒ„์ „
==========================================
LLM์ด ์ถœ๋ ฅํ•˜๋Š” DiffVox ํ˜•์‹ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ Pedalboard ์ดํŽ™ํŠธ๋กœ ๋ณ€ํ™˜
"""
import numpy as np
import soundfile as sf
import torch
from typing import Dict, List, Optional
from pedalboard import (
Pedalboard,
Compressor,
Gain,
HighShelfFilter,
LowShelfFilter,
PeakFilter,
Delay,
Reverb,
Distortion,
Limiter
)
class ParameterConverter:
"""DiffVox ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ Pedalboard ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๋ณ€ํ™˜"""
def __init__(self, sr: int = 44100):
self.sr = sr
def sigmoid(self, x: float) -> float:
"""Sigmoid ๋ณ€ํ™˜"""
return 1.0 / (1.0 + np.exp(-x))
def minmax(self, x: float, min_val: float, max_val: float) -> float:
"""MinMax ๋ณ€ํ™˜ (sigmoid ๊ธฐ๋ฐ˜)"""
return self.sigmoid(x) * (max_val - min_val) + min_val
def convert_freq(self, original_value: float, min_freq: float = 20.0) -> float:
"""์ฃผํŒŒ์ˆ˜ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€ํ™˜"""
max_freq = self.sr / 2.0 # Nyquist
return self.minmax(original_value, min_freq, max_freq)
def convert_q(self, original_value: float) -> float:
"""Q ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€ํ™˜"""
return self.minmax(original_value, 0.1, 10.0)
def convert_params(self, raw_params: Dict[str, float]) -> Dict[str, float]:
"""
LLM ์ถœ๋ ฅ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‹ค์ œ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜
LLM ์ถœ๋ ฅ ์˜ˆ์‹œ:
- eq_peak1.params.gain: -0.35 (์ง์ ‘ dB ๊ฐ’)
- eq_peak1.params.parametrizations.freq.original: -2.57 (๋ณ€ํ™˜ ํ•„์š”)
- eq_peak1.params.parametrizations.Q.original: -4.13 (๋ณ€ํ™˜ ํ•„์š”)
"""
converted = {}
# EQ Peak 1
converted["eq_peak1_freq"] = self._get_freq(raw_params, "eq_peak1")
converted["eq_peak1_gain"] = self._get_gain(raw_params, "eq_peak1")
converted["eq_peak1_q"] = self._get_q(raw_params, "eq_peak1")
# EQ Peak 2
converted["eq_peak2_freq"] = self._get_freq(raw_params, "eq_peak2")
converted["eq_peak2_gain"] = self._get_gain(raw_params, "eq_peak2")
converted["eq_peak2_q"] = self._get_q(raw_params, "eq_peak2")
# Low Shelf
converted["eq_lowshelf_freq"] = self._get_freq(raw_params, "eq_lowshelf", default=200.0)
converted["eq_lowshelf_gain"] = self._get_gain(raw_params, "eq_lowshelf")
# High Shelf
converted["eq_highshelf_freq"] = self._get_freq(raw_params, "eq_highshelf", default=8000.0)
converted["eq_highshelf_gain"] = self._get_gain(raw_params, "eq_highshelf")
# Distortion
dist_raw = raw_params.get("distortion_amount", 0.0)
# sigmoid(x) * 0.1 ํ˜•ํƒœ์˜€์Œ
converted["distortion"] = max(0, self.sigmoid(dist_raw) * 0.1)
# Delay
converted["delay_time"] = raw_params.get("delay.delay_time", 0.02)
delay_feedback_raw = raw_params.get("delay.feedback", 0.3)
converted["delay_feedback"] = self.sigmoid(delay_feedback_raw)
delay_mix_raw = raw_params.get("delay.mix", 0.2)
converted["delay_mix"] = self.sigmoid(delay_mix_raw)
# Final Mix
final_mix_raw = raw_params.get("final_wet_mix", 0.5)
converted["final_wet_mix"] = self.sigmoid(final_mix_raw)
return converted
def _get_freq(self, params: Dict, prefix: str, default: float = 1000.0) -> float:
"""์ฃผํŒŒ์ˆ˜ ๊ฐ’ ์ถ”์ถœ ๋ฐ ๋ณ€ํ™˜"""
# parametrizations ํ˜•์‹ ํ™•์ธ
key_param = f"{prefix}.params.parametrizations.freq.original"
key_direct = f"{prefix}.params.freq"
if key_param in params:
return self.convert_freq(params[key_param])
elif key_direct in params:
# ์ด๋ฏธ ๋ณ€ํ™˜๋œ ๊ฐ’์ผ ์ˆ˜ ์žˆ์Œ
val = params[key_direct]
if 20 <= val <= self.sr / 2:
return val
else:
return self.convert_freq(val)
return default
def _get_gain(self, params: Dict, prefix: str) -> float:
"""๊ฒŒ์ธ ๊ฐ’ ์ถ”์ถœ (dB, ๋ณ€ํ™˜ ๋ถˆํ•„์š”)"""
key = f"{prefix}.params.gain"
gain = params.get(key, 0.0)
# ๋ฒ”์œ„ ์ œํ•œ (-12dB ~ +12dB)
return max(-12.0, min(12.0, gain))
def _get_q(self, params: Dict, prefix: str, default: float = 1.0) -> float:
"""Q ๊ฐ’ ์ถ”์ถœ ๋ฐ ๋ณ€ํ™˜"""
key_param = f"{prefix}.params.parametrizations.Q.original"
key_direct = f"{prefix}.params.q"
key_direct2 = f"{prefix}.params.Q"
if key_param in params:
return self.convert_q(params[key_param])
elif key_direct in params:
val = params[key_direct]
if 0.1 <= val <= 10:
return val
else:
return self.convert_q(val)
elif key_direct2 in params:
val = params[key_direct2]
if 0.1 <= val <= 10:
return val
else:
return self.convert_q(val)
return default
class EffectChain:
"""DiffVox ํ˜ธํ™˜ ์ดํŽ™ํŠธ ์ฒด์ธ"""
def __init__(self, sample_rate: int = 44100):
self.sample_rate = sample_rate
self.converter = ParameterConverter(sr=sample_rate)
self.available_effects = [
"eq_peak1", "eq_peak2",
"eq_lowshelf", "eq_highshelf",
"distortion", "delay", "compressor"
]
def get_available_effects(self) -> List[str]:
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ดํŽ™ํŠธ ๋ชฉ๋ก"""
return self.available_effects
def _build_pedalboard(self, params: Dict[str, float]) -> Pedalboard:
"""๋ณ€ํ™˜๋œ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ Pedalboard ๊ตฌ์„ฑ"""
effects = []
# 1. Compressor (ํ•ญ์ƒ ์ ์šฉ)
effects.append(Compressor(
threshold_db=-18.0,
ratio=2.0,
attack_ms=10.0,
release_ms=100.0
))
# 2. EQ Peak 1
gain1 = params["eq_peak1_gain"]
if abs(gain1) > 0.1:
freq1 = params["eq_peak1_freq"]
q1 = params["eq_peak1_q"]
print(f" [EQ Peak 1] freq={freq1:.1f}Hz, gain={gain1:.2f}dB, Q={q1:.2f}")
effects.append(PeakFilter(
cutoff_frequency_hz=max(20, min(20000, freq1)),
gain_db=gain1,
q=max(0.1, min(10, q1))
))
# 3. EQ Peak 2
gain2 = params["eq_peak2_gain"]
if abs(gain2) > 0.1:
freq2 = params["eq_peak2_freq"]
q2 = params["eq_peak2_q"]
print(f" [EQ Peak 2] freq={freq2:.1f}Hz, gain={gain2:.2f}dB, Q={q2:.2f}")
effects.append(PeakFilter(
cutoff_frequency_hz=max(20, min(20000, freq2)),
gain_db=gain2,
q=max(0.1, min(10, q2))
))
# 4. Low Shelf
gain_low = params["eq_lowshelf_gain"]
if abs(gain_low) > 0.1:
freq_low = params["eq_lowshelf_freq"]
print(f" [Low Shelf] freq={freq_low:.1f}Hz, gain={gain_low:.2f}dB")
effects.append(LowShelfFilter(
cutoff_frequency_hz=max(20, min(2000, freq_low)),
gain_db=gain_low,
q=0.707
))
# 5. High Shelf
gain_high = params["eq_highshelf_gain"]
if abs(gain_high) > 0.1:
freq_high = params["eq_highshelf_freq"]
print(f" [High Shelf] freq={freq_high:.1f}Hz, gain={gain_high:.2f}dB")
effects.append(HighShelfFilter(
cutoff_frequency_hz=max(1000, min(20000, freq_high)),
gain_db=gain_high,
q=0.707
))
# 6. Distortion
dist = params["distortion"]
if dist > 0.005:
drive_db = dist * 200 # 0.1 โ†’ 20dB
print(f" [Distortion] drive={drive_db:.1f}dB")
effects.append(Distortion(drive_db=min(20, drive_db)))
# 7. Delay
delay_mix = params["delay_mix"]
if delay_mix > 0.01:
delay_time = params["delay_time"]
delay_feedback = params["delay_feedback"]
print(f" [Delay] time={delay_time:.3f}s, feedback={delay_feedback:.2f}, mix={delay_mix:.2f}")
effects.append(Delay(
delay_seconds=max(0.01, min(1.0, delay_time)),
feedback=max(0.0, min(0.9, delay_feedback)),
mix=max(0.0, min(1.0, delay_mix))
))
# 8. Limiter (ํ•ญ์ƒ ๋งˆ์ง€๋ง‰)
effects.append(Limiter(threshold_db=-1.0))
return Pedalboard(effects)
def process(
self,
input_path: str,
output_path: str,
parameters: Dict[str, float]
) -> bool:
"""์˜ค๋””์˜ค ํŒŒ์ผ ์ฒ˜๋ฆฌ"""
try:
# 1. ์˜ค๋””์˜ค ๋กœ๋“œ
audio, sr = sf.read(input_path)
# ๋ชจ๋…ธ/์Šคํ…Œ๋ ˆ์˜ค ์ฒ˜๋ฆฌ
if len(audio.shape) == 1:
audio = audio.reshape(-1, 1)
# float32๋กœ ๋ณ€ํ™˜
audio = audio.astype(np.float32)
# 2. ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€ํ™˜
print(f"\n [EffectChain] ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€ํ™˜ ์ค‘...")
converted_params = self.converter.convert_params(parameters)
print(f" [EffectChain] ๋ณ€ํ™˜๋œ ํŒŒ๋ผ๋ฏธํ„ฐ:")
for key, value in converted_params.items():
print(f" {key}: {value:.4f}")
# 3. Pedalboard ๊ตฌ์„ฑ
print(f"\n [EffectChain] ์ดํŽ™ํŠธ ์ฒด์ธ ๊ตฌ์„ฑ ์ค‘...")
board = self._build_pedalboard(converted_params)
# 4. ์ดํŽ™ํŠธ ์ ์šฉ
processed = board(audio, sr)
# 5. Wet/Dry ๋ฏน์Šค
wet_mix = converted_params["final_wet_mix"]
print(f" [Mix] wet={wet_mix:.2f}, dry={1-wet_mix:.2f}")
# ๊ธธ์ด ๋งž์ถ”๊ธฐ
min_len = min(len(audio), len(processed))
output = audio[:min_len] * (1 - wet_mix) + processed[:min_len] * wet_mix
# ํด๋ฆฌํ•‘ ๋ฐฉ์ง€
output = np.clip(output, -1.0, 1.0)
# 6. ์ €์žฅ
sf.write(output_path, output, sr)
print(f"\n [EffectChain] โœ… ์ฒ˜๋ฆฌ ์™„๋ฃŒ: {output_path}")
return True
except Exception as e:
print(f" [EffectChain] โŒ ์ฒ˜๋ฆฌ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
raise e