wkqin commited on
Commit
1792312
·
verified ·
1 Parent(s): 6e6b1b7

Upload guider.py

Browse files
Files changed (1) hide show
  1. normal_guider_net/guider.py +19 -0
normal_guider_net/guider.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+
5
+
6
+ class GuiderNet(ModelMixin, ConfigMixin):
7
+ @register_to_config
8
+ def __init__(self, in_channels=3, mid_channels=4, out_channels=8):
9
+ super().__init__()
10
+ self.layers = nn.Sequential(
11
+ nn.Conv2d(in_channels, mid_channels, 4, 2, 1),
12
+ nn.SiLU(),
13
+ nn.Conv2d(mid_channels, mid_channels, 4, 2, 1),
14
+ nn.SiLU(),
15
+ nn.Conv2d(mid_channels, out_channels, 4, 2, 1),
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.layers(x)