Update InpaintReward.py
Browse files- InpaintReward.py +4 -4
InpaintReward.py
CHANGED
|
@@ -54,7 +54,7 @@ class ViTBlock(nn.Module):
|
|
| 54 |
return x
|
| 55 |
|
| 56 |
|
| 57 |
-
class
|
| 58 |
def __init__(self, config, device='cpu'):
|
| 59 |
super().__init__()
|
| 60 |
self.config = config
|
|
@@ -62,7 +62,7 @@ class ImageReward(nn.Module):
|
|
| 62 |
|
| 63 |
self.clip_model, self.preprocess = clip.load("ViT-B/32")
|
| 64 |
self.clip_model = self.clip_model.float()
|
| 65 |
-
self.mlp = MLP(self.config['
|
| 66 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
| 67 |
|
| 68 |
self.toImage = transforms.ToPILImage()
|
|
@@ -127,7 +127,7 @@ class ImageReward(nn.Module):
|
|
| 127 |
|
| 128 |
|
| 129 |
|
| 130 |
-
class
|
| 131 |
def __init__(self, config, device='cpu'):
|
| 132 |
super().__init__()
|
| 133 |
self.config = config
|
|
@@ -136,7 +136,7 @@ class ImageRewardGroup(nn.Module):
|
|
| 136 |
self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
|
| 137 |
|
| 138 |
self.clip_model = self.clip_model.float()
|
| 139 |
-
self.mlp = MLP(config['
|
| 140 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
| 141 |
|
| 142 |
if self.config.fix_base:
|
|
|
|
| 54 |
return x
|
| 55 |
|
| 56 |
|
| 57 |
+
class InpaintReward(nn.Module):
|
| 58 |
def __init__(self, config, device='cpu'):
|
| 59 |
super().__init__()
|
| 60 |
self.config = config
|
|
|
|
| 62 |
|
| 63 |
self.clip_model, self.preprocess = clip.load("ViT-B/32")
|
| 64 |
self.clip_model = self.clip_model.float()
|
| 65 |
+
self.mlp = MLP(self.config['Reward']['mlp_dim'])
|
| 66 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
| 67 |
|
| 68 |
self.toImage = transforms.ToPILImage()
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
|
| 130 |
+
class InpaintRewardGroup(nn.Module):
|
| 131 |
def __init__(self, config, device='cpu'):
|
| 132 |
super().__init__()
|
| 133 |
self.config = config
|
|
|
|
| 136 |
self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
|
| 137 |
|
| 138 |
self.clip_model = self.clip_model.float()
|
| 139 |
+
self.mlp = MLP(config['Reward']['mlp_dim'])
|
| 140 |
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
|
| 141 |
|
| 142 |
if self.config.fix_base:
|