yoyolicoris commited on
Commit
bbb9d09
·
1 Parent(s): f4d4abb

feat: read device configuration from DEVICE.txt for dynamic device management

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -47,8 +47,7 @@ Try to play around with the sliders and buttons and see what you can come up wit
47
  > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block.
48
  """
49
 
50
- # DEVICE = "cpu"
51
- DEVICE = Path("DEVICE.txt").read_text()
52
  SLIDER_MAX = 3
53
  SLIDER_MIN = -3
54
  NUMBER_OF_PCS = 4
@@ -153,15 +152,16 @@ meter = pyln.Meter(44100)
153
 
154
 
155
  def get_embedding_model(embedding: str) -> Callable:
 
156
  match embedding:
157
  case "afx-rep":
158
- afx_rep = load_param_model().to(DEVICE)
159
  two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
160
  case "mfcc":
161
- mfcc = load_mfcc_feature_extractor().to(DEVICE)
162
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc)
163
  case "mir":
164
- mir = load_mir_feature_extractor().to(DEVICE)
165
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mir)
166
  case _:
167
  raise ValueError(f"Unknown encoder: {embedding}")
@@ -190,23 +190,24 @@ def inference(
190
  optimiser,
191
  lr,
192
  ):
 
193
  if method == "Mean":
194
- return gaussian_params_dict[dataset][0].to(DEVICE)
195
 
196
  ref = convert2float(*ref_audio)
197
  ref_loudness = meter.integrated_loudness(ref)
198
  ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
199
- ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(DEVICE)
200
 
201
  y = convert2float(*input_audio)
202
  loudness = meter.integrated_loudness(y)
203
  y = pyln.normalize.loudness(y, loudness, -18.0)
204
- y = torch.from_numpy(y).float().T.unsqueeze(0).to(DEVICE)
205
 
206
  if y.shape[1] != 1:
207
  y = y.mean(dim=1, keepdim=True)
208
 
209
- fx = deepcopy(global_fx).to(DEVICE)
210
  fx.train()
211
 
212
  two_chs_emb_fn = chain_functions(
@@ -225,8 +226,8 @@ def inference(
225
  fx,
226
  two_chs_emb_fn,
227
  to_fx_state_dict,
228
- partial(logp_x, *[x.to(DEVICE) for x in gaussian_params_dict[dataset]]),
229
- internal_mean.to(DEVICE),
230
  ref,
231
  y,
232
  optimiser_type=optimiser,
@@ -242,14 +243,15 @@ def inference(
242
 
243
 
244
  def render(y, remove_approx, ratio, vec):
 
245
  y = convert2float(*y)
246
  loudness = meter.integrated_loudness(y)
247
  y = pyln.normalize.loudness(y, loudness, -18.0)
248
- y = torch.from_numpy(y).float().T.unsqueeze(0).to(DEVICE)
249
  if remove_approx:
250
- infer_fx = instantiate(rt_config).to(DEVICE)
251
  else:
252
- infer_fx = instantiate(fx_config).to(DEVICE)
253
 
254
  infer_fx.load_state_dict(vec2dict(vec), strict=False)
255
  # fx.apply(partial(clip_delay_eq_Q, Q=0.707))
 
47
  > **_Note:_** To upload your own audio, click X on the top right corner of the input audio block.
48
  """
49
 
50
+ # device = "cpu"
 
51
  SLIDER_MAX = 3
52
  SLIDER_MIN = -3
53
  NUMBER_OF_PCS = 4
 
152
 
153
 
154
  def get_embedding_model(embedding: str) -> Callable:
155
+ device = Path("device.txt").read_text()
156
  match embedding:
157
  case "afx-rep":
158
+ afx_rep = load_param_model().to(device)
159
  two_chs_emb_fn = lambda x: get_param_embeds(x, afx_rep, 44100)
160
  case "mfcc":
161
+ mfcc = load_mfcc_feature_extractor().to(device)
162
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mfcc)
163
  case "mir":
164
+ mir = load_mir_feature_extractor().to(device)
165
  two_chs_emb_fn = lambda x: get_feature_embeds(x, mir)
166
  case _:
167
  raise ValueError(f"Unknown encoder: {embedding}")
 
190
  optimiser,
191
  lr,
192
  ):
193
+ device = Path("device.txt").read_text()
194
  if method == "Mean":
195
+ return gaussian_params_dict[dataset][0].to(device)
196
 
197
  ref = convert2float(*ref_audio)
198
  ref_loudness = meter.integrated_loudness(ref)
199
  ref = pyln.normalize.loudness(ref, ref_loudness, -18.0)
200
+ ref = torch.from_numpy(ref).float().T.unsqueeze(0).to(device)
201
 
202
  y = convert2float(*input_audio)
203
  loudness = meter.integrated_loudness(y)
204
  y = pyln.normalize.loudness(y, loudness, -18.0)
205
+ y = torch.from_numpy(y).float().T.unsqueeze(0).to(device)
206
 
207
  if y.shape[1] != 1:
208
  y = y.mean(dim=1, keepdim=True)
209
 
210
+ fx = deepcopy(global_fx).to(device)
211
  fx.train()
212
 
213
  two_chs_emb_fn = chain_functions(
 
226
  fx,
227
  two_chs_emb_fn,
228
  to_fx_state_dict,
229
+ partial(logp_x, *[x.to(device) for x in gaussian_params_dict[dataset]]),
230
+ internal_mean.to(device),
231
  ref,
232
  y,
233
  optimiser_type=optimiser,
 
243
 
244
 
245
  def render(y, remove_approx, ratio, vec):
246
+ device = Path("device.txt").read_text()
247
  y = convert2float(*y)
248
  loudness = meter.integrated_loudness(y)
249
  y = pyln.normalize.loudness(y, loudness, -18.0)
250
+ y = torch.from_numpy(y).float().T.unsqueeze(0).to(device)
251
  if remove_approx:
252
+ infer_fx = instantiate(rt_config).to(device)
253
  else:
254
+ infer_fx = instantiate(fx_config).to(device)
255
 
256
  infer_fx.load_state_dict(vec2dict(vec), strict=False)
257
  # fx.apply(partial(clip_delay_eq_Q, Q=0.707))