WEN0256 commited on
Commit
32da9de
Β·
verified Β·
1 Parent(s): fdbfb40

Initial release: Segformer85Mv1 (Segformer-b5 fine-tuned, 8-class apple orchard)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ samples/sample_00_frame_2575.jpg filter=lfs diff=lfs merge=lfs -text
37
+ samples/sample_05_frame_3371.jpg filter=lfs diff=lfs merge=lfs -text
38
+ samples/sample_09_frame_4009.jpg filter=lfs diff=lfs merge=lfs -text
39
+ v6_OOD_full_res.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - semantic-segmentation
7
+ - segformer
8
+ - agriculture
9
+ - orchard
10
+ - apple
11
+ - outdoor
12
+ library_name: transformers
13
+ pipeline_tag: image-segmentation
14
+ base_model: nvidia/segformer-b5-finetuned-ade-640-640
15
+ ---
16
+
17
+ # Segformer85Mv1 β€” Apple Orchard Semantic Segmentation
18
+
19
+ A Segformer-B5 (85M parameters) fine-tuned for **8-class semantic segmentation** of outdoor apple orchard scenes captured from a robotic platform.
20
+
21
+ ## Quick Use
22
+
23
+ ```python
24
+ from huggingface_hub import hf_hub_download
25
+ from transformers import SegformerForSemanticSegmentation
26
+ import torch, cv2, numpy as np
27
+ import torch.nn.functional as F
28
+
29
+ # 1. Download weights
30
+ ckpt_path = hf_hub_download(repo_id="YOUR_USER/Segformer85Mv1", filename="Segformer85Mv1.pt")
31
+
32
+ # 2. Init architecture from base + load fine-tuned weights
33
+ NAMES = ["tree","ground","person","sky","road","mountain","building","background"]
34
+ model = SegformerForSemanticSegmentation.from_pretrained(
35
+ "nvidia/segformer-b5-finetuned-ade-640-640",
36
+ num_labels=8,
37
+ id2label={i:n for i,n in enumerate(NAMES)},
38
+ label2id={n:i for i,n in enumerate(NAMES)},
39
+ ignore_mismatched_sizes=True,
40
+ ).cuda().eval()
41
+ model.load_state_dict(torch.load(ckpt_path, map_location="cuda")["model"])
42
+
43
+ # 3. Inference
44
+ img = cv2.imread("your_image.jpg")
45
+ H, W = img.shape[:2]
46
+ H32, W32 = (H//32)*32, (W//32)*32
47
+ rgb = cv2.cvtColor(cv2.resize(img, (W32, H32)), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
48
+ mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
49
+ x = torch.from_numpy(((rgb - mean) / std).transpose(2,0,1)).unsqueeze(0).float().cuda()
50
+
51
+ with torch.no_grad():
52
+ logits = model(pixel_values=x).logits
53
+ logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
54
+ pred = logits.argmax(1)[0].cpu().numpy() # H x W, values 0..7
55
+ ```
56
+
57
+ A ready-to-use `predict.py` is included in this repo.
58
+
59
+ ## Classes (id β†’ name)
60
+
61
+ | ID | Class | Notes |
62
+ |----|-------------|--------------------------------------------------------|
63
+ | 0 | **tree** | Apple trees (priority class for downstream tasks) |
64
+ | 1 | ground | Grass / dirt / orchard floor |
65
+ | 2 | person | Workers in scene |
66
+ | 3 | sky | |
67
+ | 4 | road | Path between rows |
68
+ | 5 | mountain | Distant terrain (often confused with sky in fog) |
69
+ | 6 | building | Sheds, equipment shelters |
70
+ | 7 | background | Unknown / unlabeled regions (model output rare) |
71
+
72
+ ## Architecture & Preprocessing
73
+
74
+ | | |
75
+ |---|---|
76
+ | Base model | `nvidia/segformer-b5-finetuned-ade-640-640` |
77
+ | Parameters | ~85M |
78
+ | Decoder head | Reinitialized for 8 classes |
79
+ | Input format | RGB, normalized with ImageNet mean/std |
80
+ | `mean` | `[0.485, 0.456, 0.406]` |
81
+ | `std` | `[0.229, 0.224, 0.225]` |
82
+ | Input resolution | Any HΓ—W where both are multiples of 32 |
83
+ | Trained at | 1024Γ—576 (native 16:9) |
84
+ | Recommended inference | 1280Γ—704 or original native (snap to 32-multiple) |
85
+ | Precision | bfloat16 fine β€” model weights stored in fp32 |
86
+
87
+ ## Performance (NO data leakage)
88
+
89
+ Validated on a temporally-disjoint hold-out (frames 4501+ from training set):
90
+
91
+ | Metric | Value |
92
+ |---|---|
93
+ | **Tree IoU** | **0.742** |
94
+ | **mIoU (7 real classes)** | **0.714** |
95
+ | **Pixel accuracy** | **0.834** |
96
+
97
+ ### Per-class IoU
98
+ | Class | IoU | Precision | Recall |
99
+ |---|---|---|---|
100
+ | tree | 0.742 | 0.79 | 0.93 |
101
+ | ground | 0.851 | 0.91 | 0.93 |
102
+ | person | 0.719 | 0.82 | 0.85 |
103
+ | sky | 0.769 | 0.83 | 0.91 |
104
+ | road | 0.804 | 0.86 | 0.92 |
105
+ | mountain | 0.437 | 0.62 | 0.66 |
106
+ | building | 0.711 | 0.84 | 0.83 |
107
+
108
+ (Reported values from epoch 21 best-tree checkpoint on the no-leak validation split.)
109
+
110
+ ### OOD evaluation
111
+ On a completely held-out recording (1912 frames from `oak_0415_twoRadar_1`, never seen in training), mean prediction confidence is **0.939**, with model predicting `tree` on 41.8% of pixels and falling back to `background` on only 7.4% β€” indicating strong out-of-distribution generalization.
112
+
113
+ ## Training Data
114
+
115
+ - ~5300 frames from a single oak_0415_oneRadar_1 recording
116
+ - Initial annotations from 3 separate Roboflow projects (SAM-assisted polygons), merged + class-aligned (`vines`β†’`tree`, `moutain`β†’`mountain` typo fixed)
117
+ - Pseudo-labels generated by an earlier model to fill SAM annotation gaps
118
+ - Temporal split: frames `<=4500` train (5177 samples), frames `>4500` validation (155 samples) β€” **no neighbor leakage**
119
+
120
+ ## Training Recipe
121
+
122
+ | Hyperparameter | Value |
123
+ |---|---|
124
+ | Optimizer | AdamW, weight_decay 0.01 |
125
+ | LR | 2e-5, cosine schedule |
126
+ | Epochs | 30 |
127
+ | Batch | 2 Γ— grad_accum 4 (effective 8) |
128
+ | Resolution | 1024Γ—576 |
129
+ | Precision | bfloat16 |
130
+ | Loss | weighted cross-entropy |
131
+ | Class weights | tree 1.5, ground 0.5, person 1.5, sky 1.0, road 1.0, mountain 1.0, building 1.0, background 0.1 |
132
+ | Hardware | RTX 5090 (32 GB), ~2.3 hours |
133
+
134
+ ## Files in This Repo
135
+
136
+ | File | Purpose |
137
+ |---|---|
138
+ | `Segformer85Mv1.pt` | Fine-tuned weights (339 MB) |
139
+ | `predict.py` | Standalone inference script |
140
+ | `README.md` | This file |
141
+ | `samples/*.jpg` | Side-by-side prediction examples |
142
+ | `train_v6_5090.py` | Training script (for reproduction) |
143
+ | `history_v6.json` | Per-epoch training history |
144
+ | `v6_OOD_full_res.mp4` | 1-minute OOD inference video at native resolution |
145
+
146
+ ## License
147
+
148
+ Apache 2.0
Segformer85Mv1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8acb2e185d8a0fdef3002f65bb77ab935c8b825a745585d484913511ea2192ec
3
+ size 338890309
history_v6.json ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "epoch": 1,
4
+ "train_loss": 0.7317733424766079,
5
+ "val_loss": 0.3614382008329416,
6
+ "pixel_accuracy": 0.7693707866053427,
7
+ "mIoU_7": 0.6290054460148573,
8
+ "mIoU_8": 0.551595071231773,
9
+ "tree_iou": 0.6935091765839309,
10
+ "per_class_iou": {
11
+ "tree": 0.6935091765839309,
12
+ "ground": 0.8269576770586198,
13
+ "person": 0.6239303161783188,
14
+ "sky": 0.6214335001151623,
15
+ "road": 0.7967923018977053,
16
+ "mountain": 0.32553944888984093,
17
+ "building": 0.5148757013804227,
18
+ "background": 0.009722447750183178
19
+ }
20
+ },
21
+ {
22
+ "epoch": 2,
23
+ "train_loss": 0.33158363348054465,
24
+ "val_loss": 0.29247073370676774,
25
+ "pixel_accuracy": 0.7828593483107918,
26
+ "mIoU_7": 0.671867518383689,
27
+ "mIoU_8": 0.5965795049819951,
28
+ "tree_iou": 0.6851515450971696,
29
+ "per_class_iou": {
30
+ "tree": 0.6851515450971696,
31
+ "ground": 0.8329074795641042,
32
+ "person": 0.651328150567213,
33
+ "sky": 0.7052746613315422,
34
+ "road": 0.7775950133402935,
35
+ "mountain": 0.38359572918942386,
36
+ "building": 0.6672200495960767,
37
+ "background": 0.0695634111701382
38
+ }
39
+ },
40
+ {
41
+ "epoch": 3,
42
+ "train_loss": 0.2567576938229606,
43
+ "val_loss": 0.2792640005548795,
44
+ "pixel_accuracy": 0.7981670639420922,
45
+ "mIoU_7": 0.6746137194404379,
46
+ "mIoU_8": 0.610904000372088,
47
+ "tree_iou": 0.7087253188861061,
48
+ "per_class_iou": {
49
+ "tree": 0.7087253188861061,
50
+ "ground": 0.8361002659414856,
51
+ "person": 0.6559418314019541,
52
+ "sky": 0.7082608919602776,
53
+ "road": 0.7365729363277305,
54
+ "mountain": 0.3941387795068854,
55
+ "building": 0.6825560120586263,
56
+ "background": 0.16493596689363926
57
+ }
58
+ },
59
+ {
60
+ "epoch": 4,
61
+ "train_loss": 0.22328906943892032,
62
+ "val_loss": 0.2556124304731687,
63
+ "pixel_accuracy": 0.8095350696194556,
64
+ "mIoU_7": 0.6815957786979532,
65
+ "mIoU_8": 0.6245835393810028,
66
+ "tree_iou": 0.7157414904333625,
67
+ "per_class_iou": {
68
+ "tree": 0.7157414904333625,
69
+ "ground": 0.8427510979252235,
70
+ "person": 0.6691967663941234,
71
+ "sky": 0.7399725381481819,
72
+ "road": 0.7453247701300584,
73
+ "mountain": 0.3829273869357179,
74
+ "building": 0.6752564009190044,
75
+ "background": 0.2254978641623496
76
+ }
77
+ },
78
+ {
79
+ "epoch": 5,
80
+ "train_loss": 0.2025882395349984,
81
+ "val_loss": 0.2460598509090069,
82
+ "pixel_accuracy": 0.8063361492635529,
83
+ "mIoU_7": 0.6819090058037277,
84
+ "mIoU_8": 0.6238768515193147,
85
+ "tree_iou": 0.7115214580042588,
86
+ "per_class_iou": {
87
+ "tree": 0.7115214580042588,
88
+ "ground": 0.840159545948353,
89
+ "person": 0.6882897261490508,
90
+ "sky": 0.7281685473144833,
91
+ "road": 0.7760738885802295,
92
+ "mountain": 0.36856786582115475,
93
+ "building": 0.6605820088085645,
94
+ "background": 0.21765177152842294
95
+ }
96
+ },
97
+ {
98
+ "epoch": 6,
99
+ "train_loss": 0.18754121326082826,
100
+ "val_loss": 0.2658998461870047,
101
+ "pixel_accuracy": 0.8175638397107415,
102
+ "mIoU_7": 0.6910897251785011,
103
+ "mIoU_8": 0.6396066650450611,
104
+ "tree_iou": 0.7255735452977728,
105
+ "per_class_iou": {
106
+ "tree": 0.7255735452977728,
107
+ "ground": 0.8462776574160872,
108
+ "person": 0.6831329647948793,
109
+ "sky": 0.7372118420359767,
110
+ "road": 0.7377976663573189,
111
+ "mountain": 0.424413863299673,
112
+ "building": 0.6832205370477994,
113
+ "background": 0.27922524411098154
114
+ }
115
+ },
116
+ {
117
+ "epoch": 7,
118
+ "train_loss": 0.1783967717151968,
119
+ "val_loss": 0.24774512161429113,
120
+ "pixel_accuracy": 0.8165311970591118,
121
+ "mIoU_7": 0.6933117282517155,
122
+ "mIoU_8": 0.6395462936236171,
123
+ "tree_iou": 0.7274226749021537,
124
+ "per_class_iou": {
125
+ "tree": 0.7274226749021537,
126
+ "ground": 0.8418132457425621,
127
+ "person": 0.6861982683859594,
128
+ "sky": 0.7479193848975584,
129
+ "road": 0.7410099984549848,
130
+ "mountain": 0.4233301221532889,
131
+ "building": 0.6854884032255009,
132
+ "background": 0.2631882512269295
133
+ }
134
+ },
135
+ {
136
+ "epoch": 8,
137
+ "train_loss": 0.1686619697196245,
138
+ "val_loss": 0.253682183722655,
139
+ "pixel_accuracy": 0.8146241000049003,
140
+ "mIoU_7": 0.6906289318837233,
141
+ "mIoU_8": 0.635875803992791,
142
+ "tree_iou": 0.724371348212418,
143
+ "per_class_iou": {
144
+ "tree": 0.724371348212418,
145
+ "ground": 0.8416168509199589,
146
+ "person": 0.689128159543506,
147
+ "sky": 0.7491636164004881,
148
+ "road": 0.7221377666908946,
149
+ "mountain": 0.41500210217063455,
150
+ "building": 0.6929826792481624,
151
+ "background": 0.2526039087562656
152
+ }
153
+ },
154
+ {
155
+ "epoch": 9,
156
+ "train_loss": 0.1617005393484829,
157
+ "val_loss": 0.23389703856828886,
158
+ "pixel_accuracy": 0.8142122877114135,
159
+ "mIoU_7": 0.6982474481555915,
160
+ "mIoU_8": 0.6396787822355385,
161
+ "tree_iou": 0.7227781453845477,
162
+ "per_class_iou": {
163
+ "tree": 0.7227781453845477,
164
+ "ground": 0.8463292649711305,
165
+ "person": 0.6948543902963943,
166
+ "sky": 0.752873076350096,
167
+ "road": 0.7767962580351105,
168
+ "mountain": 0.4051208118881919,
169
+ "building": 0.6889801901636698,
170
+ "background": 0.22969812079516688
171
+ }
172
+ },
173
+ {
174
+ "epoch": 10,
175
+ "train_loss": 0.15628578751225605,
176
+ "val_loss": 0.247397372737909,
177
+ "pixel_accuracy": 0.8216093548737119,
178
+ "mIoU_7": 0.6951037061409514,
179
+ "mIoU_8": 0.6432917480279269,
180
+ "tree_iou": 0.7360921756649008,
181
+ "per_class_iou": {
182
+ "tree": 0.7360921756649008,
183
+ "ground": 0.8459931111438093,
184
+ "person": 0.6749524232501306,
185
+ "sky": 0.7544135951669563,
186
+ "road": 0.7459900655039325,
187
+ "mountain": 0.41176574357825835,
188
+ "building": 0.6965188286786725,
189
+ "background": 0.28060804123675503
190
+ }
191
+ },
192
+ {
193
+ "epoch": 11,
194
+ "train_loss": 0.15047259357868703,
195
+ "val_loss": 0.2422896781219886,
196
+ "pixel_accuracy": 0.8186814503003192,
197
+ "mIoU_7": 0.7032928273867227,
198
+ "mIoU_8": 0.6488940620062726,
199
+ "tree_iou": 0.7280322295553546,
200
+ "per_class_iou": {
201
+ "tree": 0.7280322295553546,
202
+ "ground": 0.8452403621614176,
203
+ "person": 0.6885070862388089,
204
+ "sky": 0.7378996317184476,
205
+ "road": 0.7874709202653031,
206
+ "mountain": 0.4391025863133733,
207
+ "building": 0.6967969754543536,
208
+ "background": 0.26810270434312317
209
+ }
210
+ },
211
+ {
212
+ "epoch": 12,
213
+ "train_loss": 0.14541167222608278,
214
+ "val_loss": 0.2657107286728345,
215
+ "pixel_accuracy": 0.8122323094303036,
216
+ "mIoU_7": 0.6944432923305518,
217
+ "mIoU_8": 0.6353890116604214,
218
+ "tree_iou": 0.7179542327564974,
219
+ "per_class_iou": {
220
+ "tree": 0.7179542327564974,
221
+ "ground": 0.8458071815929991,
222
+ "person": 0.6630505235819328,
223
+ "sky": 0.7576548123175906,
224
+ "road": 0.7707661859292342,
225
+ "mountain": 0.4066247100942614,
226
+ "building": 0.6992454000413479,
227
+ "background": 0.22200904696950824
228
+ }
229
+ },
230
+ {
231
+ "epoch": 13,
232
+ "train_loss": 0.14175742986817863,
233
+ "val_loss": 0.245059463649224,
234
+ "pixel_accuracy": 0.8285006506041387,
235
+ "mIoU_7": 0.709682438526708,
236
+ "mIoU_8": 0.6600046944565648,
237
+ "tree_iou": 0.737251946536209,
238
+ "per_class_iou": {
239
+ "tree": 0.737251946536209,
240
+ "ground": 0.8495676511625408,
241
+ "person": 0.673414491471947,
242
+ "sky": 0.7658854974430468,
243
+ "road": 0.8044619128651328,
244
+ "mountain": 0.4364369682032032,
245
+ "building": 0.7007586020048767,
246
+ "background": 0.31226048596556205
247
+ }
248
+ },
249
+ {
250
+ "epoch": 14,
251
+ "train_loss": 0.13955131033507437,
252
+ "val_loss": 0.24541893630073622,
253
+ "pixel_accuracy": 0.823048581359207,
254
+ "mIoU_7": 0.703872681203359,
255
+ "mIoU_8": 0.6523065270693985,
256
+ "tree_iou": 0.7334317964255183,
257
+ "per_class_iou": {
258
+ "tree": 0.7334317964255183,
259
+ "ground": 0.8472548995668966,
260
+ "person": 0.6849725597168292,
261
+ "sky": 0.7430631079001473,
262
+ "road": 0.7870281350545207,
263
+ "mountain": 0.42822861819540836,
264
+ "building": 0.7031296515641924,
265
+ "background": 0.2913434481316755
266
+ }
267
+ },
268
+ {
269
+ "epoch": 15,
270
+ "train_loss": 0.1340014274150041,
271
+ "val_loss": 0.2613945401822909,
272
+ "pixel_accuracy": 0.8256342077767977,
273
+ "mIoU_7": 0.6938176732869097,
274
+ "mIoU_8": 0.6465352082869225,
275
+ "tree_iou": 0.7348229701448902,
276
+ "per_class_iou": {
277
+ "tree": 0.7348229701448902,
278
+ "ground": 0.8485500320261311,
279
+ "person": 0.6985548375080566,
280
+ "sky": 0.7719071628391762,
281
+ "road": 0.6875423844689514,
282
+ "mountain": 0.4175557643370586,
283
+ "building": 0.697790561684104,
284
+ "background": 0.3155579532870117
285
+ }
286
+ },
287
+ {
288
+ "epoch": 16,
289
+ "train_loss": 0.13028158216666677,
290
+ "val_loss": 0.2640461309407002,
291
+ "pixel_accuracy": 0.8236345844884072,
292
+ "mIoU_7": 0.7017641778623044,
293
+ "mIoU_8": 0.6513057693665132,
294
+ "tree_iou": 0.729911002309811,
295
+ "per_class_iou": {
296
+ "tree": 0.729911002309811,
297
+ "ground": 0.8492910742408395,
298
+ "person": 0.6838423200192018,
299
+ "sky": 0.7641037715450866,
300
+ "road": 0.777327883240598,
301
+ "mountain": 0.41684152019417936,
302
+ "building": 0.6910316734864147,
303
+ "background": 0.29809690989597504
304
+ }
305
+ },
306
+ {
307
+ "epoch": 17,
308
+ "train_loss": 0.1271395583290152,
309
+ "val_loss": 0.27516031752412135,
310
+ "pixel_accuracy": 0.8244798229586694,
311
+ "mIoU_7": 0.7030521402715308,
312
+ "mIoU_8": 0.6534132836443587,
313
+ "tree_iou": 0.7358827999718607,
314
+ "per_class_iou": {
315
+ "tree": 0.7358827999718607,
316
+ "ground": 0.8468523343784162,
317
+ "person": 0.6691681595051834,
318
+ "sky": 0.7460565593685016,
319
+ "road": 0.7955708148842865,
320
+ "mountain": 0.42532427764901415,
321
+ "building": 0.7025100361434535,
322
+ "background": 0.3059412872541537
323
+ }
324
+ },
325
+ {
326
+ "epoch": 18,
327
+ "train_loss": 0.12417552829900011,
328
+ "val_loss": 0.26260221988344806,
329
+ "pixel_accuracy": 0.8236053794942876,
330
+ "mIoU_7": 0.7054233496573896,
331
+ "mIoU_8": 0.6541734279184714,
332
+ "tree_iou": 0.7336979074099337,
333
+ "per_class_iou": {
334
+ "tree": 0.7336979074099337,
335
+ "ground": 0.8459052222065085,
336
+ "person": 0.7072702433260717,
337
+ "sky": 0.7598311921866788,
338
+ "road": 0.7741989806072286,
339
+ "mountain": 0.4169527905509687,
340
+ "building": 0.7001071113143372,
341
+ "background": 0.29542397574604473
342
+ }
343
+ },
344
+ {
345
+ "epoch": 19,
346
+ "train_loss": 0.12182419967637549,
347
+ "val_loss": 0.27713373312965417,
348
+ "pixel_accuracy": 0.8270421947629648,
349
+ "mIoU_7": 0.7109813189537684,
350
+ "mIoU_8": 0.6606370811375022,
351
+ "tree_iou": 0.7323743055163786,
352
+ "per_class_iou": {
353
+ "tree": 0.7323743055163786,
354
+ "ground": 0.8498375951196969,
355
+ "person": 0.7007254650111793,
356
+ "sky": 0.7758277755445154,
357
+ "road": 0.7673058681573506,
358
+ "mountain": 0.44176907525611614,
359
+ "building": 0.7090291480711419,
360
+ "background": 0.3082274164236383
361
+ }
362
+ },
363
+ {
364
+ "epoch": 20,
365
+ "train_loss": 0.11663105687877909,
366
+ "val_loss": 0.2742788792611697,
367
+ "pixel_accuracy": 0.8236675740997423,
368
+ "mIoU_7": 0.705239040664458,
369
+ "mIoU_8": 0.6555626455120473,
370
+ "tree_iou": 0.7277708011216564,
371
+ "per_class_iou": {
372
+ "tree": 0.7277708011216564,
373
+ "ground": 0.8480108982494552,
374
+ "person": 0.6999480566951554,
375
+ "sky": 0.7583926266348064,
376
+ "road": 0.7733277370427203,
377
+ "mountain": 0.43148052924849667,
378
+ "building": 0.6977426356589147,
379
+ "background": 0.3078278794451741
380
+ }
381
+ },
382
+ {
383
+ "epoch": 21,
384
+ "train_loss": 0.11508058199955516,
385
+ "val_loss": 0.27676021279050755,
386
+ "pixel_accuracy": 0.8313707139756944,
387
+ "mIoU_7": 0.7077708636483279,
388
+ "mIoU_8": 0.6614555034088554,
389
+ "tree_iou": 0.7424294121477844,
390
+ "per_class_iou": {
391
+ "tree": 0.7424294121477844,
392
+ "ground": 0.8522079523217,
393
+ "person": 0.7032718147223185,
394
+ "sky": 0.7628166116510707,
395
+ "road": 0.7671078882037391,
396
+ "mountain": 0.4261684625639688,
397
+ "building": 0.700393903927713,
398
+ "background": 0.3372479817325494
399
+ }
400
+ },
401
+ {
402
+ "epoch": 22,
403
+ "train_loss": 0.11490547226906007,
404
+ "val_loss": 0.2687373280716248,
405
+ "pixel_accuracy": 0.8267316264490927,
406
+ "mIoU_7": 0.7110329608311641,
407
+ "mIoU_8": 0.6609831987414228,
408
+ "tree_iou": 0.7378164635727757,
409
+ "per_class_iou": {
410
+ "tree": 0.7378164635727757,
411
+ "ground": 0.8448131953009181,
412
+ "person": 0.706298048708027,
413
+ "sky": 0.7649988259226514,
414
+ "road": 0.768533202732033,
415
+ "mountain": 0.44301062115660517,
416
+ "building": 0.7117603684251378,
417
+ "background": 0.3106348641132343
418
+ }
419
+ },
420
+ {
421
+ "epoch": 23,
422
+ "train_loss": 0.1098061478228943,
423
+ "val_loss": 0.2768534162105658,
424
+ "pixel_accuracy": 0.8206308016213038,
425
+ "mIoU_7": 0.7051513810529587,
426
+ "mIoU_8": 0.6541264498668504,
427
+ "tree_iou": 0.7231827678494563,
428
+ "per_class_iou": {
429
+ "tree": 0.7231827678494563,
430
+ "ground": 0.8449066487884843,
431
+ "person": 0.7041216903044678,
432
+ "sky": 0.7592148050549444,
433
+ "road": 0.7857372699186778,
434
+ "mountain": 0.41639223489260957,
435
+ "building": 0.7025042505620707,
436
+ "background": 0.2969519315640925
437
+ }
438
+ },
439
+ {
440
+ "epoch": 24,
441
+ "train_loss": 0.10950512597927539,
442
+ "val_loss": 0.2952564288026247,
443
+ "pixel_accuracy": 0.8288865721781195,
444
+ "mIoU_7": 0.7106030798944316,
445
+ "mIoU_8": 0.665005005296281,
446
+ "tree_iou": 0.730085894441625,
447
+ "per_class_iou": {
448
+ "tree": 0.730085894441625,
449
+ "ground": 0.8504673603873821,
450
+ "person": 0.6969281633100608,
451
+ "sky": 0.7696089528996574,
452
+ "road": 0.7867176272154283,
453
+ "mountain": 0.4279298391265215,
454
+ "building": 0.7124837218803463,
455
+ "background": 0.3458184831092274
456
+ }
457
+ },
458
+ {
459
+ "epoch": 25,
460
+ "train_loss": 0.10687056659552423,
461
+ "val_loss": 0.27761973994664657,
462
+ "pixel_accuracy": 0.8238253138825885,
463
+ "mIoU_7": 0.7077265497389625,
464
+ "mIoU_8": 0.6572216095334821,
465
+ "tree_iou": 0.7262833333589551,
466
+ "per_class_iou": {
467
+ "tree": 0.7262833333589551,
468
+ "ground": 0.8512492951444962,
469
+ "person": 0.7094425059787509,
470
+ "sky": 0.7588241111321905,
471
+ "road": 0.7839445591226314,
472
+ "mountain": 0.41280223596732046,
473
+ "building": 0.7115398074683938,
474
+ "background": 0.30368702809511827
475
+ }
476
+ },
477
+ {
478
+ "epoch": 26,
479
+ "train_loss": 0.10538023605171182,
480
+ "val_loss": 0.29933813543846977,
481
+ "pixel_accuracy": 0.823246825296819,
482
+ "mIoU_7": 0.704192355270756,
483
+ "mIoU_8": 0.6550478112982461,
484
+ "tree_iou": 0.7267114852570552,
485
+ "per_class_iou": {
486
+ "tree": 0.7267114852570552,
487
+ "ground": 0.8469380423129873,
488
+ "person": 0.7013946479260358,
489
+ "sky": 0.7600105651077275,
490
+ "road": 0.7588123322822986,
491
+ "mountain": 0.43047006821479455,
492
+ "building": 0.7050093457943926,
493
+ "background": 0.311036003490677
494
+ }
495
+ },
496
+ {
497
+ "epoch": 27,
498
+ "train_loss": 0.10386164612072006,
499
+ "val_loss": 0.3018250732849806,
500
+ "pixel_accuracy": 0.8294269083221326,
501
+ "mIoU_7": 0.7132434231102097,
502
+ "mIoU_8": 0.6670003332357785,
503
+ "tree_iou": 0.7363907878200537,
504
+ "per_class_iou": {
505
+ "tree": 0.7363907878200537,
506
+ "ground": 0.8504632417020571,
507
+ "person": 0.7128458671987324,
508
+ "sky": 0.7624585962209668,
509
+ "road": 0.7741146775923515,
510
+ "mountain": 0.4371976195380884,
511
+ "building": 0.7192331716992189,
512
+ "background": 0.3432987041147584
513
+ }
514
+ },
515
+ {
516
+ "epoch": 28,
517
+ "train_loss": 0.10010939652540676,
518
+ "val_loss": 0.3033689678861545,
519
+ "pixel_accuracy": 0.8270183713632673,
520
+ "mIoU_7": 0.7110530306755144,
521
+ "mIoU_8": 0.6650936852452135,
522
+ "tree_iou": 0.7283072811263064,
523
+ "per_class_iou": {
524
+ "tree": 0.7283072811263064,
525
+ "ground": 0.8456942664597132,
526
+ "person": 0.6923932394957782,
527
+ "sky": 0.7610996834802041,
528
+ "road": 0.7837289807636174,
529
+ "mountain": 0.4502767990300785,
530
+ "building": 0.7158709643729031,
531
+ "background": 0.34337826723310705
532
+ }
533
+ },
534
+ {
535
+ "epoch": 29,
536
+ "train_loss": 0.09945478258795541,
537
+ "val_loss": 0.3067954242802583,
538
+ "pixel_accuracy": 0.8338931613498264,
539
+ "mIoU_7": 0.7143796374000682,
540
+ "mIoU_8": 0.6711484387659205,
541
+ "tree_iou": 0.7407570830601974,
542
+ "per_class_iou": {
543
+ "tree": 0.7407570830601974,
544
+ "ground": 0.851453401283257,
545
+ "person": 0.7188091768041799,
546
+ "sky": 0.7689343025621244,
547
+ "road": 0.7752271933914192,
548
+ "mountain": 0.43486270051421433,
549
+ "building": 0.710613604185085,
550
+ "background": 0.3685300483268865
551
+ }
552
+ },
553
+ {
554
+ "epoch": 30,
555
+ "train_loss": 0.09672546130497084,
556
+ "val_loss": 0.32073069191896,
557
+ "pixel_accuracy": 0.8235090577046931,
558
+ "mIoU_7": 0.7008779030646494,
559
+ "mIoU_8": 0.655579892357319,
560
+ "tree_iou": 0.7274351359731892,
561
+ "per_class_iou": {
562
+ "tree": 0.7274351359731892,
563
+ "ground": 0.844361516651275,
564
+ "person": 0.6928363707324503,
565
+ "sky": 0.7461845320594144,
566
+ "road": 0.7788302132855058,
567
+ "mountain": 0.40112782563995303,
568
+ "building": 0.7153697271107579,
569
+ "background": 0.3384938174060061
570
+ }
571
+ }
572
+ ]
predict.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Segformer85Mv1 β€” apple-orchard semantic segmentation inference.
2
+
3
+ Usage:
4
+ python predict.py input.jpg # writes input_pred.png + input_overlay.jpg
5
+ python predict.py --dir frames/ --out out/ # batch process a folder
6
+
7
+ Classes (id β†’ name):
8
+ 0 tree 1 ground 2 person 3 sky
9
+ 4 road 5 mountain 6 building 7 background
10
+ """
11
+ from __future__ import annotations
12
+ import argparse
13
+ import os
14
+ from pathlib import Path
15
+
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import SegformerForSemanticSegmentation
21
+
22
+ # ─── config ───
23
+ BASE_MODEL = "nvidia/segformer-b5-finetuned-ade-640-640"
24
+ WEIGHTS_PATH = os.environ.get("SEGFORMER85MV1_WEIGHTS", "Segformer85Mv1.pt") # local file or full path
25
+ NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"]
26
+ PALETTE = np.array([
27
+ [60, 220, 60], # tree - green
28
+ [40, 100, 160], # ground - brown
29
+ [40, 40, 230], # person - red
30
+ [230, 200, 60], # sky - cyan
31
+ [140, 140, 140], # road - gray
32
+ [180, 60, 180], # mountain - purple
33
+ [50, 220, 220], # building - yellow
34
+ [100, 100, 100], # background - mid-gray
35
+ ], dtype=np.uint8)
36
+ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
37
+ IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
38
+
39
+
40
+ def load_model(weights_path: str | Path = WEIGHTS_PATH, device: str = "cuda"):
41
+ """Load Segformer85Mv1. Returns model in eval mode on the target device."""
42
+ model = SegformerForSemanticSegmentation.from_pretrained(
43
+ BASE_MODEL,
44
+ num_labels=len(NAMES),
45
+ id2label={i: n for i, n in enumerate(NAMES)},
46
+ label2id={n: i for i, n in enumerate(NAMES)},
47
+ ignore_mismatched_sizes=True,
48
+ ).to(device)
49
+ ckpt = torch.load(weights_path, map_location=device, weights_only=False)
50
+ state = ckpt["model"] if "model" in ckpt else ckpt
51
+ model.load_state_dict(state)
52
+ model.eval()
53
+ return model
54
+
55
+
56
+ def preprocess(bgr_img: np.ndarray) -> tuple[torch.Tensor, tuple[int, int]]:
57
+ """BGR uint8 image β†’ normalized tensor sized to 32 multiples; returns (tensor, original (H,W))."""
58
+ H, W = bgr_img.shape[:2]
59
+ H32, W32 = (H // 32) * 32, (W // 32) * 32
60
+ if H32 == 0 or W32 == 0:
61
+ raise ValueError(f"Image too small: {W}x{H}")
62
+ rgb = cv2.cvtColor(cv2.resize(bgr_img, (W32, H32)), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
63
+ rgb = (rgb - IMAGENET_MEAN) / IMAGENET_STD
64
+ x = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0).float()
65
+ return x, (H, W)
66
+
67
+
68
+ def predict(model, bgr_img: np.ndarray, device: str = "cuda") -> np.ndarray:
69
+ """Run inference on one BGR image. Returns (H,W) uint8 mask with class ids 0..7."""
70
+ x, (H, W) = preprocess(bgr_img)
71
+ x = x.to(device)
72
+ with torch.no_grad():
73
+ logits = model(pixel_values=x).logits
74
+ logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
75
+ return logits.argmax(1)[0].cpu().numpy().astype(np.uint8)
76
+
77
+
78
+ def colorize(mask: np.ndarray) -> np.ndarray:
79
+ """class-id mask (H,W) β†’ BGR color visualization (H,W,3)."""
80
+ return PALETTE[mask]
81
+
82
+
83
+ def overlay(bgr_img: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray:
84
+ """Blend prediction over original image."""
85
+ return cv2.addWeighted(bgr_img, 1 - alpha, colorize(mask), alpha, 0)
86
+
87
+
88
+ def main():
89
+ ap = argparse.ArgumentParser(description="Segformer85Mv1 inference (8-class outdoor segmentation).")
90
+ ap.add_argument("input", nargs="?", help="Single image path")
91
+ ap.add_argument("--dir", help="Directory of images to process")
92
+ ap.add_argument("--out", default=".", help="Output directory")
93
+ ap.add_argument("--weights", default=WEIGHTS_PATH, help="Path to Segformer85Mv1.pt")
94
+ ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
95
+ args = ap.parse_args()
96
+
97
+ if not args.input and not args.dir:
98
+ ap.print_help()
99
+ return
100
+
101
+ print(f"loading model from {args.weights} on {args.device} ...")
102
+ model = load_model(args.weights, device=args.device)
103
+ out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)
104
+
105
+ paths = []
106
+ if args.dir:
107
+ paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp"})
108
+ if args.input:
109
+ paths.append(Path(args.input))
110
+
111
+ for p in paths:
112
+ img = cv2.imread(str(p))
113
+ if img is None:
114
+ print(f" skip (unreadable): {p}")
115
+ continue
116
+ mask = predict(model, img, device=args.device)
117
+ cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) # raw class-id mask
118
+ cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay(img, mask)) # visualization
119
+ # quick stats
120
+ counts = np.bincount(mask.flatten(), minlength=len(NAMES))
121
+ top = counts.argmax()
122
+ print(f" {p.name:<40} top class: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)")
123
+
124
+ print(f"\noutputs -> {out_dir.resolve()}")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
samples/sample_00_frame_2575.jpg ADDED

Git LFS Details

  • SHA256: 6b5b4fa73c3be9450d1cb2ab2dc658a87833308e4adc172ba087a012177d5196
  • Pointer size: 131 Bytes
  • Size of remote file: 576 kB
samples/sample_05_frame_3371.jpg ADDED

Git LFS Details

  • SHA256: 4dc36e08d9b76a7f2d76a1b28fd8f6a880a00fd5a9b31f37114b4b682d1b77a9
  • Pointer size: 131 Bytes
  • Size of remote file: 656 kB
samples/sample_09_frame_4009.jpg ADDED

Git LFS Details

  • SHA256: b9b1d27f602db898962a224a955209d03c949296b02822491df8c5e6a205d135
  • Pointer size: 131 Bytes
  • Size of remote file: 601 kB
train_v6_5090.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """V6 β€” final, all-problems-fixed training on RTX 5090.
2
+
3
+ Fixes vs v4 (the leaky 0.78 mIoU):
4
+ 1. TEMPORAL split (frame_<=4500 train, frame_>4500 val) β€” zero neighbor leakage
5
+ 2. Native 1280x704 input (16:9, no padding, no resizing artifacts)
6
+ 3. Segformer-b5 (85M params, 4x v4's b2 capacity)
7
+ 4. batch 4 + BF16 (saturates 5090's 32GB VRAM)
8
+ 5. Global confusion-matrix IoU (not per-batch noisy averages)
9
+ 6. Pseudo-labels (carry over - they were generated by v4 on full images)
10
+ """
11
+ from __future__ import annotations
12
+ import json, re, time
13
+ from pathlib import Path
14
+ import numpy as np, cv2, torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.amp import GradScaler, autocast
19
+ import albumentations as A
20
+ from transformers import SegformerForSemanticSegmentation
21
+
22
+ # ───────────── config ─────────────
23
+ ROOT = Path("/workspace/agmotree/dataset")
24
+ IMG_DIR = ROOT / "train/images"
25
+ MSK_DIR = ROOT / "train/masks_pseudo"
26
+ OUT_DIR = Path("/workspace/agmotree/v6_output")
27
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
28
+
29
+ MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640"
30
+ NUM_CLASSES = 8
31
+ NAMES = ["tree", "ground", "person", "sky", "road", "mountain", "building", "background"]
32
+
33
+ IMG_W = 1024
34
+ IMG_H = 576 # 32 multiple closest to native 720
35
+ BATCH_SIZE = 2
36
+ GRAD_ACCUM = 4
37
+ EPOCHS = 30
38
+ LR = 2e-5
39
+ WEIGHT_DECAY = 1e-2
40
+ NUM_WORKERS = 8
41
+ SEED = 42
42
+ DEVICE = "cuda"
43
+ SPLIT_FRAME = 4500 # frames<=4500 β†’ train, >4500 β†’ val (NO LEAK)
44
+
45
+ # Hand-tuned class weights (proven in v4 - prevents collapse)
46
+ WEIGHTS = np.array([
47
+ 1.5, # tree - priority class
48
+ 0.5, # ground - very common
49
+ 1.5, # person
50
+ 1.0, # sky
51
+ 1.0, # road
52
+ 1.0, # mountain
53
+ 1.0, # building
54
+ 0.1, # background - low but trainable
55
+ ])
56
+
57
+ print(f"=== V6 / RTX 5090 / NO LEAK ===")
58
+ print(f" model: {MODEL_NAME}")
59
+ print(f" input: {IMG_W}x{IMG_H} (native 16:9)")
60
+ print(f" batch: {BATCH_SIZE} x grad_accum {GRAD_ACCUM} = effective {BATCH_SIZE*GRAD_ACCUM}")
61
+ print(f" LR: {LR}, epochs: {EPOCHS}")
62
+ print(f" TEMPORAL split: train frame<={SPLIT_FRAME}, val frame>{SPLIT_FRAME}")
63
+
64
+
65
+ # ───────────── data ─────────────
66
+ def frame_num(p: Path) -> int:
67
+ m = re.match(r"frame_(\d+)", p.stem)
68
+ return int(m.group(1)) if m else -1
69
+
70
+ all_imgs = sorted(IMG_DIR.glob("*.jpg"))
71
+ train_imgs = [p for p in all_imgs if frame_num(p) <= SPLIT_FRAME]
72
+ val_imgs = [p for p in all_imgs if frame_num(p) > SPLIT_FRAME]
73
+ train_nums = set(frame_num(p) for p in train_imgs)
74
+ val_nums = set(frame_num(p) for p in val_imgs)
75
+ print(f" train: {len(train_imgs)} files, frames {min(train_nums)}-{max(train_nums)}")
76
+ print(f" val: {len(val_imgs)} files, frames {min(val_nums)}-{max(val_nums)}")
77
+ print(f" overlap (must be 0): {len(train_nums & val_nums)}")
78
+ assert len(train_nums & val_nums) == 0
79
+
80
+ train_tf = A.Compose([
81
+ A.Resize(IMG_H, IMG_W),
82
+ A.HorizontalFlip(p=0.5),
83
+ A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
84
+ A.HueSaturationValue(10, 15, 10, p=0.3),
85
+ A.GaussianBlur(blur_limit=(3, 5), p=0.2),
86
+ A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
87
+ ])
88
+ val_tf = A.Compose([
89
+ A.Resize(IMG_H, IMG_W),
90
+ A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
91
+ ])
92
+
93
+ class SegDataset(Dataset):
94
+ def __init__(self, paths, tf):
95
+ self.paths = paths; self.tf = tf
96
+ def __len__(self): return len(self.paths)
97
+ def __getitem__(self, i):
98
+ ip = self.paths[i]
99
+ img = cv2.imread(str(ip)); img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
100
+ msk = cv2.imread(str(MSK_DIR / (ip.stem + ".png")), cv2.IMREAD_GRAYSCALE)
101
+ out = self.tf(image=img, mask=msk)
102
+ return (torch.from_numpy(out["image"]).permute(2,0,1).float(),
103
+ torch.from_numpy(out["mask"]).long())
104
+
105
+
106
+ # ───────────── train ─────────────
107
+ log_path = OUT_DIR / "training_log_v6.txt"
108
+
109
+ def log(msg):
110
+ print(msg, flush=True)
111
+ with log_path.open("a", encoding="utf-8") as f:
112
+ f.write(msg + "\n")
113
+
114
+ def compute_iou_global(cm):
115
+ n = cm.shape[0]; ious = np.zeros(n)
116
+ for c in range(n):
117
+ tp = cm[c,c]; fp = cm[:,c].sum()-tp; fn = cm[c,:].sum()-tp
118
+ ious[c] = tp/(tp+fp+fn) if (tp+fp+fn) > 0 else float("nan")
119
+ return ious
120
+
121
+ def main():
122
+ log_path.write_text("")
123
+ train_ds = SegDataset(train_imgs, train_tf)
124
+ val_ds = SegDataset(val_imgs, val_tf)
125
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
126
+ num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,
127
+ persistent_workers=True)
128
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
129
+ num_workers=NUM_WORKERS, pin_memory=True,
130
+ persistent_workers=True)
131
+
132
+ log(f"=== V6 / RTX 5090 / NO LEAK ===")
133
+ log(f"input: {IMG_W}x{IMG_H} batch: {BATCH_SIZE}x{GRAD_ACCUM} LR: {LR}")
134
+ log(f"split: TEMPORAL train={len(train_imgs)} val={len(val_imgs)} no overlap")
135
+ log(f"loading {MODEL_NAME} ...")
136
+ model = SegformerForSemanticSegmentation.from_pretrained(
137
+ MODEL_NAME, num_labels=NUM_CLASSES,
138
+ id2label={i:n for i,n in enumerate(NAMES)},
139
+ label2id={n:i for i,n in enumerate(NAMES)},
140
+ ignore_mismatched_sizes=True,
141
+ ).to(DEVICE)
142
+ log(f" params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
143
+
144
+ cw = torch.tensor(WEIGHTS, dtype=torch.float32, device=DEVICE)
145
+ loss_fn = nn.CrossEntropyLoss(weight=cw)
146
+ optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
147
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=EPOCHS*len(train_loader))
148
+ # BF16 doesn't need GradScaler, but we keep it for safety/compat
149
+ scaler = GradScaler("cuda")
150
+
151
+ log(f"device: {torch.cuda.get_device_name(0)} vram: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
152
+ log(f"train batches: {len(train_loader)} val batches: {len(val_loader)}")
153
+
154
+ best_tree_iou = -1.0
155
+ best_miou = -1.0
156
+ history = []
157
+
158
+ for epoch in range(1, EPOCHS+1):
159
+ model.train()
160
+ t0 = time.time()
161
+ epoch_loss = 0.0
162
+ optim.zero_grad()
163
+ for step,(x,y) in enumerate(train_loader):
164
+ x = x.to(DEVICE, non_blocking=True); y = y.to(DEVICE, non_blocking=True)
165
+ with autocast("cuda", dtype=torch.bfloat16):
166
+ out = model(pixel_values=x)
167
+ logits = F.interpolate(out.logits, size=y.shape[-2:], mode="bilinear", align_corners=False)
168
+ loss = loss_fn(logits, y) / GRAD_ACCUM
169
+ loss.backward()
170
+ if (step+1) % GRAD_ACCUM == 0:
171
+ optim.step(); optim.zero_grad(); sched.step()
172
+ epoch_loss += loss.item() * GRAD_ACCUM
173
+ train_loss = epoch_loss / len(train_loader)
174
+
175
+ model.eval()
176
+ cm = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)
177
+ val_loss = 0.0
178
+ with torch.no_grad():
179
+ for x,y in val_loader:
180
+ x = x.to(DEVICE, non_blocking=True); y = y.to(DEVICE, non_blocking=True)
181
+ with autocast("cuda", dtype=torch.bfloat16):
182
+ out = model(pixel_values=x)
183
+ logits = F.interpolate(out.logits, size=y.shape[-2:], mode="bilinear", align_corners=False)
184
+ val_loss += loss_fn(logits, y).item()
185
+ preds = logits.argmax(1).cpu().numpy()
186
+ ys = y.cpu().numpy()
187
+ for tc in range(NUM_CLASSES):
188
+ mt = (ys == tc)
189
+ if not mt.any(): continue
190
+ for pc in range(NUM_CLASSES):
191
+ cm[tc, pc] += int(((preds == pc) & mt).sum())
192
+ val_loss /= max(1, len(val_loader))
193
+ per_iou = compute_iou_global(cm)
194
+ miou_7 = float(np.nanmean(per_iou[:7]))
195
+ miou_8 = float(np.nanmean(per_iou))
196
+ tree_iou = float(per_iou[0])
197
+ pix_acc = float(np.diag(cm).sum() / cm.sum())
198
+
199
+ elapsed = time.time() - t0
200
+ log(f"epoch {epoch:02d}/{EPOCHS} tloss={train_loss:.4f} vloss={val_loss:.4f} "
201
+ f"pix_acc={pix_acc:.3f} mIoU(7)={miou_7:.3f} tree={tree_iou:.3f} ({elapsed:.0f}s)")
202
+ log(" per-class IoU: " + ", ".join(f"{n}={v:.3f}" for n,v in zip(NAMES, per_iou)))
203
+
204
+ history.append({
205
+ "epoch": epoch, "train_loss": float(train_loss), "val_loss": float(val_loss),
206
+ "pixel_accuracy": pix_acc, "mIoU_7": miou_7, "mIoU_8": miou_8, "tree_iou": tree_iou,
207
+ "per_class_iou": {n: float(v) for n, v in zip(NAMES, per_iou)},
208
+ })
209
+
210
+ torch.save({"model": model.state_dict(), "epoch": epoch, "miou_7": miou_7, "tree_iou": tree_iou},
211
+ OUT_DIR / "v6_last.pt")
212
+ if tree_iou > best_tree_iou:
213
+ best_tree_iou = tree_iou
214
+ torch.save({"model": model.state_dict(), "epoch": epoch, "miou_7": miou_7, "tree_iou": tree_iou},
215
+ OUT_DIR / "v6_best_tree.pt")
216
+ log(f" saved v6_best_tree.pt (tree IoU {tree_iou:.3f})")
217
+ if miou_7 > best_miou:
218
+ best_miou = miou_7
219
+ torch.save({"model": model.state_dict(), "epoch": epoch, "miou_7": miou_7, "tree_iou": tree_iou},
220
+ OUT_DIR / "v6_best_miou.pt")
221
+
222
+ (OUT_DIR / "history_v6.json").write_text(json.dumps(history, indent=2))
223
+
224
+ log(f"\n=== DONE ===")
225
+ log(f"best tree IoU (NO LEAK): {best_tree_iou:.3f}")
226
+ log(f"best mIoU(7) (NO LEAK): {best_miou:.3f}")
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()
v6_OOD_full_res.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec19068d5aaddf3d1afcf18b7f35a355bbe3ef3bfab59a18080253210373442f
3
+ size 246499502