diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac0c8d74a4d21f5d400e8a809a435b15af69efc --- /dev/null +++ b/app.py @@ -0,0 +1,519 @@ +import os +import traceback + +import gradio as gr +import numpy as np +import pyrootutils +import torch +from loguru import logger +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams, TokensPrompt +from funasr_onnx import Paraformer +from huggingface_hub import snapshot_download + +from tools.wer import compute_wers + +os.environ["EINX_FILTER_TRACEBACK"] = "false" +os.environ["VLLM_USE_V1"] = "0" + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +from i18n import i18n +from text.chn_text_norm.text import Text as ChnNormedText +from xcodec2.modeling_xcodec2 import XCodec2Model + + +TEXTBOX_PLACEHOLDER = i18n("Put your text here.") + +# ===== Hugging Face Model IDs ===== +LLASA_MODEL_ID = "ASLP-lab/VoiceSculptor" +LLASA_SUBFOLDER = "LLaSA-Instruct-3B" +XCODEC_MODEL_ID = "HKUSTAudio/xcodec2" +PARAFORMER_REPO_ID = "funasr/Paraformer-large" + +# logo +LOGO_URL = "https://raw.githubusercontent.com/ASLP-lab/VoiceSculptor/main/assets/logo.png" + + +def normalize_text_final(user_input: str) -> str: + return ChnNormedText(raw_text=user_input).normalize() + + +def extract_speech_ids(speech_tokens_str): + speech_ids = [] + for token_str in speech_tokens_str: + if token_str.startswith("<|s_") and token_str.endswith("|>"): + num_str = token_str[4:-2] + speech_ids.append(int(num_str)) + else: + logger.warning(f"Unexpected token: {token_str}") + return speech_ids + + +def get_asr(asr_model: Paraformer, wav_list: list[np.ndarray]) -> list[str]: + """wav_list: list of 1D numpy waveform (16k)""" + try: + result = asr_model(wav_list) + if isinstance(result, dict): + result = [result] + + texts = [] + for res in result: + preds = res.get("preds", None) + if preds is None: + texts.append(res.get("text", "")) + else: + texts.append(preds[0] if len(preds) > 0 else "") + + # 容错:batch 返回数量不一致 -> fallback + if len(texts) != len(wav_list): + logger.warning(f"[ASR] batch返回数量不一致: got {len(texts)} expect {len(wav_list)},fallback逐条补齐") + texts = [] + for w in wav_list: + try: + r = asr_model(w) + if isinstance(r, list) and len(r) > 0: + r0 = r[0] + preds = r0.get("preds", None) + texts.append(preds[0] if preds else r0.get("text", "")) + elif isinstance(r, dict): + preds = r.get("preds", None) + texts.append(preds[0] if preds else r.get("text", "")) + else: + texts.append("") + except Exception: + texts.append("") + return texts + + except Exception as e: + logger.warning(f"[ASR] batch失败,fallback逐条: {e}") + texts = [] + for w in wav_list: + try: + r = asr_model(w) + if isinstance(r, list) and len(r) > 0: + r0 = r[0] + preds = r0.get("preds", None) + texts.append(preds[0] if preds else r0.get("text", "")) + elif isinstance(r, dict): + preds = r.get("preds", None) + texts.append(preds[0] if preds else r.get("text", "")) + else: + texts.append("") + except Exception: + texts.append("") + return texts + + +def inference_batch( + model: LLM, + codec_model: XCodec2Model, + device: str, + tokenizer: AutoTokenizer, + refined_text: str, + instruct_text: str, + control_tags: str, + batch_size: int = 5, +) -> list[tuple[int, np.ndarray]]: + refined_text_norm = normalize_text_final(refined_text) + instruct_text_norm = normalize_text_final(instruct_text) + + if len(refined_text_norm) < 5: + raise ValueError("输入文本长度不能少于5个字符") + if len(refined_text_norm) > 150: + raise ValueError("输入文本长度不能超过150个字符") + + target_text = instruct_text_norm + "<|endofprompt|>" + control_tags + refined_text_norm + formatted_text = f"<|TEXT_UNDERSTANDING_START|>{target_text}<|TEXT_UNDERSTANDING_END|>" + chat = [ + {"role": "user", "content": "Convert the text to speech:" + formatted_text}, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + with torch.no_grad(): + input_ids = tokenizer.apply_chat_template( + chat, + tokenize=True, + return_tensors="pt", + continue_final_message=True, + ).to(device) + + speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") + prompt_ids = input_ids.squeeze(0).tolist() + prompts = [TokensPrompt(prompt_token_ids=prompt_ids) for _ in range(batch_size)] + + base_seed = int.from_bytes(os.urandom(4), "little") + + try: + sampling_params_list = [ + SamplingParams( + temperature=0.9, + top_p=0.95, + top_k=15, + max_tokens=2048, + repetition_penalty=1.05, + stop_token_ids=[speech_end_id], + seed=base_seed + i, + ) + for i in range(batch_size) + ] + outputs = model.generate(prompts=prompts, sampling_params=sampling_params_list) + except TypeError: + logger.warning("[vLLM] 当前版本不支持 SamplingParams(seed=...),将不带 seed 生成") + sampling_params = SamplingParams( + temperature=0.9, + top_p=0.95, + top_k=15, + max_tokens=2048, + repetition_penalty=1.05, + stop_token_ids=[speech_end_id], + ) + outputs = model.generate(prompts=prompts, sampling_params=sampling_params) + + audios: list[tuple[int, np.ndarray]] = [] + for out in outputs: + token_ids = out.outputs[0].token_ids + if len(token_ids) > 0 and token_ids[-1] == speech_end_id: + token_ids = token_ids[:-1] + + speech_tokens = tokenizer.batch_decode(token_ids, skip_special_tokens=True) + speech_tokens = extract_speech_ids(speech_tokens) + + speech_tokens_t = torch.tensor(speech_tokens, device=device).unsqueeze(0).unsqueeze(0) + wav = codec_model.decode_code(speech_tokens_t) + wav = wav.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32) + audios.append((16000, wav)) + + return audios + + +def build_app(): + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"✅ Loading models on device={device}") + + # ===== LLaSA ===== + tokenizer = AutoTokenizer.from_pretrained(LLASA_MODEL_ID, subfolder=LLASA_SUBFOLDER, trust_remote_code=True) + + model = LLM( + model=LLASA_MODEL_ID, + gpu_memory_utilization=0.90, + max_model_len=2048, + enable_prefix_caching=True, + dtype="auto", + quantization=None, + enforce_eager=False, + kv_cache_dtype="auto", + trust_remote_code=True, + hf_model_subfolder=LLASA_SUBFOLDER, + ) + + # ===== XCodec2 ===== + codec_model = XCodec2Model.from_pretrained(XCODEC_MODEL_ID).eval().to(device) + + # ===== Paraformer ===== + paraformer_dir = snapshot_download( + repo_id=PARAFORMER_REPO_ID, + local_dir="checkpoints/Paraformer-large", + local_dir_use_symlinks=False, + ) + asr_model = Paraformer(paraformer_dir, batch_size=5, quantize=True) + + logger.info("✅ Models loaded: VoiceSculptor + xcodec2 + Paraformer") + + INSTRUCT_TEMPLATES = { + "自定义": "", + "default": "这是一位男性评书表演者,用传统说唱腔调,以变速节奏和韵律感极强的语速讲述江湖故事,音量时高时低,充满江湖气。", + "幼儿园女教师-温柔甜美": "这是一位幼儿园女教师,用甜美明亮的嗓音,以极慢且富有耐心的语速,带着温柔鼓励的情感,用标准普通话给小朋友讲睡前故事,音量轻柔适中,咬字格外清晰。", + "电台主播-平静温柔": "深夜电台主播,男性、音调偏低、语速偏慢、音量小;情绪平静带点忧伤,语气温柔;音色微哑", + "成熟御姐-冷静坚定": "成熟御姐风格,音调偏低、语速正常、音量中等;情绪冷静,语气不容置疑的坚定;音色偏磁性,吐字清晰", + "年轻妈妈-温暖安抚": "年轻妈妈哄孩子入睡,女性、音调柔和偏低、语速偏慢、音量偏小但清晰;情绪温暖安抚、充满耐心与爱意,语气轻柔哄劝、像贴近耳边低声说话;音色软糯,吐字清晰、节奏舒缓。", + "小女孩-尖锐清脆": "一位7岁的小女孩,用天真高亢的童声,以不稳定的快节奏,充满兴奋和炫耀地背诵乘法口诀,音调忽高忽低,带着儿童特有的尖锐清脆。", + "老奶奶-沙哑低沉": "一位慈祥的老奶奶,用沙哑低沉的嗓音,以极慢而温暖的语速讲述民间传说,音量微弱但清晰,带着怀旧和神秘的情感。", + "诗歌朗诵-雄浑有力": "一位男性现代诗朗诵者,用深沉磁性的低音,以顿挫有力的节奏演绎艾青诗歌,音量洪亮,情感激昂澎湃。", + "童话风格-甜美夸张": "这是一位女性童话旁白朗诵者,用甜美夸张的童声,以跳跃变化的语速讲述《安徒生童话》,音调偏高,充满奇幻色彩。", + "评书风格-抑扬顿挫": "这是一位男性评书表演者,用传统说唱腔调,以变速节奏和韵律感极强的语速讲述江湖故事,音量时高时低,充满江湖气。", + "新闻风格-平静专业": "这是一位女性新闻主播,用标准普通话以清晰明亮的中高音,以平稳专业的语速播报时事新闻,音量洪亮,情感客观中立。", + "相声风格-夸张幽默": "这是一位男性相声表演者,用夸张幽默的嗓音,以时快时慢的节奏抖包袱,音调起伏大,充满喜感和节奏感。", + "游戏直播-亢奋激昂": "这是一位男性游戏解说,用亢奋激昂的嗓音,以极快且情绪化的语速直播电竞比赛,音量突然爆发,充满悬念和热血。", + "悬疑小说-低沉神秘": "一位男性悬疑小说演播者,用低沉神秘的嗓音,以时快时慢的变速节奏营造紧张氛围,音量忽高忽低,充满悬念感。", + "戏剧表演-夸张戏剧": "这是一位男性戏剧表演者,用夸张戏剧化的嗓音,以忽高忽低的音调和时快时慢的语速表演独白,充满张力。", + "法治节目-庄严庄重": "这是一位男性法治节目主持人,用严肃庄重的嗓音,以平稳有力的语速讲述案件,音量适中,体现法律的威严。", + "纪录片旁白-低沉磁性": "这是一位男性纪录片旁白,用深沉磁性的嗓音,以缓慢而富有画面感的语速讲述自然奇观,音量适中,充满敬畏和诗意。", + "广告配音-沧桑浑厚": "这是一位男性白酒品牌广告配音,用沧桑浑厚的嗓音,以缓慢而豪迈的语速,音量洪亮,传递历史底蕴和男人情怀。", + "冥想引导师-空灵悠长": "一位女性冥想引导师,用空灵悠长的气声,以极慢而飘渺的语速,配合环境音效,音量轻柔,营造禅意空间。", + "ASMR-气声耳语": "一位女性ASMR主播,用气声耳语,以极慢而细腻的语速,配合唇舌音,音量极轻,营造极度放松的氛围。", + } + + TEXT_REQUIREMENTS = { + "自定义": "", + "default": "话说那武松,提着哨棒,直奔景阳冈。天色将晚,酒劲上头,只听一阵狂风,老虎来啦!", + "幼儿园女教师-温柔甜美": "月亮婆婆升上天空啦,星星宝宝都困啦。小白兔躺在床上,盖好小被子,闭上眼睛。兔妈妈轻轻地唱着摇篮曲:睡吧睡吧,我亲爱的宝贝。", + "电台主播-平静温柔": "大家好,欢迎收听你的月亮我的心,好男人就是我,我就是:曾小贤。", + "成熟御姐-冷静坚定": "别担心,我不会让你输,把那些乱七八糟的念头先收起来,姐姐带你赢。", + "年轻妈妈-温暖安抚": "从前有座山,山里有座庙,庙里面有个小和尚,小和尚在给老和尚讲故事,他说:从前有座山,山里有座庙,庙里面有个小和尚。", + "小女孩-尖锐清脆": "一一得一!一二得二!一三得三!我会背乘法口诀啦!老师今天表扬我啦!妈妈说我最棒!", + "老奶奶-沙哑低沉": "很久很久以前,在山的那边,住着一只会说话的狐狸。它常常在月圆之夜,变成美丽的姑娘,来到村子里。", + "诗歌朗诵-雄浑有力": "为什么我的眼里常含泪水?因为我对这土地爱得深沉。这土地,这河流,这吹刮着的暴风。", + "童话风格-甜美夸张": "在一个很冷很冷的夜晚,小女孩擦亮了一根火柴。突然,温暖的火炉出现了!她觉得自己好像坐在火炉旁。", + "评书风格-抑扬顿挫": "话说那武松,提着哨棒,直奔景阳冈。天色将晚,酒劲上头,只听一阵狂风,老虎来啦!", + "新闻风格-平静专业": "本台讯,今日凌晨,我国成功发射新一代载人飞船试验船。此次任务验证了多项关键技术,为后续空间站建设奠定基础。", + "相声风格-夸张幽默": "我这个人啊,最大的优点就是太谦虚。谦虚到什么程度?连谦虚本身都觉得我太谦虚了!", + "游戏直播-亢奋激昂": "大招!大招好了!开团了!ACE!团灭!这波操作神了!冠军相尽显无疑!", + "悬疑小说-低沉神秘": "深夜,他独自走在空无一人的小巷。脚步声,回声,还有……另一个人的呼吸声。他猛地回头——什么也没有。", + "戏剧表演-夸张戏剧": "我疯了!彻底疯了!你们都说我疯了!可疯的是这个世界!清醒的人反而被当成疯子!", + "法治节目-庄严庄重": "天网恢恢,疏而不漏。任何触犯法律的行为,终将受到公正的审判。正义或许会迟到,但绝不会缺席。", + "纪录片旁白-低沉磁性": "在这片广袤的非洲草原上,生命与死亡每天都在上演。猎豹的速度,羚羊的敏捷,都是生存的代价。", + "广告配音-沧桑浑厚": "一杯敬过往,一杯敬远方。传承千年的酿造工艺,只在每一滴醇香。老朋友,值得好酒。", + "冥想引导师-空灵悠长": "想象你是一片叶子,随风飘落。没有牵挂,没有重量。只有呼吸,只有当下,只有宁静。", + "ASMR-气声耳语": "现在,让我在你耳边轻声细语。听到我的声音了吗?放松你的头皮,感受每一个毛孔都在呼吸。", + } + + def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo): + tag_map = { + "小孩": "<|小孩|>", "青年": "<|青年|>", "中年": "<|中年|>", "老年": "<|老年|>", + "男性": "<|男性|>", "女性": "<|女性|>", + "音调很高": "<|音调很高|>", "音调较高": "<|音调较高|>", "音调中等": "<|音调中等|>", + "音调较低": "<|音调较低|>", "音调很低": "<|音调很低|>", + "音调变化很强": "<|音调变化很强|>", "音调变化较强": "<|音调变化较强|>", "音调变化一般": "<|音调变化一般|>", + "音调变化较弱": "<|音调变化较弱|>", "音调变化很弱": "<|音调变化很弱|>", + "音量很大": "<|音量很大|>", "音量较大": "<|音量较大|>", "音量中等": "<|音量中等|>", + "音量较小": "<|音量较小|>", "音量很小": "<|音量很小|>", + "语速很快": "<|语速很快|>", "语速较快": "<|语速较快|>", "语速中等": "<|语速中等|>", + "语速较慢": "<|语速较慢|>", "语速很慢": "<|语速很慢|>", + "开心": "<|开心|>", "生气": "<|生气|>", "难过": "<|难过|>", "惊讶": "<|惊讶|>", "厌恶": "<|厌恶|>", "害怕": "<|害怕|>", + } + tags = [] + for v in [gender, age, speed, volume, pitch, pitch_var, emo]: + if v != "不指定": + tags.append(tag_map[v]) + return "".join(tags) + + def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitch_var, volume, speed, emo): + control_tags = build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo) + try: + audios5 = inference_batch( + model=model, + codec_model=codec_model, + device=device, + tokenizer=tokenizer, + refined_text=refined_text, + instruct_text=instruct_text, + control_tags=control_tags, + batch_size=5, + ) + wav_list = [wav for (_, wav) in audios5] + asr_texts = get_asr(asr_model, wav_list) + + refined_text_norm = normalize_text_final(refined_text) + gt_texts = [refined_text_norm] * len(asr_texts) + wers = compute_wers(gt_texts, asr_texts, lang="zh") + + for i, (hyp, w) in enumerate(zip(asr_texts, wers)): + logger.info(f"[ASR/WER] idx={i} wer={w:.4f} gt='{refined_text_norm}' asr='{hyp}'") + + best_idx = np.argsort(np.array(wers))[:3].tolist() + logger.info(f"[ASR/WER] best_idx={best_idx} best_wers={[float(wers[i]) for i in best_idx]}") + best3 = [audios5[i] for i in best_idx] + return best3[0], best3[1], best3[2] + except Exception as e: + logger.error(f"推理/ASR/WER 失败: {e}", exc_info=True) + logger.error("错误详细信息:\n" + traceback.format_exc()) + return None, None, None + + THEME = gr.themes.Soft( + primary_hue="orange", + secondary_hue="cyan", + neutral_hue="slate", + ) + + CUSTOM_CSS = """ + /* layout */ + #vs-root {max-width: 1180px; margin: 0 auto;} + #vs-header {padding: 14px 14px 4px 14px;} + #vs-card {border-radius: 14px; padding: 14px; border: 1px solid rgba(0,0,0,0.08);} + + /* ===== VoiceSculptor palette (from logo) ===== */ + :root, .gradio-container { + --vs-orange: #FF6A00; + --vs-orange2:#FFB000; + --vs-teal: #00A6C6; + --vs-blue: #0B2E8A; + --vs-teal-a: rgba(0,166,198,.18); + } + + /* primary button */ + .gr-button-primary, button.primary { + background: linear-gradient(90deg, var(--vs-orange), var(--vs-orange2)) !important; + border: none !important; + color: white !important; + } + .gr-button-primary:hover, button.primary:hover { + filter: brightness(1.03); + } + .gr-button-primary:active, button.primary:active { + filter: brightness(0.98); + } + + /* links */ + .gradio-container a { + color: var(--vs-teal) !important; + } + .gradio-container a:hover { + text-decoration: underline; + } + + /* focus ring / active border for inputs */ + textarea:focus, input:focus { + border-color: var(--vs-teal) !important; + box-shadow: 0 0 0 3px var(--vs-teal-a) !important; + outline: none !important; + } + /* some gradio versions wrap inputs in these */ + .gr-input:focus-within, .gr-text-input:focus-within, .gr-box:focus-within { + border-color: var(--vs-teal) !important; + box-shadow: 0 0 0 3px var(--vs-teal-a) !important; + } + + /* accordion highlight */ + .gr-accordion .label, .gr-accordion summary { + color: var(--vs-blue) !important; + } + """ + + DEFAULT_STYLE = "评书风格-抑扬顿挫" + template_choices = [k for k in INSTRUCT_TEMPLATES.keys() if k not in ("default",)] + + BEST_PRACTICE_MD = """ +## Best Practice Guide(音色设计) + +完整指南请见:Voice Design README +https://github.com/ASLP-lab/VoiceSculptor/blob/main/docs/voice_design.md + +### 关键约束 +- **voice_prompt ≤ 200 字** +- **当前仅支持中文** +- **待合成文本长度 ≥ 5 个字** + +### 写法建议 +- **具体**:用可感知特质词(低沉/清脆/沙哑/明亮、语速快慢、音量大小等),避免“好听/不错”。 +- **完整**:建议覆盖 **3–4 个维度**(人设/场景 + 性别/年龄 + 音调/语速 + 音质/情绪)。 +- **客观**:描述声音特征与表达方式,避免“我喜欢/很棒”。 +- **不做模仿**:禁止“像某明星/某演员”,只描述声音特质本身。 +- **尽量精炼**:每个词都承载信息,避免重复强调(如“非常非常”)。 + +### 参考模板 +> - 这是一位男性评书表演者,用传统说唱腔调,以变速节奏和韵律感极强的语速讲述江湖故事,音量时高时低,充满江湖气。 +> - 深夜电台主播,男性、音调偏低、语速偏慢、音量小;情绪平静带点忧伤,语气温柔;音色微哑。 +> - 成熟御姐风格,音调偏低、语速正常、音量中等;情绪冷静,语气不容置疑的坚定;音色偏磁性,吐字清晰。 + + +### 细粒度控制提示 +- 细粒度控制(年龄/性别/音调/语速/音量/情感等)**建议与指令描述保持一致**,尽量避免相互矛盾(如指令写“低沉慢速”,细粒度却选“音调很高/语速很快”)。 +""" + + with gr.Blocks(theme=THEME, css=CUSTOM_CSS) as app: + with gr.Column(elem_id="vs-root"): + with gr.Row(elem_id="vs-header"): + gr.HTML(f""" +
+ Voice Sculptor Logo +
+
Voice Sculptor
+
+ {i18n('An instruct text-to-speech solution based on LLaSA and CosyVoice2 developed by the ASLP lab and collaborators.')} +
+
+
+ """) + + with gr.Row(): + # Left: Controls + Guide + with gr.Column(scale=5, elem_id="vs-card"): + gr.Markdown("### 🪄 Voice Design(捏音色)") + + with gr.Accordion("🎭 风格与文本", open=True): + instruct_template = gr.Dropdown( + choices=template_choices, + value=DEFAULT_STYLE, + label=i18n("指令风格(必选)"), + interactive=True, + ) + + instruct_text = gr.Textbox( + label=i18n("指令文本"), + placeholder=TEXTBOX_PLACEHOLDER, + lines=4, + value=INSTRUCT_TEMPLATES.get(DEFAULT_STYLE, INSTRUCT_TEMPLATES["default"]), + ) + + text = gr.Textbox( + label=i18n("待合成文本"), + placeholder=TEXTBOX_PLACEHOLDER, + lines=4, + value=TEXT_REQUIREMENTS.get(DEFAULT_STYLE, TEXT_REQUIREMENTS["default"]), + ) + + with gr.Accordion("🎛️ 细粒度声音控制(可选)", open=False): + with gr.Row(): + age_ctrl = gr.Dropdown(label="年龄", choices=["不指定", "小孩", "青年", "中年", "老年"], value="不指定") + gender_ctrl = gr.Dropdown(label="性别", choices=["不指定", "男性", "女性"], value="不指定") + + with gr.Row(): + pitch_ctrl = gr.Dropdown( + label="音调高度", + choices=["不指定", "音调很高", "音调较高", "音调中等", "音调较低", "音调很低"], + value="不指定", + ) + pitch_var_ctrl = gr.Dropdown( + label="音调变化", + choices=["不指定", "音调变化很强", "音调变化较强", "音调变化一般", "音调变化较弱", "音调变化很弱"], + value="不指定", + ) + + with gr.Row(): + volume_ctrl = gr.Dropdown( + label="音量", + choices=["不指定", "音量很大", "音量较大", "音量中等", "音量较小", "音量很小"], + value="不指定", + ) + speed_ctrl = gr.Dropdown( + label="语速", + choices=["不指定", "语速很快", "语速较快", "语速中等", "语速较慢", "语速很慢"], + value="不指定", + ) + + emo_ctrl = gr.Dropdown( + label="情感", + choices=["不指定", "开心", "生气", "难过", "惊讶", "厌恶", "害怕"], + value="不指定", + ) + + with gr.Accordion("📚 Best Practice Guide", open=False): + gr.Markdown(BEST_PRACTICE_MD) + + def apply_template(tpl_name): + return INSTRUCT_TEMPLATES.get(tpl_name, ""), TEXT_REQUIREMENTS.get(tpl_name, "") + + instruct_template.change(apply_template, inputs=[instruct_template], outputs=[instruct_text, text]) + + # Right: Results + Generate + with gr.Column(scale=5, elem_id="vs-card"): + gr.Markdown("### 🎵 Results") + generate = gr.Button("🎧 Generate", variant="primary") + audio_output1 = gr.Audio(label=i18n("Generated Audio 1"), type="numpy", interactive=False) + audio_output2 = gr.Audio(label=i18n("Generated Audio 2"), type="numpy", interactive=False) + audio_output3 = gr.Audio(label=i18n("Generated Audio 3"), type="numpy", interactive=False) + + generate.click( + fn=inference_select_best3, + inputs=[text, instruct_text, age_ctrl, gender_ctrl, pitch_ctrl, pitch_var_ctrl, volume_ctrl, speed_ctrl, emo_ctrl], + outputs=[audio_output1, audio_output2, audio_output3], + ) + + return app + + +if __name__ == "__main__": + demo = build_app() + demo.launch() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bca7cc2412345f8e728453fa1ddf17cbf15959a1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +gradio +torch +transformers +vllm +funasr-onnx +huggingface_hub +jiwer +zhon +loguru +pyrootutils +jieba +torchtune +torchao +vector_quantize_pytorch \ No newline at end of file diff --git a/tools/__pycache__/wer.cpython-310.pyc b/tools/__pycache__/wer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc3291d706a254b003ed995c1e4ef69d6f7a7235 Binary files /dev/null and b/tools/__pycache__/wer.cpython-310.pyc differ diff --git a/tools/wer.py b/tools/wer.py new file mode 100644 index 0000000000000000000000000000000000000000..a931843906d79947e0348729d5ab3f15d22f2189 --- /dev/null +++ b/tools/wer.py @@ -0,0 +1,59 @@ +# tools/wer.py +from __future__ import annotations + +from typing import List, Tuple +import string + +from jiwer import process_words +from zhon.hanzi import punctuation as zh_punctuation + +# 中文标点 + 英文标点 + '-' +_PUNCTUATION_ALL = zh_punctuation + string.punctuation + "-" + + +def _normalize_pair(gt: str, gen: str, lang: str) -> Tuple[str, str]: + gt = "" if gt is None else str(gt) + gen = "" if gen is None else str(gen) + + # 去标点(保留 "'") + for x in _PUNCTUATION_ALL: + if x == "'": + continue + gt = gt.replace(x, "") + gen = gen.replace(x, "") + + # 统一空格与连字符 + gt = gt.replace(" ", " ").replace("-", " ") + gen = gen.replace(" ", " ").replace("-", " ") + + if lang == "zh": + # 把“字”当作 token + gt = " ".join([ch for ch in gt]) + gen = " ".join([ch for ch in gen]) + elif lang == "en": + gt = gt.lower() + gen = gen.lower() + else: + raise NotImplementedError("lang must be 'zh' or 'en'") + + return gt, gen + + +def compute_wers(gt_texts: List[str], gen_texts: List[str], lang: str = "zh") -> List[float]: + if len(gt_texts) != len(gen_texts): + raise ValueError(f"Length mismatch: {len(gt_texts)} != {len(gen_texts)}") + + wers: List[float] = [] + for gt_raw, gen_raw in zip(gt_texts, gen_texts): + gt_norm, gen_norm = _normalize_pair(gt_raw, gen_raw, lang=lang) + measures = process_words(reference=gt_norm, hypothesis=gen_norm) + wers.append(float(measures.wer)) + return wers + + + +if __name__ == "__main__": + gt = ["你好世界啊", "今天天气不对", "abc-def"] + gen = ["你好,世界!", "今天 天气 不错", "abc def"] + print(compute_wers(gt, gen, lang="zh")) + print(compute_wers(["Hello World"], ["hello, world!"], lang="en")) diff --git a/xcodec2/.gitattributes b/xcodec2/.gitattributes new file mode 100755 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/xcodec2/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/xcodec2/.vscode/settings.json b/xcodec2/.vscode/settings.json new file mode 100755 index 0000000000000000000000000000000000000000..a8c200329368ddfdbe180fdbc2deda24ed7a9ce4 --- /dev/null +++ b/xcodec2/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [] +} \ No newline at end of file diff --git a/xcodec2/README.md b/xcodec2/README.md new file mode 100755 index 0000000000000000000000000000000000000000..62dad23a13f625c72f95a3977bee90e18d19622f --- /dev/null +++ b/xcodec2/README.md @@ -0,0 +1,69 @@ +--- +license: cc-by-nc-4.0 +tags: +- audio-to-audio +pipeline_tag: audio-to-audio +--- + +[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2502.04128) +**Update (2025-02-13):** Add [Llasa finetune instruction](https://github.com/zhenye234/LLaSA_training/tree/main/finetune). + +**Update (2025-02-07):** Our paper has been released! + + +## Paper + +LLaSA: Scaling Train Time and Inference Time Compute for LLaMA based Speech Synthesis + +Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model (AAAI 2025, xcodec 1.0) + + +# Getting Started with XCodec2 on Hugging Face +XCodec2 is a speech tokenizer that offers the following key features: + +1. **Single Vector Quantization** +2. **50 Tokens per Second** +3. **Multilingual Speech Semantic Support and High-Quality Speech Reconstruction** + + +To use `xcodec2`, ensure you have it installed. You can install it using the following command: + +```bash +conda create -n xcodec2 python=3.9 +conda activate xcodec2 +pip install xcodec2 (Use `xcodec2==0.1.5` for codec inference and llasa fine-tuning. I’ve removed unnecessary dependencies, and it works fine in my testing. However, I’m not sure if other problems may arise. If you prefer more stability, I recommend using `xcodec2==0.1.3` which accurately aligns during my codec training.) + +``` +Then, +```python +import torch +import soundfile as sf +from transformers import AutoConfig + + +from xcodec2.modeling_xcodec2 import XCodec2Model + +model_path = "HKUSTAudio/xcodec2" + +model = XCodec2Model.from_pretrained(model_path) +model.eval().cuda() + + +wav, sr = sf.read("test.wav") +wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # Shape: (1, T) + + +with torch.no_grad(): + # Only 16khz speech + # Only supports single input. For batch inference, please refer to the link below. + vq_code = model.encode_code(input_waveform=wav_tensor) + print("Code:", vq_code ) + + recon_wav = model.decode_code(vq_code).cpu() # Shape: (1, 1, T') + + +sf.write("reconstructed.wav", recon_wav[0, 0, :].numpy(), sr) +print("Done! Check reconstructed.wav") +``` + +# If you want to train your own xcodec2, batch inference, or large-scale code extraction, the code is released [here](https://github.com/zhenye234/X-Codec-2.0). \ No newline at end of file diff --git a/xcodec2/__init__.py b/xcodec2/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/xcodec2/__pycache__/__init__.cpython-310.pyc b/xcodec2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5d53d148b6f56ba74300128c778774d3ee3c2ef Binary files /dev/null and b/xcodec2/__pycache__/__init__.cpython-310.pyc differ diff --git a/xcodec2/__pycache__/configuration_bigcodec.cpython-310.pyc b/xcodec2/__pycache__/configuration_bigcodec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..992407af71e7ff35b0b1a5eadf2a3d304d2075cc Binary files /dev/null and b/xcodec2/__pycache__/configuration_bigcodec.cpython-310.pyc differ diff --git a/xcodec2/__pycache__/configuration_bigcodec.cpython-38.pyc b/xcodec2/__pycache__/configuration_bigcodec.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c632a2e8cf7379b702ea3522e842e3c437aef32f Binary files /dev/null and b/xcodec2/__pycache__/configuration_bigcodec.cpython-38.pyc differ diff --git a/xcodec2/__pycache__/modeling_xcodec2.cpython-310.pyc b/xcodec2/__pycache__/modeling_xcodec2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41f284587665be62e83026658f7b64e8c59d298a Binary files /dev/null and b/xcodec2/__pycache__/modeling_xcodec2.cpython-310.pyc differ diff --git a/xcodec2/__pycache__/modeling_xcodec2.cpython-38.pyc b/xcodec2/__pycache__/modeling_xcodec2.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6c6a4c39986e4cd619c09b9588e28f0dc5ad9811 Binary files /dev/null and b/xcodec2/__pycache__/modeling_xcodec2.cpython-38.pyc differ diff --git a/xcodec2/config.json b/xcodec2/config.json new file mode 100755 index 0000000000000000000000000000000000000000..5acd94f351833f55f40fd49e4f9da49d34c8540f --- /dev/null +++ b/xcodec2/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "xcodec2", + "semantic_hidden_size": 1024, + "codec_encoder_hidden_size": 1024, + "codec_decoder_hidden_size": 1024, + "use_vocos": true, + "architectures": [ + "XCodec2Model" + ] + } + \ No newline at end of file diff --git a/xcodec2/configuration_bigcodec.py b/xcodec2/configuration_bigcodec.py new file mode 100755 index 0000000000000000000000000000000000000000..c7aa72b64eca8cae1220413f47039a112100d866 --- /dev/null +++ b/xcodec2/configuration_bigcodec.py @@ -0,0 +1,19 @@ +from transformers import PretrainedConfig + +class BigCodecConfig(PretrainedConfig): + model_type = "bigcodec" + + def __init__( + self, + # 下面这些只是示例超参 + semantic_hidden_size=1024, + codec_encoder_hidden_size=1024, + codec_decoder_hidden_size=1024, + use_vocos=True, + **kwargs + ): + super().__init__(**kwargs) + self.semantic_hidden_size = semantic_hidden_size + self.codec_encoder_hidden_size = codec_encoder_hidden_size + self.codec_decoder_hidden_size = codec_decoder_hidden_size + self.use_vocos = use_vocos diff --git a/xcodec2/modeling_xcodec2.py b/xcodec2/modeling_xcodec2.py new file mode 100755 index 0000000000000000000000000000000000000000..93b12ec6a6108213e19a3ad87057ee22d17994c2 --- /dev/null +++ b/xcodec2/modeling_xcodec2.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from xcodec2.configuration_bigcodec import BigCodecConfig + +from xcodec2.vq.codec_encoder import CodecEncoder_Transformer +from xcodec2.vq.codec_decoder_vocos import CodecDecoderVocos +from xcodec2.vq.module import SemanticEncoder +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel + +class XCodec2Model(PreTrainedModel): + config_class = BigCodecConfig + + def __init__(self, config: BigCodecConfig): + super().__init__(config) + + # 1) 语义模型 + self.semantic_model = Wav2Vec2BertModel.from_pretrained( + "facebook/w2v-bert-2.0", + output_hidden_states=True + ) + self.semantic_model.eval() + + self.SemanticEncoder_module = SemanticEncoder( + config.semantic_hidden_size, + config.semantic_hidden_size, + config.semantic_hidden_size + ) + + # 2) Codec Encoder + self.CodecEnc = CodecEncoder_Transformer() + + # 3) Codec Decoder + self.generator = CodecDecoderVocos() + + # 4) 两个全连接层 + self.fc_prior = nn.Linear(2048, 2048) + self.fc_post_a = nn.Linear(2048, 1024) + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") + self.feature_extractor = feature_extractor + + def forward(self, input_waveform, sample_rate=16000): + """ + 这里的 forward 不一定要叫 forward,也可以拆成别的方法; + 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。 + + 参数: + input_waveform: [batch_size, waveform_length] + sample_rate: 默认 16000 + 返回: + 重构后的语音音频 (Tensor) + """ + # 1) 特征提取 + # 如果需要 padding,可以在这里做 + input_features = self.feature_extractor( + input_waveform, + sampling_rate=sample_rate, + return_tensors="pt" + ).input_features.to(self.device) # [batch, frames, feat_dim] + + # 2) 语义层 + semantic_output = self.semantic_model(input_features) + semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层 + semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames] + semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) + + # 3) codec encoder + wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time] + vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例 + vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames] + + # 对齐语义向量的时间帧数,这里只做示例处理 + # 真实做法里可能要先对齐维度 + if vq_emb.shape[-1] != semantic_encoded.shape[-1]: + # 简单强行截断或补零都行,需要你自己决定 + min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) + vq_emb = vq_emb[:, :, :min_len] + semantic_encoded = semantic_encoded[:, :, :min_len] + + # 4) 拼接 + concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames] + + # 5) fc_prior + concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) + + # 6) decoder 的量化部分 + _, vq_code, _ = self.generator(concat_emb, vq=True) + vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) + vq_post_emb = vq_post_emb.transpose(1, 2) + + # 7) fc_post_a + vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) + + # 8) 最后解码成波形 + recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] + # recon_audio: [batch, time] + return recon_audio + + def encode_code(self, input_waveform, sample_rate=16000): + """ + 将输入的音频编码为代码表示。 + + 参数: + input_waveform: [batch_size, waveform_length] + sample_rate: 默认 16000 + 返回: + 编码后的代码 (Tensor) + """ + with torch.no_grad(): + # 1) 特征提取 + input_features = self.feature_extractor( + input_waveform, + sampling_rate=sample_rate, + return_tensors="pt" + ).input_features.to(self.device) # [batch, frames, feat_dim] + + # 2) 语义层 + semantic_output = self.semantic_model(input_features) + semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层 + semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames] + semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) + + # 3) codec encoder + wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time] + vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例 + vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames] + + # 对齐语义向量的时间帧数,这里只做示例处理 + if vq_emb.shape[-1] != semantic_encoded.shape[-1]: + min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) + vq_emb = vq_emb[:, :, :min_len] + semantic_encoded = semantic_encoded[:, :, :min_len] + + # 4) 拼接 + concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames] + + # 5) fc_prior + concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) + + # 6) decoder 的量化部分,获取code + _, vq_code, _ = self.generator(concat_emb, vq=True) + # vq_code: [batch, frames] + return vq_code + + def decode_code(self, vq_code): + """ + 将编码后的代码解码回音频。 + + 参数: + vq_code: 编码后的代码 (Tensor) [batch, frames] + 返回: + 解码后的音频 (Tensor) [batch, waveform_length] + """ + with torch.no_grad(): + # 获取量化后的嵌入 + vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) + vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames] + + # 7) fc_post_a + vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames] + + # 8) 最后解码成波形 + recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time] + return recon_audio diff --git a/xcodec2/module.py b/xcodec2/module.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/xcodec2/vq/__init__.py b/xcodec2/vq/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9de36961fa5ddf77090a182fc427437ccdc7aff9 --- /dev/null +++ b/xcodec2/vq/__init__.py @@ -0,0 +1,4 @@ +from xcodec2.vq.codec_encoder import CodecEncoder +from xcodec2.vq.codec_decoder import CodecDecoder +from xcodec2.vq.codec_decoder_vocos import CodecDecoderVocos +from xcodec2.vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer \ No newline at end of file diff --git a/xcodec2/vq/__pycache__/__init__.cpython-310.pyc b/xcodec2/vq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe88c3604bf4cdd17ba29e04c47a0aea3222e111 Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/__init__.cpython-311.pyc b/xcodec2/vq/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..3a622e8da2ba3969a82261d2c1e6a3dbe2a87155 Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/__init__.cpython-312.pyc b/xcodec2/vq/__pycache__/__init__.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a1bd6ded01b7b342483f5e8ab7993e74a8b74054 Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/__init__.cpython-38.pyc b/xcodec2/vq/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..63230e0918a56b5b56c5d83e79c957975782b555 Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/__init__.cpython-39.pyc b/xcodec2/vq/__pycache__/__init__.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ee6edbe0d9811e1f6d087bee9c44604e457ae546 Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/activations.cpython-310.pyc b/xcodec2/vq/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d8b95e80d0e513cfa122854d290eed64a52a6a3 Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/activations.cpython-311.pyc b/xcodec2/vq/__pycache__/activations.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b68f81ee8be448f64ef4c8f7bab31d57bdd561a8 Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/activations.cpython-312.pyc b/xcodec2/vq/__pycache__/activations.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7b65c1363e2538d7f7da65eacf3801535d9dfe7e Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/activations.cpython-38.pyc b/xcodec2/vq/__pycache__/activations.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a8d378a7fc0cf94615324f3cc048787815f89879 Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/activations.cpython-39.pyc b/xcodec2/vq/__pycache__/activations.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a93430eb9e182116727bb2fd9a8c1ced9e1a423b Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/blocks.cpython-310.pyc b/xcodec2/vq/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d49d37165ecdc9f6f8ff41aabb31134c35baa720 Binary files /dev/null and b/xcodec2/vq/__pycache__/blocks.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/blocks.cpython-38.pyc b/xcodec2/vq/__pycache__/blocks.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c8b1ace989b247dbfd268e8693d20d09558e7858 Binary files /dev/null and b/xcodec2/vq/__pycache__/blocks.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/blocks.cpython-39.pyc b/xcodec2/vq/__pycache__/blocks.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..fa36890d26c5d256baf108973f548e68237fc1a9 Binary files /dev/null and b/xcodec2/vq/__pycache__/blocks.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/bs_roformer5.cpython-310.pyc b/xcodec2/vq/__pycache__/bs_roformer5.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bf1a481e797c324cd5c7ea897c596eb9ffebbe2 Binary files /dev/null and b/xcodec2/vq/__pycache__/bs_roformer5.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/bs_roformer5.cpython-38.pyc b/xcodec2/vq/__pycache__/bs_roformer5.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8c4edcbdb6c7963eabc1201b3020b3178343a672 Binary files /dev/null and b/xcodec2/vq/__pycache__/bs_roformer5.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/bs_roformer5.cpython-39.pyc b/xcodec2/vq/__pycache__/bs_roformer5.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..92d8ebd686f5ca067f8afedd3c11df3eae114f8b Binary files /dev/null and b/xcodec2/vq/__pycache__/bs_roformer5.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-310.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd2f514cda5c60fbf4267573e0f421799879ebd Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-311.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d3bba0654a05fc2d5df7f0cfa85d4e1e8f9e43a9 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-312.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9fd232bd7d3e8217ae32122d85158c8d3d07a0ba Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-38.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e5ab02d4e11a9b3abe48f4eb7d1b1b982be70703 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-39.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..342c797d5e48be38a60b3987b4ccf03e115c42f1 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b45e5d90c8b7cb9202abee91370584c8fec28f9f Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..16d7559f3e464f7cc8229cf5cf3864cf6ce5270a Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4d14cbb3475d4adca86027ce48447562dd8bd872 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-38.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..094936bcc87dbe3f0480acf7b5987be727685911 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..039c2ff83bd0dece5f89baf003ee5c212a9ec9eb Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-310.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1815fe3ea18bd9f317d6f1b2b21496cc2b24ed25 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-311.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a95e106c52b918fc9879df154585f48f5b97f084 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-312.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..033919b2465e77531ac1f4340a42e8247dcf927e Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-38.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4f4810f05dd5515271540d51492b58aafe648829 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-39.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a6776bd150fda7c289c24c8798d44fb5abe57993 Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db91e97b2f8bb8d07f0c21a706ee6aeab66561a9 Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d1018bb516e0ca39612d3046cc66e93d3e9fc857 Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a79936f97258dfac4e4c1fa00bcba84459f754bf Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-38.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f60d99f878ba719a277329f713bdca110ebb514b Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ff4eb90349eef0207e03b2d991f85499da62ab59 Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/module.cpython-310.pyc b/xcodec2/vq/__pycache__/module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a30f28a9cf972e1615e89597e4cdef4c03fd3d4 Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/module.cpython-311.pyc b/xcodec2/vq/__pycache__/module.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8bae56a74114e5ecb523e39275bf451e406e43a2 Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/module.cpython-312.pyc b/xcodec2/vq/__pycache__/module.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..67dc7809ed198ecba09efecfa433224dc7b5ea4a Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/module.cpython-38.pyc b/xcodec2/vq/__pycache__/module.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d3726fc890588cf3cf4e921e4d0a9d446561befc Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/module.cpython-39.pyc b/xcodec2/vq/__pycache__/module.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..97d6cb68f45b74a72d220b758c224fc1d7bcf974 Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-310.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a5ed5c3cef57c636101212049f045ad72269253 Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-310.pyc differ diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-311.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..92069858ec62cb51a7cce7cdf1550695c656e557 Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-311.pyc differ diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-312.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4d81960f8bbec9c309fa08267b44245d8a5b4959 Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-38.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..91a0dc321a82434f8f32fd177bbca7f59d4c276d Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-38.pyc differ diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-39.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c004cb105a51fd161b325e6568b965634f0cddce Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-39.pyc differ diff --git a/xcodec2/vq/__pycache__/unet.cpython-312.pyc b/xcodec2/vq/__pycache__/unet.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e431b555864d1029ed9735ed1bad0bae1b75c4f8 Binary files /dev/null and b/xcodec2/vq/__pycache__/unet.cpython-312.pyc differ diff --git a/xcodec2/vq/__pycache__/unet.cpython-39.pyc b/xcodec2/vq/__pycache__/unet.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9e12a03c0fba3df01e7d87e249fd77a3bab599f1 Binary files /dev/null and b/xcodec2/vq/__pycache__/unet.cpython-39.pyc differ diff --git a/xcodec2/vq/activations.py b/xcodec2/vq/activations.py new file mode 100755 index 0000000000000000000000000000000000000000..2444a7bd9d52018e97892820a072b39a21245372 --- /dev/null +++ b/xcodec2/vq/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.bias = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.bias = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.bias.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.bias.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/xcodec2/vq/alias_free_torch/__init__.py b/xcodec2/vq/alias_free_torch/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc --- /dev/null +++ b/xcodec2/vq/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..661c1297fc3bbe2f0fa5c1197243e51c4d4cb263 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8c624a5d6d4dc886d41f429249030bde5539887e Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a43e41d400b719ffcbd2edeb74835601d56febf5 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c68c37cfa7cf1fa7585601af0b02e7842b96e717 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..01dd0800751de52eba5308427e45658cfa47442d Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b149407ee3f1bcd541d4cf0e96496b383e51302 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-310.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6db1958197fb0a213b4ff1df75dcfbf89317ac7d Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-311.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8cfb68556b30eb0c90e80734d84ce4630c84f6fe Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-312.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7e857551f8e5e2cf76b60615e55d954038fc623d Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-38.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..11d880c40d0fed4924c1d5245722c5218e185ef2 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-39.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..879b4e00e5fd77097e91242846374d5dae2c579c Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ce1c1840a49147f4771cb6c43c75fafdff331ebf Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e7c8b0886d3813d2f3184509934becbdd6a7cf6f Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f6066d9ddb5cb651a8080ac437a8b64c94060de2 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..46ca966e33e20803fba84a5da40097a83907bc7d Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2829b6c928284aa5de8360dcfca87edcb435945 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..51f3f964031846bf45ae58bc751a8d74afe4495f Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..bd086d178282ace7cda97ede64bd598c43012dab Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d0f4291e2f540087842b52b38ac8c302bacd150e Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc differ diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..15613b4678ac49403758cba125a01e482c12b390 Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc differ diff --git a/xcodec2/vq/alias_free_torch/act.py b/xcodec2/vq/alias_free_torch/act.py new file mode 100755 index 0000000000000000000000000000000000000000..028debd697dd60458aae75010057df038bd3518a --- /dev/null +++ b/xcodec2/vq/alias_free_torch/act.py @@ -0,0 +1,28 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/xcodec2/vq/alias_free_torch/filter.py b/xcodec2/vq/alias_free_torch/filter.py new file mode 100755 index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1 --- /dev/null +++ b/xcodec2/vq/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/xcodec2/vq/alias_free_torch/resample.py b/xcodec2/vq/alias_free_torch/resample.py new file mode 100755 index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7 --- /dev/null +++ b/xcodec2/vq/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/xcodec2/vq/blocks.py b/xcodec2/vq/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..3996fec146cbf4f3caef4f9da3bbbe04f7729bbb --- /dev/null +++ b/xcodec2/vq/blocks.py @@ -0,0 +1,183 @@ +from typing import Callable, Sequence, Type, Union + +import numpy as np +import torch +import torch.nn as nn + +ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] + + +class FeedForwardModule(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class Residual(nn.Module): + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + x + + +class DilatedConvolutionalUnit(FeedForwardModule): + + def __init__( + self, + hidden_dim: int, + dilation: int, + kernel_size: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.Conv1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=kernel_size, + dilation=dilation, + padding=((kernel_size - 1) * dilation) // 2, + )), + activation(), + nn.Conv1d(in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=1), + ) + + +class UpsamplingUnit(FeedForwardModule): + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.ConvTranspose1d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2+ stride % 2, + output_padding=1 if stride % 2 != 0 else 0 + ))) + + +class DownsamplingUnit(FeedForwardModule): + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.Conv1d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=2 * stride, + stride=stride, + padding= stride // 2+ stride % 2, + + ))) + + +class DilatedResidualEncoder(FeedForwardModule): + + def __init__( + self, + capacity: int, + dilated_unit: Type[DilatedConvolutionalUnit], + downsampling_unit: Type[DownsamplingUnit], + ratios: Sequence[int], + dilations: Union[Sequence[int], Sequence[Sequence[int]]], + pre_network_conv: Type[nn.Conv1d], + post_network_conv: Type[nn.Conv1d], + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + channels = capacity * 2**np.arange(len(ratios) + 1) + + dilations_list = self.normalize_dilations(dilations, ratios) + + net = [normalization(pre_network_conv(out_channels=channels[0]))] + + for ratio, dilations, input_dim, output_dim in zip( + ratios, dilations_list, channels[:-1], channels[1:]): + for dilation in dilations: + net.append(Residual(dilated_unit(input_dim, dilation))) + net.append(downsampling_unit(input_dim, output_dim, ratio)) + + net.append(post_network_conv(in_channels=output_dim)) + + self.net = nn.Sequential(*net) + + @staticmethod + def normalize_dilations(dilations: Union[Sequence[int], + Sequence[Sequence[int]]], + ratios: Sequence[int]): + if isinstance(dilations[0], int): + dilations = [dilations for _ in ratios] + return dilations + + +class DilatedResidualDecoder(FeedForwardModule): + + def __init__( + self, + capacity: int, + dilated_unit: Type[DilatedConvolutionalUnit], + upsampling_unit: Type[UpsamplingUnit], + ratios: Sequence[int], + dilations: Union[Sequence[int], Sequence[Sequence[int]]], + pre_network_conv: Type[nn.Conv1d], + post_network_conv: Type[nn.Conv1d], + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + channels = capacity * 2**np.arange(len(ratios) + 1) + channels = channels[::-1] + + dilations_list = self.normalize_dilations(dilations, ratios) + dilations_list = dilations_list[::-1] + + net = [pre_network_conv(out_channels=channels[0])] + + for ratio, dilations, input_dim, output_dim in zip( + ratios, dilations_list, channels[:-1], channels[1:]): + net.append(upsampling_unit(input_dim, output_dim, ratio)) + for dilation in dilations: + net.append(Residual(dilated_unit(output_dim, dilation))) + + net.append(normalization(post_network_conv(in_channels=output_dim))) + + self.net = nn.Sequential(*net) + + @staticmethod + def normalize_dilations(dilations: Union[Sequence[int], + Sequence[Sequence[int]]], + ratios: Sequence[int]): + if isinstance(dilations[0], int): + dilations = [dilations for _ in ratios] + return dilations \ No newline at end of file diff --git a/xcodec2/vq/bs_roformer5.py b/xcodec2/vq/bs_roformer5.py new file mode 100755 index 0000000000000000000000000000000000000000..08aa016d731a6a5cae3e4f38514d97187ad7adb4 --- /dev/null +++ b/xcodec2/vq/bs_roformer5.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, ModuleList +import torchaudio +from einops import rearrange +import numpy as np +# from rotary_embedding_torch import RotaryEmbedding + +from torchtune.modules import RotaryPositionalEmbeddings + + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) + output = x * torch.rsqrt(norm_x + self.eps) * self.weight + return output + + + +class MLP(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + + self.fc1 = nn.Linear(dim, 4 * dim, bias=False) + self.silu = nn.SiLU() + self.fc2 = nn.Linear(4 * dim, dim, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.silu(x) + x = self.fc2(x) + return x + + +class Attention(nn.Module): + + def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): + super().__init__() + + assert dim % n_heads == 0 + + self.n_heads = n_heads + self.dim = dim + self.rotary_embed = rotary_embed + + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + assert self.flash, "Must have flash attention." + + self.c_attn = nn.Linear(dim, 3 * dim, bias=False) + self.c_proj = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + r""" + Args: + x: (b, t, h*d) + + Constants: + b: batch_size + t: time steps + r: 3 + h: heads_num + d: heads_dim + """ + B, T, C = x.size() + + q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) + # q, k, v: (b, h, t, d) + + q = self.rotary_embed(q) + k = self.rotary_embed(k) + + if self.flash: + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) + + y = rearrange(y, 'b h t d -> b t (h d)') + + y = self.c_proj(y) + # shape: (b, t, h*d) + + return y + + +class TransformerBlock(nn.Module): + def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): + + super().__init__() + self.dim = dim + self.n_heads = n_heads + + self.att_norm = RMSNorm(dim) + self.ffn_norm = RMSNorm(dim) + self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) + self.mlp = MLP(dim=dim) + + + def forward( + self, + x: torch.Tensor, + ): + x = x + self.att(self.att_norm(x)) + x = x + self.mlp(self.ffn_norm(x)) + return x + + +if __name__ == '__main__': + rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) + transformer_block = TransformerBlock( + dim=1024, + n_heads=8, + rotary_embed=rotary_embed_128 + ) + x = torch.randn(2, 128, 1024) + y = transformer_block(x) + print(y.shape) + c=1 \ No newline at end of file diff --git a/xcodec2/vq/codec_decoder.py b/xcodec2/vq/codec_decoder.py new file mode 100755 index 0000000000000000000000000000000000000000..9ea72e31ee70507b74d63752d8a8911929049e17 --- /dev/null +++ b/xcodec2/vq/codec_decoder.py @@ -0,0 +1,304 @@ +import sys + +import numpy as np +import torch +import torch.nn as nn +from xcodec2.vq.residual_vq import ResidualVQ +from xcodec2.vq.module import WNConv1d, DecoderBlock, ResLSTM +from xcodec2.vq.alias_free_torch import * +from xcodec2.vq import activations +from xcodec2.vq import blocks as blocks +from torch.nn import utils + +from xcodec2.vq.bs_roformer5 import TransformerBlock + +from torchtune.modules import RotaryPositionalEmbeddings + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecDecoder(nn.Module): + def __init__(self, + in_channels=1024, + upsample_initial_channel=1536, + ngf=48, + use_rnn=True, + rnn_bidirectional=False, + rnn_num_layers=2, + up_ratios=(5, 4, 4, 4, 2), + dilations=(1, 3, 9), + vq_num_quantizers=1, + vq_dim=2048, + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=32, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, # double the dim for acousitc and semantic + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + if use_rnn: + layers += [ + ResLSTM(channels, + num_layers=rnn_num_layers, + bidirectional=rnn_bidirectional + ) + ] + + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride, dilations)] + + layers += [ + Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x = self.model(x) + return x + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class CodecDecoder_oobleck_Transformer(nn.Module): + def __init__(self, + ngf=32, + up_ratios=(5, 4, 4, 4, 2), + dilations=(1, 3, 9), + vq_num_quantizers=1, + vq_dim=1024, + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.capacity = ngf + self.up_ratios = up_ratios + self.hidden_dim = hidden_dim + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, # double the dim for acousitc and semantic + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + self.conv_blocks = blocks.DilatedResidualDecoder( + capacity=self.capacity, + dilated_unit=self.dilated_unit, + upsampling_unit=self.upsampling_unit, + ratios=up_ratios, # 逆转编码器的下采样比率 + dilations=dilations, + pre_network_conv=self.pre_conv, + post_network_conv=self.post_conv, + ) + + + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x= self.transformers(x) + x = self.final_layer_norm(x) + x = x.permute(0, 2, 1) + x = self.conv_blocks(x) + return x + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + def pre_conv(self, out_channels): + return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1) + + # 定义后处理卷积层,将模型的输出映射到最终的输出通道数 + def post_conv(self,in_channels): + return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1) + + def dilated_unit(self, hidden_dim, dilation): + return blocks.DilatedConvolutionalUnit( + hidden_dim=hidden_dim, + dilation=dilation, + kernel_size=3, + activation=nn.ReLU , + normalization=utils.weight_norm + ) + + # 定义上采样单元 + def upsampling_unit(self,input_dim, output_dim, stride): + return blocks.UpsamplingUnit( + input_dim=input_dim, + output_dim=output_dim, + stride=stride, + activation=nn.ReLU , + normalization=utils.weight_norm + ) + +def main(): + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 初始化模型 + model = CodecDecoder_oobleck_Transformer().to(device) + print("Model initialized.") + + # 创建测试输入: batch_size x in_channels x sequence_length + batch_size = 2 + in_channels = 1024 + sequence_length = 100 # 示例长度,可以根据需要调整 + dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device) + print(f"Dummy input shape: {dummy_input.shape}") + + # 将模型设为评估模式 + model.eval() + + + + output_no_vq = model(dummy_input, vq=False) + c=1 + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/xcodec2/vq/codec_decoder_vocos.py b/xcodec2/vq/codec_decoder_vocos.py new file mode 100755 index 0000000000000000000000000000000000000000..ba7b5dc86b078515bf291ee3cbf45979bf4b7dff --- /dev/null +++ b/xcodec2/vq/codec_decoder_vocos.py @@ -0,0 +1,638 @@ +import sys +sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos') +import numpy as np +import torch +import torch.nn as nn +from xcodec2.vq.residual_vq import ResidualVQ +from xcodec2.vq.module import WNConv1d, DecoderBlock, ResLSTM +from xcodec2.vq.alias_free_torch import * +from xcodec2.vq import activations +from typing import Optional +from xcodec2.vq.module import ConvNeXtBlock, AdaLayerNorm +from xcodec2.vq.bs_roformer5 import TransformerBlock +# from rotary_embedding_torch import RotaryEmbedding +from torchtune.modules import RotaryPositionalEmbeddings +from vector_quantize_pytorch import ResidualFSQ +from torch.nn import Module, ModuleList +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x_pred = self.out(x ) + # x_pred = x + x_pred = x_pred.transpose(1, 2) + mag, p = x_pred.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio.unsqueeze(1),x_pred + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv1d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h = q.shape + q = q.permute(0, 2, 1) # b,hw,c + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + h_ = self.proj_out(h_) + + return x + h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): + super().__init__() + + self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3) + + + + self.temb_ch = 0 + block_in = hidden_dim + dropout = 0.1 + + prior_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ] + self.prior_net = nn.Sequential(*prior_net) + + depth = depth + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + + self.transformers = nn.Sequential(*transformer_blocks) + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + post_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ] + self.post_net = nn.Sequential(*post_net) + + def forward(self, x: torch.Tensor ) -> torch.Tensor: + x = x.transpose(1, 2) + x = self.embed(x) + x = self.prior_net(x) + x = x.transpose(1, 2) + x= self.transformers(x) + x = x.transpose(1, 2) + x = self.post_net(x) + x = x.transpose(1, 2) + x = self.final_layer_norm(x) + return x + + + + + + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecDecoderVocos(nn.Module): + def __init__(self, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + hop_length=320, + vq_num_quantizers=1, + vq_dim=2048, #1024 2048 + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + ): + super().__init__() + self.hop_length = hop_length + + self.quantizer = ResidualFSQ( + dim = vq_dim, + levels = [4, 4, 4, 4, 4,4,4,4], + num_quantizers = 1 + ) + + # self.quantizer = ResidualVQ( + # num_quantizers=vq_num_quantizers, + # dim=vq_dim, + # codebook_size=codebook_size, + # codebook_dim=codebook_dim, + # threshold_ema_dead_code=2, + # commitment=vq_commit_weight, + # weight_init=vq_weight_init, + # full_commit_loss=vq_full_commit_loss, + # ) + + + self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) + + self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + # x, q, commit_loss = self.quantizer(x) + x = x.permute(0, 2, 1) + x, q = self.quantizer(x) + x = x.permute(0, 2, 1) + q = q.permute(0, 2, 1) + return x, q, None + x = self.backbone(x) + x,_ = self.head(x) + + return x ,_ + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + +class CodecDecoderVocos_transpose(nn.Module): + def __init__(self, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + hop_length=320, + vq_num_quantizers=1, + vq_dim=1024, #1024 2048 + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + ): + super().__init__() + self.hop_length = hop_length + + + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + + + self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) + + self.inverse_mel_conv = nn.Sequential( + nn.GELU(), + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1 # 确保输出长度与编码前匹配 + ), + nn.GELU(), + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + padding=1 + ) + ) + + self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x = self.backbone(x) + x,_ = self.head(x) + + return x ,_ + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + + +def main(): + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 初始化模型 + model = CodecDecoderVocos_transpose().to(device) + print("Model initialized.") + + # 创建测试输入: batch_size x in_channels x sequence_length + batch_size = 2 + in_channels = 1024 + sequence_length = 50 # 示例长度,可以根据需要调整 + dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device) + print(f"Dummy input shape: {dummy_input.shape}") + + # 将模型设为评估模式 + model.eval() + + # 前向传播(使用 VQ) + # with torch.no_grad(): + # try: + # output, q, commit_loss = model(dummy_input, vq=True) + # print("Forward pass with VQ:") + # print(f"Output shape: {output.shape}") + # print(f"Quantized codes shape: {q.shape}") + # print(f"Commitment loss: {commit_loss}") + # except Exception as e: + # print(f"Error during forward pass with VQ: {e}") + + # 前向传播(不使用 VQ) + with torch.no_grad(): + # try: + output_no_vq = model(dummy_input, vq=False) + print("\nForward pass without VQ:") + print(f"Output shape: {output_no_vq.shape}") + c=1 + # except Exception as e: + # print(f"Error during forward pass without VQ: {e}") + + + # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) + # model_size_mb = model_size_bytes / (1024 ** 2) + # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/xcodec2/vq/codec_encoder.py b/xcodec2/vq/codec_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..717399f8c1c2eff77904a02058a888e550d75cf8 --- /dev/null +++ b/xcodec2/vq/codec_encoder.py @@ -0,0 +1,335 @@ +import sys + +import torch +from torch import nn +import numpy as np +from xcodec2.vq.module import WNConv1d, EncoderBlock, ResLSTM +from xcodec2.vq.alias_free_torch import * +from xcodec2.vq import activations +from xcodec2.vq.bs_roformer5 import TransformerBlock + +from torchtune.modules import RotaryPositionalEmbeddings +import xcodec2.vq.blocks as blocks +from torch.nn import utils +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecEncoder(nn.Module): + def __init__(self, + ngf=48, + use_rnn=True, + rnn_bidirectional=False, + rnn_num_layers=2, + up_ratios=(2, 2, 4, 4, 5), + dilations=(1, 3, 9), + out_channels=1024): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + # Create first convolution + d_model = ngf + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for i, stride in enumerate(up_ratios): + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] + # RNN + if use_rnn: + self.block += [ + ResLSTM(d_model, + num_layers=rnn_num_layers, + bidirectional=rnn_bidirectional + ) + ] + # Create last convolution + self.block += [ + Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, out_channels, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.reset_parameters() + + def forward(self, x): + out = self.block(x) + return out + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class Transpose(nn.Module): + def __init__(self, dim1, dim2): + super(Transpose, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x): + return x.transpose(self.dim1, self.dim2) + +class CodecEncoder_Transformer(nn.Module): + def __init__(self, + ngf=48, + up_ratios=[2, 2, 4, 4, 5], + dilations=(1, 3, 9), + hidden_dim=1024, + depth=12, + heads=12, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf =ngf + self.up_ratios = up_ratios + + d_model = ngf + self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + + for i, stride in enumerate(up_ratios): + d_model *= 2 + self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] + + self.conv_blocks = nn.Sequential(*self.conv_blocks) + + + # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + # transformer_blocks = [ + # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + # for _ in range(depth) + # ] + + + # self.transformers = nn.Sequential(*transformer_blocks) + + # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + self.conv_final_block = [ + Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), + ] + self.conv_final_block = nn.Sequential(*self.conv_final_block) + + self.reset_parameters() + + def forward(self, x): + x = self.conv_blocks(x) + # x = x.permute(0, 2, 1) + # x= self.transformers(x) + # x = self.final_layer_norm(x) + # x = x.permute(0, 2, 1) + x = self.conv_final_block (x) + x = x.permute(0, 2, 1) + return x + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + +class Codec_oobleck_Transformer(nn.Module): + def __init__(self, + ngf=32, + up_ratios=(2, 2,4,4, 5), + dilations=(1, 3, 9), + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf =ngf + self.up_ratios = up_ratios + self.hidden_dim = hidden_dim + + + self.conv_blocks = blocks.DilatedResidualEncoder( + capacity=ngf, + dilated_unit=self.dilated_unit, + downsampling_unit=self.downsampling_unit, + ratios=up_ratios, + dilations=dilations, + pre_network_conv=self.pre_conv, + post_network_conv=self.post_conv, + ) + + + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + + self.reset_parameters() + + def forward(self, x): + x = self.conv_blocks(x) + x = x.permute(0, 2, 1) + x= self.transformers(x) + x = self.final_layer_norm(x) + return x + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + def dilated_unit(self,hidden_dim, dilation): + return blocks.DilatedConvolutionalUnit(hidden_dim, + dilation, + kernel_size=3, + activation=nn.ReLU, + normalization=utils.weight_norm) + + def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): + return blocks.DownsamplingUnit(input_dim, + output_dim, + stride, + nn.ReLU, + normalization=utils.weight_norm) + + def pre_conv(self,out_channels): + return nn.Conv1d(1, out_channels, 1) + + def post_conv(self,in_channels): + return nn.Conv1d(in_channels, self.hidden_dim, 1) + + + + + +class CodecEncoder_only_Transformer(nn.Module): + def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): + super().__init__() + # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300, + + depth = depth + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + def forward(self, x: torch.Tensor ) -> torch.Tensor: + # x = self.embed(x) + + + x= self.transformers(x) + x = self.final_layer_norm(x) + + return x + + + + + + + +def get_model_size(model): + # 计算总参数数 + total_params = sum(p.numel() for p in model.parameters()) + + # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位) + model_size_bytes = total_params # 每个参数4字节 + + # 转换为更易读的单位(例如,MB) + model_size_mb = model_size_bytes / (1024 ** 2) + + return total_params, model_size_mb + +if __name__ == '__main__': + model = Codec_oobleck_Transformer() + x = torch.randn(1, 1, 16000) # example input tensor + output = model(x) + print("Output shape:", output.shape) diff --git a/xcodec2/vq/factorized_vector_quantize.py b/xcodec2/vq/factorized_vector_quantize.py new file mode 100755 index 0000000000000000000000000000000000000000..35f0c66736112f771a1933ca7e156b8cd5259e66 --- /dev/null +++ b/xcodec2/vq/factorized_vector_quantize.py @@ -0,0 +1,109 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +class FactorizedVectorQuantize(nn.Module): + def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + + if dim != self.codebook_dim: + self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) + self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) + else: + self.in_proj = nn.Identity() + self.out_proj = nn.Identity() + self._codebook = nn.Embedding(codebook_size, self.codebook_dim) + + @property + def codebook(self): + return self._codebook + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + z = rearrange(z, "b d t -> b t d") + + # Factorized codes project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x T x D) + z_e = rearrange(z_e, "b t d -> b d t") + z_q, indices = self.decode_latents(z_e) + + + if self.training: + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2]) + commit_loss = commitment_loss + codebook_loss + else: + commit_loss = torch.zeros(z.shape[0], device = z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = rearrange(z_q, "b d t -> b t d") + z_q = self.out_proj(z_q) + z_q = rearrange(z_q, "b t d -> b d t") + + return z_q, indices, commit_loss + + def vq2emb(self, vq, proj=True): + emb = self.embed_code(vq) + if proj: + emb = self.out_proj(emb) + return emb + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices \ No newline at end of file diff --git a/xcodec2/vq/module.py b/xcodec2/vq/module.py new file mode 100755 index 0000000000000000000000000000000000000000..0c4f69b351abbc3906ced487f4609ed784c29975 --- /dev/null +++ b/xcodec2/vq/module.py @@ -0,0 +1,420 @@ +import torch.nn as nn +from einops import rearrange +from . import activations +from .alias_free_torch import * +from torch.nn.utils import weight_norm + +from typing import Optional, Tuple + +from torch.nn.utils import weight_norm, remove_weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)): + super().__init__() + runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations] + self.block = nn.Sequential( + *runits, + Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + ), + ) + + def forward(self, x): + return self.block(x) + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)): + super().__init__() + self.block = nn.Sequential( + Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding= stride % 2, + ) + ) + self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations]) + + def forward(self, x): + return self.block(x) + +class ResLSTM(nn.Module): + def __init__(self, dimension: int, + num_layers: int = 2, + bidirectional: bool = False, + skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2, + num_layers, batch_first=True, + bidirectional=bidirectional) + + def forward(self, x): + """ + Args: + x: [B, F, T] + + Returns: + y: [B, F, T] + """ + x = rearrange(x, "b f t -> b t f") + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = rearrange(y, "b t f -> b f t") + return y + + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + + +class SemanticEncoder(nn.Module): + def __init__( + self, + input_channels: int, + code_dim: int, + encode_channels: int, + kernel_size: int = 3, + bias: bool = True, + ): + super(SemanticEncoder, self).__init__() + + # 初始卷积,将 input_channels 映射到 encode_channels + self.initial_conv = nn.Conv1d( + in_channels=input_channels, + out_channels=encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + # 残差块 + self.residual_blocks = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv1d( + encode_channels, + encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias + ), + nn.ReLU(inplace=True), + nn.Conv1d( + encode_channels, + encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias + ) + ) + + # 最终卷积,将 encode_channels 映射到 code_dim + self.final_conv = nn.Conv1d( + in_channels=encode_channels, + out_channels=code_dim, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + def forward(self, x): + """ + 前向传播方法。 + + Args: + x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length) + + Returns: + Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length) + """ + x = self.initial_conv(x) # (Batch, Encode_channels, Length) + x = self.residual_blocks(x) + x # 残差连接 + x = self.final_conv(x) # (Batch, Code_dim, Length) + return x + +class SemanticDecoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + kernel_size: int = 3, + bias: bool = True, + ): + super(SemanticDecoder, self).__init__() + + # Initial convolution to map code_dim to decode_channels + self.initial_conv = nn.Conv1d( + in_channels=code_dim, + out_channels=decode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + # Residual Blocks + self.residual_blocks = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias), + nn.ReLU(inplace=True), + nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias) + ) + + # Final convolution to map decode_channels to output_channels + self.final_conv = nn.Conv1d( + in_channels=decode_channels, + out_channels=output_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + def forward(self, z): + # z: (Batch, Code_dim, Length) + x = self.initial_conv(z) # (Batch, Decode_channels, Length) + x = self.residual_blocks(x) + x # Residual connection + x = self.final_conv(x) # (Batch, Output_channels, Length) + return x \ No newline at end of file diff --git a/xcodec2/vq/residual_vq.py b/xcodec2/vq/residual_vq.py new file mode 100755 index 0000000000000000000000000000000000000000..40d3338fd940aa2e41177827d6e24c5269765b86 --- /dev/null +++ b/xcodec2/vq/residual_vq.py @@ -0,0 +1,53 @@ +import math +import torch +from torch import nn +from .factorized_vector_quantize import FactorizedVectorQuantize + +class ResidualVQ(nn.Module): + def __init__( + self, + *, + num_quantizers, + codebook_size, + **kwargs + ): + super().__init__() + VQ = FactorizedVectorQuantize + if type(codebook_size) == int: + codebook_size = [codebook_size] * num_quantizers + self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size]) + self.num_quantizers = num_quantizers + + def forward(self, x): + quantized_out = 0. + residual = x + + all_losses = [] + all_indices = [] + + for idx, layer in enumerate(self.layers): + quantized, indices, loss = layer(residual) + + residual = residual - quantized + + quantized_out = quantized_out + quantized + + loss = loss.mean() + + all_indices.append(indices) + all_losses.append(loss) + all_losses, all_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, all_indices, all_losses + + def vq2emb(self, vq, proj=True): + # [B, T, num_quantizers] + quantized_out = 0. + for idx, layer in enumerate(self.layers): + quantized = layer.vq2emb(vq[:, :, idx], proj=proj) + quantized_out = quantized_out + quantized + return quantized_out + def get_emb(self): + embs = [] + for idx, layer in enumerate(self.layers): + embs.append(layer.get_emb()) + return embs diff --git a/xcodec2/vq/unet.py b/xcodec2/vq/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..ca31029d0866b61663f75045c7770cc7208d9482 --- /dev/null +++ b/xcodec2/vq/unet.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import numpy as np + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(EncoderBlock, self).__init__() + + self.pool_size = 2 + + self.conv_block = ConvBlock(in_channels, out_channels, kernel_size) + + def forward(self, x): + latent = self.conv_block(x) + output = F.avg_pool2d(latent, kernel_size=self.pool_size) + return output, latent + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(DecoderBlock, self).__init__() + + stride = 2 + + self.upsample = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=stride, + stride=stride, + padding=(0, 0), + bias=False, + ) + + self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size) + + def forward(self, x, latent): + x = self.upsample(x) + x = torch.cat((x, latent), dim=1) + output = self.conv_block(x) + return output + + +class UNet(nn.Module): + def __init__(self,freq_dim=1281,out_channel=1024): + super(UNet, self).__init__() + + self.downsample_ratio = 16 + + + in_channels = 1 #self.audio_channels * self.cmplx_num + + self.encoder_block1 = EncoderBlock(in_channels, 16) + self.encoder_block2 = EncoderBlock(16, 64) + self.encoder_block3 = EncoderBlock(64, 256) + self.encoder_block4 = EncoderBlock(256, 1024) + self.middle = EncoderBlock(1024, 1024) + self.decoder_block1 = DecoderBlock(1024, 256) + self.decoder_block2 = DecoderBlock(256, 64) + self.decoder_block3 = DecoderBlock(64, 16) + self.decoder_block4 = DecoderBlock(16, 16) + + self.fc = nn.Linear(freq_dim*16, out_channel) + + def forward(self, x_ori): + """ + Args: + complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量 + + Returns: + output: (batch_size, channels_num, time_steps, freq_bins),复数张量 + """ + + + x= self.process_image(x_ori) + x1, latent1 = self.encoder_block1(x) + x2, latent2 = self.encoder_block2(x1) + x3, latent3 = self.encoder_block3(x2) + x4, latent4 = self.encoder_block4(x3) + _, h = self.middle(x4) + x5 = self.decoder_block1(h, latent4) + x6 = self.decoder_block2(x5, latent3) + x7 = self.decoder_block3(x6, latent2) + x8 = self.decoder_block4(x7, latent1) + x= self.unprocess_image(x8,x_ori.shape[2]) + x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024] + x = x.view(x.size(0), x.size(1), -1) + x= self.fc(x) + + return x + + def process_image(self, x): + """ + 处理频谱以便可以被 downsample_ratio 整除。 + + Args: + x: (B, C, T, F) + + Returns: + output: (B, C, T_padded, F_reduced) + """ + + B, C, T, Freq = x.shape + + pad_len = ( + int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio + - T + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + + output = x[:, :, :, 0 : Freq - 1] + + return output + + def unprocess_image(self, x,time_steps): + """ + 恢复频谱到原始形状。 + + Args: + x: (B, C, T_padded, F_reduced) + + Returns: + output: (B, C, T_original, F_original) + """ + x = F.pad(x, pad=(0, 1)) + + output = x[:, :,0:time_steps, :] + + return output + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(ConvBlock, self).__init__() + + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + self.bn1 = nn.BatchNorm2d(in_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + padding=(0, 0), + ) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + h = self.conv1(F.leaky_relu_(self.bn1(x))) + h = self.conv2(F.leaky_relu_(self.bn2(h))) + + if self.is_shortcut: + return self.shortcut(x) + h + else: + return x + h + + +def test_unet(): + # 定义输入参数 + batch_size = 6 + channels = 1 # 音频通道数 + time_steps = 256 # 时间步数 + freq_bins = 1024 # 频率 bins 数 + + # 创建一个随机的复数张量作为输入 + real_part = torch.randn(batch_size, channels, time_steps, freq_bins) + imag_part = torch.randn(batch_size, channels, time_steps, freq_bins) + complex_sp = real_part #torch.complex(real_part, imag_part) + + # 实例化 UNet 模型 + model = UNet() + + # 前向传播 + output = model(complex_sp) + + # 输出输入和输出的形状 + print("输入形状:", complex_sp.shape) + print("输出形状:", output.shape) + + # 检查输出是否为复数张量 + assert torch.is_complex(output), "输出不是复数张量" + + # 检查输出形状是否与输入形状一致 + assert output.shape == complex_sp.shape, "输出形状与输入形状不一致" + + print("测试通过,模型正常工作。") + +# 运行测试函数 +if __name__ == "__main__": + test_unet() \ No newline at end of file