| | |
| | |
| | |
| | |
| | @@ -19,7 +19,7 @@ class ModulationModule(Module): |
| | |
| | def forward(self, x, embedding, cut_flag): |
| | x = self.fc(x) |
| | - x = self.norm(x) |
| | + x = self.norm(x) |
| | if cut_flag == 1: |
| | return x |
| | gamma = self.gamma_function(embedding.float()) |
| | @@ -39,20 +39,20 @@ class SubHairMapper(Module): |
| | def forward(self, x, embedding, cut_flag=0): |
| | x = self.pixelnorm(x) |
| | for modulation_module in self.modulation_module_list: |
| | - x = modulation_module(x, embedding, cut_flag) |
| | + x = modulation_module(x, embedding, cut_flag) |
| | return x |
| | |
| | -class HairMapper(Module): |
| | +class HairMapper(Module): |
| | def __init__(self, opts): |
| | super(HairMapper, self).__init__() |
| | self.opts = opts |
| | - self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") |
| | + self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device) |
| | self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) |
| | self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) |
| | self.hairstyle_cut_flag = 0 |
| | self.color_cut_flag = 0 |
| | |
| | - if not opts.no_coarse_mapper: |
| | + if not opts.no_coarse_mapper: |
| | self.course_mapping = SubHairMapper(opts, 4) |
| | if not opts.no_medium_mapper: |
| | self.medium_mapping = SubHairMapper(opts, 4) |
| | @@ -70,13 +70,13 @@ class HairMapper(Module): |
| | elif hairstyle_tensor.shape[1] != 1: |
| | hairstyle_embedding = self.gen_image_embedding(hairstyle_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach() |
| | else: |
| | - hairstyle_embedding = torch.ones(x.shape[0], 18, 512).cuda() |
| | + hairstyle_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device) |
| | if color_text_inputs.shape[1] != 1: |
| | color_embedding = self.clip_model.encode_text(color_text_inputs).unsqueeze(1).repeat(1, 18, 1).detach() |
| | elif color_tensor.shape[1] != 1: |
| | color_embedding = self.gen_image_embedding(color_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach() |
| | else: |
| | - color_embedding = torch.ones(x.shape[0], 18, 512).cuda() |
| | + color_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device) |
| | |
| | |
| | if (hairstyle_text_inputs.shape[1] == 1) and (hairstyle_tensor.shape[1] == 1): |
| | @@ -106,4 +106,4 @@ class HairMapper(Module): |
| | x_fine = torch.zeros_like(x_fine) |
| | |
| | out = torch.cat([x_coarse, x_medium, x_fine], dim=1) |
| | - return out |
| | \ No newline at end of file |
| | + return out |
| |
|