Spaces:
Runtime error
Runtime error
Fixed normalization
Browse files
app.py
CHANGED
|
@@ -43,7 +43,7 @@ def waveformer(audio, label_choices):
|
|
| 43 |
if fs != 44100:
|
| 44 |
raise ValueError(fs)
|
| 45 |
mixture = torch.from_numpy(
|
| 46 |
-
mixture).unsqueeze(0).unsqueeze(0).to(torch.float)
|
| 47 |
|
| 48 |
# Construct the query vector
|
| 49 |
if len(label_choices) == 0:
|
|
@@ -53,7 +53,7 @@ def waveformer(audio, label_choices):
|
|
| 53 |
query[0, TARGETS.index(t)] = 1.
|
| 54 |
|
| 55 |
with torch.no_grad():
|
| 56 |
-
output = model(mixture, query)
|
| 57 |
|
| 58 |
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
|
| 59 |
|
|
|
|
| 43 |
if fs != 44100:
|
| 44 |
raise ValueError(fs)
|
| 45 |
mixture = torch.from_numpy(
|
| 46 |
+
mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15)
|
| 47 |
|
| 48 |
# Construct the query vector
|
| 49 |
if len(label_choices) == 0:
|
|
|
|
| 53 |
query[0, TARGETS.index(t)] = 1.
|
| 54 |
|
| 55 |
with torch.no_grad():
|
| 56 |
+
output = (2.0 ** 15) * model(mixture, query)
|
| 57 |
|
| 58 |
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
|
| 59 |
|