AbstractPhil commited on
Commit
f3986cf
Β·
verified Β·
1 Parent(s): 0420e01

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +566 -0
trainer.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT-Thetis Colab Training Script
3
+ ----------------------------------
4
+ Pretrain BERT-Thetis on WikiText-103 with Masked Language Modeling.
5
+
6
+ Designed for Google Colab with:
7
+ - Easy setup and installation
8
+ - HuggingFace Hub integration
9
+ - Memory-efficient training
10
+ - Progress tracking and logging
11
+ - Automatic checkpointing
12
+
13
+ Author: AbstractPhil + Claude Sonnet 4.5
14
+ License: MIT
15
+ """
16
+
17
+ import os
18
+ import math
19
+ import time
20
+ from pathlib import Path
21
+ from typing import Optional, Dict, Any
22
+ from dataclasses import dataclass, field
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data import DataLoader, Dataset
27
+ from torch.optim import AdamW
28
+ from torch.optim.lr_scheduler import OneCycleLR
29
+
30
+ from datasets import load_dataset
31
+ from transformers import AutoTokenizer
32
+ from tqdm.auto import tqdm
33
+
34
+ # Import BERT-Thetis
35
+ from geovocab2.train.model.core.bert_thetis import (
36
+ ThetisConfig,
37
+ ThetisForMaskedLM
38
+ )
39
+
40
+
41
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
42
+ # Configuration
43
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
44
+
45
+ @dataclass
46
+ class TrainingConfig:
47
+ """Training configuration for Colab."""
48
+
49
+ # Model
50
+ model_name: str = "bert-thetis-tiny-wikitext103"
51
+ crystal_dim: int = 256
52
+ num_layers: int = 4
53
+ num_attention_heads: int = 4
54
+ intermediate_size: int = 1024
55
+ vocab_size: int = 30522
56
+ beatrix_levels: int = 16
57
+ max_position_embeddings: int = 512
58
+
59
+ # Dataset
60
+ dataset_name: str = "wikitext"
61
+ dataset_config: str = "wikitext-103-raw-v1"
62
+ tokenizer_name: str = "bert-base-uncased"
63
+ max_length: int = 128
64
+ mlm_probability: float = 0.15
65
+
66
+ # Training
67
+ num_epochs: int = 10
68
+ batch_size: int = 64
69
+ gradient_accumulation_steps: int = 2
70
+ learning_rate: float = 5e-4
71
+ weight_decay: float = 0.01
72
+ warmup_ratio: float = 0.1
73
+ max_grad_norm: float = 1.0
74
+
75
+ # Hardware
76
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
77
+ num_workers: int = 2
78
+ pin_memory: bool = True
79
+ mixed_precision: bool = True # Use AMP for faster training
80
+
81
+ # Checkpointing
82
+ save_steps: int = 1000
83
+ eval_steps: int = 500
84
+ logging_steps: int = 100
85
+ save_total_limit: int = 3
86
+
87
+ # HuggingFace Hub
88
+ push_to_hub: bool = True
89
+ hub_model_id: str = "AbstractPhil/bert-thetis-tiny-wikitext103"
90
+ hub_token: Optional[str] = None # Will read from HF_TOKEN env var
91
+
92
+ # Paths
93
+ output_dir: str = "./thetis-outputs"
94
+ cache_dir: str = "./cache"
95
+
96
+ def __post_init__(self):
97
+ """Setup paths and device."""
98
+ os.makedirs(self.output_dir, exist_ok=True)
99
+ os.makedirs(self.cache_dir, exist_ok=True)
100
+
101
+ # Get HF token from environment if not provided
102
+ if self.hub_token is None:
103
+ self.hub_token = os.environ.get("HF_TOKEN")
104
+
105
+ print(f"🚒 BERT-Thetis Training Configuration")
106
+ print(f" Device: {self.device}")
107
+ print(f" Mixed Precision: {self.mixed_precision}")
108
+ print(f" Model: {self.model_name}")
109
+ print(f" Dataset: {self.dataset_name}/{self.dataset_config}")
110
+ print(f" Output: {self.output_dir}")
111
+ print(f" Push to Hub: {self.push_to_hub}")
112
+ if self.push_to_hub:
113
+ print(f" Hub Repo: {self.hub_model_id}")
114
+
115
+
116
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
117
+ # Dataset
118
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
119
+
120
+ class MaskedLMDataset(Dataset):
121
+ """Dataset for Masked Language Modeling."""
122
+
123
+ def __init__(
124
+ self,
125
+ texts,
126
+ tokenizer,
127
+ max_length: int = 128,
128
+ mlm_probability: float = 0.15
129
+ ):
130
+ self.texts = texts
131
+ self.tokenizer = tokenizer
132
+ self.max_length = max_length
133
+ self.mlm_probability = mlm_probability
134
+
135
+ def __len__(self):
136
+ return len(self.texts)
137
+
138
+ def __getitem__(self, idx):
139
+ text = self.texts[idx]
140
+
141
+ # Tokenize
142
+ encoding = self.tokenizer(
143
+ text,
144
+ max_length=self.max_length,
145
+ padding="max_length",
146
+ truncation=True,
147
+ return_tensors="pt"
148
+ )
149
+
150
+ input_ids = encoding["input_ids"].squeeze(0)
151
+ attention_mask = encoding["attention_mask"].squeeze(0)
152
+
153
+ # Create masked version
154
+ labels = input_ids.clone()
155
+
156
+ # Mask tokens
157
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
158
+
159
+ # Don't mask special tokens (pass the whole list, not individual tokens)
160
+ special_tokens_mask = self.tokenizer.get_special_tokens_mask(
161
+ labels.tolist(), already_has_special_tokens=True
162
+ )
163
+ probability_matrix.masked_fill_(
164
+ torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0
165
+ )
166
+
167
+ masked_indices = torch.bernoulli(probability_matrix).bool()
168
+ labels[~masked_indices] = -100 # Only compute loss on masked tokens
169
+
170
+ # 80% of the time, replace with [MASK]
171
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
172
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
173
+
174
+ # 10% of the time, replace with random token
175
+ indices_random = (
176
+ torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
177
+ & masked_indices
178
+ & ~indices_replaced
179
+ )
180
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
181
+ input_ids[indices_random] = random_words[indices_random]
182
+
183
+ # 10% of the time, keep original
184
+
185
+ return {
186
+ "input_ids": input_ids,
187
+ "attention_mask": attention_mask,
188
+ "labels": labels
189
+ }
190
+
191
+
192
+ def prepare_datasets(config: TrainingConfig):
193
+ """Load and prepare WikiText-103 datasets."""
194
+ print(f"\nπŸ“š Loading {config.dataset_name}...")
195
+
196
+ # Load dataset
197
+ dataset = load_dataset(
198
+ config.dataset_name,
199
+ config.dataset_config,
200
+ cache_dir=config.cache_dir
201
+ )
202
+
203
+ # Load tokenizer
204
+ tokenizer = AutoTokenizer.from_pretrained(
205
+ config.tokenizer_name,
206
+ cache_dir=config.cache_dir
207
+ )
208
+
209
+ # Filter out empty texts
210
+ def is_valid(example):
211
+ return len(example["text"].strip()) > 0
212
+
213
+ train_texts = [ex["text"] for ex in dataset["train"] if is_valid(ex)]
214
+ val_texts = [ex["text"] for ex in dataset["validation"] if is_valid(ex)]
215
+
216
+ print(f" Train samples: {len(train_texts):,}")
217
+ print(f" Val samples: {len(val_texts):,}")
218
+
219
+ # Create datasets
220
+ train_dataset = MaskedLMDataset(
221
+ train_texts,
222
+ tokenizer,
223
+ config.max_length,
224
+ config.mlm_probability
225
+ )
226
+
227
+ val_dataset = MaskedLMDataset(
228
+ val_texts,
229
+ tokenizer,
230
+ config.max_length,
231
+ config.mlm_probability
232
+ )
233
+
234
+ return train_dataset, val_dataset, tokenizer
235
+
236
+
237
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
238
+ # Training Loop
239
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
240
+
241
+ class ThetisTrainer:
242
+ """Trainer for BERT-Thetis with MLM."""
243
+
244
+ def __init__(
245
+ self,
246
+ model: ThetisForMaskedLM,
247
+ train_dataset: Dataset,
248
+ val_dataset: Dataset,
249
+ config: TrainingConfig
250
+ ):
251
+ self.model = model
252
+ self.train_dataset = train_dataset
253
+ self.val_dataset = val_dataset
254
+ self.config = config
255
+
256
+ # Move model to device
257
+ self.model.to(config.device)
258
+
259
+ # Data loaders
260
+ self.train_loader = DataLoader(
261
+ train_dataset,
262
+ batch_size=config.batch_size,
263
+ shuffle=True,
264
+ num_workers=config.num_workers,
265
+ pin_memory=config.pin_memory
266
+ )
267
+
268
+ self.val_loader = DataLoader(
269
+ val_dataset,
270
+ batch_size=config.batch_size * 2, # Larger batch for eval
271
+ shuffle=False,
272
+ num_workers=config.num_workers,
273
+ pin_memory=config.pin_memory
274
+ )
275
+
276
+ # Optimizer
277
+ no_decay = ["bias", "LayerNorm.weight"]
278
+ optimizer_grouped_parameters = [
279
+ {
280
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
281
+ "weight_decay": config.weight_decay,
282
+ },
283
+ {
284
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
285
+ "weight_decay": 0.0,
286
+ },
287
+ ]
288
+
289
+ self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
290
+
291
+ # Scheduler
292
+ total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps
293
+ warmup_steps = int(total_steps * config.warmup_ratio)
294
+
295
+ self.scheduler = OneCycleLR(
296
+ self.optimizer,
297
+ max_lr=config.learning_rate,
298
+ total_steps=total_steps,
299
+ pct_start=config.warmup_ratio,
300
+ anneal_strategy="cos"
301
+ )
302
+
303
+ # Mixed precision
304
+ self.scaler = torch.amp.GradScaler('cuda') if config.mixed_precision and config.device == 'cuda' else None
305
+
306
+ # Training state
307
+ self.global_step = 0
308
+ self.epoch = 0
309
+ self.best_val_loss = float("inf")
310
+
311
+ print(f"\n🎯 Training Setup")
312
+ print(f" Total steps: {total_steps:,}")
313
+ print(f" Warmup steps: {warmup_steps:,}")
314
+ print(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
315
+
316
+ def train_epoch(self):
317
+ """Train for one epoch."""
318
+ self.model.train()
319
+ total_loss = 0
320
+
321
+ progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch + 1}")
322
+
323
+ for step, batch in enumerate(progress_bar):
324
+ # Move to device
325
+ batch = {k: v.to(self.config.device) for k, v in batch.items()}
326
+
327
+ # Forward pass
328
+ with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'):
329
+ loss, _ = self.model(
330
+ token_ids=batch["input_ids"],
331
+ attention_mask=batch["attention_mask"],
332
+ labels=batch["labels"]
333
+ )
334
+ loss = loss / self.config.gradient_accumulation_steps
335
+
336
+ # Backward pass
337
+ if self.scaler is not None:
338
+ self.scaler.scale(loss).backward()
339
+ else:
340
+ loss.backward()
341
+
342
+ total_loss += loss.item()
343
+
344
+ # Update weights
345
+ if (step + 1) % self.config.gradient_accumulation_steps == 0:
346
+ if self.scaler is not None:
347
+ self.scaler.unscale_(self.optimizer)
348
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
349
+ self.scaler.step(self.optimizer)
350
+ self.scaler.update()
351
+ else:
352
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
353
+ self.optimizer.step()
354
+
355
+ self.scheduler.step()
356
+ self.optimizer.zero_grad()
357
+ self.global_step += 1
358
+
359
+ # Update progress bar
360
+ progress_bar.set_postfix({
361
+ "loss": f"{loss.item() * self.config.gradient_accumulation_steps:.4f}",
362
+ "lr": f"{self.scheduler.get_last_lr()[0]:.2e}"
363
+ })
364
+
365
+ # Logging
366
+ if self.global_step % self.config.logging_steps == 0:
367
+ avg_loss = total_loss / self.config.logging_steps
368
+ print(f"\n Step {self.global_step}: loss={avg_loss:.4f}, lr={self.scheduler.get_last_lr()[0]:.2e}")
369
+ total_loss = 0
370
+
371
+ # Evaluation
372
+ if self.global_step % self.config.eval_steps == 0:
373
+ val_loss = self.evaluate()
374
+ print(f" Validation loss: {val_loss:.4f}")
375
+
376
+ # Save best model
377
+ if val_loss < self.best_val_loss:
378
+ self.best_val_loss = val_loss
379
+ self.save_checkpoint("best")
380
+ print(f" βœ“ New best model saved!")
381
+
382
+ self.model.train()
383
+
384
+ # Save checkpoint
385
+ if self.global_step % self.config.save_steps == 0:
386
+ self.save_checkpoint(f"step-{self.global_step}")
387
+
388
+ @torch.no_grad()
389
+ def evaluate(self):
390
+ """Evaluate on validation set."""
391
+ self.model.eval()
392
+ total_loss = 0
393
+ total_steps = 0
394
+
395
+ for batch in tqdm(self.val_loader, desc="Evaluating", leave=False):
396
+ batch = {k: v.to(self.config.device) for k, v in batch.items()}
397
+
398
+ with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'):
399
+ loss, _ = self.model(
400
+ token_ids=batch["input_ids"],
401
+ attention_mask=batch["attention_mask"],
402
+ labels=batch["labels"]
403
+ )
404
+
405
+ total_loss += loss.item()
406
+ total_steps += 1
407
+
408
+ return total_loss / total_steps
409
+
410
+ def train(self):
411
+ """Full training loop."""
412
+ print(f"\nπŸš€ Starting Training")
413
+ print("=" * 70)
414
+
415
+ start_time = time.time()
416
+
417
+ for epoch in range(self.config.num_epochs):
418
+ self.epoch = epoch
419
+ print(f"\nπŸ“– Epoch {epoch + 1}/{self.config.num_epochs}")
420
+
421
+ self.train_epoch()
422
+
423
+ # Epoch evaluation
424
+ val_loss = self.evaluate()
425
+ print(f"\n Epoch {epoch + 1} validation loss: {val_loss:.4f}")
426
+
427
+ # Save epoch checkpoint
428
+ self.save_checkpoint(f"epoch-{epoch + 1}")
429
+
430
+ # Final evaluation
431
+ final_val_loss = self.evaluate()
432
+ print(f"\nβœ… Training Complete!")
433
+ print(f" Final validation loss: {final_val_loss:.4f}")
434
+ print(f" Best validation loss: {self.best_val_loss:.4f}")
435
+ print(f" Total time: {(time.time() - start_time) / 3600:.2f} hours")
436
+
437
+ # Save final model
438
+ self.save_checkpoint("final")
439
+
440
+ # Push to hub
441
+ if self.config.push_to_hub:
442
+ self.push_to_hub()
443
+
444
+ def save_checkpoint(self, name: str):
445
+ """Save model checkpoint."""
446
+ output_dir = Path(self.config.output_dir) / name
447
+ output_dir.mkdir(parents=True, exist_ok=True)
448
+
449
+ # Save model
450
+ torch.save(self.model.state_dict(), output_dir / "pytorch_model.bin")
451
+
452
+ # Save config
453
+ config_dict = {
454
+ "crystal_dim": self.config.crystal_dim,
455
+ "num_layers": self.config.num_layers,
456
+ "num_attention_heads": self.config.num_attention_heads,
457
+ "intermediate_size": self.config.intermediate_size,
458
+ "vocab_size": self.config.vocab_size,
459
+ "beatrix_levels": self.config.beatrix_levels,
460
+ "max_position_embeddings": self.config.max_position_embeddings,
461
+ }
462
+
463
+ import json
464
+ with open(output_dir / "config.json", "w") as f:
465
+ json.dump(config_dict, f, indent=2)
466
+
467
+ # Save training state
468
+ state = {
469
+ "global_step": self.global_step,
470
+ "epoch": self.epoch,
471
+ "best_val_loss": self.best_val_loss,
472
+ }
473
+ torch.save(state, output_dir / "training_state.pt")
474
+
475
+ def push_to_hub(self):
476
+ """Push model to HuggingFace Hub."""
477
+ if not self.config.hub_token:
478
+ print("⚠️ No HuggingFace token found. Skipping push to hub.")
479
+ return
480
+
481
+ print(f"\nπŸ“€ Pushing to HuggingFace Hub: {self.config.hub_model_id}")
482
+
483
+ try:
484
+ from huggingface_hub import HfApi, create_repo
485
+
486
+ api = HfApi(token=self.config.hub_token)
487
+
488
+ # Create repo if it doesn't exist
489
+ try:
490
+ create_repo(
491
+ repo_id=self.config.hub_model_id,
492
+ token=self.config.hub_token,
493
+ exist_ok=True
494
+ )
495
+ except Exception as e:
496
+ print(f" Repo creation: {e}")
497
+
498
+ # Upload best checkpoint
499
+ best_dir = Path(self.config.output_dir) / "best"
500
+ if best_dir.exists():
501
+ api.upload_folder(
502
+ folder_path=str(best_dir),
503
+ repo_id=self.config.hub_model_id,
504
+ token=self.config.hub_token
505
+ )
506
+ print(f" βœ“ Best model uploaded!")
507
+
508
+ # Upload final checkpoint
509
+ final_dir = Path(self.config.output_dir) / "final"
510
+ if final_dir.exists():
511
+ api.upload_folder(
512
+ folder_path=str(final_dir),
513
+ repo_id=self.config.hub_model_id,
514
+ path_in_repo="final",
515
+ token=self.config.hub_token
516
+ )
517
+ print(f" βœ“ Final model uploaded!")
518
+
519
+ except Exception as e:
520
+ print(f"⚠️ Failed to push to hub: {e}")
521
+
522
+
523
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
524
+ # Main Entry Point
525
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
526
+
527
+ def main():
528
+ """Main training function."""
529
+ # Configuration
530
+ config = TrainingConfig()
531
+
532
+ # Prepare datasets
533
+ train_dataset, val_dataset, tokenizer = prepare_datasets(config)
534
+
535
+ # Create model
536
+ print(f"\nπŸ—οΈ Creating BERT-Thetis model...")
537
+ model_config = ThetisConfig(
538
+ crystal_dim=config.crystal_dim,
539
+ num_vertices=5,
540
+ num_layers=config.num_layers,
541
+ num_attention_heads=config.num_attention_heads,
542
+ intermediate_size=config.intermediate_size,
543
+ vocab_size=config.vocab_size,
544
+ beatrix_levels=config.beatrix_levels,
545
+ max_position_embeddings=config.max_position_embeddings,
546
+ )
547
+
548
+ model = ThetisForMaskedLM(model_config)
549
+
550
+ total_params = sum(p.numel() for p in model.parameters())
551
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
552
+
553
+ print(f" Total parameters: {total_params:,}")
554
+ print(f" Trainable parameters: {trainable_params:,}")
555
+
556
+ # Create trainer
557
+ trainer = ThetisTrainer(model, train_dataset, val_dataset, config)
558
+
559
+ # Train
560
+ trainer.train()
561
+
562
+ print("\nπŸŽ‰ All done! BERT-Thetis is ready to sail!")
563
+
564
+
565
+ if __name__ == "__main__":
566
+ main()