Update Exp2_Global_ocean_simulation/train.py
Browse files
Exp2_Global_ocean_simulation/train.py
CHANGED
|
@@ -275,7 +275,7 @@ class Trainer():
|
|
| 275 |
|
| 276 |
if self.params.enable_amp:
|
| 277 |
self.gscaler.update()
|
| 278 |
-
break
|
| 279 |
|
| 280 |
tr_time += time.time() - tr_start
|
| 281 |
|
|
@@ -308,8 +308,8 @@ class Trainer():
|
|
| 308 |
sample_idx = np.random.randint(len(self.valid_data_loader))
|
| 309 |
with torch.no_grad():
|
| 310 |
for i, data in enumerate(self.valid_data_loader, 0):
|
| 311 |
-
if i > 1:
|
| 312 |
-
|
| 313 |
inp, tar = map(lambda x: x.to(self.device, dtype=torch.float), data)
|
| 314 |
# gen = self.model(inp)
|
| 315 |
num_steps = params.multi_steps_finetune
|
|
|
|
| 275 |
|
| 276 |
if self.params.enable_amp:
|
| 277 |
self.gscaler.update()
|
| 278 |
+
# break
|
| 279 |
|
| 280 |
tr_time += time.time() - tr_start
|
| 281 |
|
|
|
|
| 308 |
sample_idx = np.random.randint(len(self.valid_data_loader))
|
| 309 |
with torch.no_grad():
|
| 310 |
for i, data in enumerate(self.valid_data_loader, 0):
|
| 311 |
+
# if i > 1:
|
| 312 |
+
# break
|
| 313 |
inp, tar = map(lambda x: x.to(self.device, dtype=torch.float), data)
|
| 314 |
# gen = self.model(inp)
|
| 315 |
num_steps = params.multi_steps_finetune
|