| |
| |
| |
| |
| @@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel(): |
| |
| def __init__(self, opt): |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| self.is_train = opt['is_train'] |
| |
| self.top_encoder = Encoder( |
| |
| |
| |
| |
| @@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel(): |
| |
| def __init__(self, opt): |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| self.top_encoder = Encoder( |
| ch=opt['top_ch'], |
| num_res_blocks=opt['top_num_res_blocks'], |
| |
| |
| |
| |
| @@ -22,7 +22,7 @@ class ParsingGenModel(): |
| |
| def __init__(self, opt): |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| self.is_train = opt['is_train'] |
| |
| self.attr_embedder = ShapeAttrEmbedding( |
| |
| |
| |
| |
| @@ -23,7 +23,7 @@ class BaseSampleModel(): |
| |
| def __init__(self, opt): |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| |
| # hierarchical VQVAE |
| self.decoder = Decoder( |
| @@ -123,7 +123,7 @@ class BaseSampleModel(): |
| |
| def load_top_pretrain_models(self): |
| # load pretrained vqgan |
| - top_vae_checkpoint = torch.load(self.opt['top_vae_path']) |
| + top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device) |
| |
| self.decoder.load_state_dict( |
| top_vae_checkpoint['decoder'], strict=True) |
| @@ -137,7 +137,7 @@ class BaseSampleModel(): |
| self.top_post_quant_conv.eval() |
| |
| def load_bot_pretrain_network(self): |
| - checkpoint = torch.load(self.opt['bot_vae_path']) |
| + checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device) |
| self.bot_decoder_res.load_state_dict( |
| checkpoint['bot_decoder_res'], strict=True) |
| self.decoder.load_state_dict(checkpoint['decoder'], strict=True) |
| @@ -153,7 +153,7 @@ class BaseSampleModel(): |
| |
| def load_pretrained_segm_token(self): |
| # load pretrained vqgan for segmentation mask |
| - segm_token_checkpoint = torch.load(self.opt['segm_token_path']) |
| + segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device) |
| self.segm_encoder.load_state_dict( |
| segm_token_checkpoint['encoder'], strict=True) |
| self.segm_quantizer.load_state_dict( |
| @@ -166,7 +166,7 @@ class BaseSampleModel(): |
| self.segm_quant_conv.eval() |
| |
| def load_index_pred_network(self): |
| - checkpoint = torch.load(self.opt['pretrained_index_network']) |
| + checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device) |
| self.index_pred_guidance_encoder.load_state_dict( |
| checkpoint['guidance_encoder'], strict=True) |
| self.index_pred_decoder.load_state_dict( |
| @@ -176,7 +176,7 @@ class BaseSampleModel(): |
| self.index_pred_decoder.eval() |
| |
| def load_sampler_pretrained_network(self): |
| - checkpoint = torch.load(self.opt['pretrained_sampler']) |
| + checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device) |
| self.sampler_fn.load_state_dict(checkpoint, strict=True) |
| self.sampler_fn.eval() |
| |
| @@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel): |
| [185, 210, 205], [130, 165, 180], [225, 141, 151]] |
| |
| def load_shape_generation_models(self): |
| - checkpoint = torch.load(self.opt['pretrained_parsing_gen']) |
| + checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device) |
| |
| self.shape_attr_embedder.load_state_dict( |
| checkpoint['embedder'], strict=True) |
| |
| |
| |
| |
| @@ -21,7 +21,7 @@ class TransformerTextureAwareModel(): |
| |
| def __init__(self, opt): |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| self.is_train = opt['is_train'] |
| |
| # VQVAE for image |
| @@ -317,10 +317,10 @@ class TransformerTextureAwareModel(): |
| def sample_fn(self, temp=1.0, sample_steps=None): |
| self._denoise_fn.eval() |
| |
| - b, device = self.image.size(0), 'cuda' |
| + b = self.image.size(0) |
| x_t = torch.ones( |
| - (b, np.prod(self.shape)), device=device).long() * self.mask_id |
| - unmasked = torch.zeros_like(x_t, device=device).bool() |
| + (b, np.prod(self.shape)), device=self.device).long() * self.mask_id |
| + unmasked = torch.zeros_like(x_t, device=self.device).bool() |
| sample_steps = list(range(1, sample_steps + 1)) |
| |
| texture_mask_flatten = self.texture_tokens.view(-1) |
| @@ -336,11 +336,11 @@ class TransformerTextureAwareModel(): |
| |
| for t in reversed(sample_steps): |
| print(f'Sample timestep {t:4d}', end='\r') |
| - t = torch.full((b, ), t, device=device, dtype=torch.long) |
| + t = torch.full((b, ), t, device=self.device, dtype=torch.long) |
| |
| # where to unmask |
| changes = torch.rand( |
| - x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) |
| + x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) |
| # don't unmask somewhere already unmasked |
| changes = torch.bitwise_xor(changes, |
| torch.bitwise_and(changes, unmasked)) |
| |
| |
| |
| |
| @@ -20,7 +20,7 @@ class VQModel(): |
| def __init__(self, opt): |
| super().__init__() |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| self.encoder = Encoder( |
| ch=opt['ch'], |
| num_res_blocks=opt['num_res_blocks'], |
| @@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel): |
| |
| def __init__(self, opt): |
| self.opt = opt |
| - self.device = torch.device('cuda') |
| + self.device = torch.device(opt['device']) |
| self.encoder = Encoder( |
| ch=opt['ch'], |
| num_res_blocks=opt['num_res_blocks'], |
|
|