Spaces:
Running
on
T4
Running
on
T4
try to figure out how ZeroGPU works
Browse files
Architectures/ControllabilityGAN/wgan/wgan_qc.py
CHANGED
|
@@ -245,7 +245,10 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
|
|
| 245 |
latent_samples = latent_samples.to(self.device)
|
| 246 |
if nograd:
|
| 247 |
with torch.no_grad():
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
| 249 |
else:
|
| 250 |
generated_data = self.G(latent_samples)
|
| 251 |
self.G.train()
|
|
|
|
| 245 |
latent_samples = latent_samples.to(self.device)
|
| 246 |
if nograd:
|
| 247 |
with torch.no_grad():
|
| 248 |
+
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
| 249 |
+
generated_data = self.G.module(latent_samples.to("cpu"), return_intermediate=return_intermediate)
|
| 250 |
+
else:
|
| 251 |
+
generated_data = self.G(latent_samples.to("cpu"), return_intermediate=return_intermediate)
|
| 252 |
else:
|
| 253 |
generated_data = self.G(latent_samples)
|
| 254 |
self.G.train()
|