#!/usr/bin/env python3 """CMI-RM 音乐评分应用 — Gradio Web UI 上传音频文件,输入文字描述,获取 AI 音乐质量评分。 可部署到 HuggingFace Spaces(免费 CPU 版)。 """ import logging import os import sys from pathlib import Path import gradio as gr import torch from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # 路径设置 # --------------------------------------------------------------------------- _ROOT = Path(__file__).resolve().parent _MODEL_SRC = _ROOT / "models" / "cmi-rm" / "src" if str(_MODEL_SRC) not in sys.path: sys.path.insert(0, str(_MODEL_SRC)) if str(_ROOT) not in sys.path: sys.path.insert(0, str(_ROOT)) from baselines.inference import RewardModelInference logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # 模型下载与加载 # --------------------------------------------------------------------------- MODEL_REPO = "HaiwenXia/CMI-RM" MODEL_DIR = _ROOT / "baselines" / "model" def download_model(): """从 HuggingFace Hub 下载模型文件(如果本地不存在)。""" ckpt = MODEL_DIR / "model.safetensors" cfg = MODEL_DIR / "config.yaml" if ckpt.exists() and cfg.exists(): logger.info("模型文件已存在,跳过下载") return str(ckpt) MODEL_DIR.mkdir(parents=True, exist_ok=True) logger.info("正在从 %s 下载模型...", MODEL_REPO) hf_hub_download(MODEL_REPO, "model.safetensors", local_dir=str(MODEL_DIR)) hf_hub_download(MODEL_REPO, "config.yaml", local_dir=str(MODEL_DIR)) logger.info("模型下载完成") return str(ckpt) def load_model(): """加载模型到内存。""" ckpt_path = download_model() logger.info("正在加载模型(CPU 模式)...") model = RewardModelInference( checkpoint=ckpt_path, device="cpu", mode="final", bf16=False, ) logger.info("模型加载完成") return model # 全局模型实例(启动时加载一次) MODEL = None def get_model(): global MODEL if MODEL is None: MODEL = load_model() return MODEL # --------------------------------------------------------------------------- # 推理函数 # --------------------------------------------------------------------------- def score_audio(audio_path, text_prompt, lyrics): """对单个音频文件进行评分(generator,支持加载动效)。""" if audio_path is None: raise gr.Error("请上传一个音频文件") if not text_prompt.strip(): raise gr.Error("请输入文字描述(Prompt)") # 第一步:显示加载状态 yield ( "⏳ **正在分析音频...**\n\n" "模型正在聆听你的音乐,预计需要 30-60 秒,请稍候。\n\n" "```\n" "[ ░░░░░░░░░░░░░░░░░░░░ ] 分析中...\n" "```" ) model = get_model() lyrics = lyrics.strip() if lyrics else "" logger.info("开始评分: %s | prompt=%s", audio_path, text_prompt[:50]) scores = model.score( audio=audio_path, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0, ) alignment = scores["alignment"] quality = scores["quality"] logger.info("评分结果: alignment=%.4f, quality=%.4f", alignment, quality) # 第二步:显示结果 yield ( f"## 评分结果\n\n" f"| 指标 | 分数 |\n" f"|------|------|\n" f"| **Alignment(文本对齐度)** | {alignment:.4f} |\n" f"| **Quality(音乐质量)** | {quality:.4f} |\n" ) def score_compare(audio_a, audio_b, text_prompt, lyrics): """对比两个音频文件的评分(generator,支持加载动效)。""" if audio_a is None or audio_b is None: raise gr.Error("请上传两个音频文件进行对比") if not text_prompt.strip(): raise gr.Error("请输入文字描述(Prompt)") model = get_model() lyrics = lyrics.strip() if lyrics else "" # 第一步:评分 A yield ( "⏳ **正在分析音频 A...**\n\n" "模型正在聆听第 1 首音乐,预计共需 1-2 分钟。\n\n" "```\n" "[ ██████████░░░░░░░░░░ ] 1/2 分析中...\n" "```" ) scores_a = model.score(audio=audio_a, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0) # 第二步:评分 B yield ( "⏳ **正在分析音频 B...**\n\n" "第 1 首已完成,正在聆听第 2 首...\n\n" "```\n" "[ ████████████████████░] 2/2 分析中...\n" "```" ) scores_b = model.score(audio=audio_b, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0) def winner(a, b): if a > b: return "A 胜出 ✓" elif b > a: return "B 胜出 ✓" return "平局" # 第三步:显示对比结果 yield ( f"## 对比结果\n\n" f"| 指标 | 音频 A | 音频 B | 胜出 |\n" f"|------|--------|--------|------|\n" f"| **Alignment** | {scores_a['alignment']:.4f} | {scores_b['alignment']:.4f} | {winner(scores_a['alignment'], scores_b['alignment'])} |\n" f"| **Quality** | {scores_a['quality']:.4f} | {scores_b['quality']:.4f} | {winner(scores_a['quality'], scores_b['quality'])} |\n" ) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="CMI-RM 音乐评分", theme=gr.themes.Soft()) as demo: gr.Markdown( "# 🎵 CMI-RM 音乐评分系统\n" "上传 AI 生成的音乐,输入生成时的文字描述(Prompt),获取质量评分。\n" "- **Alignment**:音乐与文字描述的匹配程度\n" "- **Quality**:音乐整体质量\n" ) with gr.Tabs(): # ---- Tab 1: 单曲评分 ---- with gr.TabItem("单曲评分"): with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio( label="上传音频", type="filepath", sources=["upload"], ) text_input = gr.Textbox( label="文字描述(Prompt)", placeholder="例如:A cheerful acoustic folk song with strumming guitar and whistling.", lines=2, ) lyrics_input = gr.Textbox( label="歌词(可选)", placeholder="如果有歌词,粘贴到这里", lines=3, ) score_btn = gr.Button("开始评分", variant="primary", size="lg") with gr.Column(scale=1): result_output = gr.Markdown(label="评分结果") score_btn.click( fn=score_audio, inputs=[audio_input, text_input, lyrics_input], outputs=result_output, ) # ---- Tab 2: 对比评分 ---- with gr.TabItem("A/B 对比"): with gr.Row(): with gr.Column(scale=1): audio_a = gr.Audio(label="音频 A", type="filepath", sources=["upload"]) audio_b = gr.Audio(label="音频 B", type="filepath", sources=["upload"]) text_compare = gr.Textbox( label="文字描述(Prompt)", placeholder="两个音频使用相同的 Prompt", lines=2, ) lyrics_compare = gr.Textbox( label="歌词(可选)", placeholder="如果有歌词,粘贴到这里", lines=3, ) compare_btn = gr.Button("开始对比", variant="primary", size="lg") with gr.Column(scale=1): compare_output = gr.Markdown(label="对比结果") compare_btn.click( fn=score_compare, inputs=[audio_a, audio_b, text_compare, lyrics_compare], outputs=compare_output, ) gr.Markdown( "---\n" "*基于 [CMI-RewardBench](https://arxiv.org/abs/2603.00610) 项目 | " "CPU 推理约 30-60 秒/首*" ) if __name__ == "__main__": # 启动时预加载模型 get_model() demo.launch(server_name="0.0.0.0", server_port=7860)