CMI_Autoeval / app.py
apple
Fix: move theme to gr.Blocks()
746823a
#!/usr/bin/env python3
"""CMI-RM 音乐评分应用 — Gradio Web UI
上传音频文件,输入文字描述,获取 AI 音乐质量评分。
可部署到 HuggingFace Spaces(免费 CPU 版)。
"""
import logging
import os
import sys
from pathlib import Path
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# 路径设置
# ---------------------------------------------------------------------------
_ROOT = Path(__file__).resolve().parent
_MODEL_SRC = _ROOT / "models" / "cmi-rm" / "src"
if str(_MODEL_SRC) not in sys.path:
sys.path.insert(0, str(_MODEL_SRC))
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
from baselines.inference import RewardModelInference
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 模型下载与加载
# ---------------------------------------------------------------------------
MODEL_REPO = "HaiwenXia/CMI-RM"
MODEL_DIR = _ROOT / "baselines" / "model"
def download_model():
"""从 HuggingFace Hub 下载模型文件(如果本地不存在)。"""
ckpt = MODEL_DIR / "model.safetensors"
cfg = MODEL_DIR / "config.yaml"
if ckpt.exists() and cfg.exists():
logger.info("模型文件已存在,跳过下载")
return str(ckpt)
MODEL_DIR.mkdir(parents=True, exist_ok=True)
logger.info("正在从 %s 下载模型...", MODEL_REPO)
hf_hub_download(MODEL_REPO, "model.safetensors", local_dir=str(MODEL_DIR))
hf_hub_download(MODEL_REPO, "config.yaml", local_dir=str(MODEL_DIR))
logger.info("模型下载完成")
return str(ckpt)
def load_model():
"""加载模型到内存。"""
ckpt_path = download_model()
logger.info("正在加载模型(CPU 模式)...")
model = RewardModelInference(
checkpoint=ckpt_path,
device="cpu",
mode="final",
bf16=False,
)
logger.info("模型加载完成")
return model
# 全局模型实例(启动时加载一次)
MODEL = None
def get_model():
global MODEL
if MODEL is None:
MODEL = load_model()
return MODEL
# ---------------------------------------------------------------------------
# 推理函数
# ---------------------------------------------------------------------------
def score_audio(audio_path, text_prompt, lyrics):
"""对单个音频文件进行评分(generator,支持加载动效)。"""
if audio_path is None:
raise gr.Error("请上传一个音频文件")
if not text_prompt.strip():
raise gr.Error("请输入文字描述(Prompt)")
# 第一步:显示加载状态
yield (
"⏳ **正在分析音频...**\n\n"
"模型正在聆听你的音乐,预计需要 30-60 秒,请稍候。\n\n"
"```\n"
"[ ░░░░░░░░░░░░░░░░░░░░ ] 分析中...\n"
"```"
)
model = get_model()
lyrics = lyrics.strip() if lyrics else ""
logger.info("开始评分: %s | prompt=%s", audio_path, text_prompt[:50])
scores = model.score(
audio=audio_path,
text=text_prompt.strip(),
lyrics=lyrics,
max_dur=30.0,
)
alignment = scores["alignment"]
quality = scores["quality"]
logger.info("评分结果: alignment=%.4f, quality=%.4f", alignment, quality)
# 第二步:显示结果
yield (
f"## 评分结果\n\n"
f"| 指标 | 分数 |\n"
f"|------|------|\n"
f"| **Alignment(文本对齐度)** | {alignment:.4f} |\n"
f"| **Quality(音乐质量)** | {quality:.4f} |\n"
)
def score_compare(audio_a, audio_b, text_prompt, lyrics):
"""对比两个音频文件的评分(generator,支持加载动效)。"""
if audio_a is None or audio_b is None:
raise gr.Error("请上传两个音频文件进行对比")
if not text_prompt.strip():
raise gr.Error("请输入文字描述(Prompt)")
model = get_model()
lyrics = lyrics.strip() if lyrics else ""
# 第一步:评分 A
yield (
"⏳ **正在分析音频 A...**\n\n"
"模型正在聆听第 1 首音乐,预计共需 1-2 分钟。\n\n"
"```\n"
"[ ██████████░░░░░░░░░░ ] 1/2 分析中...\n"
"```"
)
scores_a = model.score(audio=audio_a, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0)
# 第二步:评分 B
yield (
"⏳ **正在分析音频 B...**\n\n"
"第 1 首已完成,正在聆听第 2 首...\n\n"
"```\n"
"[ ████████████████████░] 2/2 分析中...\n"
"```"
)
scores_b = model.score(audio=audio_b, text=text_prompt.strip(), lyrics=lyrics, max_dur=30.0)
def winner(a, b):
if a > b:
return "A 胜出 ✓"
elif b > a:
return "B 胜出 ✓"
return "平局"
# 第三步:显示对比结果
yield (
f"## 对比结果\n\n"
f"| 指标 | 音频 A | 音频 B | 胜出 |\n"
f"|------|--------|--------|------|\n"
f"| **Alignment** | {scores_a['alignment']:.4f} | {scores_b['alignment']:.4f} | {winner(scores_a['alignment'], scores_b['alignment'])} |\n"
f"| **Quality** | {scores_a['quality']:.4f} | {scores_b['quality']:.4f} | {winner(scores_a['quality'], scores_b['quality'])} |\n"
)
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="CMI-RM 音乐评分", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# 🎵 CMI-RM 音乐评分系统\n"
"上传 AI 生成的音乐,输入生成时的文字描述(Prompt),获取质量评分。\n"
"- **Alignment**:音乐与文字描述的匹配程度\n"
"- **Quality**:音乐整体质量\n"
)
with gr.Tabs():
# ---- Tab 1: 单曲评分 ----
with gr.TabItem("单曲评分"):
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
label="上传音频",
type="filepath",
sources=["upload"],
)
text_input = gr.Textbox(
label="文字描述(Prompt)",
placeholder="例如:A cheerful acoustic folk song with strumming guitar and whistling.",
lines=2,
)
lyrics_input = gr.Textbox(
label="歌词(可选)",
placeholder="如果有歌词,粘贴到这里",
lines=3,
)
score_btn = gr.Button("开始评分", variant="primary", size="lg")
with gr.Column(scale=1):
result_output = gr.Markdown(label="评分结果")
score_btn.click(
fn=score_audio,
inputs=[audio_input, text_input, lyrics_input],
outputs=result_output,
)
# ---- Tab 2: 对比评分 ----
with gr.TabItem("A/B 对比"):
with gr.Row():
with gr.Column(scale=1):
audio_a = gr.Audio(label="音频 A", type="filepath", sources=["upload"])
audio_b = gr.Audio(label="音频 B", type="filepath", sources=["upload"])
text_compare = gr.Textbox(
label="文字描述(Prompt)",
placeholder="两个音频使用相同的 Prompt",
lines=2,
)
lyrics_compare = gr.Textbox(
label="歌词(可选)",
placeholder="如果有歌词,粘贴到这里",
lines=3,
)
compare_btn = gr.Button("开始对比", variant="primary", size="lg")
with gr.Column(scale=1):
compare_output = gr.Markdown(label="对比结果")
compare_btn.click(
fn=score_compare,
inputs=[audio_a, audio_b, text_compare, lyrics_compare],
outputs=compare_output,
)
gr.Markdown(
"---\n"
"*基于 [CMI-RewardBench](https://arxiv.org/abs/2603.00610) 项目 | "
"CPU 推理约 30-60 秒/首*"
)
if __name__ == "__main__":
# 启动时预加载模型
get_model()
demo.launch(server_name="0.0.0.0", server_port=7860)