Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """CMI-RM 音乐评分应用 — Gradio Web UI | |
| 上传音频文件,输入文字描述,获取 AI 音乐质量评分。 | |
| 可部署到 HuggingFace Spaces(免费 CPU 版)。 | |
| """ | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # 路径设置 | |
| # --------------------------------------------------------------------------- | |
| _ROOT = Path(__file__).resolve().parent | |
| _MODEL_SRC = _ROOT / "models" / "cmi-rm" / "src" | |
| if str(_MODEL_SRC) not in sys.path: | |
| sys.path.insert(0, str(_MODEL_SRC)) | |
| if str(_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_ROOT)) | |
| from baselines.inference import RewardModelInference | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # 模型下载与加载 | |
| # --------------------------------------------------------------------------- | |
| MODEL_REPO = "HaiwenXia/CMI-RM" | |
| MODEL_DIR = _ROOT / "baselines" / "model" | |
| def download_model(): | |
| """从 HuggingFace Hub 下载模型文件(如果本地不存在)。""" | |
| ckpt = MODEL_DIR / "model.safetensors" | |
| cfg = MODEL_DIR / "config.yaml" | |
| if ckpt.exists() and cfg.exists(): | |
| logger.info("模型文件已存在,跳过下载") | |
| return str(ckpt) | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| logger.info("正在从 %s 下载模型...", MODEL_REPO) | |
| hf_hub_download(MODEL_REPO, "model.safetensors", local_dir=str(MODEL_DIR)) | |
| hf_hub_download(MODEL_REPO, "config.yaml", local_dir=str(MODEL_DIR)) | |
| logger.info("模型下载完成") | |
| return str(ckpt) | |
| def load_model(): | |
| """加载模型到内存。""" | |
| ckpt_path = download_model() | |
| logger.info("正在加载模型(CPU 模式)...") | |
| model = RewardModelInference( | |
| checkpoint=ckpt_path, | |
| device="cpu", | |
| mode="final", | |
| bf16=False, | |
| ) | |
| logger.info("模型加载完成") | |
| return model | |
| # 全局模型实例(启动时加载一次) | |
| MODEL = None | |
| def get_model(): | |
| global MODEL | |
| if MODEL is None: | |
| MODEL = load_model() | |
| return MODEL | |
| # --------------------------------------------------------------------------- | |
| # 推理函数 | |
| # --------------------------------------------------------------------------- | |
| def score_audio(audio_path, text_prompt, lyrics): | |
| """对单个音频文件进行评分(generator,支持加载动效)。""" | |
| if audio_path is None: | |
| raise gr.Error("请上传一个音频文件") | |
| if not text_prompt.strip(): | |
| raise gr.Error("请输入文字描述(Prompt)") | |
| # 第一步:显示加载状态 | |
| yield ( | |
| "⏳ **正在分析音频...**\n\n" | |
| "模型正在聆听你的音乐,预计需要 30-60 秒,请稍候。\n\n" | |
| "```\n" | |
| "[ ░░░░░░░░░░░░░░░░░░░░ ] 分析中...\n" | |
| "```" | |
| ) | |
| model = get_model() | |
| lyrics = lyrics.strip() if lyrics else "" | |
| logger.info("开始评分: %s | prompt=%s", audio_path, text_prompt[:50]) | |
| scores = model.score( | |
| audio=audio_path, | |
| text=text_prompt.strip(), | |
| lyrics=lyrics, | |
| max_dur=30.0, | |
| ) | |
| alignment = scores["alignment"] | |
| quality = scores["quality"] | |
| logger.info("评分结果: alignment=%.4f, quality=%.4f", alignment, quality) | |
| # 第二步:显示结果 | |
| yield ( | |
| f"## 评分结果\n\n" | |
| f"| 指标 | 分数 |\n" | |
| f"|------|------|\n" | |
| f"| **Alignment(文本对齐度)** | {alignment:.4f} |\n" | |
| f"| **Quality(音乐质量)** | {quality:.4f} |\n" | |
| ) | |
| def score_compare(audio_a, audio_b, text_prompt, lyrics): | |
| """对比两个音频文件的评分(generator,支持加载动效)。""" | |
| if audio_a is None or audio_b is None: | |
| raise gr.Error("请上传两个音频文件进行对比") | |
| if not text_prompt.strip(): | |
| raise gr.Error("请输入文字描述(Prompt)") | |
| model = get_model() | |
| lyrics = lyrics.strip() if lyrics else "" | |
| # 第一步:评分 A | |
| yield ( | |
| "⏳ **正在分析音频 A...**\n\n" | |
| "模型正在聆听第 1 首音乐,预计共需 1-2 分钟。\n\n" | |
| "```\n" | |
| "[ ██████████░░░░░░░░░░ ] 1/2 分析中...\n" | |
| "```" | |
| ) | |
| scores_a = model.score(audio=audio_a, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0) | |
| # 第二步:评分 B | |
| yield ( | |
| "⏳ **正在分析音频 B...**\n\n" | |
| "第 1 首已完成,正在聆听第 2 首...\n\n" | |
| "```\n" | |
| "[ ████████████████████░] 2/2 分析中...\n" | |
| "```" | |
| ) | |
| scores_b = model.score(audio=audio_b, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0) | |
| def winner(a, b): | |
| if a > b: | |
| return "A 胜出 ✓" | |
| elif b > a: | |
| return "B 胜出 ✓" | |
| return "平局" | |
| # 第三步:显示对比结果 | |
| yield ( | |
| f"## 对比结果\n\n" | |
| f"| 指标 | 音频 A | 音频 B | 胜出 |\n" | |
| f"|------|--------|--------|------|\n" | |
| f"| **Alignment** | {scores_a['alignment']:.4f} | {scores_b['alignment']:.4f} | {winner(scores_a['alignment'], scores_b['alignment'])} |\n" | |
| f"| **Quality** | {scores_a['quality']:.4f} | {scores_b['quality']:.4f} | {winner(scores_a['quality'], scores_b['quality'])} |\n" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="CMI-RM 音乐评分", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# 🎵 CMI-RM 音乐评分系统\n" | |
| "上传 AI 生成的音乐,输入生成时的文字描述(Prompt),获取质量评分。\n" | |
| "- **Alignment**:音乐与文字描述的匹配程度\n" | |
| "- **Quality**:音乐整体质量\n" | |
| ) | |
| with gr.Tabs(): | |
| # ---- Tab 1: 单曲评分 ---- | |
| with gr.TabItem("单曲评分"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="上传音频", | |
| type="filepath", | |
| sources=["upload"], | |
| ) | |
| text_input = gr.Textbox( | |
| label="文字描述(Prompt)", | |
| placeholder="例如:A cheerful acoustic folk song with strumming guitar and whistling.", | |
| lines=2, | |
| ) | |
| lyrics_input = gr.Textbox( | |
| label="歌词(可选)", | |
| placeholder="如果有歌词,粘贴到这里", | |
| lines=3, | |
| ) | |
| score_btn = gr.Button("开始评分", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| result_output = gr.Markdown(label="评分结果") | |
| score_btn.click( | |
| fn=score_audio, | |
| inputs=[audio_input, text_input, lyrics_input], | |
| outputs=result_output, | |
| ) | |
| # ---- Tab 2: 对比评分 ---- | |
| with gr.TabItem("A/B 对比"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_a = gr.Audio(label="音频 A", type="filepath", sources=["upload"]) | |
| audio_b = gr.Audio(label="音频 B", type="filepath", sources=["upload"]) | |
| text_compare = gr.Textbox( | |
| label="文字描述(Prompt)", | |
| placeholder="两个音频使用相同的 Prompt", | |
| lines=2, | |
| ) | |
| lyrics_compare = gr.Textbox( | |
| label="歌词(可选)", | |
| placeholder="如果有歌词,粘贴到这里", | |
| lines=3, | |
| ) | |
| compare_btn = gr.Button("开始对比", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| compare_output = gr.Markdown(label="对比结果") | |
| compare_btn.click( | |
| fn=score_compare, | |
| inputs=[audio_a, audio_b, text_compare, lyrics_compare], | |
| outputs=compare_output, | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "*基于 [CMI-RewardBench](https://arxiv.org/abs/2603.00610) 项目 | " | |
| "CPU 推理约 30-60 秒/首*" | |
| ) | |
| if __name__ == "__main__": | |
| # 启动时预加载模型 | |
| get_model() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |