| |
| |
| import torch |
| import torch.nn as nn |
|
|
| def norm(img): |
| low=float(img.min()) |
| high=float(img.max()) |
| img.sub_(low).div_(max(high - low, 1e-5)) |
| def random_sample(batch_size, z_dim, device): |
| |
| |
| return torch.randn(batch_size,z_dim, 1, 1).to(device) |
|
|
|
|
| def init_weight(m): |
| if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): |
| |
| |
| nn.init.normal_(m.weight, 0, 0.02) |
| if m.bias is not None: |
| if m.bias.data is not None: |
| m.bias.data.zero_() |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| if m.bias.data is not None: |
| m.bias.data.zero_() |
|
|