add rlhf
Browse files
app.py
CHANGED
|
@@ -37,7 +37,7 @@ setup_eval_logging()
|
|
| 37 |
|
| 38 |
OUTPUT_DIR = Path("./output/gradio")
|
| 39 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
-
NUM_SAMPLE =
|
| 41 |
|
| 42 |
# 创建RLHF反馈数据目录
|
| 43 |
FEEDBACK_DIR = Path("./rlhf")
|
|
@@ -175,10 +175,11 @@ def generate_audio_gradio(
|
|
| 175 |
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
|
| 176 |
log.info(f"Audio saved to {save_path}")
|
| 177 |
save_paths.append(str(save_path))
|
|
|
|
| 178 |
if device == "cuda":
|
| 179 |
torch.cuda.empty_cache()
|
| 180 |
|
| 181 |
-
return save_paths[0],
|
| 182 |
|
| 183 |
|
| 184 |
# Gradio input and output components
|
|
@@ -194,12 +195,11 @@ gr_interface = gr.Interface(
|
|
| 194 |
fn=generate_audio_gradio,
|
| 195 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
| 196 |
outputs=[
|
| 197 |
-
gr.Audio(label="🎵 Audio Sample
|
| 198 |
-
gr.Audio(label="🎵 Audio Sample 2", type="filepath"),
|
| 199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
| 200 |
],
|
| 201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
| 202 |
-
description="
|
| 203 |
flagging_mode="never",
|
| 204 |
examples=[
|
| 205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
|
@@ -218,79 +218,11 @@ gr_interface = gr.Interface(
|
|
| 218 |
cache_examples="lazy",
|
| 219 |
)
|
| 220 |
|
| 221 |
-
# ==== Preference collection UI (RLHF) ====
|
| 222 |
-
|
| 223 |
-
# 允许用户在两段音频之间选择偏好,并补充备注
|
| 224 |
-
with gr.Blocks() as pref_block:
|
| 225 |
-
gr.Markdown("## 🧠 RLHF 偏好标注")
|
| 226 |
-
gr.Markdown("生成完成后,请在下方选择您更喜欢的音频(或都不好/差不多),并可附加简短备注。点“提交偏好”即可写入 `./rlhf/user_preferences.jsonl`。")
|
| 227 |
-
|
| 228 |
-
# 这里复用上面 Interface 的输出:我们需要拿到两段音频的文件路径与使用的 prompt
|
| 229 |
-
# 为了连接这两个“界面”,再放一组可粘连的输入组件:
|
| 230 |
-
with gr.Row():
|
| 231 |
-
gen_audio1_path = gr.Textbox(label="Audio 1 路径(自动填充)", interactive=False)
|
| 232 |
-
gen_audio2_path = gr.Textbox(label="Audio 2 路径(自动填充)", interactive=False)
|
| 233 |
-
prompt_used = gr.Textbox(label="Prompt(自动填充)", interactive=False)
|
| 234 |
-
|
| 235 |
-
# 偏好选项与备注
|
| 236 |
-
pref_choice = gr.Radio(
|
| 237 |
-
["audio1", "audio2", "equal", "both_bad"],
|
| 238 |
-
value="audio1",
|
| 239 |
-
label="你更偏好哪个?",
|
| 240 |
-
info="equal=差不多; both_bad=都不好"
|
| 241 |
-
)
|
| 242 |
-
pref_comment = gr.Textbox(label="可选备注(例如:哪一段更贴合描述、是否有噪声/破音等)", lines=2)
|
| 243 |
-
|
| 244 |
-
submit_btn = gr.Button("✅ 提交偏好")
|
| 245 |
-
submit_status = gr.Markdown()
|
| 246 |
-
|
| 247 |
-
# 小工具:读取当前标注条目数
|
| 248 |
-
def _count_feedback():
|
| 249 |
-
try:
|
| 250 |
-
with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
|
| 251 |
-
return sum(1 for _ in f)
|
| 252 |
-
except FileNotFoundError:
|
| 253 |
-
return 0
|
| 254 |
-
|
| 255 |
-
refresh_btn = gr.Button("📈 刷新统计")
|
| 256 |
-
count_box = gr.Markdown()
|
| 257 |
-
|
| 258 |
-
def submit_preference_ui(a1, a2, p, pref, cmt):
|
| 259 |
-
if not a1 or not a2:
|
| 260 |
-
return "❗请先在上面的生成器里生成两段音频。"
|
| 261 |
-
# 写入 jsonl
|
| 262 |
-
msg = save_preference_feedback(p, a1, a2, pref, cmt)
|
| 263 |
-
return msg
|
| 264 |
-
|
| 265 |
-
def refresh_count_ui():
|
| 266 |
-
n = _count_feedback()
|
| 267 |
-
return f"当前已收集 **{n}** 条偏好样本。"
|
| 268 |
-
|
| 269 |
-
submit_btn.click(
|
| 270 |
-
fn=submit_preference_ui,
|
| 271 |
-
inputs=[gen_audio1_path, gen_audio2_path, prompt_used, pref_choice, pref_comment],
|
| 272 |
-
outputs=submit_status
|
| 273 |
-
)
|
| 274 |
-
refresh_btn.click(fn=refresh_count_ui, outputs=count_box)
|
| 275 |
-
|
| 276 |
-
# —— 把 Interface 的输出“联动”到偏好区:当用户生成完成后,自动把路径和 prompt 填入偏好区输入框 ——
|
| 277 |
-
def _passthrough(a1, a2, p):
|
| 278 |
-
# 直接把接口输出透传给下方偏好区
|
| 279 |
-
return a1, a2, p
|
| 280 |
-
|
| 281 |
-
# 用 Events 把 Interface 的输出连到 pref_block 的三个文本框
|
| 282 |
-
gr_interface.submit(
|
| 283 |
-
fn=_passthrough,
|
| 284 |
-
inputs=gr_interface.outputs, # [Audio1(filepath), Audio2(filepath), PromptUsed]
|
| 285 |
-
outputs=[gen_audio1_path, gen_audio2_path, prompt_used],
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
ensure_models_downloaded()
|
| 292 |
load_model_cache()
|
| 293 |
-
gr_interface.queue(15).launch(
|
| 294 |
|
| 295 |
# theme = gr.themes.Soft(
|
| 296 |
# primary_hue="blue",
|
|
|
|
| 37 |
|
| 38 |
OUTPUT_DIR = Path("./output/gradio")
|
| 39 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
NUM_SAMPLE = 1
|
| 41 |
|
| 42 |
# 创建RLHF反馈数据目录
|
| 43 |
FEEDBACK_DIR = Path("./rlhf")
|
|
|
|
| 175 |
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
|
| 176 |
log.info(f"Audio saved to {save_path}")
|
| 177 |
save_paths.append(str(save_path))
|
| 178 |
+
|
| 179 |
if device == "cuda":
|
| 180 |
torch.cuda.empty_cache()
|
| 181 |
|
| 182 |
+
return save_paths[0], prompt
|
| 183 |
|
| 184 |
|
| 185 |
# Gradio input and output components
|
|
|
|
| 195 |
fn=generate_audio_gradio,
|
| 196 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
| 197 |
outputs=[
|
| 198 |
+
gr.Audio(label="🎵 Audio Sample", type="filepath"),
|
|
|
|
| 199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
| 200 |
],
|
| 201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
| 202 |
+
description="",
|
| 203 |
flagging_mode="never",
|
| 204 |
examples=[
|
| 205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
|
|
|
| 218 |
cache_examples="lazy",
|
| 219 |
)
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
if __name__ == "__main__":
|
| 223 |
ensure_models_downloaded()
|
| 224 |
load_model_cache()
|
| 225 |
+
gr_interface.queue(15).launch()
|
| 226 |
|
| 227 |
# theme = gr.themes.Soft(
|
| 228 |
# primary_hue="blue",
|