diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac0c8d74a4d21f5d400e8a809a435b15af69efc
--- /dev/null
+++ b/app.py
@@ -0,0 +1,519 @@
+import os
+import traceback
+
+import gradio as gr
+import numpy as np
+import pyrootutils
+import torch
+from loguru import logger
+from transformers import AutoTokenizer
+from vllm import LLM, SamplingParams, TokensPrompt
+from funasr_onnx import Paraformer
+from huggingface_hub import snapshot_download
+
+from tools.wer import compute_wers
+
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+os.environ["VLLM_USE_V1"] = "0"
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+from i18n import i18n
+from text.chn_text_norm.text import Text as ChnNormedText
+from xcodec2.modeling_xcodec2 import XCodec2Model
+
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+
+# ===== Hugging Face Model IDs =====
+LLASA_MODEL_ID = "ASLP-lab/VoiceSculptor"
+LLASA_SUBFOLDER = "LLaSA-Instruct-3B"
+XCODEC_MODEL_ID = "HKUSTAudio/xcodec2"
+PARAFORMER_REPO_ID = "funasr/Paraformer-large"
+
+# logo
+LOGO_URL = "https://raw.githubusercontent.com/ASLP-lab/VoiceSculptor/main/assets/logo.png"
+
+
+def normalize_text_final(user_input: str) -> str:
+ return ChnNormedText(raw_text=user_input).normalize()
+
+
+def extract_speech_ids(speech_tokens_str):
+ speech_ids = []
+ for token_str in speech_tokens_str:
+ if token_str.startswith("<|s_") and token_str.endswith("|>"):
+ num_str = token_str[4:-2]
+ speech_ids.append(int(num_str))
+ else:
+ logger.warning(f"Unexpected token: {token_str}")
+ return speech_ids
+
+
+def get_asr(asr_model: Paraformer, wav_list: list[np.ndarray]) -> list[str]:
+ """wav_list: list of 1D numpy waveform (16k)"""
+ try:
+ result = asr_model(wav_list)
+ if isinstance(result, dict):
+ result = [result]
+
+ texts = []
+ for res in result:
+ preds = res.get("preds", None)
+ if preds is None:
+ texts.append(res.get("text", ""))
+ else:
+ texts.append(preds[0] if len(preds) > 0 else "")
+
+ # 容错:batch 返回数量不一致 -> fallback
+ if len(texts) != len(wav_list):
+ logger.warning(f"[ASR] batch返回数量不一致: got {len(texts)} expect {len(wav_list)},fallback逐条补齐")
+ texts = []
+ for w in wav_list:
+ try:
+ r = asr_model(w)
+ if isinstance(r, list) and len(r) > 0:
+ r0 = r[0]
+ preds = r0.get("preds", None)
+ texts.append(preds[0] if preds else r0.get("text", ""))
+ elif isinstance(r, dict):
+ preds = r.get("preds", None)
+ texts.append(preds[0] if preds else r.get("text", ""))
+ else:
+ texts.append("")
+ except Exception:
+ texts.append("")
+ return texts
+
+ except Exception as e:
+ logger.warning(f"[ASR] batch失败,fallback逐条: {e}")
+ texts = []
+ for w in wav_list:
+ try:
+ r = asr_model(w)
+ if isinstance(r, list) and len(r) > 0:
+ r0 = r[0]
+ preds = r0.get("preds", None)
+ texts.append(preds[0] if preds else r0.get("text", ""))
+ elif isinstance(r, dict):
+ preds = r.get("preds", None)
+ texts.append(preds[0] if preds else r.get("text", ""))
+ else:
+ texts.append("")
+ except Exception:
+ texts.append("")
+ return texts
+
+
+def inference_batch(
+ model: LLM,
+ codec_model: XCodec2Model,
+ device: str,
+ tokenizer: AutoTokenizer,
+ refined_text: str,
+ instruct_text: str,
+ control_tags: str,
+ batch_size: int = 5,
+) -> list[tuple[int, np.ndarray]]:
+ refined_text_norm = normalize_text_final(refined_text)
+ instruct_text_norm = normalize_text_final(instruct_text)
+
+ if len(refined_text_norm) < 5:
+ raise ValueError("输入文本长度不能少于5个字符")
+ if len(refined_text_norm) > 150:
+ raise ValueError("输入文本长度不能超过150个字符")
+
+ target_text = instruct_text_norm + "<|endofprompt|>" + control_tags + refined_text_norm
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{target_text}<|TEXT_UNDERSTANDING_END|>"
+ chat = [
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
+ ]
+
+ with torch.no_grad():
+ input_ids = tokenizer.apply_chat_template(
+ chat,
+ tokenize=True,
+ return_tensors="pt",
+ continue_final_message=True,
+ ).to(device)
+
+ speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
+ prompt_ids = input_ids.squeeze(0).tolist()
+ prompts = [TokensPrompt(prompt_token_ids=prompt_ids) for _ in range(batch_size)]
+
+ base_seed = int.from_bytes(os.urandom(4), "little")
+
+ try:
+ sampling_params_list = [
+ SamplingParams(
+ temperature=0.9,
+ top_p=0.95,
+ top_k=15,
+ max_tokens=2048,
+ repetition_penalty=1.05,
+ stop_token_ids=[speech_end_id],
+ seed=base_seed + i,
+ )
+ for i in range(batch_size)
+ ]
+ outputs = model.generate(prompts=prompts, sampling_params=sampling_params_list)
+ except TypeError:
+ logger.warning("[vLLM] 当前版本不支持 SamplingParams(seed=...),将不带 seed 生成")
+ sampling_params = SamplingParams(
+ temperature=0.9,
+ top_p=0.95,
+ top_k=15,
+ max_tokens=2048,
+ repetition_penalty=1.05,
+ stop_token_ids=[speech_end_id],
+ )
+ outputs = model.generate(prompts=prompts, sampling_params=sampling_params)
+
+ audios: list[tuple[int, np.ndarray]] = []
+ for out in outputs:
+ token_ids = out.outputs[0].token_ids
+ if len(token_ids) > 0 and token_ids[-1] == speech_end_id:
+ token_ids = token_ids[:-1]
+
+ speech_tokens = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
+ speech_tokens = extract_speech_ids(speech_tokens)
+
+ speech_tokens_t = torch.tensor(speech_tokens, device=device).unsqueeze(0).unsqueeze(0)
+ wav = codec_model.decode_code(speech_tokens_t)
+ wav = wav.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32)
+ audios.append((16000, wav))
+
+ return audios
+
+
+def build_app():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ logger.info(f"✅ Loading models on device={device}")
+
+ # ===== LLaSA =====
+ tokenizer = AutoTokenizer.from_pretrained(LLASA_MODEL_ID, subfolder=LLASA_SUBFOLDER, trust_remote_code=True)
+
+ model = LLM(
+ model=LLASA_MODEL_ID,
+ gpu_memory_utilization=0.90,
+ max_model_len=2048,
+ enable_prefix_caching=True,
+ dtype="auto",
+ quantization=None,
+ enforce_eager=False,
+ kv_cache_dtype="auto",
+ trust_remote_code=True,
+ hf_model_subfolder=LLASA_SUBFOLDER,
+ )
+
+ # ===== XCodec2 =====
+ codec_model = XCodec2Model.from_pretrained(XCODEC_MODEL_ID).eval().to(device)
+
+ # ===== Paraformer =====
+ paraformer_dir = snapshot_download(
+ repo_id=PARAFORMER_REPO_ID,
+ local_dir="checkpoints/Paraformer-large",
+ local_dir_use_symlinks=False,
+ )
+ asr_model = Paraformer(paraformer_dir, batch_size=5, quantize=True)
+
+ logger.info("✅ Models loaded: VoiceSculptor + xcodec2 + Paraformer")
+
+ INSTRUCT_TEMPLATES = {
+ "自定义": "",
+ "default": "这是一位男性评书表演者,用传统说唱腔调,以变速节奏和韵律感极强的语速讲述江湖故事,音量时高时低,充满江湖气。",
+ "幼儿园女教师-温柔甜美": "这是一位幼儿园女教师,用甜美明亮的嗓音,以极慢且富有耐心的语速,带着温柔鼓励的情感,用标准普通话给小朋友讲睡前故事,音量轻柔适中,咬字格外清晰。",
+ "电台主播-平静温柔": "深夜电台主播,男性、音调偏低、语速偏慢、音量小;情绪平静带点忧伤,语气温柔;音色微哑",
+ "成熟御姐-冷静坚定": "成熟御姐风格,音调偏低、语速正常、音量中等;情绪冷静,语气不容置疑的坚定;音色偏磁性,吐字清晰",
+ "年轻妈妈-温暖安抚": "年轻妈妈哄孩子入睡,女性、音调柔和偏低、语速偏慢、音量偏小但清晰;情绪温暖安抚、充满耐心与爱意,语气轻柔哄劝、像贴近耳边低声说话;音色软糯,吐字清晰、节奏舒缓。",
+ "小女孩-尖锐清脆": "一位7岁的小女孩,用天真高亢的童声,以不稳定的快节奏,充满兴奋和炫耀地背诵乘法口诀,音调忽高忽低,带着儿童特有的尖锐清脆。",
+ "老奶奶-沙哑低沉": "一位慈祥的老奶奶,用沙哑低沉的嗓音,以极慢而温暖的语速讲述民间传说,音量微弱但清晰,带着怀旧和神秘的情感。",
+ "诗歌朗诵-雄浑有力": "一位男性现代诗朗诵者,用深沉磁性的低音,以顿挫有力的节奏演绎艾青诗歌,音量洪亮,情感激昂澎湃。",
+ "童话风格-甜美夸张": "这是一位女性童话旁白朗诵者,用甜美夸张的童声,以跳跃变化的语速讲述《安徒生童话》,音调偏高,充满奇幻色彩。",
+ "评书风格-抑扬顿挫": "这是一位男性评书表演者,用传统说唱腔调,以变速节奏和韵律感极强的语速讲述江湖故事,音量时高时低,充满江湖气。",
+ "新闻风格-平静专业": "这是一位女性新闻主播,用标准普通话以清晰明亮的中高音,以平稳专业的语速播报时事新闻,音量洪亮,情感客观中立。",
+ "相声风格-夸张幽默": "这是一位男性相声表演者,用夸张幽默的嗓音,以时快时慢的节奏抖包袱,音调起伏大,充满喜感和节奏感。",
+ "游戏直播-亢奋激昂": "这是一位男性游戏解说,用亢奋激昂的嗓音,以极快且情绪化的语速直播电竞比赛,音量突然爆发,充满悬念和热血。",
+ "悬疑小说-低沉神秘": "一位男性悬疑小说演播者,用低沉神秘的嗓音,以时快时慢的变速节奏营造紧张氛围,音量忽高忽低,充满悬念感。",
+ "戏剧表演-夸张戏剧": "这是一位男性戏剧表演者,用夸张戏剧化的嗓音,以忽高忽低的音调和时快时慢的语速表演独白,充满张力。",
+ "法治节目-庄严庄重": "这是一位男性法治节目主持人,用严肃庄重的嗓音,以平稳有力的语速讲述案件,音量适中,体现法律的威严。",
+ "纪录片旁白-低沉磁性": "这是一位男性纪录片旁白,用深沉磁性的嗓音,以缓慢而富有画面感的语速讲述自然奇观,音量适中,充满敬畏和诗意。",
+ "广告配音-沧桑浑厚": "这是一位男性白酒品牌广告配音,用沧桑浑厚的嗓音,以缓慢而豪迈的语速,音量洪亮,传递历史底蕴和男人情怀。",
+ "冥想引导师-空灵悠长": "一位女性冥想引导师,用空灵悠长的气声,以极慢而飘渺的语速,配合环境音效,音量轻柔,营造禅意空间。",
+ "ASMR-气声耳语": "一位女性ASMR主播,用气声耳语,以极慢而细腻的语速,配合唇舌音,音量极轻,营造极度放松的氛围。",
+ }
+
+ TEXT_REQUIREMENTS = {
+ "自定义": "",
+ "default": "话说那武松,提着哨棒,直奔景阳冈。天色将晚,酒劲上头,只听一阵狂风,老虎来啦!",
+ "幼儿园女教师-温柔甜美": "月亮婆婆升上天空啦,星星宝宝都困啦。小白兔躺在床上,盖好小被子,闭上眼睛。兔妈妈轻轻地唱着摇篮曲:睡吧睡吧,我亲爱的宝贝。",
+ "电台主播-平静温柔": "大家好,欢迎收听你的月亮我的心,好男人就是我,我就是:曾小贤。",
+ "成熟御姐-冷静坚定": "别担心,我不会让你输,把那些乱七八糟的念头先收起来,姐姐带你赢。",
+ "年轻妈妈-温暖安抚": "从前有座山,山里有座庙,庙里面有个小和尚,小和尚在给老和尚讲故事,他说:从前有座山,山里有座庙,庙里面有个小和尚。",
+ "小女孩-尖锐清脆": "一一得一!一二得二!一三得三!我会背乘法口诀啦!老师今天表扬我啦!妈妈说我最棒!",
+ "老奶奶-沙哑低沉": "很久很久以前,在山的那边,住着一只会说话的狐狸。它常常在月圆之夜,变成美丽的姑娘,来到村子里。",
+ "诗歌朗诵-雄浑有力": "为什么我的眼里常含泪水?因为我对这土地爱得深沉。这土地,这河流,这吹刮着的暴风。",
+ "童话风格-甜美夸张": "在一个很冷很冷的夜晚,小女孩擦亮了一根火柴。突然,温暖的火炉出现了!她觉得自己好像坐在火炉旁。",
+ "评书风格-抑扬顿挫": "话说那武松,提着哨棒,直奔景阳冈。天色将晚,酒劲上头,只听一阵狂风,老虎来啦!",
+ "新闻风格-平静专业": "本台讯,今日凌晨,我国成功发射新一代载人飞船试验船。此次任务验证了多项关键技术,为后续空间站建设奠定基础。",
+ "相声风格-夸张幽默": "我这个人啊,最大的优点就是太谦虚。谦虚到什么程度?连谦虚本身都觉得我太谦虚了!",
+ "游戏直播-亢奋激昂": "大招!大招好了!开团了!ACE!团灭!这波操作神了!冠军相尽显无疑!",
+ "悬疑小说-低沉神秘": "深夜,他独自走在空无一人的小巷。脚步声,回声,还有……另一个人的呼吸声。他猛地回头——什么也没有。",
+ "戏剧表演-夸张戏剧": "我疯了!彻底疯了!你们都说我疯了!可疯的是这个世界!清醒的人反而被当成疯子!",
+ "法治节目-庄严庄重": "天网恢恢,疏而不漏。任何触犯法律的行为,终将受到公正的审判。正义或许会迟到,但绝不会缺席。",
+ "纪录片旁白-低沉磁性": "在这片广袤的非洲草原上,生命与死亡每天都在上演。猎豹的速度,羚羊的敏捷,都是生存的代价。",
+ "广告配音-沧桑浑厚": "一杯敬过往,一杯敬远方。传承千年的酿造工艺,只在每一滴醇香。老朋友,值得好酒。",
+ "冥想引导师-空灵悠长": "想象你是一片叶子,随风飘落。没有牵挂,没有重量。只有呼吸,只有当下,只有宁静。",
+ "ASMR-气声耳语": "现在,让我在你耳边轻声细语。听到我的声音了吗?放松你的头皮,感受每一个毛孔都在呼吸。",
+ }
+
+ def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
+ tag_map = {
+ "小孩": "<|小孩|>", "青年": "<|青年|>", "中年": "<|中年|>", "老年": "<|老年|>",
+ "男性": "<|男性|>", "女性": "<|女性|>",
+ "音调很高": "<|音调很高|>", "音调较高": "<|音调较高|>", "音调中等": "<|音调中等|>",
+ "音调较低": "<|音调较低|>", "音调很低": "<|音调很低|>",
+ "音调变化很强": "<|音调变化很强|>", "音调变化较强": "<|音调变化较强|>", "音调变化一般": "<|音调变化一般|>",
+ "音调变化较弱": "<|音调变化较弱|>", "音调变化很弱": "<|音调变化很弱|>",
+ "音量很大": "<|音量很大|>", "音量较大": "<|音量较大|>", "音量中等": "<|音量中等|>",
+ "音量较小": "<|音量较小|>", "音量很小": "<|音量很小|>",
+ "语速很快": "<|语速很快|>", "语速较快": "<|语速较快|>", "语速中等": "<|语速中等|>",
+ "语速较慢": "<|语速较慢|>", "语速很慢": "<|语速很慢|>",
+ "开心": "<|开心|>", "生气": "<|生气|>", "难过": "<|难过|>", "惊讶": "<|惊讶|>", "厌恶": "<|厌恶|>", "害怕": "<|害怕|>",
+ }
+ tags = []
+ for v in [gender, age, speed, volume, pitch, pitch_var, emo]:
+ if v != "不指定":
+ tags.append(tag_map[v])
+ return "".join(tags)
+
+ def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitch_var, volume, speed, emo):
+ control_tags = build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo)
+ try:
+ audios5 = inference_batch(
+ model=model,
+ codec_model=codec_model,
+ device=device,
+ tokenizer=tokenizer,
+ refined_text=refined_text,
+ instruct_text=instruct_text,
+ control_tags=control_tags,
+ batch_size=5,
+ )
+ wav_list = [wav for (_, wav) in audios5]
+ asr_texts = get_asr(asr_model, wav_list)
+
+ refined_text_norm = normalize_text_final(refined_text)
+ gt_texts = [refined_text_norm] * len(asr_texts)
+ wers = compute_wers(gt_texts, asr_texts, lang="zh")
+
+ for i, (hyp, w) in enumerate(zip(asr_texts, wers)):
+ logger.info(f"[ASR/WER] idx={i} wer={w:.4f} gt='{refined_text_norm}' asr='{hyp}'")
+
+ best_idx = np.argsort(np.array(wers))[:3].tolist()
+ logger.info(f"[ASR/WER] best_idx={best_idx} best_wers={[float(wers[i]) for i in best_idx]}")
+ best3 = [audios5[i] for i in best_idx]
+ return best3[0], best3[1], best3[2]
+ except Exception as e:
+ logger.error(f"推理/ASR/WER 失败: {e}", exc_info=True)
+ logger.error("错误详细信息:\n" + traceback.format_exc())
+ return None, None, None
+
+ THEME = gr.themes.Soft(
+ primary_hue="orange",
+ secondary_hue="cyan",
+ neutral_hue="slate",
+ )
+
+ CUSTOM_CSS = """
+ /* layout */
+ #vs-root {max-width: 1180px; margin: 0 auto;}
+ #vs-header {padding: 14px 14px 4px 14px;}
+ #vs-card {border-radius: 14px; padding: 14px; border: 1px solid rgba(0,0,0,0.08);}
+
+ /* ===== VoiceSculptor palette (from logo) ===== */
+ :root, .gradio-container {
+ --vs-orange: #FF6A00;
+ --vs-orange2:#FFB000;
+ --vs-teal: #00A6C6;
+ --vs-blue: #0B2E8A;
+ --vs-teal-a: rgba(0,166,198,.18);
+ }
+
+ /* primary button */
+ .gr-button-primary, button.primary {
+ background: linear-gradient(90deg, var(--vs-orange), var(--vs-orange2)) !important;
+ border: none !important;
+ color: white !important;
+ }
+ .gr-button-primary:hover, button.primary:hover {
+ filter: brightness(1.03);
+ }
+ .gr-button-primary:active, button.primary:active {
+ filter: brightness(0.98);
+ }
+
+ /* links */
+ .gradio-container a {
+ color: var(--vs-teal) !important;
+ }
+ .gradio-container a:hover {
+ text-decoration: underline;
+ }
+
+ /* focus ring / active border for inputs */
+ textarea:focus, input:focus {
+ border-color: var(--vs-teal) !important;
+ box-shadow: 0 0 0 3px var(--vs-teal-a) !important;
+ outline: none !important;
+ }
+ /* some gradio versions wrap inputs in these */
+ .gr-input:focus-within, .gr-text-input:focus-within, .gr-box:focus-within {
+ border-color: var(--vs-teal) !important;
+ box-shadow: 0 0 0 3px var(--vs-teal-a) !important;
+ }
+
+ /* accordion highlight */
+ .gr-accordion .label, .gr-accordion summary {
+ color: var(--vs-blue) !important;
+ }
+ """
+
+ DEFAULT_STYLE = "评书风格-抑扬顿挫"
+ template_choices = [k for k in INSTRUCT_TEMPLATES.keys() if k not in ("default",)]
+
+ BEST_PRACTICE_MD = """
+## Best Practice Guide(音色设计)
+
+完整指南请见:Voice Design README
+https://github.com/ASLP-lab/VoiceSculptor/blob/main/docs/voice_design.md
+
+### 关键约束
+- **voice_prompt ≤ 200 字**
+- **当前仅支持中文**
+- **待合成文本长度 ≥ 5 个字**
+
+### 写法建议
+- **具体**:用可感知特质词(低沉/清脆/沙哑/明亮、语速快慢、音量大小等),避免“好听/不错”。
+- **完整**:建议覆盖 **3–4 个维度**(人设/场景 + 性别/年龄 + 音调/语速 + 音质/情绪)。
+- **客观**:描述声音特征与表达方式,避免“我喜欢/很棒”。
+- **不做模仿**:禁止“像某明星/某演员”,只描述声音特质本身。
+- **尽量精炼**:每个词都承载信息,避免重复强调(如“非常非常”)。
+
+### 参考模板
+> - 这是一位男性评书表演者,用传统说唱腔调,以变速节奏和韵律感极强的语速讲述江湖故事,音量时高时低,充满江湖气。
+> - 深夜电台主播,男性、音调偏低、语速偏慢、音量小;情绪平静带点忧伤,语气温柔;音色微哑。
+> - 成熟御姐风格,音调偏低、语速正常、音量中等;情绪冷静,语气不容置疑的坚定;音色偏磁性,吐字清晰。
+
+
+### 细粒度控制提示
+- 细粒度控制(年龄/性别/音调/语速/音量/情感等)**建议与指令描述保持一致**,尽量避免相互矛盾(如指令写“低沉慢速”,细粒度却选“音调很高/语速很快”)。
+"""
+
+ with gr.Blocks(theme=THEME, css=CUSTOM_CSS) as app:
+ with gr.Column(elem_id="vs-root"):
+ with gr.Row(elem_id="vs-header"):
+ gr.HTML(f"""
+
+

+
+
Voice Sculptor
+
+ {i18n('An instruct text-to-speech solution based on LLaSA and CosyVoice2 developed by the ASLP lab and collaborators.')}
+
+
+
+ """)
+
+ with gr.Row():
+ # Left: Controls + Guide
+ with gr.Column(scale=5, elem_id="vs-card"):
+ gr.Markdown("### 🪄 Voice Design(捏音色)")
+
+ with gr.Accordion("🎭 风格与文本", open=True):
+ instruct_template = gr.Dropdown(
+ choices=template_choices,
+ value=DEFAULT_STYLE,
+ label=i18n("指令风格(必选)"),
+ interactive=True,
+ )
+
+ instruct_text = gr.Textbox(
+ label=i18n("指令文本"),
+ placeholder=TEXTBOX_PLACEHOLDER,
+ lines=4,
+ value=INSTRUCT_TEMPLATES.get(DEFAULT_STYLE, INSTRUCT_TEMPLATES["default"]),
+ )
+
+ text = gr.Textbox(
+ label=i18n("待合成文本"),
+ placeholder=TEXTBOX_PLACEHOLDER,
+ lines=4,
+ value=TEXT_REQUIREMENTS.get(DEFAULT_STYLE, TEXT_REQUIREMENTS["default"]),
+ )
+
+ with gr.Accordion("🎛️ 细粒度声音控制(可选)", open=False):
+ with gr.Row():
+ age_ctrl = gr.Dropdown(label="年龄", choices=["不指定", "小孩", "青年", "中年", "老年"], value="不指定")
+ gender_ctrl = gr.Dropdown(label="性别", choices=["不指定", "男性", "女性"], value="不指定")
+
+ with gr.Row():
+ pitch_ctrl = gr.Dropdown(
+ label="音调高度",
+ choices=["不指定", "音调很高", "音调较高", "音调中等", "音调较低", "音调很低"],
+ value="不指定",
+ )
+ pitch_var_ctrl = gr.Dropdown(
+ label="音调变化",
+ choices=["不指定", "音调变化很强", "音调变化较强", "音调变化一般", "音调变化较弱", "音调变化很弱"],
+ value="不指定",
+ )
+
+ with gr.Row():
+ volume_ctrl = gr.Dropdown(
+ label="音量",
+ choices=["不指定", "音量很大", "音量较大", "音量中等", "音量较小", "音量很小"],
+ value="不指定",
+ )
+ speed_ctrl = gr.Dropdown(
+ label="语速",
+ choices=["不指定", "语速很快", "语速较快", "语速中等", "语速较慢", "语速很慢"],
+ value="不指定",
+ )
+
+ emo_ctrl = gr.Dropdown(
+ label="情感",
+ choices=["不指定", "开心", "生气", "难过", "惊讶", "厌恶", "害怕"],
+ value="不指定",
+ )
+
+ with gr.Accordion("📚 Best Practice Guide", open=False):
+ gr.Markdown(BEST_PRACTICE_MD)
+
+ def apply_template(tpl_name):
+ return INSTRUCT_TEMPLATES.get(tpl_name, ""), TEXT_REQUIREMENTS.get(tpl_name, "")
+
+ instruct_template.change(apply_template, inputs=[instruct_template], outputs=[instruct_text, text])
+
+ # Right: Results + Generate
+ with gr.Column(scale=5, elem_id="vs-card"):
+ gr.Markdown("### 🎵 Results")
+ generate = gr.Button("🎧 Generate", variant="primary")
+ audio_output1 = gr.Audio(label=i18n("Generated Audio 1"), type="numpy", interactive=False)
+ audio_output2 = gr.Audio(label=i18n("Generated Audio 2"), type="numpy", interactive=False)
+ audio_output3 = gr.Audio(label=i18n("Generated Audio 3"), type="numpy", interactive=False)
+
+ generate.click(
+ fn=inference_select_best3,
+ inputs=[text, instruct_text, age_ctrl, gender_ctrl, pitch_ctrl, pitch_var_ctrl, volume_ctrl, speed_ctrl, emo_ctrl],
+ outputs=[audio_output1, audio_output2, audio_output3],
+ )
+
+ return app
+
+
+if __name__ == "__main__":
+ demo = build_app()
+ demo.launch()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bca7cc2412345f8e728453fa1ddf17cbf15959a1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+gradio
+torch
+transformers
+vllm
+funasr-onnx
+huggingface_hub
+jiwer
+zhon
+loguru
+pyrootutils
+jieba
+torchtune
+torchao
+vector_quantize_pytorch
\ No newline at end of file
diff --git a/tools/__pycache__/wer.cpython-310.pyc b/tools/__pycache__/wer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc3291d706a254b003ed995c1e4ef69d6f7a7235
Binary files /dev/null and b/tools/__pycache__/wer.cpython-310.pyc differ
diff --git a/tools/wer.py b/tools/wer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a931843906d79947e0348729d5ab3f15d22f2189
--- /dev/null
+++ b/tools/wer.py
@@ -0,0 +1,59 @@
+# tools/wer.py
+from __future__ import annotations
+
+from typing import List, Tuple
+import string
+
+from jiwer import process_words
+from zhon.hanzi import punctuation as zh_punctuation
+
+# 中文标点 + 英文标点 + '-'
+_PUNCTUATION_ALL = zh_punctuation + string.punctuation + "-"
+
+
+def _normalize_pair(gt: str, gen: str, lang: str) -> Tuple[str, str]:
+ gt = "" if gt is None else str(gt)
+ gen = "" if gen is None else str(gen)
+
+ # 去标点(保留 "'")
+ for x in _PUNCTUATION_ALL:
+ if x == "'":
+ continue
+ gt = gt.replace(x, "")
+ gen = gen.replace(x, "")
+
+ # 统一空格与连字符
+ gt = gt.replace(" ", " ").replace("-", " ")
+ gen = gen.replace(" ", " ").replace("-", " ")
+
+ if lang == "zh":
+ # 把“字”当作 token
+ gt = " ".join([ch for ch in gt])
+ gen = " ".join([ch for ch in gen])
+ elif lang == "en":
+ gt = gt.lower()
+ gen = gen.lower()
+ else:
+ raise NotImplementedError("lang must be 'zh' or 'en'")
+
+ return gt, gen
+
+
+def compute_wers(gt_texts: List[str], gen_texts: List[str], lang: str = "zh") -> List[float]:
+ if len(gt_texts) != len(gen_texts):
+ raise ValueError(f"Length mismatch: {len(gt_texts)} != {len(gen_texts)}")
+
+ wers: List[float] = []
+ for gt_raw, gen_raw in zip(gt_texts, gen_texts):
+ gt_norm, gen_norm = _normalize_pair(gt_raw, gen_raw, lang=lang)
+ measures = process_words(reference=gt_norm, hypothesis=gen_norm)
+ wers.append(float(measures.wer))
+ return wers
+
+
+
+if __name__ == "__main__":
+ gt = ["你好世界啊", "今天天气不对", "abc-def"]
+ gen = ["你好,世界!", "今天 天气 不错", "abc def"]
+ print(compute_wers(gt, gen, lang="zh"))
+ print(compute_wers(["Hello World"], ["hello, world!"], lang="en"))
diff --git a/xcodec2/.gitattributes b/xcodec2/.gitattributes
new file mode 100755
index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b
--- /dev/null
+++ b/xcodec2/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/xcodec2/.vscode/settings.json b/xcodec2/.vscode/settings.json
new file mode 100755
index 0000000000000000000000000000000000000000..a8c200329368ddfdbe180fdbc2deda24ed7a9ce4
--- /dev/null
+++ b/xcodec2/.vscode/settings.json
@@ -0,0 +1,5 @@
+{
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
+ "python-envs.pythonProjects": []
+}
\ No newline at end of file
diff --git a/xcodec2/README.md b/xcodec2/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..62dad23a13f625c72f95a3977bee90e18d19622f
--- /dev/null
+++ b/xcodec2/README.md
@@ -0,0 +1,69 @@
+---
+license: cc-by-nc-4.0
+tags:
+- audio-to-audio
+pipeline_tag: audio-to-audio
+---
+
+[](https://arxiv.org/abs/2502.04128)
+**Update (2025-02-13):** Add [Llasa finetune instruction](https://github.com/zhenye234/LLaSA_training/tree/main/finetune).
+
+**Update (2025-02-07):** Our paper has been released!
+
+
+## Paper
+
+LLaSA: Scaling Train Time and Inference Time Compute for LLaMA based Speech Synthesis
+
+Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model (AAAI 2025, xcodec 1.0)
+
+
+# Getting Started with XCodec2 on Hugging Face
+XCodec2 is a speech tokenizer that offers the following key features:
+
+1. **Single Vector Quantization**
+2. **50 Tokens per Second**
+3. **Multilingual Speech Semantic Support and High-Quality Speech Reconstruction**
+
+
+To use `xcodec2`, ensure you have it installed. You can install it using the following command:
+
+```bash
+conda create -n xcodec2 python=3.9
+conda activate xcodec2
+pip install xcodec2 (Use `xcodec2==0.1.5` for codec inference and llasa fine-tuning. I’ve removed unnecessary dependencies, and it works fine in my testing. However, I’m not sure if other problems may arise. If you prefer more stability, I recommend using `xcodec2==0.1.3` which accurately aligns during my codec training.)
+
+```
+Then,
+```python
+import torch
+import soundfile as sf
+from transformers import AutoConfig
+
+
+from xcodec2.modeling_xcodec2 import XCodec2Model
+
+model_path = "HKUSTAudio/xcodec2"
+
+model = XCodec2Model.from_pretrained(model_path)
+model.eval().cuda()
+
+
+wav, sr = sf.read("test.wav")
+wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # Shape: (1, T)
+
+
+with torch.no_grad():
+ # Only 16khz speech
+ # Only supports single input. For batch inference, please refer to the link below.
+ vq_code = model.encode_code(input_waveform=wav_tensor)
+ print("Code:", vq_code )
+
+ recon_wav = model.decode_code(vq_code).cpu() # Shape: (1, 1, T')
+
+
+sf.write("reconstructed.wav", recon_wav[0, 0, :].numpy(), sr)
+print("Done! Check reconstructed.wav")
+```
+
+# If you want to train your own xcodec2, batch inference, or large-scale code extraction, the code is released [here](https://github.com/zhenye234/X-Codec-2.0).
\ No newline at end of file
diff --git a/xcodec2/__init__.py b/xcodec2/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/xcodec2/__pycache__/__init__.cpython-310.pyc b/xcodec2/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5d53d148b6f56ba74300128c778774d3ee3c2ef
Binary files /dev/null and b/xcodec2/__pycache__/__init__.cpython-310.pyc differ
diff --git a/xcodec2/__pycache__/configuration_bigcodec.cpython-310.pyc b/xcodec2/__pycache__/configuration_bigcodec.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..992407af71e7ff35b0b1a5eadf2a3d304d2075cc
Binary files /dev/null and b/xcodec2/__pycache__/configuration_bigcodec.cpython-310.pyc differ
diff --git a/xcodec2/__pycache__/configuration_bigcodec.cpython-38.pyc b/xcodec2/__pycache__/configuration_bigcodec.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c632a2e8cf7379b702ea3522e842e3c437aef32f
Binary files /dev/null and b/xcodec2/__pycache__/configuration_bigcodec.cpython-38.pyc differ
diff --git a/xcodec2/__pycache__/modeling_xcodec2.cpython-310.pyc b/xcodec2/__pycache__/modeling_xcodec2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41f284587665be62e83026658f7b64e8c59d298a
Binary files /dev/null and b/xcodec2/__pycache__/modeling_xcodec2.cpython-310.pyc differ
diff --git a/xcodec2/__pycache__/modeling_xcodec2.cpython-38.pyc b/xcodec2/__pycache__/modeling_xcodec2.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..6c6a4c39986e4cd619c09b9588e28f0dc5ad9811
Binary files /dev/null and b/xcodec2/__pycache__/modeling_xcodec2.cpython-38.pyc differ
diff --git a/xcodec2/config.json b/xcodec2/config.json
new file mode 100755
index 0000000000000000000000000000000000000000..5acd94f351833f55f40fd49e4f9da49d34c8540f
--- /dev/null
+++ b/xcodec2/config.json
@@ -0,0 +1,11 @@
+{
+ "model_type": "xcodec2",
+ "semantic_hidden_size": 1024,
+ "codec_encoder_hidden_size": 1024,
+ "codec_decoder_hidden_size": 1024,
+ "use_vocos": true,
+ "architectures": [
+ "XCodec2Model"
+ ]
+ }
+
\ No newline at end of file
diff --git a/xcodec2/configuration_bigcodec.py b/xcodec2/configuration_bigcodec.py
new file mode 100755
index 0000000000000000000000000000000000000000..c7aa72b64eca8cae1220413f47039a112100d866
--- /dev/null
+++ b/xcodec2/configuration_bigcodec.py
@@ -0,0 +1,19 @@
+from transformers import PretrainedConfig
+
+class BigCodecConfig(PretrainedConfig):
+ model_type = "bigcodec"
+
+ def __init__(
+ self,
+ # 下面这些只是示例超参
+ semantic_hidden_size=1024,
+ codec_encoder_hidden_size=1024,
+ codec_decoder_hidden_size=1024,
+ use_vocos=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.semantic_hidden_size = semantic_hidden_size
+ self.codec_encoder_hidden_size = codec_encoder_hidden_size
+ self.codec_decoder_hidden_size = codec_decoder_hidden_size
+ self.use_vocos = use_vocos
diff --git a/xcodec2/modeling_xcodec2.py b/xcodec2/modeling_xcodec2.py
new file mode 100755
index 0000000000000000000000000000000000000000..93b12ec6a6108213e19a3ad87057ee22d17994c2
--- /dev/null
+++ b/xcodec2/modeling_xcodec2.py
@@ -0,0 +1,164 @@
+import torch
+import torch.nn as nn
+from transformers import PreTrainedModel
+from xcodec2.configuration_bigcodec import BigCodecConfig
+
+from xcodec2.vq.codec_encoder import CodecEncoder_Transformer
+from xcodec2.vq.codec_decoder_vocos import CodecDecoderVocos
+from xcodec2.vq.module import SemanticEncoder
+from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
+
+class XCodec2Model(PreTrainedModel):
+ config_class = BigCodecConfig
+
+ def __init__(self, config: BigCodecConfig):
+ super().__init__(config)
+
+ # 1) 语义模型
+ self.semantic_model = Wav2Vec2BertModel.from_pretrained(
+ "facebook/w2v-bert-2.0",
+ output_hidden_states=True
+ )
+ self.semantic_model.eval()
+
+ self.SemanticEncoder_module = SemanticEncoder(
+ config.semantic_hidden_size,
+ config.semantic_hidden_size,
+ config.semantic_hidden_size
+ )
+
+ # 2) Codec Encoder
+ self.CodecEnc = CodecEncoder_Transformer()
+
+ # 3) Codec Decoder
+ self.generator = CodecDecoderVocos()
+
+ # 4) 两个全连接层
+ self.fc_prior = nn.Linear(2048, 2048)
+ self.fc_post_a = nn.Linear(2048, 1024)
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
+ self.feature_extractor = feature_extractor
+
+ def forward(self, input_waveform, sample_rate=16000):
+ """
+ 这里的 forward 不一定要叫 forward,也可以拆成别的方法;
+ 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。
+
+ 参数:
+ input_waveform: [batch_size, waveform_length]
+ sample_rate: 默认 16000
+ 返回:
+ 重构后的语音音频 (Tensor)
+ """
+ # 1) 特征提取
+ # 如果需要 padding,可以在这里做
+ input_features = self.feature_extractor(
+ input_waveform,
+ sampling_rate=sample_rate,
+ return_tensors="pt"
+ ).input_features.to(self.device) # [batch, frames, feat_dim]
+
+ # 2) 语义层
+ semantic_output = self.semantic_model(input_features)
+ semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
+ semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
+ semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
+
+ # 3) codec encoder
+ wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
+ vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
+ vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
+
+ # 对齐语义向量的时间帧数,这里只做示例处理
+ # 真实做法里可能要先对齐维度
+ if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
+ # 简单强行截断或补零都行,需要你自己决定
+ min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
+ vq_emb = vq_emb[:, :, :min_len]
+ semantic_encoded = semantic_encoded[:, :, :min_len]
+
+ # 4) 拼接
+ concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames]
+
+ # 5) fc_prior
+ concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
+
+ # 6) decoder 的量化部分
+ _, vq_code, _ = self.generator(concat_emb, vq=True)
+ vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
+ vq_post_emb = vq_post_emb.transpose(1, 2)
+
+ # 7) fc_post_a
+ vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2)
+
+ # 8) 最后解码成波形
+ recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0]
+ # recon_audio: [batch, time]
+ return recon_audio
+
+ def encode_code(self, input_waveform, sample_rate=16000):
+ """
+ 将输入的音频编码为代码表示。
+
+ 参数:
+ input_waveform: [batch_size, waveform_length]
+ sample_rate: 默认 16000
+ 返回:
+ 编码后的代码 (Tensor)
+ """
+ with torch.no_grad():
+ # 1) 特征提取
+ input_features = self.feature_extractor(
+ input_waveform,
+ sampling_rate=sample_rate,
+ return_tensors="pt"
+ ).input_features.to(self.device) # [batch, frames, feat_dim]
+
+ # 2) 语义层
+ semantic_output = self.semantic_model(input_features)
+ semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
+ semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
+ semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
+
+ # 3) codec encoder
+ wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
+ vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
+ vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
+
+ # 对齐语义向量的时间帧数,这里只做示例处理
+ if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
+ min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
+ vq_emb = vq_emb[:, :, :min_len]
+ semantic_encoded = semantic_encoded[:, :, :min_len]
+
+ # 4) 拼接
+ concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames]
+
+ # 5) fc_prior
+ concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
+
+ # 6) decoder 的量化部分,获取code
+ _, vq_code, _ = self.generator(concat_emb, vq=True)
+ # vq_code: [batch, frames]
+ return vq_code
+
+ def decode_code(self, vq_code):
+ """
+ 将编码后的代码解码回音频。
+
+ 参数:
+ vq_code: 编码后的代码 (Tensor) [batch, frames]
+ 返回:
+ 解码后的音频 (Tensor) [batch, waveform_length]
+ """
+ with torch.no_grad():
+ # 获取量化后的嵌入
+ vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
+ vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
+
+ # 7) fc_post_a
+ vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames]
+
+ # 8) 最后解码成波形
+ recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time]
+ return recon_audio
diff --git a/xcodec2/module.py b/xcodec2/module.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/xcodec2/vq/__init__.py b/xcodec2/vq/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..9de36961fa5ddf77090a182fc427437ccdc7aff9
--- /dev/null
+++ b/xcodec2/vq/__init__.py
@@ -0,0 +1,4 @@
+from xcodec2.vq.codec_encoder import CodecEncoder
+from xcodec2.vq.codec_decoder import CodecDecoder
+from xcodec2.vq.codec_decoder_vocos import CodecDecoderVocos
+from xcodec2.vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer
\ No newline at end of file
diff --git a/xcodec2/vq/__pycache__/__init__.cpython-310.pyc b/xcodec2/vq/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe88c3604bf4cdd17ba29e04c47a0aea3222e111
Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/__init__.cpython-311.pyc b/xcodec2/vq/__pycache__/__init__.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..3a622e8da2ba3969a82261d2c1e6a3dbe2a87155
Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/__init__.cpython-312.pyc b/xcodec2/vq/__pycache__/__init__.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a1bd6ded01b7b342483f5e8ab7993e74a8b74054
Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/__init__.cpython-38.pyc b/xcodec2/vq/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..63230e0918a56b5b56c5d83e79c957975782b555
Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/__init__.cpython-39.pyc b/xcodec2/vq/__pycache__/__init__.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..ee6edbe0d9811e1f6d087bee9c44604e457ae546
Binary files /dev/null and b/xcodec2/vq/__pycache__/__init__.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/activations.cpython-310.pyc b/xcodec2/vq/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d8b95e80d0e513cfa122854d290eed64a52a6a3
Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/activations.cpython-311.pyc b/xcodec2/vq/__pycache__/activations.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..b68f81ee8be448f64ef4c8f7bab31d57bdd561a8
Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/activations.cpython-312.pyc b/xcodec2/vq/__pycache__/activations.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..7b65c1363e2538d7f7da65eacf3801535d9dfe7e
Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/activations.cpython-38.pyc b/xcodec2/vq/__pycache__/activations.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a8d378a7fc0cf94615324f3cc048787815f89879
Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/activations.cpython-39.pyc b/xcodec2/vq/__pycache__/activations.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a93430eb9e182116727bb2fd9a8c1ced9e1a423b
Binary files /dev/null and b/xcodec2/vq/__pycache__/activations.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/blocks.cpython-310.pyc b/xcodec2/vq/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d49d37165ecdc9f6f8ff41aabb31134c35baa720
Binary files /dev/null and b/xcodec2/vq/__pycache__/blocks.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/blocks.cpython-38.pyc b/xcodec2/vq/__pycache__/blocks.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c8b1ace989b247dbfd268e8693d20d09558e7858
Binary files /dev/null and b/xcodec2/vq/__pycache__/blocks.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/blocks.cpython-39.pyc b/xcodec2/vq/__pycache__/blocks.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..fa36890d26c5d256baf108973f548e68237fc1a9
Binary files /dev/null and b/xcodec2/vq/__pycache__/blocks.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/bs_roformer5.cpython-310.pyc b/xcodec2/vq/__pycache__/bs_roformer5.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bf1a481e797c324cd5c7ea897c596eb9ffebbe2
Binary files /dev/null and b/xcodec2/vq/__pycache__/bs_roformer5.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/bs_roformer5.cpython-38.pyc b/xcodec2/vq/__pycache__/bs_roformer5.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..8c4edcbdb6c7963eabc1201b3020b3178343a672
Binary files /dev/null and b/xcodec2/vq/__pycache__/bs_roformer5.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/bs_roformer5.cpython-39.pyc b/xcodec2/vq/__pycache__/bs_roformer5.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..92d8ebd686f5ca067f8afedd3c11df3eae114f8b
Binary files /dev/null and b/xcodec2/vq/__pycache__/bs_roformer5.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-310.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcd2f514cda5c60fbf4267573e0f421799879ebd
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-311.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..d3bba0654a05fc2d5df7f0cfa85d4e1e8f9e43a9
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-312.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..9fd232bd7d3e8217ae32122d85158c8d3d07a0ba
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-38.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..e5ab02d4e11a9b3abe48f4eb7d1b1b982be70703
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder.cpython-39.pyc b/xcodec2/vq/__pycache__/codec_decoder.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..342c797d5e48be38a60b3987b4ccf03e115c42f1
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b45e5d90c8b7cb9202abee91370584c8fec28f9f
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..16d7559f3e464f7cc8229cf5cf3864cf6ce5270a
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..4d14cbb3475d4adca86027ce48447562dd8bd872
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-38.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..094936bcc87dbe3f0480acf7b5987be727685911
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..039c2ff83bd0dece5f89baf003ee5c212a9ec9eb
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-310.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1815fe3ea18bd9f317d6f1b2b21496cc2b24ed25
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-311.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a95e106c52b918fc9879df154585f48f5b97f084
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-312.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..033919b2465e77531ac1f4340a42e8247dcf927e
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-38.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..4f4810f05dd5515271540d51492b58aafe648829
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/codec_encoder.cpython-39.pyc b/xcodec2/vq/__pycache__/codec_encoder.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a6776bd150fda7c289c24c8798d44fb5abe57993
Binary files /dev/null and b/xcodec2/vq/__pycache__/codec_encoder.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db91e97b2f8bb8d07f0c21a706ee6aeab66561a9
Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..d1018bb516e0ca39612d3046cc66e93d3e9fc857
Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a79936f97258dfac4e4c1fa00bcba84459f754bf
Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-38.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..f60d99f878ba719a277329f713bdca110ebb514b
Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..ff4eb90349eef0207e03b2d991f85499da62ab59
Binary files /dev/null and b/xcodec2/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/module.cpython-310.pyc b/xcodec2/vq/__pycache__/module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a30f28a9cf972e1615e89597e4cdef4c03fd3d4
Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/module.cpython-311.pyc b/xcodec2/vq/__pycache__/module.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..8bae56a74114e5ecb523e39275bf451e406e43a2
Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/module.cpython-312.pyc b/xcodec2/vq/__pycache__/module.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..67dc7809ed198ecba09efecfa433224dc7b5ea4a
Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/module.cpython-38.pyc b/xcodec2/vq/__pycache__/module.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..d3726fc890588cf3cf4e921e4d0a9d446561befc
Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/module.cpython-39.pyc b/xcodec2/vq/__pycache__/module.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..97d6cb68f45b74a72d220b758c224fc1d7bcf974
Binary files /dev/null and b/xcodec2/vq/__pycache__/module.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-310.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a5ed5c3cef57c636101212049f045ad72269253
Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-310.pyc differ
diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-311.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..92069858ec62cb51a7cce7cdf1550695c656e557
Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-311.pyc differ
diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-312.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..4d81960f8bbec9c309fa08267b44245d8a5b4959
Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-38.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..91a0dc321a82434f8f32fd177bbca7f59d4c276d
Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-38.pyc differ
diff --git a/xcodec2/vq/__pycache__/residual_vq.cpython-39.pyc b/xcodec2/vq/__pycache__/residual_vq.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c004cb105a51fd161b325e6568b965634f0cddce
Binary files /dev/null and b/xcodec2/vq/__pycache__/residual_vq.cpython-39.pyc differ
diff --git a/xcodec2/vq/__pycache__/unet.cpython-312.pyc b/xcodec2/vq/__pycache__/unet.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..e431b555864d1029ed9735ed1bad0bae1b75c4f8
Binary files /dev/null and b/xcodec2/vq/__pycache__/unet.cpython-312.pyc differ
diff --git a/xcodec2/vq/__pycache__/unet.cpython-39.pyc b/xcodec2/vq/__pycache__/unet.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..9e12a03c0fba3df01e7d87e249fd77a3bab599f1
Binary files /dev/null and b/xcodec2/vq/__pycache__/unet.cpython-39.pyc differ
diff --git a/xcodec2/vq/activations.py b/xcodec2/vq/activations.py
new file mode 100755
index 0000000000000000000000000000000000000000..2444a7bd9d52018e97892820a072b39a21245372
--- /dev/null
+++ b/xcodec2/vq/activations.py
@@ -0,0 +1,120 @@
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ '''
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.bias = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.bias = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.bias.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.bias.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
\ No newline at end of file
diff --git a/xcodec2/vq/alias_free_torch/__init__.py b/xcodec2/vq/alias_free_torch/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc
--- /dev/null
+++ b/xcodec2/vq/alias_free_torch/__init__.py
@@ -0,0 +1,6 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+from .filter import *
+from .resample import *
+from .act import *
\ No newline at end of file
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..661c1297fc3bbe2f0fa5c1197243e51c4d4cb263
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..8c624a5d6d4dc886d41f429249030bde5539887e
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a43e41d400b719ffcbd2edeb74835601d56febf5
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c68c37cfa7cf1fa7585601af0b02e7842b96e717
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..01dd0800751de52eba5308427e45658cfa47442d
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b149407ee3f1bcd541d4cf0e96496b383e51302
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-310.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..6db1958197fb0a213b4ff1df75dcfbf89317ac7d
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-311.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..8cfb68556b30eb0c90e80734d84ce4630c84f6fe
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-312.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..7e857551f8e5e2cf76b60615e55d954038fc623d
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-38.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..11d880c40d0fed4924c1d5245722c5218e185ef2
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/act.cpython-39.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..879b4e00e5fd77097e91242846374d5dae2c579c
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..ce1c1840a49147f4771cb6c43c75fafdff331ebf
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..e7c8b0886d3813d2f3184509934becbdd6a7cf6f
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..f6066d9ddb5cb651a8080ac437a8b64c94060de2
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..46ca966e33e20803fba84a5da40097a83907bc7d
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2829b6c928284aa5de8360dcfca87edcb435945
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..51f3f964031846bf45ae58bc751a8d74afe4495f
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..bd086d178282ace7cda97ede64bd598c43012dab
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..d0f4291e2f540087842b52b38ac8c302bacd150e
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..15613b4678ac49403758cba125a01e482c12b390
Binary files /dev/null and b/xcodec2/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc differ
diff --git a/xcodec2/vq/alias_free_torch/act.py b/xcodec2/vq/alias_free_torch/act.py
new file mode 100755
index 0000000000000000000000000000000000000000..028debd697dd60458aae75010057df038bd3518a
--- /dev/null
+++ b/xcodec2/vq/alias_free_torch/act.py
@@ -0,0 +1,28 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from .resample import UpSample1d, DownSample1d
+
+
+class Activation1d(nn.Module):
+ def __init__(self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
\ No newline at end of file
diff --git a/xcodec2/vq/alias_free_torch/filter.py b/xcodec2/vq/alias_free_torch/filter.py
new file mode 100755
index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1
--- /dev/null
+++ b/xcodec2/vq/alias_free_torch/filter.py
@@ -0,0 +1,95 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if 'sinc' in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ # LICENSE is in incl_licenses directory.
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(x == 0,
+ torch.tensor(1., device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x)
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+# LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
+ even = (kernel_size % 2 == 0)
+ half_size = kernel_size // 2
+
+ #For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.:
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
+ else:
+ beta = 0.
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = (torch.arange(-half_size, half_size) + 0.5)
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = 'replicate',
+ kernel_size: int = 12):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = (kernel_size % 2 == 0)
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ #input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right),
+ mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
+ stride=self.stride, groups=C)
+
+ return out
\ No newline at end of file
diff --git a/xcodec2/vq/alias_free_torch/resample.py b/xcodec2/vq/alias_free_torch/resample.py
new file mode 100755
index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7
--- /dev/null
+++ b/xcodec2/vq/alias_free_torch/resample.py
@@ -0,0 +1,49 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from torch.nn import functional as F
+from .filter import LowPassFilter1d
+from .filter import kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+ x = x[..., self.pad_left:-self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size)
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
\ No newline at end of file
diff --git a/xcodec2/vq/blocks.py b/xcodec2/vq/blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..3996fec146cbf4f3caef4f9da3bbbe04f7729bbb
--- /dev/null
+++ b/xcodec2/vq/blocks.py
@@ -0,0 +1,183 @@
+from typing import Callable, Sequence, Type, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]]
+
+
+class FeedForwardModule(nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.net = None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.net(x)
+
+
+class Residual(nn.Module):
+
+ def __init__(self, module: nn.Module) -> None:
+ super().__init__()
+ self.module = module
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.module(x) + x
+
+
+class DilatedConvolutionalUnit(FeedForwardModule):
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ dilation: int,
+ kernel_size: int,
+ activation: ModuleFactory,
+ normalization: Callable[[nn.Module],
+ nn.Module] = lambda x: x) -> None:
+ super().__init__()
+ self.net = nn.Sequential(
+ activation(),
+ normalization(
+ nn.Conv1d(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=((kernel_size - 1) * dilation) // 2,
+ )),
+ activation(),
+ nn.Conv1d(in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=1),
+ )
+
+
+class UpsamplingUnit(FeedForwardModule):
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ stride: int,
+ activation: ModuleFactory,
+ normalization: Callable[[nn.Module],
+ nn.Module] = lambda x: x) -> None:
+ super().__init__()
+ self.net = nn.Sequential(
+ activation(),
+ normalization(
+ nn.ConvTranspose1d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=stride // 2+ stride % 2,
+ output_padding=1 if stride % 2 != 0 else 0
+ )))
+
+
+class DownsamplingUnit(FeedForwardModule):
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ stride: int,
+ activation: ModuleFactory,
+ normalization: Callable[[nn.Module],
+ nn.Module] = lambda x: x) -> None:
+ super().__init__()
+ self.net = nn.Sequential(
+ activation(),
+ normalization(
+ nn.Conv1d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding= stride // 2+ stride % 2,
+
+ )))
+
+
+class DilatedResidualEncoder(FeedForwardModule):
+
+ def __init__(
+ self,
+ capacity: int,
+ dilated_unit: Type[DilatedConvolutionalUnit],
+ downsampling_unit: Type[DownsamplingUnit],
+ ratios: Sequence[int],
+ dilations: Union[Sequence[int], Sequence[Sequence[int]]],
+ pre_network_conv: Type[nn.Conv1d],
+ post_network_conv: Type[nn.Conv1d],
+ normalization: Callable[[nn.Module],
+ nn.Module] = lambda x: x) -> None:
+ super().__init__()
+ channels = capacity * 2**np.arange(len(ratios) + 1)
+
+ dilations_list = self.normalize_dilations(dilations, ratios)
+
+ net = [normalization(pre_network_conv(out_channels=channels[0]))]
+
+ for ratio, dilations, input_dim, output_dim in zip(
+ ratios, dilations_list, channels[:-1], channels[1:]):
+ for dilation in dilations:
+ net.append(Residual(dilated_unit(input_dim, dilation)))
+ net.append(downsampling_unit(input_dim, output_dim, ratio))
+
+ net.append(post_network_conv(in_channels=output_dim))
+
+ self.net = nn.Sequential(*net)
+
+ @staticmethod
+ def normalize_dilations(dilations: Union[Sequence[int],
+ Sequence[Sequence[int]]],
+ ratios: Sequence[int]):
+ if isinstance(dilations[0], int):
+ dilations = [dilations for _ in ratios]
+ return dilations
+
+
+class DilatedResidualDecoder(FeedForwardModule):
+
+ def __init__(
+ self,
+ capacity: int,
+ dilated_unit: Type[DilatedConvolutionalUnit],
+ upsampling_unit: Type[UpsamplingUnit],
+ ratios: Sequence[int],
+ dilations: Union[Sequence[int], Sequence[Sequence[int]]],
+ pre_network_conv: Type[nn.Conv1d],
+ post_network_conv: Type[nn.Conv1d],
+ normalization: Callable[[nn.Module],
+ nn.Module] = lambda x: x) -> None:
+ super().__init__()
+ channels = capacity * 2**np.arange(len(ratios) + 1)
+ channels = channels[::-1]
+
+ dilations_list = self.normalize_dilations(dilations, ratios)
+ dilations_list = dilations_list[::-1]
+
+ net = [pre_network_conv(out_channels=channels[0])]
+
+ for ratio, dilations, input_dim, output_dim in zip(
+ ratios, dilations_list, channels[:-1], channels[1:]):
+ net.append(upsampling_unit(input_dim, output_dim, ratio))
+ for dilation in dilations:
+ net.append(Residual(dilated_unit(output_dim, dilation)))
+
+ net.append(normalization(post_network_conv(in_channels=output_dim)))
+
+ self.net = nn.Sequential(*net)
+
+ @staticmethod
+ def normalize_dilations(dilations: Union[Sequence[int],
+ Sequence[Sequence[int]]],
+ ratios: Sequence[int]):
+ if isinstance(dilations[0], int):
+ dilations = [dilations for _ in ratios]
+ return dilations
\ No newline at end of file
diff --git a/xcodec2/vq/bs_roformer5.py b/xcodec2/vq/bs_roformer5.py
new file mode 100755
index 0000000000000000000000000000000000000000..08aa016d731a6a5cae3e4f38514d97187ad7adb4
--- /dev/null
+++ b/xcodec2/vq/bs_roformer5.py
@@ -0,0 +1,123 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Module, ModuleList
+import torchaudio
+from einops import rearrange
+import numpy as np
+# from rotary_embedding_torch import RotaryEmbedding
+
+from torchtune.modules import RotaryPositionalEmbeddings
+
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm_x = torch.mean(x ** 2, dim=-1, keepdim=True)
+ output = x * torch.rsqrt(norm_x + self.eps) * self.weight
+ return output
+
+
+
+class MLP(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+
+ self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
+ self.silu = nn.SiLU()
+ self.fc2 = nn.Linear(4 * dim, dim, bias=False)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.silu(x)
+ x = self.fc2(x)
+ return x
+
+
+class Attention(nn.Module):
+
+ def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
+ super().__init__()
+
+ assert dim % n_heads == 0
+
+ self.n_heads = n_heads
+ self.dim = dim
+ self.rotary_embed = rotary_embed
+
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+ assert self.flash, "Must have flash attention."
+
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
+ self.c_proj = nn.Linear(dim, dim, bias=False)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x: (b, t, h*d)
+
+ Constants:
+ b: batch_size
+ t: time steps
+ r: 3
+ h: heads_num
+ d: heads_dim
+ """
+ B, T, C = x.size()
+
+ q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads)
+ # q, k, v: (b, h, t, d)
+
+ q = self.rotary_embed(q)
+ k = self.rotary_embed(k)
+
+ if self.flash:
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False)
+
+ y = rearrange(y, 'b h t d -> b t (h d)')
+
+ y = self.c_proj(y)
+ # shape: (b, t, h*d)
+
+ return y
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings):
+
+ super().__init__()
+ self.dim = dim
+ self.n_heads = n_heads
+
+ self.att_norm = RMSNorm(dim)
+ self.ffn_norm = RMSNorm(dim)
+ self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
+ self.mlp = MLP(dim=dim)
+
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ):
+ x = x + self.att(self.att_norm(x))
+ x = x + self.mlp(self.ffn_norm(x))
+ return x
+
+
+if __name__ == '__main__':
+ rotary_embed_128 = RotaryPositionalEmbeddings(dim=128)
+ transformer_block = TransformerBlock(
+ dim=1024,
+ n_heads=8,
+ rotary_embed=rotary_embed_128
+ )
+ x = torch.randn(2, 128, 1024)
+ y = transformer_block(x)
+ print(y.shape)
+ c=1
\ No newline at end of file
diff --git a/xcodec2/vq/codec_decoder.py b/xcodec2/vq/codec_decoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..9ea72e31ee70507b74d63752d8a8911929049e17
--- /dev/null
+++ b/xcodec2/vq/codec_decoder.py
@@ -0,0 +1,304 @@
+import sys
+
+import numpy as np
+import torch
+import torch.nn as nn
+from xcodec2.vq.residual_vq import ResidualVQ
+from xcodec2.vq.module import WNConv1d, DecoderBlock, ResLSTM
+from xcodec2.vq.alias_free_torch import *
+from xcodec2.vq import activations
+from xcodec2.vq import blocks as blocks
+from torch.nn import utils
+
+from xcodec2.vq.bs_roformer5 import TransformerBlock
+
+from torchtune.modules import RotaryPositionalEmbeddings
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+class CodecDecoder(nn.Module):
+ def __init__(self,
+ in_channels=1024,
+ upsample_initial_channel=1536,
+ ngf=48,
+ use_rnn=True,
+ rnn_bidirectional=False,
+ rnn_num_layers=2,
+ up_ratios=(5, 4, 4, 4, 2),
+ dilations=(1, 3, 9),
+ vq_num_quantizers=1,
+ vq_dim=2048,
+ vq_commit_weight=0.25,
+ vq_weight_init=False,
+ vq_full_commit_loss=False,
+ codebook_size=16384,
+ codebook_dim=32,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.ngf = ngf
+ self.up_ratios = up_ratios
+
+ self.quantizer = ResidualVQ(
+ num_quantizers=vq_num_quantizers,
+ dim=vq_dim, # double the dim for acousitc and semantic
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ )
+ channels = upsample_initial_channel
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
+
+ if use_rnn:
+ layers += [
+ ResLSTM(channels,
+ num_layers=rnn_num_layers,
+ bidirectional=rnn_bidirectional
+ )
+ ]
+
+ for i, stride in enumerate(up_ratios):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride, dilations)]
+
+ layers += [
+ Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)),
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ self.reset_parameters()
+
+ def forward(self, x, vq=True):
+ if vq is True:
+ x, q, commit_loss = self.quantizer(x)
+ return x, q, commit_loss
+ x = self.model(x)
+ return x
+
+ def vq2emb(self, vq):
+ self.quantizer = self.quantizer.eval()
+ x = self.quantizer.vq2emb(vq)
+ return x
+
+ def get_emb(self):
+ self.quantizer = self.quantizer.eval()
+ embs = self.quantizer.get_emb()
+ return embs
+
+ def inference_vq(self, vq):
+ x = vq[None,:,:]
+ x = self.model(x)
+ return x
+
+ def inference_0(self, x):
+ x, q, loss, perp = self.quantizer(x)
+ x = self.model(x)
+ return x, None
+
+ def inference(self, x):
+ x = self.model(x)
+ return x, None
+
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+class CodecDecoder_oobleck_Transformer(nn.Module):
+ def __init__(self,
+ ngf=32,
+ up_ratios=(5, 4, 4, 4, 2),
+ dilations=(1, 3, 9),
+ vq_num_quantizers=1,
+ vq_dim=1024,
+ vq_commit_weight=0.25,
+ vq_weight_init=False,
+ vq_full_commit_loss=False,
+ codebook_size=16384,
+ codebook_dim=16,
+ hidden_dim=1024,
+ depth=12,
+ heads=16,
+ pos_meb_dim=64,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.capacity = ngf
+ self.up_ratios = up_ratios
+ self.hidden_dim = hidden_dim
+ self.quantizer = ResidualVQ(
+ num_quantizers=vq_num_quantizers,
+ dim=vq_dim, # double the dim for acousitc and semantic
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ )
+
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
+
+ transformer_blocks = [
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
+ for _ in range(depth)
+ ]
+
+ self.transformers = nn.Sequential(*transformer_blocks)
+
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
+
+ self.conv_blocks = blocks.DilatedResidualDecoder(
+ capacity=self.capacity,
+ dilated_unit=self.dilated_unit,
+ upsampling_unit=self.upsampling_unit,
+ ratios=up_ratios, # 逆转编码器的下采样比率
+ dilations=dilations,
+ pre_network_conv=self.pre_conv,
+ post_network_conv=self.post_conv,
+ )
+
+
+
+ self.reset_parameters()
+
+ def forward(self, x, vq=True):
+ if vq is True:
+ x, q, commit_loss = self.quantizer(x)
+ return x, q, commit_loss
+ x= self.transformers(x)
+ x = self.final_layer_norm(x)
+ x = x.permute(0, 2, 1)
+ x = self.conv_blocks(x)
+ return x
+
+ def vq2emb(self, vq):
+ self.quantizer = self.quantizer.eval()
+ x = self.quantizer.vq2emb(vq)
+ return x
+
+ def get_emb(self):
+ self.quantizer = self.quantizer.eval()
+ embs = self.quantizer.get_emb()
+ return embs
+
+ def inference_vq(self, vq):
+ x = vq[None,:,:]
+ x = self.model(x)
+ return x
+
+ def inference_0(self, x):
+ x, q, loss, perp = self.quantizer(x)
+ x = self.model(x)
+ return x, None
+
+ def inference(self, x):
+ x = self.model(x)
+ return x, None
+
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+ def pre_conv(self, out_channels):
+ return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1)
+
+ # 定义后处理卷积层,将模型的输出映射到最终的输出通道数
+ def post_conv(self,in_channels):
+ return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1)
+
+ def dilated_unit(self, hidden_dim, dilation):
+ return blocks.DilatedConvolutionalUnit(
+ hidden_dim=hidden_dim,
+ dilation=dilation,
+ kernel_size=3,
+ activation=nn.ReLU ,
+ normalization=utils.weight_norm
+ )
+
+ # 定义上采样单元
+ def upsampling_unit(self,input_dim, output_dim, stride):
+ return blocks.UpsamplingUnit(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ stride=stride,
+ activation=nn.ReLU ,
+ normalization=utils.weight_norm
+ )
+
+def main():
+ # 设置设备
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # 初始化模型
+ model = CodecDecoder_oobleck_Transformer().to(device)
+ print("Model initialized.")
+
+ # 创建测试输入: batch_size x in_channels x sequence_length
+ batch_size = 2
+ in_channels = 1024
+ sequence_length = 100 # 示例长度,可以根据需要调整
+ dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device)
+ print(f"Dummy input shape: {dummy_input.shape}")
+
+ # 将模型设为评估模式
+ model.eval()
+
+
+
+ output_no_vq = model(dummy_input, vq=False)
+ c=1
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/xcodec2/vq/codec_decoder_vocos.py b/xcodec2/vq/codec_decoder_vocos.py
new file mode 100755
index 0000000000000000000000000000000000000000..ba7b5dc86b078515bf291ee3cbf45979bf4b7dff
--- /dev/null
+++ b/xcodec2/vq/codec_decoder_vocos.py
@@ -0,0 +1,638 @@
+import sys
+sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos')
+import numpy as np
+import torch
+import torch.nn as nn
+from xcodec2.vq.residual_vq import ResidualVQ
+from xcodec2.vq.module import WNConv1d, DecoderBlock, ResLSTM
+from xcodec2.vq.alias_free_torch import *
+from xcodec2.vq import activations
+from typing import Optional
+from xcodec2.vq.module import ConvNeXtBlock, AdaLayerNorm
+from xcodec2.vq.bs_roformer5 import TransformerBlock
+# from rotary_embedding_torch import RotaryEmbedding
+from torchtune.modules import RotaryPositionalEmbeddings
+from vector_quantize_pytorch import ResidualFSQ
+from torch.nn import Module, ModuleList
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ assert (window_envelope > 1e-11).all()
+ y = y / window_envelope
+
+ return y
+
+
+
+class FourierHead(nn.Module):
+ """Base class for inverse fourier modules."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class ISTFTHead(FourierHead):
+ """
+ ISTFT Head module for predicting STFT complex coefficients.
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
+ the resolution of the input features.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
+ super().__init__()
+ out_dim = n_fft + 2
+ self.out = torch.nn.Linear(dim, out_dim)
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ISTFTHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x_pred = self.out(x )
+ # x_pred = x
+ x_pred = x_pred.transpose(1, 2)
+ mag, p = x_pred.chunk(2, dim=1)
+ mag = torch.exp(mag)
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
+ # wrapping happens here. These two lines produce real and imaginary value
+ x = torch.cos(p)
+ y = torch.sin(p)
+ # recalculating phase here does not produce anything new
+ # only costs time
+ # phase = torch.atan2(y, x)
+ # S = mag * torch.exp(phase * 1j)
+ # better directly produce the complex value
+ S = mag * (x + 1j * y)
+ audio = self.istft(S)
+ return audio.unsqueeze(1),x_pred
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv1d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb=None):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h = q.shape
+ q = q.permute(0, 2, 1) # b,hw,c
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+
+
+class Backbone(nn.Module):
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
+ C denotes output features, and L is the sequence length.
+
+ Returns:
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
+ and H denotes the model dimension.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class VocosBackbone(Backbone):
+ """
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
+ num_layers (int): Number of ConvNeXtBlock layers.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional model. Defaults to None.
+ """
+
+ def __init__(
+ self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64):
+ super().__init__()
+
+ self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
+
+
+
+ self.temb_ch = 0
+ block_in = hidden_dim
+ dropout = 0.1
+
+ prior_net : tp.List[nn.Module] = [
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ ]
+ self.prior_net = nn.Sequential(*prior_net)
+
+ depth = depth
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
+
+
+ transformer_blocks = [
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
+ for _ in range(depth)
+ ]
+
+
+ self.transformers = nn.Sequential(*transformer_blocks)
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
+ post_net : tp.List[nn.Module] = [
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ ]
+ self.post_net = nn.Sequential(*post_net)
+
+ def forward(self, x: torch.Tensor ) -> torch.Tensor:
+ x = x.transpose(1, 2)
+ x = self.embed(x)
+ x = self.prior_net(x)
+ x = x.transpose(1, 2)
+ x= self.transformers(x)
+ x = x.transpose(1, 2)
+ x = self.post_net(x)
+ x = x.transpose(1, 2)
+ x = self.final_layer_norm(x)
+ return x
+
+
+
+
+
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+class CodecDecoderVocos(nn.Module):
+ def __init__(self,
+ hidden_dim=1024,
+ depth=12,
+ heads=16,
+ pos_meb_dim=64,
+ hop_length=320,
+ vq_num_quantizers=1,
+ vq_dim=2048, #1024 2048
+ vq_commit_weight=0.25,
+ vq_weight_init=False,
+ vq_full_commit_loss=False,
+ codebook_size=16384,
+ codebook_dim=16,
+ ):
+ super().__init__()
+ self.hop_length = hop_length
+
+ self.quantizer = ResidualFSQ(
+ dim = vq_dim,
+ levels = [4, 4, 4, 4, 4,4,4,4],
+ num_quantizers = 1
+ )
+
+ # self.quantizer = ResidualVQ(
+ # num_quantizers=vq_num_quantizers,
+ # dim=vq_dim,
+ # codebook_size=codebook_size,
+ # codebook_dim=codebook_dim,
+ # threshold_ema_dead_code=2,
+ # commitment=vq_commit_weight,
+ # weight_init=vq_weight_init,
+ # full_commit_loss=vq_full_commit_loss,
+ # )
+
+
+ self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim)
+
+ self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same")
+
+ self.reset_parameters()
+
+ def forward(self, x, vq=True):
+ if vq is True:
+ # x, q, commit_loss = self.quantizer(x)
+ x = x.permute(0, 2, 1)
+ x, q = self.quantizer(x)
+ x = x.permute(0, 2, 1)
+ q = q.permute(0, 2, 1)
+ return x, q, None
+ x = self.backbone(x)
+ x,_ = self.head(x)
+
+ return x ,_
+
+ def vq2emb(self, vq):
+ self.quantizer = self.quantizer.eval()
+ x = self.quantizer.vq2emb(vq)
+ return x
+
+ def get_emb(self):
+ self.quantizer = self.quantizer.eval()
+ embs = self.quantizer.get_emb()
+ return embs
+
+ def inference_vq(self, vq):
+ x = vq[None,:,:]
+ x = self.model(x)
+ return x
+
+ def inference_0(self, x):
+ x, q, loss, perp = self.quantizer(x)
+ x = self.model(x)
+ return x, None
+
+ def inference(self, x):
+ x = self.model(x)
+ return x, None
+
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+
+class CodecDecoderVocos_transpose(nn.Module):
+ def __init__(self,
+ hidden_dim=1024,
+ depth=12,
+ heads=16,
+ pos_meb_dim=64,
+ hop_length=320,
+ vq_num_quantizers=1,
+ vq_dim=1024, #1024 2048
+ vq_commit_weight=0.25,
+ vq_weight_init=False,
+ vq_full_commit_loss=False,
+ codebook_size=16384,
+ codebook_dim=16,
+ ):
+ super().__init__()
+ self.hop_length = hop_length
+
+
+ self.quantizer = ResidualVQ(
+ num_quantizers=vq_num_quantizers,
+ dim=vq_dim,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ )
+
+
+ self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim)
+
+ self.inverse_mel_conv = nn.Sequential(
+ nn.GELU(),
+ nn.ConvTranspose1d(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1 # 确保输出长度与编码前匹配
+ ),
+ nn.GELU(),
+ nn.ConvTranspose1d(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ padding=1
+ )
+ )
+
+ self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same")
+
+ self.reset_parameters()
+
+ def forward(self, x, vq=True):
+ if vq is True:
+ x, q, commit_loss = self.quantizer(x)
+ return x, q, commit_loss
+ x = self.backbone(x)
+ x,_ = self.head(x)
+
+ return x ,_
+
+ def vq2emb(self, vq):
+ self.quantizer = self.quantizer.eval()
+ x = self.quantizer.vq2emb(vq)
+ return x
+
+ def get_emb(self):
+ self.quantizer = self.quantizer.eval()
+ embs = self.quantizer.get_emb()
+ return embs
+
+ def inference_vq(self, vq):
+ x = vq[None,:,:]
+ x = self.model(x)
+ return x
+
+ def inference_0(self, x):
+ x, q, loss, perp = self.quantizer(x)
+ x = self.model(x)
+ return x, None
+
+ def inference(self, x):
+ x = self.model(x)
+ return x, None
+
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+
+
+def main():
+ # 设置设备
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # 初始化模型
+ model = CodecDecoderVocos_transpose().to(device)
+ print("Model initialized.")
+
+ # 创建测试输入: batch_size x in_channels x sequence_length
+ batch_size = 2
+ in_channels = 1024
+ sequence_length = 50 # 示例长度,可以根据需要调整
+ dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device)
+ print(f"Dummy input shape: {dummy_input.shape}")
+
+ # 将模型设为评估模式
+ model.eval()
+
+ # 前向传播(使用 VQ)
+ # with torch.no_grad():
+ # try:
+ # output, q, commit_loss = model(dummy_input, vq=True)
+ # print("Forward pass with VQ:")
+ # print(f"Output shape: {output.shape}")
+ # print(f"Quantized codes shape: {q.shape}")
+ # print(f"Commitment loss: {commit_loss}")
+ # except Exception as e:
+ # print(f"Error during forward pass with VQ: {e}")
+
+ # 前向传播(不使用 VQ)
+ with torch.no_grad():
+ # try:
+ output_no_vq = model(dummy_input, vq=False)
+ print("\nForward pass without VQ:")
+ print(f"Output shape: {output_no_vq.shape}")
+ c=1
+ # except Exception as e:
+ # print(f"Error during forward pass without VQ: {e}")
+
+
+ # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
+ # model_size_mb = model_size_bytes / (1024 ** 2)
+ # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/xcodec2/vq/codec_encoder.py b/xcodec2/vq/codec_encoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..717399f8c1c2eff77904a02058a888e550d75cf8
--- /dev/null
+++ b/xcodec2/vq/codec_encoder.py
@@ -0,0 +1,335 @@
+import sys
+
+import torch
+from torch import nn
+import numpy as np
+from xcodec2.vq.module import WNConv1d, EncoderBlock, ResLSTM
+from xcodec2.vq.alias_free_torch import *
+from xcodec2.vq import activations
+from xcodec2.vq.bs_roformer5 import TransformerBlock
+
+from torchtune.modules import RotaryPositionalEmbeddings
+import xcodec2.vq.blocks as blocks
+from torch.nn import utils
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+class CodecEncoder(nn.Module):
+ def __init__(self,
+ ngf=48,
+ use_rnn=True,
+ rnn_bidirectional=False,
+ rnn_num_layers=2,
+ up_ratios=(2, 2, 4, 4, 5),
+ dilations=(1, 3, 9),
+ out_channels=1024):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.ngf = ngf
+ self.up_ratios = up_ratios
+
+ # Create first convolution
+ d_model = ngf
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for i, stride in enumerate(up_ratios):
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)]
+ # RNN
+ if use_rnn:
+ self.block += [
+ ResLSTM(d_model,
+ num_layers=rnn_num_layers,
+ bidirectional=rnn_bidirectional
+ )
+ ]
+ # Create last convolution
+ self.block += [
+ Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)),
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ out = self.block(x)
+ return out
+
+ def inference(self, x):
+ return self.block(x)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim1, dim2):
+ super(Transpose, self).__init__()
+ self.dim1 = dim1
+ self.dim2 = dim2
+
+ def forward(self, x):
+ return x.transpose(self.dim1, self.dim2)
+
+class CodecEncoder_Transformer(nn.Module):
+ def __init__(self,
+ ngf=48,
+ up_ratios=[2, 2, 4, 4, 5],
+ dilations=(1, 3, 9),
+ hidden_dim=1024,
+ depth=12,
+ heads=12,
+ pos_meb_dim=64,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.ngf =ngf
+ self.up_ratios = up_ratios
+
+ d_model = ngf
+ self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+
+ for i, stride in enumerate(up_ratios):
+ d_model *= 2
+ self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)]
+
+ self.conv_blocks = nn.Sequential(*self.conv_blocks)
+
+
+ # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
+
+
+ # transformer_blocks = [
+ # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
+ # for _ in range(depth)
+ # ]
+
+
+ # self.transformers = nn.Sequential(*transformer_blocks)
+
+ # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
+
+ self.conv_final_block = [
+ Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)),
+ WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1),
+ ]
+ self.conv_final_block = nn.Sequential(*self.conv_final_block)
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ x = self.conv_blocks(x)
+ # x = x.permute(0, 2, 1)
+ # x= self.transformers(x)
+ # x = self.final_layer_norm(x)
+ # x = x.permute(0, 2, 1)
+ x = self.conv_final_block (x)
+ x = x.permute(0, 2, 1)
+ return x
+
+ def inference(self, x):
+ return self.block(x)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+
+class Codec_oobleck_Transformer(nn.Module):
+ def __init__(self,
+ ngf=32,
+ up_ratios=(2, 2,4,4, 5),
+ dilations=(1, 3, 9),
+ hidden_dim=1024,
+ depth=12,
+ heads=16,
+ pos_meb_dim=64,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.ngf =ngf
+ self.up_ratios = up_ratios
+ self.hidden_dim = hidden_dim
+
+
+ self.conv_blocks = blocks.DilatedResidualEncoder(
+ capacity=ngf,
+ dilated_unit=self.dilated_unit,
+ downsampling_unit=self.downsampling_unit,
+ ratios=up_ratios,
+ dilations=dilations,
+ pre_network_conv=self.pre_conv,
+ post_network_conv=self.post_conv,
+ )
+
+
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
+
+ transformer_blocks = [
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
+ for _ in range(depth)
+ ]
+
+ self.transformers = nn.Sequential(*transformer_blocks)
+
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
+
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ x = self.conv_blocks(x)
+ x = x.permute(0, 2, 1)
+ x= self.transformers(x)
+ x = self.final_layer_norm(x)
+ return x
+
+ def inference(self, x):
+ return self.block(x)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+ def dilated_unit(self,hidden_dim, dilation):
+ return blocks.DilatedConvolutionalUnit(hidden_dim,
+ dilation,
+ kernel_size=3,
+ activation=nn.ReLU,
+ normalization=utils.weight_norm)
+
+ def downsampling_unit(self, input_dim: int, output_dim: int, stride: int):
+ return blocks.DownsamplingUnit(input_dim,
+ output_dim,
+ stride,
+ nn.ReLU,
+ normalization=utils.weight_norm)
+
+ def pre_conv(self,out_channels):
+ return nn.Conv1d(1, out_channels, 1)
+
+ def post_conv(self,in_channels):
+ return nn.Conv1d(in_channels, self.hidden_dim, 1)
+
+
+
+
+
+class CodecEncoder_only_Transformer(nn.Module):
+ def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64):
+ super().__init__()
+ # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300,
+
+ depth = depth
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
+
+
+ transformer_blocks = [
+ TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
+ for _ in range(depth)
+ ]
+
+
+ self.transformers = nn.Sequential(*transformer_blocks)
+
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
+
+ def forward(self, x: torch.Tensor ) -> torch.Tensor:
+ # x = self.embed(x)
+
+
+ x= self.transformers(x)
+ x = self.final_layer_norm(x)
+
+ return x
+
+
+
+
+
+
+
+def get_model_size(model):
+ # 计算总参数数
+ total_params = sum(p.numel() for p in model.parameters())
+
+ # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位)
+ model_size_bytes = total_params # 每个参数4字节
+
+ # 转换为更易读的单位(例如,MB)
+ model_size_mb = model_size_bytes / (1024 ** 2)
+
+ return total_params, model_size_mb
+
+if __name__ == '__main__':
+ model = Codec_oobleck_Transformer()
+ x = torch.randn(1, 1, 16000) # example input tensor
+ output = model(x)
+ print("Output shape:", output.shape)
diff --git a/xcodec2/vq/factorized_vector_quantize.py b/xcodec2/vq/factorized_vector_quantize.py
new file mode 100755
index 0000000000000000000000000000000000000000..35f0c66736112f771a1933ca7e156b8cd5259e66
--- /dev/null
+++ b/xcodec2/vq/factorized_vector_quantize.py
@@ -0,0 +1,109 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+class FactorizedVectorQuantize(nn.Module):
+ def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.commitment = commitment
+
+ if dim != self.codebook_dim:
+ self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
+ self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
+ else:
+ self.in_proj = nn.Identity()
+ self.out_proj = nn.Identity()
+ self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
+
+ @property
+ def codebook(self):
+ return self._codebook
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+ # transpose since we use linear
+
+ z = rearrange(z, "b d t -> b t d")
+
+ # Factorized codes project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x T x D)
+ z_e = rearrange(z_e, "b t d -> b d t")
+ z_q, indices = self.decode_latents(z_e)
+
+
+ if self.training:
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2])
+ commit_loss = commitment_loss + codebook_loss
+ else:
+ commit_loss = torch.zeros(z.shape[0], device = z.device)
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = rearrange(z_q, "b d t -> b t d")
+ z_q = self.out_proj(z_q)
+ z_q = rearrange(z_q, "b t d -> b d t")
+
+ return z_q, indices, commit_loss
+
+ def vq2emb(self, vq, proj=True):
+ emb = self.embed_code(vq)
+ if proj:
+ emb = self.out_proj(emb)
+ return emb
+
+ def get_emb(self):
+ return self.codebook.weight
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
\ No newline at end of file
diff --git a/xcodec2/vq/module.py b/xcodec2/vq/module.py
new file mode 100755
index 0000000000000000000000000000000000000000..0c4f69b351abbc3906ced487f4609ed784c29975
--- /dev/null
+++ b/xcodec2/vq/module.py
@@ -0,0 +1,420 @@
+import torch.nn as nn
+from einops import rearrange
+from . import activations
+from .alias_free_torch import *
+from torch.nn.utils import weight_norm
+
+from typing import Optional, Tuple
+
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)),
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ return x + self.block(x)
+
+class EncoderBlock(nn.Module):
+ def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)):
+ super().__init__()
+ runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
+ self.block = nn.Sequential(
+ *runits,
+ Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)),
+ WNConv1d(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=stride // 2 + stride % 2,
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+class DecoderBlock(nn.Module):
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)):
+ super().__init__()
+ self.block = nn.Sequential(
+ Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)),
+ WNConvTranspose1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=stride // 2 + stride % 2,
+ output_padding= stride % 2,
+ )
+ )
+ self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations])
+
+ def forward(self, x):
+ return self.block(x)
+
+class ResLSTM(nn.Module):
+ def __init__(self, dimension: int,
+ num_layers: int = 2,
+ bidirectional: bool = False,
+ skip: bool = True):
+ super().__init__()
+ self.skip = skip
+ self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2,
+ num_layers, batch_first=True,
+ bidirectional=bidirectional)
+
+ def forward(self, x):
+ """
+ Args:
+ x: [B, F, T]
+
+ Returns:
+ y: [B, F, T]
+ """
+ x = rearrange(x, "b f t -> b t f")
+ y, _ = self.lstm(x)
+ if self.skip:
+ y = y + x
+ y = rearrange(y, "b t f -> b f t")
+ return y
+
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ layer_scale_init_value: float,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ if self.adanorm:
+ assert cond_embedding_id is not None
+ x = self.norm(x, cond_embedding_id)
+ else:
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
+
+ Args:
+ num_embeddings (int): Number of embeddings.
+ embedding_dim (int): Dimension of the embeddings.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.dim = embedding_dim
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ torch.nn.init.ones_(self.scale.weight)
+ torch.nn.init.zeros_(self.shift.weight)
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
+ scale = self.scale(cond_embedding_id)
+ shift = self.shift(cond_embedding_id)
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
+ x = x * scale + shift
+ return x
+
+
+class ResBlock1(nn.Module):
+ """
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
+ but without upsampling layers.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
+ Defaults to (1, 3, 5).
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
+ Defaults to 0.1.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int = 3,
+ dilation: Tuple[int, int, int] = (1, 3, 5),
+ lrelu_slope: float = 0.1,
+ layer_scale_init_value: Optional[float] = None,
+ ):
+ super().__init__()
+ self.lrelu_slope = lrelu_slope
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self.get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self.get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self.get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ ]
+ )
+
+ self.gamma = nn.ParameterList(
+ [
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
+ xt = c1(xt)
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
+ xt = c2(xt)
+ if gamma is not None:
+ xt = gamma * xt
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+ @staticmethod
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
+
+
+
+class SemanticEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels: int,
+ code_dim: int,
+ encode_channels: int,
+ kernel_size: int = 3,
+ bias: bool = True,
+ ):
+ super(SemanticEncoder, self).__init__()
+
+ # 初始卷积,将 input_channels 映射到 encode_channels
+ self.initial_conv = nn.Conv1d(
+ in_channels=input_channels,
+ out_channels=encode_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=False
+ )
+
+ # 残差块
+ self.residual_blocks = nn.Sequential(
+ nn.ReLU(inplace=True),
+ nn.Conv1d(
+ encode_channels,
+ encode_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=bias
+ ),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(
+ encode_channels,
+ encode_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=bias
+ )
+ )
+
+ # 最终卷积,将 encode_channels 映射到 code_dim
+ self.final_conv = nn.Conv1d(
+ in_channels=encode_channels,
+ out_channels=code_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=False
+ )
+
+ def forward(self, x):
+ """
+ 前向传播方法。
+
+ Args:
+ x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length)
+
+ Returns:
+ Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length)
+ """
+ x = self.initial_conv(x) # (Batch, Encode_channels, Length)
+ x = self.residual_blocks(x) + x # 残差连接
+ x = self.final_conv(x) # (Batch, Code_dim, Length)
+ return x
+
+class SemanticDecoder(nn.Module):
+ def __init__(
+ self,
+ code_dim: int,
+ output_channels: int,
+ decode_channels: int,
+ kernel_size: int = 3,
+ bias: bool = True,
+ ):
+ super(SemanticDecoder, self).__init__()
+
+ # Initial convolution to map code_dim to decode_channels
+ self.initial_conv = nn.Conv1d(
+ in_channels=code_dim,
+ out_channels=decode_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=False
+ )
+
+ # Residual Blocks
+ self.residual_blocks = nn.Sequential(
+ nn.ReLU(inplace=True),
+ nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias)
+ )
+
+ # Final convolution to map decode_channels to output_channels
+ self.final_conv = nn.Conv1d(
+ in_channels=decode_channels,
+ out_channels=output_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=False
+ )
+
+ def forward(self, z):
+ # z: (Batch, Code_dim, Length)
+ x = self.initial_conv(z) # (Batch, Decode_channels, Length)
+ x = self.residual_blocks(x) + x # Residual connection
+ x = self.final_conv(x) # (Batch, Output_channels, Length)
+ return x
\ No newline at end of file
diff --git a/xcodec2/vq/residual_vq.py b/xcodec2/vq/residual_vq.py
new file mode 100755
index 0000000000000000000000000000000000000000..40d3338fd940aa2e41177827d6e24c5269765b86
--- /dev/null
+++ b/xcodec2/vq/residual_vq.py
@@ -0,0 +1,53 @@
+import math
+import torch
+from torch import nn
+from .factorized_vector_quantize import FactorizedVectorQuantize
+
+class ResidualVQ(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_quantizers,
+ codebook_size,
+ **kwargs
+ ):
+ super().__init__()
+ VQ = FactorizedVectorQuantize
+ if type(codebook_size) == int:
+ codebook_size = [codebook_size] * num_quantizers
+ self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size])
+ self.num_quantizers = num_quantizers
+
+ def forward(self, x):
+ quantized_out = 0.
+ residual = x
+
+ all_losses = []
+ all_indices = []
+
+ for idx, layer in enumerate(self.layers):
+ quantized, indices, loss = layer(residual)
+
+ residual = residual - quantized
+
+ quantized_out = quantized_out + quantized
+
+ loss = loss.mean()
+
+ all_indices.append(indices)
+ all_losses.append(loss)
+ all_losses, all_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, all_indices, all_losses
+
+ def vq2emb(self, vq, proj=True):
+ # [B, T, num_quantizers]
+ quantized_out = 0.
+ for idx, layer in enumerate(self.layers):
+ quantized = layer.vq2emb(vq[:, :, idx], proj=proj)
+ quantized_out = quantized_out + quantized
+ return quantized_out
+ def get_emb(self):
+ embs = []
+ for idx, layer in enumerate(self.layers):
+ embs.append(layer.get_emb())
+ return embs
diff --git a/xcodec2/vq/unet.py b/xcodec2/vq/unet.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca31029d0866b61663f75045c7770cc7208d9482
--- /dev/null
+++ b/xcodec2/vq/unet.py
@@ -0,0 +1,210 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+import numpy as np
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
+ super(EncoderBlock, self).__init__()
+
+ self.pool_size = 2
+
+ self.conv_block = ConvBlock(in_channels, out_channels, kernel_size)
+
+ def forward(self, x):
+ latent = self.conv_block(x)
+ output = F.avg_pool2d(latent, kernel_size=self.pool_size)
+ return output, latent
+
+class DecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
+ super(DecoderBlock, self).__init__()
+
+ stride = 2
+
+ self.upsample = nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=stride,
+ stride=stride,
+ padding=(0, 0),
+ bias=False,
+ )
+
+ self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size)
+
+ def forward(self, x, latent):
+ x = self.upsample(x)
+ x = torch.cat((x, latent), dim=1)
+ output = self.conv_block(x)
+ return output
+
+
+class UNet(nn.Module):
+ def __init__(self,freq_dim=1281,out_channel=1024):
+ super(UNet, self).__init__()
+
+ self.downsample_ratio = 16
+
+
+ in_channels = 1 #self.audio_channels * self.cmplx_num
+
+ self.encoder_block1 = EncoderBlock(in_channels, 16)
+ self.encoder_block2 = EncoderBlock(16, 64)
+ self.encoder_block3 = EncoderBlock(64, 256)
+ self.encoder_block4 = EncoderBlock(256, 1024)
+ self.middle = EncoderBlock(1024, 1024)
+ self.decoder_block1 = DecoderBlock(1024, 256)
+ self.decoder_block2 = DecoderBlock(256, 64)
+ self.decoder_block3 = DecoderBlock(64, 16)
+ self.decoder_block4 = DecoderBlock(16, 16)
+
+ self.fc = nn.Linear(freq_dim*16, out_channel)
+
+ def forward(self, x_ori):
+ """
+ Args:
+ complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量
+
+ Returns:
+ output: (batch_size, channels_num, time_steps, freq_bins),复数张量
+ """
+
+
+ x= self.process_image(x_ori)
+ x1, latent1 = self.encoder_block1(x)
+ x2, latent2 = self.encoder_block2(x1)
+ x3, latent3 = self.encoder_block3(x2)
+ x4, latent4 = self.encoder_block4(x3)
+ _, h = self.middle(x4)
+ x5 = self.decoder_block1(h, latent4)
+ x6 = self.decoder_block2(x5, latent3)
+ x7 = self.decoder_block3(x6, latent2)
+ x8 = self.decoder_block4(x7, latent1)
+ x= self.unprocess_image(x8,x_ori.shape[2])
+ x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024]
+ x = x.view(x.size(0), x.size(1), -1)
+ x= self.fc(x)
+
+ return x
+
+ def process_image(self, x):
+ """
+ 处理频谱以便可以被 downsample_ratio 整除。
+
+ Args:
+ x: (B, C, T, F)
+
+ Returns:
+ output: (B, C, T_padded, F_reduced)
+ """
+
+ B, C, T, Freq = x.shape
+
+ pad_len = (
+ int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio
+ - T
+ )
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
+
+ output = x[:, :, :, 0 : Freq - 1]
+
+ return output
+
+ def unprocess_image(self, x,time_steps):
+ """
+ 恢复频谱到原始形状。
+
+ Args:
+ x: (B, C, T_padded, F_reduced)
+
+ Returns:
+ output: (B, C, T_original, F_original)
+ """
+ x = F.pad(x, pad=(0, 1))
+
+ output = x[:, :,0:time_steps, :]
+
+ return output
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
+ super(ConvBlock, self).__init__()
+
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
+
+ self.bn1 = nn.BatchNorm2d(in_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=False,
+ )
+
+ self.conv2 = nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=False,
+ )
+
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(1, 1),
+ padding=(0, 0),
+ )
+ self.is_shortcut = True
+ else:
+ self.is_shortcut = False
+
+ def forward(self, x):
+ h = self.conv1(F.leaky_relu_(self.bn1(x)))
+ h = self.conv2(F.leaky_relu_(self.bn2(h)))
+
+ if self.is_shortcut:
+ return self.shortcut(x) + h
+ else:
+ return x + h
+
+
+def test_unet():
+ # 定义输入参数
+ batch_size = 6
+ channels = 1 # 音频通道数
+ time_steps = 256 # 时间步数
+ freq_bins = 1024 # 频率 bins 数
+
+ # 创建一个随机的复数张量作为输入
+ real_part = torch.randn(batch_size, channels, time_steps, freq_bins)
+ imag_part = torch.randn(batch_size, channels, time_steps, freq_bins)
+ complex_sp = real_part #torch.complex(real_part, imag_part)
+
+ # 实例化 UNet 模型
+ model = UNet()
+
+ # 前向传播
+ output = model(complex_sp)
+
+ # 输出输入和输出的形状
+ print("输入形状:", complex_sp.shape)
+ print("输出形状:", output.shape)
+
+ # 检查输出是否为复数张量
+ assert torch.is_complex(output), "输出不是复数张量"
+
+ # 检查输出形状是否与输入形状一致
+ assert output.shape == complex_sp.shape, "输出形状与输入形状不一致"
+
+ print("测试通过,模型正常工作。")
+
+# 运行测试函数
+if __name__ == "__main__":
+ test_unet()
\ No newline at end of file