jlynxdev commited on
Commit
52e4145
·
0 Parent(s):

initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. .gitignore +2 -0
  3. README.md +43 -0
  4. config.json +6 -0
  5. generator.pth +3 -0
  6. model.py +126 -0
  7. model.safetensors +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ .idea
README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ To load and initialize the `Generator` model from the repository, follow these steps:
2
+
3
+ 1. **Install Required Packages**: Ensure you have the necessary Python packages installed:
4
+
5
+ ```python
6
+ pip install torch omegaconf huggingface_hub
7
+ ```
8
+
9
+ 2. **Download Model Files**: Retrieve the `generator.pth`, `config.json`, and `model.py` files from the Hugging Face repository. You can use the `huggingface_hub` library for this:
10
+
11
+ ```python
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ repo_id = "Kiwinicki/sat2map-generator"
15
+ generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
16
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
17
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
18
+ ```
19
+
20
+ 3. **Load the Model**: Incorporate the downloaded `model.py` to define the `Generator` class, then load the model's state dictionary and configuration:
21
+
22
+ ```python
23
+ import torch
24
+ import json
25
+ from omegaconf import OmegaConf
26
+ import sys
27
+ from pathlib import Path
28
+ from model import Generator
29
+
30
+ # Load configuration
31
+ with open(config_path, "r") as f:
32
+ config_dict = json.load(f)
33
+ cfg = OmegaConf.create(config_dict)
34
+
35
+ # Initialize and load the generator model
36
+ generator = Generator(cfg)
37
+ generator.load_state_dict(torch.load(generator_path))
38
+ generator.eval()
39
+ x = torch.randn([1, cfg['channels'], 256, 256])
40
+ out = generator(x)
41
+ ```
42
+
43
+ Here, `generator` is the initialized model ready for inference.
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "channels": 3,
3
+ "num_features": 64,
4
+ "num_residuals": 12,
5
+ "depth": 4
6
+ }
generator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:925df068c3a6b7110b3be435eb4432399fd337e61ca2b512462f2b596864eca9
3
+ size 59701794
model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import tanh, Tensor
2
+ import torch.nn as nn
3
+ from omegaconf import DictConfig
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ class BaseGenerator(ABC, nn.Module):
8
+ def __init__(self, channels: int = 3):
9
+ super().__init__()
10
+ self.channels = channels
11
+
12
+ @abstractmethod
13
+ def forward(self, x: Tensor) -> Tensor:
14
+ pass
15
+
16
+
17
+ class Generator(BaseGenerator):
18
+ def __init__(self, cfg: DictConfig):
19
+ super().__init__(cfg.channels)
20
+ self.cfg = cfg
21
+ self.model = self._construct_model()
22
+
23
+ def _construct_model(self):
24
+ initial_layer = nn.Sequential(
25
+ nn.Conv2d(
26
+ self.cfg.channels,
27
+ self.cfg.num_features,
28
+ kernel_size=7,
29
+ stride=1,
30
+ padding=3,
31
+ padding_mode="reflect",
32
+ ),
33
+ nn.ReLU(inplace=True),
34
+ )
35
+
36
+ down_blocks = nn.Sequential(
37
+ ConvBlock(
38
+ self.cfg.num_features,
39
+ self.cfg.num_features * 2,
40
+ kernel_size=3,
41
+ stride=2,
42
+ padding=1,
43
+ ),
44
+ ConvBlock(
45
+ self.cfg.num_features * 2,
46
+ self.cfg.num_features * 4,
47
+ kernel_size=3,
48
+ stride=2,
49
+ padding=1,
50
+ ),
51
+ )
52
+
53
+ residual_blocks = nn.Sequential(
54
+ *[
55
+ ResidualBlock(self.cfg.num_features * 4)
56
+ for _ in range(self.cfg.num_residuals)
57
+ ]
58
+ )
59
+
60
+ up_blocks = nn.Sequential(
61
+ ConvBlock(
62
+ self.cfg.num_features * 4,
63
+ self.cfg.num_features * 2,
64
+ down=False,
65
+ kernel_size=3,
66
+ stride=2,
67
+ padding=1,
68
+ output_padding=1,
69
+ ),
70
+ ConvBlock(
71
+ self.cfg.num_features * 2,
72
+ self.cfg.num_features,
73
+ down=False,
74
+ kernel_size=3,
75
+ stride=2,
76
+ padding=1,
77
+ output_padding=1,
78
+ ),
79
+ )
80
+
81
+ last_layer = nn.Conv2d(
82
+ self.cfg.num_features,
83
+ self.cfg.channels,
84
+ kernel_size=7,
85
+ stride=1,
86
+ padding=3,
87
+ padding_mode="reflect",
88
+ )
89
+
90
+ return nn.Sequential(
91
+ initial_layer, down_blocks, residual_blocks, up_blocks, last_layer
92
+ )
93
+
94
+ def forward(self, x: Tensor) -> Tensor:
95
+ return tanh(self.model(x))
96
+
97
+
98
+ class ConvBlock(nn.Module):
99
+ def __init__(
100
+ self, in_channels, out_channels, down=True, use_activation=True, **kwargs
101
+ ):
102
+ super().__init__()
103
+ self.conv = nn.Sequential(
104
+ nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
105
+ if down
106
+ else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
107
+ nn.InstanceNorm2d(out_channels),
108
+ nn.ReLU(inplace=True) if use_activation else nn.Identity(),
109
+ )
110
+
111
+ def forward(self, x: Tensor) -> Tensor:
112
+ return self.conv(x)
113
+
114
+
115
+ class ResidualBlock(nn.Module):
116
+ def __init__(self, channels: int):
117
+ super().__init__()
118
+ self.block = nn.Sequential(
119
+ ConvBlock(channels, channels, kernel_size=3, padding=1),
120
+ ConvBlock(
121
+ channels, channels, use_activation=False, kernel_size=3, padding=1
122
+ ),
123
+ )
124
+
125
+ def forward(self, x: Tensor) -> Tensor:
126
+ return x + self.block(x)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15e1ccacf5b528313d57c55df11eccf643e3efb54a1089ffaf52766c3e4174d4
3
+ size 59680580