vyluong commited on
Commit
d975a12
·
verified ·
1 Parent(s): 03282ea

Update app/services/denoiser.py

Browse files
Files changed (1) hide show
  1. app/services/denoiser.py +50 -19
app/services/denoiser.py CHANGED
@@ -5,7 +5,11 @@ import logging
5
  import torch
6
  import torchaudio
7
 
8
- from df.enhance import enhance, init_df
 
 
 
 
9
 
10
  from app.core.config import get_settings
11
 
@@ -17,6 +21,8 @@ class DenoiserService:
17
 
18
  _model = None
19
  _df_state = None
 
 
20
 
21
  @classmethod
22
  def _load_model(cls):
@@ -24,64 +30,89 @@ class DenoiserService:
24
  if cls._model is not None:
25
  return
26
 
 
 
 
27
  logger.info("Loading DeepFilterNet...")
28
 
29
  model, df_state, _ = init_df()
30
 
 
 
 
 
 
31
  cls._model = model
32
  cls._df_state = df_state
33
 
34
- logger.info("DeepFilterNet READY")
35
 
 
36
  @classmethod
37
- async def enhance_audio(
38
- cls,
39
- input_path: Path
40
- ) -> Path:
41
 
42
  if not settings.enable_denoiser:
43
  return input_path
44
 
45
- loop = asyncio.get_event_loop()
46
 
47
  return await loop.run_in_executor(
48
  None,
49
  lambda: cls._run_enhancement(input_path)
50
  )
51
 
 
52
  @classmethod
53
- def _run_enhancement(
54
- cls,
55
- input_path: Path
56
- ) -> Path:
57
 
58
  try:
59
-
60
  cls._load_model()
61
 
 
 
 
62
  audio, sr = torchaudio.load(str(input_path))
63
 
64
- enhanced = enhance(
65
- cls._model,
66
- cls._df_state,
67
- audio
68
- )
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  output_path = (
71
  settings.processed_dir /
72
  f"{input_path.stem}_enhanced.wav"
73
  )
74
 
 
 
 
 
 
75
  torchaudio.save(
76
  str(output_path),
77
- enhanced.cpu(),
78
  sr
79
  )
80
 
 
 
81
  return output_path
82
 
83
  except Exception as e:
84
 
85
- logger.exception(e)
86
 
 
87
  return input_path
 
5
  import torch
6
  import torchaudio
7
 
8
+ try:
9
+ from df.enhance import enhance, init_df
10
+ DF_AVAILABLE = True
11
+ except Exception:
12
+ DF_AVAILABLE = False
13
 
14
  from app.core.config import get_settings
15
 
 
21
 
22
  _model = None
23
  _df_state = None
24
+ _device = None
25
+
26
 
27
  @classmethod
28
  def _load_model(cls):
 
30
  if cls._model is not None:
31
  return
32
 
33
+ if not DF_AVAILABLE:
34
+ raise ImportError("DeepFilterNet is not available")
35
+
36
  logger.info("Loading DeepFilterNet...")
37
 
38
  model, df_state, _ = init_df()
39
 
40
+ cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ model = model.to(cls._device)
43
+ model.eval()
44
+
45
  cls._model = model
46
  cls._df_state = df_state
47
 
48
+ logger.info(f"DeepFilterNet READY on {cls._device}")
49
 
50
+
51
  @classmethod
52
+ async def enhance_audio(cls, input_path: Path) -> Path:
 
 
 
53
 
54
  if not settings.enable_denoiser:
55
  return input_path
56
 
57
+ loop = asyncio.get_running_loop()
58
 
59
  return await loop.run_in_executor(
60
  None,
61
  lambda: cls._run_enhancement(input_path)
62
  )
63
 
64
+
65
  @classmethod
66
+ def _run_enhancement(cls, input_path: Path) -> Path:
 
 
 
67
 
68
  try:
 
69
  cls._load_model()
70
 
71
+ # ----------------------------
72
+ # LOAD AUDIO
73
+ # ----------------------------
74
  audio, sr = torchaudio.load(str(input_path))
75
 
76
+ # mono conversion
77
+ if audio.shape[0] > 1:
78
+ audio = torch.mean(audio, dim=0, keepdim=True)
79
+
80
+ audio = audio.float()
81
+
82
+ # move to device
83
+ audio = audio.to(cls._device)
84
+
85
+
86
+ with torch.no_grad():
87
+ enhanced = enhance(
88
+ cls._model,
89
+ cls._df_state,
90
+ audio
91
+ )
92
 
93
  output_path = (
94
  settings.processed_dir /
95
  f"{input_path.stem}_enhanced.wav"
96
  )
97
 
98
+ output_path.parent.mkdir(parents=True, exist_ok=True)
99
+
100
+ # move back CPU before save
101
+ enhanced = enhanced.cpu()
102
+
103
  torchaudio.save(
104
  str(output_path),
105
+ enhanced,
106
  sr
107
  )
108
 
109
+ logger.info(f"Denoised audio saved: {output_path}")
110
+
111
  return output_path
112
 
113
  except Exception as e:
114
 
115
+ logger.exception("DeepFilterNet enhancement failed")
116
 
117
+ # fallback = original file
118
  return input_path