hikmatfarhat commited on
Commit
cd4cfe7
·
1 Parent(s): f4025ed

Upload WGAN_GP

Browse files
Files changed (4) hide show
  1. config.json +80 -0
  2. config.py +6 -0
  3. model.py +17 -0
  4. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WGAN_GP"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.Config",
7
+ "AutoModel": "model.WGAN_GP"
8
+ },
9
+ "batch_size": 256,
10
+ "cfg": {
11
+ "batch_size": 256,
12
+ "comet_project": "wgan-gp",
13
+ "comet_workspace": "wgan",
14
+ "d_iter_per_g": 5,
15
+ "data_dir": "/home/user/PyTorch-Lightning-GAN/GAN/celeba",
16
+ "device": "cuda",
17
+ "epochs": 500,
18
+ "final_activation": {
19
+ "g": "tanh"
20
+ },
21
+ "images_dir": "images",
22
+ "img_ch": 3,
23
+ "imsize": 64,
24
+ "lr": {
25
+ "d": 0.0001,
26
+ "g": 0.0001
27
+ },
28
+ "name": "WGAN-GP",
29
+ "norm_type": {
30
+ "d": "GroupNorm",
31
+ "g": "GroupNorm"
32
+ },
33
+ "num_sample_epochs": 20,
34
+ "num_workers": 4,
35
+ "resume": false,
36
+ "samples_dir": "samples",
37
+ "save_image_freq": 50,
38
+ "save_model_freq": 100,
39
+ "seed": 42,
40
+ "use_fabric": false,
41
+ "w_gp": 10,
42
+ "weights_dir": "chkpt",
43
+ "zdim": 128
44
+ },
45
+ "comet_project": "wgan-gp",
46
+ "comet_workspace": "wgan",
47
+ "d_iter_per_g": 5,
48
+ "data_dir": "/home/user/PyTorch-Lightning-GAN/GAN/celeba",
49
+ "device": "cuda",
50
+ "epochs": 500,
51
+ "final_activation": {
52
+ "g": "tanh"
53
+ },
54
+ "images_dir": "images",
55
+ "img_ch": 3,
56
+ "imsize": 64,
57
+ "lr": {
58
+ "d": 0.0001,
59
+ "g": 0.0001
60
+ },
61
+ "model_type": "WGAN_GP",
62
+ "name": "WGAN-GP",
63
+ "norm_type": {
64
+ "d": "GroupNorm",
65
+ "g": "GroupNorm"
66
+ },
67
+ "num_sample_epochs": 20,
68
+ "num_workers": 4,
69
+ "resume": false,
70
+ "samples_dir": "samples",
71
+ "save_image_freq": 50,
72
+ "save_model_freq": 100,
73
+ "seed": 42,
74
+ "torch_dtype": "float32",
75
+ "transformers_version": "4.35.2",
76
+ "use_fabric": false,
77
+ "w_gp": 10,
78
+ "weights_dir": "chkpt",
79
+ "zdim": 128
80
+ }
config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ class Config(PretrainedConfig):
3
+ model_type = "WGAN_GP"
4
+ def __init__(self, **kwargs):
5
+ super().__init__(**kwargs)
6
+ self.cfg=kwargs
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ from .config import Config
3
+ from transformers import PreTrainedModel
4
+ from dcgan import Generator
5
+
6
+ # config = Config()
7
+ # config.save_pretrained("WGAN-GP")
8
+ class WGAN_GP(PreTrainedModel):
9
+ config_class = Config
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ self.generator=Generator(config.cfg["imsize"],config.cfg["img_ch"],config.cfg["zdim"],
14
+ config.cfg["norm_type"]["g"],config.cfg["final_activation"]["g"])
15
+
16
+ def forward(self, input):
17
+ return self.generator(input)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06bd22721a4ef7a2db42ec94c3a9b5fca1aca5bbd7352502d3e6617e2e84639f
3
+ size 52470080