Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,644 Bytes
bd096d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: mica@tue.mpg.de
import importlib
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
def find_model_using_name(model_dir, model_name):
# adapted from pix2pix framework: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/__init__.py#L25
# import "model_dir/modelname.py"
model_filename = model_dir + "." + model_name
modellib = importlib.import_module(model_filename, package=model_dir)
# In the file, the class called ModelName() will
# be instantiated. It has to be a subclass of BaseModel,
# and it is case-insensitive.
model = None
target_model_name = model_name.replace('_', '')
for name, cls in modellib.__dict__.items():
# if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
if name.lower() == target_model_name.lower():
model = cls
if model is None:
print("In %s.py, there should be a class with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def visualize_grid(visdict, savepath=None, size=224, dim=1, return_gird=True):
'''
image range should be [0,1]
dim: 2 for horizontal. 1 for vertical
'''
assert dim == 1 or dim == 2
grids = {}
for key in visdict:
b, c, h, w = visdict[key].shape
if dim == 2:
new_h = size
new_w = int(w * size / h)
elif dim == 1:
new_h = int(h * size / w)
new_w = size
grids[key] = torchvision.utils.make_grid(F.interpolate(visdict[key], [new_h, new_w]).detach().cpu(), nrow=b, padding=0)
grid = torch.cat(list(grids.values()), dim)
grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]]
grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8)
if savepath:
cv2.imwrite(savepath, grid_image)
if return_gird:
return grid_image
|