csaybar commited on
Commit
054d3ca
·
verified ·
1 Parent(s): 1e5be1e

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ single/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
37
+ single/model.safetensor filter=lfs diff=lfs merge=lfs -text
single/example_data.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a66d52bb558f756d105b41ead9386cdd6f04b4ac9cdc0173b5632aa00f35b244
3
+ size 524504
single/load.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import safetensors.torch
3
+ import segmentation_models_pytorch as smp
4
+ import matplotlib.pyplot as plt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class SegformerBranch(nn.Module):
12
+ def __init__(self, in_channels=4, classes=4):
13
+ super(SegformerBranch, self).__init__()
14
+ self.segformer = smp.Segformer(
15
+ encoder_name="mobilenet_v2",
16
+ encoder_weights=None,
17
+ in_channels=in_channels,
18
+ classes=classes,
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.segformer(x)
23
+
24
+
25
+ class PixelWiseNet(nn.Module):
26
+ def __init__(self, in_channels=4, out_channels=4, base_channels=32):
27
+ super(PixelWiseNet, self).__init__()
28
+ self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False)
29
+ self.bn1 = nn.BatchNorm2d(base_channels)
30
+ self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False)
31
+ self.bn2 = nn.BatchNorm2d(base_channels)
32
+ self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False)
33
+
34
+ def forward(self, x):
35
+ x = F.relu(self.bn1(self.conv1(x)))
36
+ x = F.relu(self.bn2(self.conv2(x)))
37
+ x = self.conv3(x)
38
+ return x
39
+
40
+ class CombinedNet(nn.Module):
41
+ def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False):
42
+ super(CombinedNet, self).__init__()
43
+ self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes)
44
+ self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels)
45
+ self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False)
46
+ self.benchmark = benchmark
47
+
48
+ def forward(self, x):
49
+ seg_out = self.seg_branch(x)
50
+ pixel_out = self.pixel_branch(x)
51
+ fused = seg_out + pixel_out
52
+ out = self.fusion_conv(fused)
53
+ if self.benchmark:
54
+ out = torch.sigmoid(out)
55
+ return out
56
+
57
+
58
+
59
+ # MLSTAC API -----------------------------------------------------------------------
60
+ def example_data(path: pathlib.Path, *args, **kwargs):
61
+ data_f = path / "example_data.safetensor"
62
+ sample = safetensors.torch.load_file(data_f)
63
+ return sample["image"]
64
+
65
+ def trainable_model(path, device: str = "cpu", *args, **kwargs):
66
+ trainable_f = path / "model.safetensor"
67
+
68
+ # Load model parameters
69
+ cloud_model_weights = safetensors.torch.load_file(trainable_f)
70
+ cloud_model = CombinedNet(classes=1)
71
+ cloud_model.load_state_dict(cloud_model_weights)
72
+ cloud_model = cloud_model.eval()
73
+
74
+ return cloud_model
75
+
76
+
77
+ def compiled_model(path, device: str = "cpu", *args, **kwargs):
78
+ trainable_f = path / "model.safetensor"
79
+
80
+ # Load model parameters
81
+ cloud_model_weights = safetensors.torch.load_file(trainable_f)
82
+ cloud_model = CombinedNet(classes=1, benchmark=True)
83
+ cloud_model.load_state_dict(cloud_model_weights)
84
+ cloud_model = cloud_model.eval()
85
+
86
+ # Move model to device
87
+ cloud_model = cloud_model.to(device)
88
+
89
+ # Desativate gradients
90
+ for param in cloud_model.parameters():
91
+ param.requires_grad = False
92
+
93
+ return cloud_model
94
+
95
+ def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
96
+ # Load model
97
+ model = compiled_model(path, device, benchmark=True)
98
+
99
+ # Load data
100
+ probav = example_data(path)
101
+
102
+ # Run model
103
+ cloudprobs = model(probav.float().unsqueeze(0).to(device)).squeeze(0).cpu()
104
+
105
+ #Display results
106
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
107
+ ax[0].imshow(probav[[2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
108
+ ax[0].set_title("Input")
109
+ ax[1].imshow(cloudprobs[0].cpu().detach().numpy(), cmap="gray")
110
+ ax[1].set_title("Output")
111
+ for a in ax:
112
+ a.axis("off")
113
+ fig.tight_layout()
114
+ return fig
single/mlm.json ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.4.0/schema.json"
6
+ ],
7
+ "id": "SegFormerPlusMLP",
8
+ "geometry": {
9
+ "type": "Polygon",
10
+ "coordinates": [
11
+ [
12
+ [
13
+ -180.0,
14
+ -90.0
15
+ ],
16
+ [
17
+ -180.0,
18
+ 90.0
19
+ ],
20
+ [
21
+ 180.0,
22
+ 90.0
23
+ ],
24
+ [
25
+ 180.0,
26
+ -90.0
27
+ ],
28
+ [
29
+ -180.0,
30
+ -90.0
31
+ ]
32
+ ]
33
+ ]
34
+ },
35
+ "bbox": [
36
+ -180,
37
+ -90,
38
+ 180,
39
+ 90
40
+ ],
41
+ "properties": {
42
+ "start_datetime": "1900-01-01T00:00:00Z",
43
+ "end_datetime": "9999-01-01T00:00:00Z",
44
+ "description": "A two branch model for cloud detection in PROBA and SPOTVGT images. The first branch is a SegFormer model for semantic segmentation, and the second branch is a MLP model. The model is trained on PROBA-V and SPOTVGT images.",
45
+ "dependencies": [
46
+ "torch",
47
+ "safetensors.torch",
48
+ "semantic-segmentation-models-pytorch"
49
+ ],
50
+ "mlm:framework": "pytorch",
51
+ "mlm:framework_version": "2.1.2+cu121",
52
+ "file:size": 40455416,
53
+ "mlm:memory_size": 1,
54
+ "mlm:accelerator": "cuda",
55
+ "mlm:accelerator_constrained": false,
56
+ "mlm:accelerator_summary": "Unknown",
57
+ "mlm:name": "SegFormerPlusMLP",
58
+ "mlm:architecture": "SegFormer and MLP",
59
+ "mlm:tasks": [
60
+ "cloud detection"
61
+ ],
62
+ "mlm:input": [
63
+ {
64
+ "name": "ProbaVGT and SPOTVGT images",
65
+ "bands": [
66
+ "Blue[B0]",
67
+ "Red[B1]",
68
+ "Near-Infrared[B3]",
69
+ "SWIR[MIR]"
70
+ ],
71
+ "input": {
72
+ "shape": [
73
+ -1,
74
+ 4,
75
+ 128,
76
+ 128
77
+ ],
78
+ "dim_order": [
79
+ "batch",
80
+ "channel",
81
+ "height",
82
+ "width"
83
+ ],
84
+ "data_type": "float32"
85
+ },
86
+ "pre_processing_function": null
87
+ }
88
+ ],
89
+ "mlm:output": [
90
+ {
91
+ "name": "cloud mask",
92
+ "tasks": [
93
+ "cloud detection"
94
+ ],
95
+ "result": {
96
+ "shape": [
97
+ -1,
98
+ 1,
99
+ 128,
100
+ 128
101
+ ],
102
+ "dim_order": [
103
+ "batch",
104
+ "channel",
105
+ "height",
106
+ "width"
107
+ ],
108
+ "data_type": "uint8"
109
+ },
110
+ "classification:classes": [],
111
+ "post_processing_function": null
112
+ }
113
+ ],
114
+ "mlm:total_parameters": 12894526,
115
+ "mlm:pretrained": true,
116
+ "datetime": null
117
+ },
118
+ "links": [],
119
+ "assets": {
120
+ "model": {
121
+ "href": "https://huggingface.co/tacofoundation/PROBAandSPOT/resolve/main/single/model.safetensor",
122
+ "type": "application/octet-stream; application=safetensor",
123
+ "title": "Pytorch model weights checkpoint",
124
+ "description": "The weights of the model in safetensor format.",
125
+ "mlm:artifact_type": "safetensor.torch.save_file",
126
+ "roles": [
127
+ "mlm:model",
128
+ "mlm:weights",
129
+ "data"
130
+ ]
131
+ },
132
+ "source_code": {
133
+ "href": "https://huggingface.co/tacofoundation/PROBAandSPOT/resolve/main/single/load.py",
134
+ "type": "text/x-python",
135
+ "title": "Model load script",
136
+ "description": "Python script to load the model.",
137
+ "roles": [
138
+ "mlm:source_code",
139
+ "code"
140
+ ]
141
+ },
142
+ "example_data": {
143
+ "href": "https://huggingface.co/tacofoundation/PROBAandSPOT/resolve/main/single/example_data.safetensor",
144
+ "type": "application/octet-stream; application=safetensors",
145
+ "title": "Example Sentinel-2 image",
146
+ "description": "Example Sentinel-2 image for model inference.",
147
+ "roles": [
148
+ "mlm:example_data",
149
+ "data"
150
+ ]
151
+ }
152
+ },
153
+ "collection": "SegFormerPlusMLP"
154
+ }
single/model.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97d2f0c8ddf006459d66fcb9a1484e5a7b97b944899fd6a4947996693e3a4f19
3
+ size 11885088