| | |
| | |
| | |
| | |
| | |
| |
|
| | import sys |
| |
|
| | import tqdm |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data.distributed import DistributedSampler |
| |
|
| | from .utils import apply_model, average_metric, center_trim |
| |
|
| |
|
| | def train_model(epoch, |
| | dataset, |
| | model, |
| | criterion, |
| | optimizer, |
| | augment, |
| | quantizer=None, |
| | diffq=0, |
| | repeat=1, |
| | device="cpu", |
| | seed=None, |
| | workers=4, |
| | world_size=1, |
| | batch_size=16): |
| |
|
| | if world_size > 1: |
| | sampler = DistributedSampler(dataset) |
| | sampler_epoch = epoch * repeat |
| | if seed is not None: |
| | sampler_epoch += seed * 1000 |
| | sampler.set_epoch(sampler_epoch) |
| | batch_size //= world_size |
| | loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers) |
| | else: |
| | loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True) |
| | current_loss = 0 |
| | model_size = 0 |
| | for repetition in range(repeat): |
| | tq = tqdm.tqdm(loader, |
| | ncols=120, |
| | desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})", |
| | leave=False, |
| | file=sys.stdout, |
| | unit=" batch") |
| | total_loss = 0 |
| | for idx, sources in enumerate(tq): |
| | if len(sources) < batch_size: |
| | |
| | continue |
| | sources = sources.to(device) |
| | sources = augment(sources) |
| | mix = sources.sum(dim=1) |
| |
|
| | estimates = model(mix) |
| | sources = center_trim(sources, estimates) |
| | loss = criterion(estimates, sources) |
| | model_size = 0 |
| | if quantizer is not None: |
| | model_size = quantizer.model_size() |
| |
|
| | train_loss = loss + diffq * model_size |
| | train_loss.backward() |
| | grad_norm = 0 |
| | for p in model.parameters(): |
| | if p.grad is not None: |
| | grad_norm += p.grad.data.norm()**2 |
| | grad_norm = grad_norm**0.5 |
| | optimizer.step() |
| | optimizer.zero_grad() |
| |
|
| | if quantizer is not None: |
| | model_size = model_size.item() |
| |
|
| | total_loss += loss.item() |
| | current_loss = total_loss / (1 + idx) |
| | tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}", |
| | grad=f"{grad_norm:.5f}") |
| |
|
| | |
| | del sources, mix, estimates, loss, train_loss |
| |
|
| | if world_size > 1: |
| | sampler.epoch += 1 |
| |
|
| | if world_size > 1: |
| | current_loss = average_metric(current_loss) |
| | return current_loss, model_size |
| |
|
| |
|
| | def validate_model(epoch, |
| | dataset, |
| | model, |
| | criterion, |
| | device="cpu", |
| | rank=0, |
| | world_size=1, |
| | shifts=0, |
| | overlap=0.25, |
| | split=False): |
| | indexes = range(rank, len(dataset), world_size) |
| | tq = tqdm.tqdm(indexes, |
| | ncols=120, |
| | desc=f"[{epoch:03d}] valid", |
| | leave=False, |
| | file=sys.stdout, |
| | unit=" track") |
| | current_loss = 0 |
| | for index in tq: |
| | streams = dataset[index] |
| | |
| | streams = streams[..., :15_000_000] |
| | streams = streams.to(device) |
| | sources = streams[1:] |
| | mix = streams[0] |
| | estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap) |
| | loss = criterion(estimates, sources) |
| | current_loss += loss.item() / len(indexes) |
| | del estimates, streams, sources |
| |
|
| | if world_size > 1: |
| | current_loss = average_metric(current_loss, len(indexes)) |
| | return current_loss |
| |
|