Ada312 commited on
Commit
a66bed7
·
verified ·
1 Parent(s): 310317c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -1,41 +1,57 @@
1
- import os, numpy as np, torch, gradio as gr, librosa
 
 
 
 
2
  from huggingface_hub import hf_hub_download
3
- from model import DCCRN # 确保你上传了 model.py utils 依赖
4
 
 
5
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6
- SR = 16000 # 你的模型训练采样率
7
 
8
- # 从环境变量读取模型仓库名和权重文件
9
- REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN") # 你的模型仓库
10
- FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt") # 权重文件
11
- TOKEN = os.getenv("HF_TOKEN") # 如果模型仓库是私有,就需要这个
12
 
13
- # 下载权重到本地缓存
14
  ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
15
 
16
- # 初始化模型并加载权重
17
- net = DCCRN()
18
  ckpt = torch.load(ckpt_path, map_location=DEVICE)
19
  state = ckpt.get("state_dict", ckpt)
20
- state = {k.replace("model.","").replace("module.",""): v for k,v in state.items()}
21
  net.load_state_dict(state, strict=False)
22
  net.to(DEVICE).eval()
23
 
24
- # 推理函数:输入 noisy audio → 输出 enhanced audio
25
  def enhance(audio_path: str):
26
  wav, _ = librosa.load(audio_path, sr=SR, mono=True)
27
- x = torch.from_numpy(wav).float().to(DEVICE)[None, None, :]
 
 
 
28
  with torch.no_grad():
29
- y = net(x).squeeze().cpu().numpy()
 
 
 
 
 
 
30
  return (SR, y)
31
 
32
- # Gradio 界面
33
  with gr.Blocks() as demo:
34
  gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。")
35
  with gr.Row():
36
- inp = gr.Audio(sources=["upload","microphone"], type="filepath", label="Noisy speech")
37
  out = gr.Audio(label="Enhanced speech")
38
- gr.Button("去噪").click(enhance, inputs=inp, outputs=out)
 
 
39
 
40
- demo.queue(concurrency_count=1, max_size=8)
 
41
  demo.launch()
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ import librosa
6
  from huggingface_hub import hf_hub_download
7
+ from model import DCCRN # 确保已有 model.py utils/ 依赖
8
 
9
+ # ===== 基本配置 =====
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ SR = int(os.getenv("SAMPLE_RATE", "16000"))
12
 
13
+ # 从环境变量读取模型仓库权重文件
14
+ REPO_ID = os.getenv("MODEL_REPO_ID", "Ada312/DCCRN")
15
+ FILENAME = os.getenv("MODEL_FILENAME", "dccrn.ckpt")
16
+ TOKEN = os.getenv("HF_TOKEN") # 私有模型仓库需要
17
 
18
+ # ===== 下载并加载权重 =====
19
  ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=TOKEN)
20
 
21
+ net = DCCRN() # 如果训练时用了自定义参数,请按实际填入
 
22
  ckpt = torch.load(ckpt_path, map_location=DEVICE)
23
  state = ckpt.get("state_dict", ckpt)
24
+ state = {k.replace("model.", "").replace("module.", ""): v for k, v in state.items()}
25
  net.load_state_dict(state, strict=False)
26
  net.to(DEVICE).eval()
27
 
28
+ # ===== 推理函数 =====
29
  def enhance(audio_path: str):
30
  wav, _ = librosa.load(audio_path, sr=SR, mono=True)
31
+ x = torch.from_numpy(wav).float().to(DEVICE)
32
+ if x.ndim == 1:
33
+ x = x.unsqueeze(0) # [1, T]
34
+
35
  with torch.no_grad():
36
+ # 许多 DCCRN 期望 [B, 1, T],先尝试该形状;不行再退回 [B, T]
37
+ try:
38
+ y = net(x.unsqueeze(1)) # [1, 1, T]
39
+ except Exception:
40
+ y = net(x) # [1, T]
41
+
42
+ y = y.squeeze().detach().cpu().numpy()
43
  return (SR, y)
44
 
45
+ # ===== Gradio 界面 =====
46
  with gr.Blocks() as demo:
47
  gr.Markdown("## 🎧 DCCRN Speech Enhancement\n上传或录音,点击“去噪”。")
48
  with gr.Row():
49
+ inp = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Noisy speech")
50
  out = gr.Audio(label="Enhanced speech")
51
+ btn = gr.Button("去噪")
52
+ # 新写法:把并发限制写在事件监听器上
53
+ btn.click(enhance, inputs=inp, outputs=out, concurrency_limit=1)
54
 
55
+ # 队列:保留排队上限即可(不再使用已废弃的 concurrency_count)
56
+ demo.queue(max_size=8)
57
  demo.launch()