KU_SW_Academy / models /ai_effector.py
heybaeheef's picture
Upload 9 files
3cc9d6f verified
raw
history blame
14.7 kB
"""
AI Effector Model - DiffVox LLM ํ†ตํ•ฉ ๋ฒ„์ „
==========================================
CLAP ์ธ์ฝ”๋” + ํ•™์Šต๋œ LLM์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ค๋””์˜ค์—์„œ ์ดํŽ™ํ„ฐ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์˜ˆ์ธก
DiffVox LLM ํŒŒ๋ผ๋ฏธํ„ฐ โ†’ MagicPath ์›น ํŒŒ๋ผ๋ฏธํ„ฐ ์ž๋™ ๋ณ€ํ™˜
"""
import json
import re
import os
from pathlib import Path
from typing import Dict, Any, Optional
import torch
# AI ๋ชจ๋ธ ๊ด€๋ จ import (์„ค์น˜ ํ•„์š”)
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("[AIEffector] transformers/peft ๋ฏธ์„ค์น˜ - ํ”„๋ฆฌ์…‹ ๋ชจ๋“œ๋กœ ๋™์ž‘")
# CLAP ์ธ์ฝ”๋” (๋ณ„๋„ ํŒŒ์ผ)
try:
from models.audio_encoder import AudioEncoder
AUDIO_ENCODER_AVAILABLE = True
except ImportError:
AUDIO_ENCODER_AVAILABLE = False
print("[AIEffector] AudioEncoder ๋ฏธ์„ค์น˜ - ํ”„๋ฆฌ์…‹ ๋ชจ๋“œ๋กœ ๋™์ž‘")
class ParameterMapper:
"""DiffVox LLM ํŒŒ๋ผ๋ฏธํ„ฐ โ†” MagicPath ์›น ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€ํ™˜"""
# DiffVox LLM โ†’ MagicPath ์›น ๋งคํ•‘
DIFFVOX_TO_WEB = {
# EQ Low Shelf
"eq_lowshelf.params.gain": "eq_lowshelf_gain",
"eq_lowshelf.params.parametrizations.freq.original": "eq_lowshelf_freq",
# EQ High Shelf
"eq_highshelf.params.gain": "eq_highshelf_gain",
"eq_highshelf.params.parametrizations.freq.original": "eq_highshelf_freq",
# EQ Peak 1
"eq_peak1.params.gain": "eq_peak1_gain",
"eq_peak1.params.parametrizations.freq.original": "eq_peak1_freq",
"eq_peak1.params.parametrizations.Q.original": "eq_peak1_q",
# EQ Peak 2
"eq_peak2.params.gain": "eq_peak2_gain",
"eq_peak2.params.parametrizations.freq.original": "eq_peak2_freq",
"eq_peak2.params.parametrizations.Q.original": "eq_peak2_q",
# Delay
"delay.delay_time": "delay_time",
"delay.feedback": "delay_feedback",
"delay.mix": "delay_mix",
# Distortion
"distortion_amount": "distortion_amount",
# Master
"final_wet_mix": "final_wet_mix",
}
# ์—ญ๋ฐฉํ–ฅ ๋งคํ•‘
WEB_TO_DIFFVOX = {v: k for k, v in DIFFVOX_TO_WEB.items()}
# ๊ฐ’ ๋ณ€ํ™˜ ๊ทœ์น™ (์ •๊ทœํ™”๋œ ๊ฐ’ โ†’ ์‹ค์ œ ๊ฐ’)
VALUE_TRANSFORMS = {
# EQ gain: -1~1 โ†’ -12~12 dB
"eq_lowshelf_gain": lambda x: x * 12,
"eq_highshelf_gain": lambda x: x * 12,
"eq_peak1_gain": lambda x: x * 12,
"eq_peak2_gain": lambda x: x * 12,
# EQ freq: ์ •๊ทœํ™”๋œ ๊ฐ’ โ†’ Hz (๋กœ๊ทธ ์Šค์ผ€์ผ ์—ญ๋ณ€ํ™˜ ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Œ)
"eq_lowshelf_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2), # -1~1 โ†’ 20~20000
"eq_highshelf_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2),
"eq_peak1_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2),
"eq_peak2_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2),
# Q: -1~1 โ†’ 0.1~10
"eq_peak1_q": lambda x: 0.1 * (10/0.1) ** ((x + 1) / 2),
"eq_peak2_q": lambda x: 0.1 * (10/0.1) ** ((x + 1) / 2),
# Delay time: -1~1 โ†’ 0~1000 ms
"delay_time": lambda x: (x + 1) / 2 * 1000,
# Delay feedback: -1~1 โ†’ 0~1
"delay_feedback": lambda x: (x + 1) / 2,
# Delay mix: -1~1 โ†’ 0~1
"delay_mix": lambda x: (x + 1) / 2,
# Distortion: -1~1 โ†’ 0~1
"distortion_amount": lambda x: (x + 1) / 2,
# Wet mix: -1~1 โ†’ 0~1
"final_wet_mix": lambda x: (x + 1) / 2,
}
@classmethod
def diffvox_to_web(cls, diffvox_params: Dict[str, float]) -> Dict[str, float]:
"""DiffVox LLM ์ถœ๋ ฅ โ†’ MagicPath ์›น ํŒŒ๋ผ๋ฏธํ„ฐ"""
web_params = {}
for diffvox_key, value in diffvox_params.items():
# ํ‚ค ๋ณ€ํ™˜
if diffvox_key in cls.DIFFVOX_TO_WEB:
web_key = cls.DIFFVOX_TO_WEB[diffvox_key]
else:
# ๋งคํ•‘์— ์—†์œผ๋ฉด ์Šคํ‚ต
continue
# ๊ฐ’ ๋ณ€ํ™˜
if web_key in cls.VALUE_TRANSFORMS:
try:
web_params[web_key] = cls.VALUE_TRANSFORMS[web_key](value)
except:
web_params[web_key] = value
else:
web_params[web_key] = value
return web_params
class ParameterParser:
"""LLM ์ถœ๋ ฅ์—์„œ ํŒŒ๋ผ๋ฏธํ„ฐ JSON ์ถ”์ถœ"""
@staticmethod
def parse(llm_output: str) -> Optional[Dict]:
"""LLM ์ถœ๋ ฅ์—์„œ ํŒŒ๋ผ๋ฏธํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ ์ถ”์ถœ"""
# ๋ฐฉ๋ฒ• 1: JSON ๋ธ”๋ก ์ฐพ๊ธฐ
json_patterns = [
r'\{[^{}]*\}',
r'\{(?:[^{}]|\{[^{}]*\})*\}',
]
for pattern in json_patterns:
matches = re.findall(pattern, llm_output, re.DOTALL)
for match in matches:
try:
params = json.loads(match)
if isinstance(params, dict) and len(params) > 0:
return params
except json.JSONDecodeError:
continue
# ๋ฐฉ๋ฒ• 2: key: value ํŒจํ„ด ํŒŒ์‹ฑ
param_pattern = r'"([^"]+)":\s*([-\d.]+)'
matches = re.findall(param_pattern, llm_output)
if matches:
params = {}
for key, value in matches:
try:
params[key] = float(value)
except ValueError:
params[key] = value
if params:
return params
return None
class AIEffector:
"""AI ๊ธฐ๋ฐ˜ ์ดํŽ™ํ„ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์˜ˆ์ธก ๋ชจ๋ธ - DiffVox LLM ํ†ตํ•ฉ"""
# ๊ธฐ๋ณธ ํŒŒ๋ผ๋ฏธํ„ฐ
DEFAULT_PARAMS = {
"eq_lowshelf_gain": 0.0,
"eq_lowshelf_freq": 200,
"eq_highshelf_gain": 0.0,
"eq_highshelf_freq": 8000,
"eq_peak1_gain": 0.0,
"eq_peak1_freq": 1000,
"eq_peak1_q": 1.0,
"eq_peak2_gain": 0.0,
"eq_peak2_freq": 3000,
"eq_peak2_q": 1.0,
"compressor_threshold": -24,
"compressor_ratio": 4.0,
"compressor_attack": 5,
"compressor_release": 50,
"compressor_makeup": 0.0,
"distortion_amount": 0.0,
"distortion_tone": 0.5,
"delay_time": 250,
"delay_feedback": 0.3,
"delay_mix": 0.0,
"reverb_room_size": 0.5,
"reverb_damping": 0.5,
"reverb_wet_dry": 0.0,
"final_wet_mix": 0.5
}
# ํ”„๋ฆฌ์…‹ (fallback์šฉ)
PRESETS = {
"warm": {
"eq_lowshelf_gain": 5.5,
"eq_lowshelf_freq": 200,
"eq_highshelf_gain": -1.5,
"eq_highshelf_freq": 8000,
"eq_peak1_gain": 2.0,
"eq_peak1_freq": 400,
"eq_peak1_q": 1.0,
"compressor_threshold": -18,
"compressor_ratio": 3.0,
"distortion_amount": 0.05,
"reverb_room_size": 0.4,
"reverb_wet_dry": 0.15,
"final_wet_mix": 0.5
},
"bright": {
"eq_lowshelf_gain": -2.0,
"eq_lowshelf_freq": 150,
"eq_highshelf_gain": 4.0,
"eq_highshelf_freq": 6000,
"eq_peak1_gain": 1.0,
"eq_peak1_freq": 3000,
"compressor_threshold": -20,
"compressor_ratio": 6.0,
"reverb_room_size": 0.3,
"reverb_wet_dry": 0.1,
"final_wet_mix": 0.5
},
}
def __init__(
self,
model_path: Optional[str] = None,
base_model_name: str = "Qwen/Qwen3-8B",
audio_feature_dim: int = 64,
use_huggingface: bool = True
):
"""
AI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
Args:
model_path: ํ•™์Šต๋œ LoRA ๋ชจ๋ธ ๊ฒฝ๋กœ (๋กœ์ปฌ ๋˜๋Š” Hugging Face ๋ ˆํฌ)
base_model_name: ๋ฒ ์ด์Šค LLM ๋ชจ๋ธ ์ด๋ฆ„
audio_feature_dim: ์˜ค๋””์˜ค ํŠน์ง• ์ฐจ์› (CLAP ์ถœ๋ ฅ)
use_huggingface: True๋ฉด model_path๋ฅผ Hugging Face ๋ ˆํฌ๋กœ ๊ฐ„์ฃผ
"""
self.model = None
self.tokenizer = None
self.audio_encoder = None
self.model_loaded = False
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.base_model_name = base_model_name
self.audio_feature_dim = audio_feature_dim
self.use_huggingface = use_huggingface
if model_path:
self._load_model(model_path)
def _load_model(self, model_path: str):
"""ํ•™์Šต๋œ LoRA ๋ชจ๋ธ ๋กœ๋“œ (๋กœ์ปฌ ๋˜๋Š” Hugging Face)"""
if not TRANSFORMERS_AVAILABLE:
print("[AIEffector] transformers/peft ๋ฏธ์„ค์น˜")
return
# ๋กœ์ปฌ ๊ฒฝ๋กœ์ธ์ง€ Hugging Face ๋ ˆํฌ์ธ์ง€ ํ™•์ธ
is_local = os.path.exists(model_path)
if not is_local and not self.use_huggingface:
print(f"[AIEffector] ๋กœ์ปฌ ๋ชจ๋ธ ๊ฒฝ๋กœ ์—†์Œ: {model_path}")
return
try:
if self.use_huggingface and not is_local:
print(f"[AIEffector] Hugging Face์—์„œ ๋ชจ๋ธ ๋กœ๋”ฉ: {model_path}")
else:
print(f"[AIEffector] ๋กœ์ปฌ ๋ชจ๋ธ ๋กœ๋”ฉ: {model_path}")
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# ๋ฒ ์ด์Šค ๋ชจ๋ธ ๋กœ๋“œ
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ (Hugging Face ๋ ˆํฌ ๋˜๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ)
self.model = PeftModel.from_pretrained(
base_model,
model_path, # Hugging Face ๋ ˆํฌ ์ด๋ฆ„ ๋˜๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ
is_trainable=False
)
self.model.eval()
# ์˜ค๋””์˜ค ์ธ์ฝ”๋” ๋กœ๋“œ
if AUDIO_ENCODER_AVAILABLE:
self.audio_encoder = AudioEncoder(
output_dim=self.audio_feature_dim,
reduction_method="pool"
)
print("[AIEffector] AudioEncoder ๋กœ๋“œ ์™„๋ฃŒ")
self.model_loaded = True
print("[AIEffector] โœ… ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
except Exception as e:
print(f"[AIEffector] โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
self.model_loaded = False
def is_loaded(self) -> bool:
"""AI ๋ชจ๋ธ ๋กœ๋“œ ์ƒํƒœ ํ™•์ธ"""
return self.model_loaded
def predict(self, audio_path: str, text_prompt: str) -> Dict[str, float]:
"""
์˜ค๋””์˜ค์™€ ํ…์ŠคํŠธ๋กœ๋ถ€ํ„ฐ ์ดํŽ™ํ„ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์˜ˆ์ธก
Args:
audio_path: ์ž…๋ ฅ ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
text_prompt: ์‚ฌ์šฉ์ž ํ…์ŠคํŠธ ๋ช…๋ น
Returns:
MagicPath ์›น ํ˜•์‹์˜ ์ดํŽ™ํ„ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ
"""
if self.model_loaded and self.audio_encoder:
return self._predict_with_model(audio_path, text_prompt)
else:
return self._predict_with_preset(text_prompt)
def _predict_with_model(self, audio_path: str, text_prompt: str) -> Dict[str, float]:
"""ํ•™์Šต๋œ DiffVox LLM์œผ๋กœ ์ถ”๋ก """
try:
# 1. ์˜ค๋””์˜ค ํŠน์ง• ์ถ”์ถœ
audio_features = self.audio_encoder.get_audio_features(audio_path)
if not audio_features:
print("[AIEffector] ์˜ค๋””์˜ค ํŠน์ง• ์ถ”์ถœ ์‹คํŒจ, ํ”„๋ฆฌ์…‹ ์‚ฌ์šฉ")
return self._predict_with_preset(text_prompt)
# 2. ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ (train_model.py์™€ ๋™์ผํ•œ ํ˜•์‹)
audio_state_str = json.dumps(audio_features)
prompt = f"""Task: Convert text to audio parameters.
Audio: {audio_state_str}
Text: {text_prompt}
Parameters:"""
# 3. LLM ์ถ”๋ก 
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1500
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=500,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
)
generated_text = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip()
print(f"[AIEffector] LLM ์ถœ๋ ฅ: {generated_text[:200]}...")
# 4. ํŒŒ๋ผ๋ฏธํ„ฐ ํŒŒ์‹ฑ
diffvox_params = ParameterParser.parse(generated_text)
if not diffvox_params:
print("[AIEffector] ํŒŒ๋ผ๋ฏธํ„ฐ ํŒŒ์‹ฑ ์‹คํŒจ, ํ”„๋ฆฌ์…‹ ์‚ฌ์šฉ")
return self._predict_with_preset(text_prompt)
# 5. DiffVox โ†’ Web ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€ํ™˜
web_params = ParameterMapper.diffvox_to_web(diffvox_params)
# 6. ๊ธฐ๋ณธ๊ฐ’๊ณผ ๋ณ‘ํ•ฉ
result = self.DEFAULT_PARAMS.copy()
result.update(web_params)
print(f"[AIEffector] โœ… AI ํŒŒ๋ผ๋ฏธํ„ฐ ์ƒ์„ฑ ์™„๋ฃŒ: {len(web_params)}๊ฐœ ํŒŒ๋ผ๋ฏธํ„ฐ")
return result
except Exception as e:
print(f"[AIEffector] ์ถ”๋ก  ์—๋Ÿฌ: {e}")
import traceback
traceback.print_exc()
return self._predict_with_preset(text_prompt)
def _predict_with_preset(self, text_prompt: str) -> Dict[str, float]:
"""ํ”„๋ฆฌ์…‹ ๊ธฐ๋ฐ˜ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ฐ˜ํ™˜ (fallback)"""
prompt_lower = text_prompt.lower()
for preset_name, preset_params in self.PRESETS.items():
if preset_name in prompt_lower:
print(f"[AIEffector] ํ”„๋ฆฌ์…‹ ๋งค์นญ: '{preset_name}'")
result = self.DEFAULT_PARAMS.copy()
result.update(preset_params)
return result
print("[AIEffector] ํ”„๋ฆฌ์…‹ ๋งค์นญ ์‹คํŒจ, ๊ธฐ๋ณธ๊ฐ’ ๋ฐ˜ํ™˜")
return self.DEFAULT_PARAMS.copy()