Spaces:
Running
on
A10G
Running
on
A10G
| """ | |
| AI Effector Model - DiffVox LLM ํตํฉ ๋ฒ์ | |
| ========================================== | |
| CLAP ์ธ์ฝ๋ + ํ์ต๋ LLM์ ์ฌ์ฉํ์ฌ ์ค๋์ค์์ ์ดํํฐ ํ๋ผ๋ฏธํฐ๋ฅผ ์์ธก | |
| DiffVox LLM ํ๋ผ๋ฏธํฐ โ MagicPath ์น ํ๋ผ๋ฏธํฐ ์๋ ๋ณํ | |
| """ | |
| import json | |
| import re | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional | |
| import torch | |
| # AI ๋ชจ๋ธ ๊ด๋ จ import (์ค์น ํ์) | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| TRANSFORMERS_AVAILABLE = True | |
| except ImportError: | |
| TRANSFORMERS_AVAILABLE = False | |
| print("[AIEffector] transformers/peft ๋ฏธ์ค์น - ํ๋ฆฌ์ ๋ชจ๋๋ก ๋์") | |
| # CLAP ์ธ์ฝ๋ (๋ณ๋ ํ์ผ) | |
| try: | |
| from models.audio_encoder import AudioEncoder | |
| AUDIO_ENCODER_AVAILABLE = True | |
| except ImportError: | |
| AUDIO_ENCODER_AVAILABLE = False | |
| print("[AIEffector] AudioEncoder ๋ฏธ์ค์น - ํ๋ฆฌ์ ๋ชจ๋๋ก ๋์") | |
| class ParameterMapper: | |
| """DiffVox LLM ํ๋ผ๋ฏธํฐ โ MagicPath ์น ํ๋ผ๋ฏธํฐ ๋ณํ""" | |
| # DiffVox LLM โ MagicPath ์น ๋งคํ | |
| DIFFVOX_TO_WEB = { | |
| # EQ Low Shelf | |
| "eq_lowshelf.params.gain": "eq_lowshelf_gain", | |
| "eq_lowshelf.params.parametrizations.freq.original": "eq_lowshelf_freq", | |
| # EQ High Shelf | |
| "eq_highshelf.params.gain": "eq_highshelf_gain", | |
| "eq_highshelf.params.parametrizations.freq.original": "eq_highshelf_freq", | |
| # EQ Peak 1 | |
| "eq_peak1.params.gain": "eq_peak1_gain", | |
| "eq_peak1.params.parametrizations.freq.original": "eq_peak1_freq", | |
| "eq_peak1.params.parametrizations.Q.original": "eq_peak1_q", | |
| # EQ Peak 2 | |
| "eq_peak2.params.gain": "eq_peak2_gain", | |
| "eq_peak2.params.parametrizations.freq.original": "eq_peak2_freq", | |
| "eq_peak2.params.parametrizations.Q.original": "eq_peak2_q", | |
| # Delay | |
| "delay.delay_time": "delay_time", | |
| "delay.feedback": "delay_feedback", | |
| "delay.mix": "delay_mix", | |
| # Distortion | |
| "distortion_amount": "distortion_amount", | |
| # Master | |
| "final_wet_mix": "final_wet_mix", | |
| } | |
| # ์ญ๋ฐฉํฅ ๋งคํ | |
| WEB_TO_DIFFVOX = {v: k for k, v in DIFFVOX_TO_WEB.items()} | |
| # ๊ฐ ๋ณํ ๊ท์น (์ ๊ทํ๋ ๊ฐ โ ์ค์ ๊ฐ) | |
| VALUE_TRANSFORMS = { | |
| # EQ gain: -1~1 โ -12~12 dB | |
| "eq_lowshelf_gain": lambda x: x * 12, | |
| "eq_highshelf_gain": lambda x: x * 12, | |
| "eq_peak1_gain": lambda x: x * 12, | |
| "eq_peak2_gain": lambda x: x * 12, | |
| # EQ freq: ์ ๊ทํ๋ ๊ฐ โ Hz (๋ก๊ทธ ์ค์ผ์ผ ์ญ๋ณํ ํ์ํ ์ ์์) | |
| "eq_lowshelf_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2), # -1~1 โ 20~20000 | |
| "eq_highshelf_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2), | |
| "eq_peak1_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2), | |
| "eq_peak2_freq": lambda x: 20 * (20000/20) ** ((x + 1) / 2), | |
| # Q: -1~1 โ 0.1~10 | |
| "eq_peak1_q": lambda x: 0.1 * (10/0.1) ** ((x + 1) / 2), | |
| "eq_peak2_q": lambda x: 0.1 * (10/0.1) ** ((x + 1) / 2), | |
| # Delay time: -1~1 โ 0~1000 ms | |
| "delay_time": lambda x: (x + 1) / 2 * 1000, | |
| # Delay feedback: -1~1 โ 0~1 | |
| "delay_feedback": lambda x: (x + 1) / 2, | |
| # Delay mix: -1~1 โ 0~1 | |
| "delay_mix": lambda x: (x + 1) / 2, | |
| # Distortion: -1~1 โ 0~1 | |
| "distortion_amount": lambda x: (x + 1) / 2, | |
| # Wet mix: -1~1 โ 0~1 | |
| "final_wet_mix": lambda x: (x + 1) / 2, | |
| } | |
| def diffvox_to_web(cls, diffvox_params: Dict[str, float]) -> Dict[str, float]: | |
| """DiffVox LLM ์ถ๋ ฅ โ MagicPath ์น ํ๋ผ๋ฏธํฐ""" | |
| web_params = {} | |
| for diffvox_key, value in diffvox_params.items(): | |
| # ํค ๋ณํ | |
| if diffvox_key in cls.DIFFVOX_TO_WEB: | |
| web_key = cls.DIFFVOX_TO_WEB[diffvox_key] | |
| else: | |
| # ๋งคํ์ ์์ผ๋ฉด ์คํต | |
| continue | |
| # ๊ฐ ๋ณํ | |
| if web_key in cls.VALUE_TRANSFORMS: | |
| try: | |
| web_params[web_key] = cls.VALUE_TRANSFORMS[web_key](value) | |
| except: | |
| web_params[web_key] = value | |
| else: | |
| web_params[web_key] = value | |
| return web_params | |
| class ParameterParser: | |
| """LLM ์ถ๋ ฅ์์ ํ๋ผ๋ฏธํฐ JSON ์ถ์ถ""" | |
| def parse(llm_output: str) -> Optional[Dict]: | |
| """LLM ์ถ๋ ฅ์์ ํ๋ผ๋ฏธํฐ ๋์ ๋๋ฆฌ ์ถ์ถ""" | |
| # ๋ฐฉ๋ฒ 1: JSON ๋ธ๋ก ์ฐพ๊ธฐ | |
| json_patterns = [ | |
| r'\{[^{}]*\}', | |
| r'\{(?:[^{}]|\{[^{}]*\})*\}', | |
| ] | |
| for pattern in json_patterns: | |
| matches = re.findall(pattern, llm_output, re.DOTALL) | |
| for match in matches: | |
| try: | |
| params = json.loads(match) | |
| if isinstance(params, dict) and len(params) > 0: | |
| return params | |
| except json.JSONDecodeError: | |
| continue | |
| # ๋ฐฉ๋ฒ 2: key: value ํจํด ํ์ฑ | |
| param_pattern = r'"([^"]+)":\s*([-\d.]+)' | |
| matches = re.findall(param_pattern, llm_output) | |
| if matches: | |
| params = {} | |
| for key, value in matches: | |
| try: | |
| params[key] = float(value) | |
| except ValueError: | |
| params[key] = value | |
| if params: | |
| return params | |
| return None | |
| class AIEffector: | |
| """AI ๊ธฐ๋ฐ ์ดํํฐ ํ๋ผ๋ฏธํฐ ์์ธก ๋ชจ๋ธ - DiffVox LLM ํตํฉ""" | |
| # ๊ธฐ๋ณธ ํ๋ผ๋ฏธํฐ | |
| DEFAULT_PARAMS = { | |
| "eq_lowshelf_gain": 0.0, | |
| "eq_lowshelf_freq": 200, | |
| "eq_highshelf_gain": 0.0, | |
| "eq_highshelf_freq": 8000, | |
| "eq_peak1_gain": 0.0, | |
| "eq_peak1_freq": 1000, | |
| "eq_peak1_q": 1.0, | |
| "eq_peak2_gain": 0.0, | |
| "eq_peak2_freq": 3000, | |
| "eq_peak2_q": 1.0, | |
| "compressor_threshold": -24, | |
| "compressor_ratio": 4.0, | |
| "compressor_attack": 5, | |
| "compressor_release": 50, | |
| "compressor_makeup": 0.0, | |
| "distortion_amount": 0.0, | |
| "distortion_tone": 0.5, | |
| "delay_time": 250, | |
| "delay_feedback": 0.3, | |
| "delay_mix": 0.0, | |
| "reverb_room_size": 0.5, | |
| "reverb_damping": 0.5, | |
| "reverb_wet_dry": 0.0, | |
| "final_wet_mix": 0.5 | |
| } | |
| # ํ๋ฆฌ์ (fallback์ฉ) | |
| PRESETS = { | |
| "warm": { | |
| "eq_lowshelf_gain": 5.5, | |
| "eq_lowshelf_freq": 200, | |
| "eq_highshelf_gain": -1.5, | |
| "eq_highshelf_freq": 8000, | |
| "eq_peak1_gain": 2.0, | |
| "eq_peak1_freq": 400, | |
| "eq_peak1_q": 1.0, | |
| "compressor_threshold": -18, | |
| "compressor_ratio": 3.0, | |
| "distortion_amount": 0.05, | |
| "reverb_room_size": 0.4, | |
| "reverb_wet_dry": 0.15, | |
| "final_wet_mix": 0.5 | |
| }, | |
| "bright": { | |
| "eq_lowshelf_gain": -2.0, | |
| "eq_lowshelf_freq": 150, | |
| "eq_highshelf_gain": 4.0, | |
| "eq_highshelf_freq": 6000, | |
| "eq_peak1_gain": 1.0, | |
| "eq_peak1_freq": 3000, | |
| "compressor_threshold": -20, | |
| "compressor_ratio": 6.0, | |
| "reverb_room_size": 0.3, | |
| "reverb_wet_dry": 0.1, | |
| "final_wet_mix": 0.5 | |
| }, | |
| } | |
| def __init__( | |
| self, | |
| model_path: Optional[str] = None, | |
| base_model_name: str = "Qwen/Qwen3-8B", | |
| audio_feature_dim: int = 64, | |
| use_huggingface: bool = True | |
| ): | |
| """ | |
| AI ๋ชจ๋ธ ์ด๊ธฐํ | |
| Args: | |
| model_path: ํ์ต๋ LoRA ๋ชจ๋ธ ๊ฒฝ๋ก (๋ก์ปฌ ๋๋ Hugging Face ๋ ํฌ) | |
| base_model_name: ๋ฒ ์ด์ค LLM ๋ชจ๋ธ ์ด๋ฆ | |
| audio_feature_dim: ์ค๋์ค ํน์ง ์ฐจ์ (CLAP ์ถ๋ ฅ) | |
| use_huggingface: True๋ฉด model_path๋ฅผ Hugging Face ๋ ํฌ๋ก ๊ฐ์ฃผ | |
| """ | |
| self.model = None | |
| self.tokenizer = None | |
| self.audio_encoder = None | |
| self.model_loaded = False | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.base_model_name = base_model_name | |
| self.audio_feature_dim = audio_feature_dim | |
| self.use_huggingface = use_huggingface | |
| if model_path: | |
| self._load_model(model_path) | |
| def _load_model(self, model_path: str): | |
| """ํ์ต๋ LoRA ๋ชจ๋ธ ๋ก๋ (๋ก์ปฌ ๋๋ Hugging Face)""" | |
| if not TRANSFORMERS_AVAILABLE: | |
| print("[AIEffector] transformers/peft ๋ฏธ์ค์น") | |
| return | |
| # ๋ก์ปฌ ๊ฒฝ๋ก์ธ์ง Hugging Face ๋ ํฌ์ธ์ง ํ์ธ | |
| is_local = os.path.exists(model_path) | |
| if not is_local and not self.use_huggingface: | |
| print(f"[AIEffector] ๋ก์ปฌ ๋ชจ๋ธ ๊ฒฝ๋ก ์์: {model_path}") | |
| return | |
| try: | |
| if self.use_huggingface and not is_local: | |
| print(f"[AIEffector] Hugging Face์์ ๋ชจ๋ธ ๋ก๋ฉ: {model_path}") | |
| else: | |
| print(f"[AIEffector] ๋ก์ปฌ ๋ชจ๋ธ ๋ก๋ฉ: {model_path}") | |
| # ํ ํฌ๋์ด์ ๋ก๋ | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.base_model_name, | |
| trust_remote_code=True | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # ๋ฒ ์ด์ค ๋ชจ๋ธ ๋ก๋ | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| self.base_model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # LoRA ์ด๋ํฐ ์ ์ฉ (Hugging Face ๋ ํฌ ๋๋ ๋ก์ปฌ ๊ฒฝ๋ก) | |
| self.model = PeftModel.from_pretrained( | |
| base_model, | |
| model_path, # Hugging Face ๋ ํฌ ์ด๋ฆ ๋๋ ๋ก์ปฌ ๊ฒฝ๋ก | |
| is_trainable=False | |
| ) | |
| self.model.eval() | |
| # ์ค๋์ค ์ธ์ฝ๋ ๋ก๋ | |
| if AUDIO_ENCODER_AVAILABLE: | |
| self.audio_encoder = AudioEncoder( | |
| output_dim=self.audio_feature_dim, | |
| reduction_method="pool" | |
| ) | |
| print("[AIEffector] AudioEncoder ๋ก๋ ์๋ฃ") | |
| self.model_loaded = True | |
| print("[AIEffector] โ ๋ชจ๋ธ ๋ก๋ ์๋ฃ") | |
| except Exception as e: | |
| print(f"[AIEffector] โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| self.model_loaded = False | |
| def is_loaded(self) -> bool: | |
| """AI ๋ชจ๋ธ ๋ก๋ ์ํ ํ์ธ""" | |
| return self.model_loaded | |
| def predict(self, audio_path: str, text_prompt: str) -> Dict[str, float]: | |
| """ | |
| ์ค๋์ค์ ํ ์คํธ๋ก๋ถํฐ ์ดํํฐ ํ๋ผ๋ฏธํฐ ์์ธก | |
| Args: | |
| audio_path: ์ ๋ ฅ ์ค๋์ค ํ์ผ ๊ฒฝ๋ก | |
| text_prompt: ์ฌ์ฉ์ ํ ์คํธ ๋ช ๋ น | |
| Returns: | |
| MagicPath ์น ํ์์ ์ดํํฐ ํ๋ผ๋ฏธํฐ ๋์ ๋๋ฆฌ | |
| """ | |
| if self.model_loaded and self.audio_encoder: | |
| return self._predict_with_model(audio_path, text_prompt) | |
| else: | |
| return self._predict_with_preset(text_prompt) | |
| def _predict_with_model(self, audio_path: str, text_prompt: str) -> Dict[str, float]: | |
| """ํ์ต๋ DiffVox LLM์ผ๋ก ์ถ๋ก """ | |
| try: | |
| # 1. ์ค๋์ค ํน์ง ์ถ์ถ | |
| audio_features = self.audio_encoder.get_audio_features(audio_path) | |
| if not audio_features: | |
| print("[AIEffector] ์ค๋์ค ํน์ง ์ถ์ถ ์คํจ, ํ๋ฆฌ์ ์ฌ์ฉ") | |
| return self._predict_with_preset(text_prompt) | |
| # 2. ํ๋กฌํํธ ๊ตฌ์ฑ (train_model.py์ ๋์ผํ ํ์) | |
| audio_state_str = json.dumps(audio_features) | |
| prompt = f"""Task: Convert text to audio parameters. | |
| Audio: {audio_state_str} | |
| Text: {text_prompt} | |
| Parameters:""" | |
| # 3. LLM ์ถ๋ก | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=1500 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=500, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| generated_text = self.tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| print(f"[AIEffector] LLM ์ถ๋ ฅ: {generated_text[:200]}...") | |
| # 4. ํ๋ผ๋ฏธํฐ ํ์ฑ | |
| diffvox_params = ParameterParser.parse(generated_text) | |
| if not diffvox_params: | |
| print("[AIEffector] ํ๋ผ๋ฏธํฐ ํ์ฑ ์คํจ, ํ๋ฆฌ์ ์ฌ์ฉ") | |
| return self._predict_with_preset(text_prompt) | |
| # 5. DiffVox โ Web ํ๋ผ๋ฏธํฐ ๋ณํ | |
| web_params = ParameterMapper.diffvox_to_web(diffvox_params) | |
| # 6. ๊ธฐ๋ณธ๊ฐ๊ณผ ๋ณํฉ | |
| result = self.DEFAULT_PARAMS.copy() | |
| result.update(web_params) | |
| print(f"[AIEffector] โ AI ํ๋ผ๋ฏธํฐ ์์ฑ ์๋ฃ: {len(web_params)}๊ฐ ํ๋ผ๋ฏธํฐ") | |
| return result | |
| except Exception as e: | |
| print(f"[AIEffector] ์ถ๋ก ์๋ฌ: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return self._predict_with_preset(text_prompt) | |
| def _predict_with_preset(self, text_prompt: str) -> Dict[str, float]: | |
| """ํ๋ฆฌ์ ๊ธฐ๋ฐ ํ๋ผ๋ฏธํฐ ๋ฐํ (fallback)""" | |
| prompt_lower = text_prompt.lower() | |
| for preset_name, preset_params in self.PRESETS.items(): | |
| if preset_name in prompt_lower: | |
| print(f"[AIEffector] ํ๋ฆฌ์ ๋งค์นญ: '{preset_name}'") | |
| result = self.DEFAULT_PARAMS.copy() | |
| result.update(preset_params) | |
| return result | |
| print("[AIEffector] ํ๋ฆฌ์ ๋งค์นญ ์คํจ, ๊ธฐ๋ณธ๊ฐ ๋ฐํ") | |
| return self.DEFAULT_PARAMS.copy() | |