YoshimuraHiroto commited on
Commit
33bee55
·
verified ·
1 Parent(s): 4e3cc0e

Upload finetune_large_geolifeclef.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. finetune_large_geolifeclef.py +264 -0
finetune_large_geolifeclef.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finetune BFM Large model on GeoLifeCLEF 500 species data.
3
+
4
+ Based on the paper's pipeline (arXiv:2507.09080v2):
5
+ 1. Load BFM Large pretrained model from safetensors
6
+ 2. Wrap with BFMRaw (replace encoder/decoder with species-specific ones)
7
+ 3. Train with L1 loss + AdamW + CosineAnnealing on GeoLifeCLEF data
8
+ 4. Save finetuned checkpoint
9
+
10
+ Usage:
11
+ conda run -n bfm python finetune_large_geolifeclef.py 2>&1 | tee finetune_large.log
12
+ """
13
+
14
+ import math
15
+ import os
16
+ import sys
17
+ import time
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.utils.data import DataLoader
23
+
24
+ # Add bfm-model to path
25
+ PROJECT_ROOT = Path(__file__).resolve().parent
26
+ BFM_MODEL_DIR = PROJECT_ROOT / "bfm-model"
27
+ sys.path.insert(0, str(BFM_MODEL_DIR))
28
+
29
+ SAFETENSORS_PATH = PROJECT_ROOT / "bfm-pretrained" / "bfm-pretrain-large.safetensors"
30
+ OUTPUT_DIR = PROJECT_ROOT / "outputs_finetune_large"
31
+ CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
32
+
33
+ # ─── Training hyperparameters (from finetune_config.yaml) ───
34
+ NUM_SPECIES = 500
35
+ BATCH_SIZE = 1
36
+ NUM_EPOCHS = 100
37
+ LEARNING_RATE = 3e-4
38
+ VAL_EVERY = 5
39
+ NUM_WORKERS = 8
40
+
41
+ # ─── Model config (Large) ───
42
+ MODEL_CONFIG = {
43
+ "embed_dim": 512, "depth": 10, "patch_size": 8,
44
+ "swin_backbone_size": "large", "num_heads": 16, "head_dim": 64,
45
+ "H": 160, "W": 280, "num_latent_tokens": 8,
46
+ "perceiver_latents": 16100, "T": 2,
47
+ }
48
+
49
+ SWIN_LARGE_CONFIG = {
50
+ "swin_encoder_depths": (2, 2, 2), "swin_encoder_num_heads": (8, 16, 32),
51
+ "swin_decoder_depths": (2, 2, 2), "swin_decoder_num_heads": (32, 16, 8),
52
+ "swin_window_size": (1, 4, 5), "swin_mlp_ratio": 4.0,
53
+ "swin_qkv_bias": True, "swin_drop_rate": 0.0,
54
+ "swin_attn_drop_rate": 0.0, "swin_drop_path_rate": 0.1,
55
+ "use_lora": False,
56
+ }
57
+
58
+ SPECIES_VARS = [
59
+ "1340361", "1340503", "1536449", "1898286", "1920506", "2430567",
60
+ "2431885", "2433433", "2434779", "2435240", "2435261", "2437394",
61
+ "2441454", "2473958", "2491534", "2891770", "3034825", "4408498",
62
+ "5218786", "5219073", "5219173", "5219219", "5844449", "8002952",
63
+ "8077224", "8894817", "8909809", "9809229",
64
+ ]
65
+ SURFACE_VARS = ["t2m", "msl", "slt", "z", "u10", "v10", "lsm"]
66
+ EDAPHIC_VARS = ["swvl1", "swvl2", "stl1", "stl2"]
67
+ ATMOS_VARS = ["z", "t", "u", "v", "q"]
68
+ CLIMATE_VARS = [
69
+ "smlt", "tp", "csfr", "avg_sdswrf", "avg_snswrf", "avg_snlwrf",
70
+ "avg_tprate", "avg_sdswrfcs", "sd", "t2m", "d2m",
71
+ ]
72
+ VEGETATION_VARS = ["NDVI"]
73
+ LAND_VARS = ["Land"]
74
+ AGRICULTURE_VARS = ["Agriculture", "Arable", "Cropland"]
75
+ FOREST_VARS = ["Forest"]
76
+ REDLIST_VARS = ["RLI"]
77
+ MISC_VARS = ["avg_slhtf", "avg_pevr"]
78
+ ATMOS_LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
79
+
80
+
81
+ def build_base_model():
82
+ """Build BFM Large model."""
83
+ from bfm_model.bfm.model import BFM
84
+ model = BFM(
85
+ surface_vars=SURFACE_VARS, edaphic_vars=EDAPHIC_VARS,
86
+ atmos_vars=ATMOS_VARS, climate_vars=CLIMATE_VARS,
87
+ species_vars=SPECIES_VARS, vegetation_vars=VEGETATION_VARS,
88
+ land_vars=LAND_VARS, agriculture_vars=AGRICULTURE_VARS,
89
+ forest_vars=FOREST_VARS, redlist_vars=REDLIST_VARS,
90
+ misc_vars=MISC_VARS, atmos_levels=ATMOS_LEVELS,
91
+ species_num=len(SPECIES_VARS),
92
+ H=MODEL_CONFIG["H"], W=MODEL_CONFIG["W"],
93
+ num_latent_tokens=MODEL_CONFIG["num_latent_tokens"],
94
+ backbone_type="swin", patch_size=MODEL_CONFIG["patch_size"],
95
+ embed_dim=MODEL_CONFIG["embed_dim"],
96
+ num_heads=MODEL_CONFIG["num_heads"],
97
+ head_dim=MODEL_CONFIG["head_dim"],
98
+ depth=MODEL_CONFIG["depth"],
99
+ perceiver_latents=MODEL_CONFIG["perceiver_latents"],
100
+ batch_size=1, td_learning=True, use_mask="no",
101
+ **SWIN_LARGE_CONFIG,
102
+ )
103
+ return model
104
+
105
+
106
+ def load_pretrained_weights(model):
107
+ """Load pretrained safetensors weights into base model."""
108
+ from safetensors.torch import load_file
109
+ print(f"Loading weights from {SAFETENSORS_PATH.name}...")
110
+ state = load_file(str(SAFETENSORS_PATH), device="cpu")
111
+ missing, unexpected = model.load_state_dict(state, strict=False)
112
+ alias_missing = [k for k in missing if "_latent_parameter_list" in k]
113
+ real_missing = [k for k in missing if "_latent_parameter_list" not in k]
114
+ print(f" Total missing: {len(missing)} ({len(alias_missing)} aliases)")
115
+ print(f" Real missing: {len(real_missing)}")
116
+ print(f" Unexpected: {len(unexpected)}")
117
+ return model
118
+
119
+
120
+ def train_epoch(model, dataloader, optimizer, criterion, scheduler, device):
121
+ """One training epoch."""
122
+ model.train()
123
+ epoch_loss = 0.0
124
+ for batch_idx, sample in enumerate(dataloader):
125
+ batch = sample["batch"]
126
+ batch["species_distribution"] = batch["species_distribution"].to(device)
127
+ targets = sample["target"].to(device)
128
+ optimizer.zero_grad()
129
+ outputs = model(batch)
130
+ loss = criterion(outputs, targets)
131
+ loss.backward()
132
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
133
+ optimizer.step()
134
+ scheduler.step()
135
+ epoch_loss += loss.item()
136
+ if (batch_idx + 1) % 1 == 0:
137
+ print(f" Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.6f}")
138
+ return epoch_loss / max(len(dataloader), 1)
139
+
140
+
141
+ def validate_epoch(model, dataloader, criterion, device):
142
+ """One validation epoch."""
143
+ model.eval()
144
+ epoch_loss = 0.0
145
+ with torch.inference_mode():
146
+ for sample in dataloader:
147
+ batch = sample["batch"]
148
+ batch["species_distribution"] = batch["species_distribution"].to(device)
149
+ targets = sample["target"].to(device)
150
+ outputs = model(batch)
151
+ loss = criterion(outputs, targets)
152
+ epoch_loss += loss.item()
153
+ return epoch_loss / max(len(dataloader), 1)
154
+
155
+
156
+ def save_checkpoint(model, optimizer, epoch, loss, path):
157
+ """Save training checkpoint."""
158
+ os.makedirs(path, exist_ok=True)
159
+ filepath = path / "best_checkpoint.pth"
160
+ torch.save({
161
+ "epoch": epoch,
162
+ "model_state_dict": model.state_dict(),
163
+ "optimizer_state_dict": optimizer.state_dict(),
164
+ "loss": loss,
165
+ }, filepath)
166
+ print(f" Checkpoint saved: epoch={epoch}, loss={loss:.6f}")
167
+
168
+
169
+ def main():
170
+ print("=" * 70)
171
+ print("BFM Large Model GeoLifeCLEF Finetuning")
172
+ print("=" * 70)
173
+
174
+ assert SAFETENSORS_PATH.exists(), f"Weights not found: {SAFETENSORS_PATH}"
175
+
176
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177
+ print(f"Device: {device}")
178
+ torch.set_float32_matmul_precision("highest")
179
+
180
+ # 1. Build base model and load pretrained weights
181
+ print("\nBuilding BFM Large model...")
182
+ base_model = build_base_model()
183
+ total_params = sum(p.numel() for p in base_model.parameters())
184
+ print(f"Base model parameters: {total_params / 1e6:.1f}M")
185
+
186
+ base_model = load_pretrained_weights(base_model)
187
+
188
+ # 2. Wrap with BFMRaw for species finetuning
189
+ print("\nWrapping with BFMRaw for species finetuning...")
190
+ from bfm_finetune.bfm_mod import BFMRaw
191
+ model = BFMRaw(base_model=base_model, n_species=NUM_SPECIES, mode="train")
192
+ model.to(device)
193
+
194
+ # 3. Setup datasets
195
+ print("\nLoading GeoLifeCLEF datasets...")
196
+ from bfm_finetune.dataloaders.geolifeclef_species.dataloader import GeoLifeCLEFSpeciesDataset
197
+ from bfm_finetune.dataloaders.dataloader_utils import custom_collate_fn
198
+
199
+ train_dataset = GeoLifeCLEFSpeciesDataset(
200
+ num_species=NUM_SPECIES, mode="train", negative_lon_mode="ignore",
201
+ )
202
+ val_dataset = GeoLifeCLEFSpeciesDataset(
203
+ num_species=NUM_SPECIES, mode="val", negative_lon_mode="ignore",
204
+ )
205
+
206
+ train_dataloader = DataLoader(
207
+ train_dataset, batch_size=BATCH_SIZE, shuffle=True,
208
+ collate_fn=custom_collate_fn, num_workers=NUM_WORKERS,
209
+ )
210
+ val_dataloader = DataLoader(
211
+ val_dataset, batch_size=1, shuffle=False,
212
+ collate_fn=custom_collate_fn, num_workers=NUM_WORKERS,
213
+ )
214
+ print(f"Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
215
+
216
+ # 4. Setup optimizer, scheduler, loss
217
+ optimizer = torch.optim.AdamW(
218
+ model.parameters(), lr=LEARNING_RATE,
219
+ weight_decay=0.0001, betas=(0.9, 0.95), eps=1e-8,
220
+ )
221
+ criterion = nn.L1Loss()
222
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
223
+ optimizer, T_max=NUM_EPOCHS * len(train_dataloader),
224
+ eta_min=LEARNING_RATE / 10,
225
+ )
226
+
227
+ # 5. Training loop
228
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
229
+ best_loss = float("inf")
230
+ start_time = time.time()
231
+
232
+ print(f"\nStarting training: {NUM_EPOCHS} epochs, batch_size={BATCH_SIZE}")
233
+ print("-" * 70)
234
+
235
+ for epoch in range(1, NUM_EPOCHS + 1):
236
+ t0 = time.time()
237
+ train_loss = train_epoch(model, train_dataloader, optimizer, criterion, scheduler, device)
238
+ epoch_time = time.time() - t0
239
+
240
+ log_msg = f"Epoch {epoch}/{NUM_EPOCHS}, Train Loss: {train_loss:.6f}, Time: {epoch_time:.1f}s"
241
+
242
+ if epoch % VAL_EVERY == 0:
243
+ val_loss = validate_epoch(model, val_dataloader, criterion, device)
244
+ log_msg += f", Val Loss: {val_loss:.6f}"
245
+
246
+ if val_loss < best_loss:
247
+ best_loss = val_loss
248
+ save_checkpoint(model, optimizer, epoch, best_loss, CHECKPOINT_DIR)
249
+ else:
250
+ # Save based on train loss if no validation this epoch
251
+ if train_loss < best_loss:
252
+ best_loss = train_loss
253
+ save_checkpoint(model, optimizer, epoch, best_loss, CHECKPOINT_DIR)
254
+
255
+ print(log_msg)
256
+
257
+ total_time = time.time() - start_time
258
+ print(f"\nTraining complete! Total time: {total_time/60:.1f} minutes")
259
+ print(f"Best loss: {best_loss:.6f}")
260
+ print(f"Checkpoint saved to: {CHECKPOINT_DIR}")
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()