Spaces:
Runtime error
Runtime error
Commit ·
aecaaea
1
Parent(s): f8fea2a
Fix causal cropping for input metrics
Browse files- remfx/models.py +4 -2
remfx/models.py
CHANGED
|
@@ -188,8 +188,9 @@ class RemFX(pl.LightningModule):
|
|
| 188 |
|
| 189 |
loss, output = self.model((x, y))
|
| 190 |
# Crop target to match output
|
|
|
|
| 191 |
if output.shape[-1] < y.shape[-1]:
|
| 192 |
-
|
| 193 |
self.log(f"{mode}_loss", loss)
|
| 194 |
# Metric logging
|
| 195 |
with torch.no_grad():
|
|
@@ -204,13 +205,14 @@ class RemFX(pl.LightningModule):
|
|
| 204 |
continue
|
| 205 |
self.log(
|
| 206 |
f"{mode}_{metric}",
|
| 207 |
-
negate * self.metrics[metric](output,
|
| 208 |
on_step=False,
|
| 209 |
on_epoch=True,
|
| 210 |
logger=True,
|
| 211 |
prog_bar=True,
|
| 212 |
sync_dist=True,
|
| 213 |
)
|
|
|
|
| 214 |
self.log(
|
| 215 |
f"Input_{metric}",
|
| 216 |
negate * self.metrics[metric](x, y),
|
|
|
|
| 188 |
|
| 189 |
loss, output = self.model((x, y))
|
| 190 |
# Crop target to match output
|
| 191 |
+
target = y
|
| 192 |
if output.shape[-1] < y.shape[-1]:
|
| 193 |
+
target = causal_crop(y, output.shape[-1])
|
| 194 |
self.log(f"{mode}_loss", loss)
|
| 195 |
# Metric logging
|
| 196 |
with torch.no_grad():
|
|
|
|
| 205 |
continue
|
| 206 |
self.log(
|
| 207 |
f"{mode}_{metric}",
|
| 208 |
+
negate * self.metrics[metric](output, target),
|
| 209 |
on_step=False,
|
| 210 |
on_epoch=True,
|
| 211 |
logger=True,
|
| 212 |
prog_bar=True,
|
| 213 |
sync_dist=True,
|
| 214 |
)
|
| 215 |
+
|
| 216 |
self.log(
|
| 217 |
f"Input_{metric}",
|
| 218 |
negate * self.metrics[metric](x, y),
|