Sarthak commited on
Commit
729d700
·
1 Parent(s): 93151b9

refactor(distiller): improve beam distillation and tokenlearn integration

Browse files

This commit introduces separate Beam functions for distillation and training, enabling more modular and controllable workflows. It also enhances tokenlearn integration by using persistent directories for caching and checkpointing, and improves error handling for training failures.

The changes also include validation of the model to check vocab and embedding sizes match, which can highlight issues in downstream usage.

Files changed (2) hide show
  1. patches/tokenlearn.patch +25 -0
  2. src/distiller/distill.py +381 -256
patches/tokenlearn.patch ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- a/tokenlearn/pretrain.py
2
+ +++ b/tokenlearn/pretrain.py
3
+ @@ -38,7 +38,10 @@ class FinetunableStaticModel(nn.Module):
4
+ """Run the model using input IDs."""
5
+ input_ids = input_ids.view(-1)
6
+ input_ids = input_ids[input_ids != self.pad_token_id]
7
+ - w = self.w[input_ids]
8
+ + # Fix for index out of bounds issue
9
+ + # Clamp input_ids to valid range to prevent IndexError during training
10
+ + valid_input_ids = torch.clamp(input_ids, 0, self.w.shape[0] - 1)
11
+ + w = self.w[valid_input_ids]
12
+ return self.sub_forward(w)
13
+
14
+ def forward(self, x):
15
+ @@ -46,7 +49,10 @@ class FinetunableStaticModel(nn.Module):
16
+ # Add a small epsilon to avoid division by zero
17
+ length = zeros.sum(1) + 1e-16
18
+ - embedded = self.embeddings(input_ids)
19
+ + # Fix for embedding index out of bounds issue
20
+ + # Clamp input_ids to valid embedding range
21
+ + valid_input_ids = torch.clamp(input_ids, 0, self.embeddings.num_embeddings - 1)
22
+ + embedded = self.embeddings(valid_input_ids)
23
+ # Zero out the padding
24
+ embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
25
+ # Simulate actual mean
src/distiller/distill.py CHANGED
@@ -49,6 +49,7 @@ from .config import (
49
  directories,
50
  distillation_config,
51
  get_distillation_function_kwargs,
 
52
  get_volume_config,
53
  languages_config,
54
  )
@@ -358,6 +359,21 @@ def simple_distillation(
358
 
359
  logger.info("✅ Core distillation completed successfully")
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  # Save the model
362
  model.save_pretrained(str(output_path))
363
  logger.info(f"💾 Model saved to {output_path}")
@@ -772,7 +788,11 @@ def apply_post_training_regularization(
772
  logger.info(f"🔄 Applying PCA with {pca_dims} dimensions...")
773
 
774
  # Get current embeddings
775
- embeddings = model.embedding.cpu().numpy().astype(np.float64)
 
 
 
 
776
  original_shape = embeddings.shape
777
  logger.info(f"Original embedding shape: {original_shape}")
778
 
@@ -846,229 +866,288 @@ def tokenlearn_training(
846
  4. Post-training re-regularization (PCA + SIF weighting)
847
  """
848
  import subprocess
849
- import tempfile
850
  from pathlib import Path
851
 
852
  logger.info("🧪 Starting tokenlearn training (POTION approach)...")
853
 
854
- # Create temporary directories for tokenlearn workflow
855
- with tempfile.TemporaryDirectory() as temp_dir:
856
- temp_path = Path(temp_dir)
857
- features_dir = temp_path / "features"
858
- model_dir = temp_path / "base_model"
859
- trained_dir = temp_path / "trained_model"
860
-
861
- features_dir.mkdir(exist_ok=True)
862
- model_dir.mkdir(exist_ok=True)
863
- trained_dir.mkdir(exist_ok=True)
864
-
865
- # Save the base distilled model for tokenlearn
866
- student_model.save_pretrained(str(model_dir))
867
- logger.info(f"💾 Saved base model to {model_dir}")
868
-
869
- # Step 2: Create features using sentence transformer
870
- logger.info("🔍 Step 2: Creating features using sentence transformer...")
871
-
872
- # Get teacher model name/path for tokenlearn
873
- teacher_model_name = getattr(teacher_model, "model_name", None)
874
- if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
875
- # Try to extract from the first module if it's a SentenceTransformer
876
- # _modules is a dict-like container, get the first module by iterating
877
- first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
878
- if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
879
- teacher_model_name = first_module.auto_model.name_or_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880
 
881
  logger.info(f"📊 Using teacher model: {teacher_model_name}")
882
 
883
- # Check if featurization already completed (checkpoint detection)
884
- featurization_complete_marker = features_dir / ".featurization_complete"
885
- if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
886
- logger.info("✅ Found existing featurization checkpoint with valid output files")
887
- logger.info(f"📂 Using cached features from: {features_dir}")
888
-
889
- # Verify marker is still valid
890
- output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
891
- logger.info(f"📁 Found {len(output_files)} cached feature files")
892
- else:
893
- if featurization_complete_marker.exists():
894
- logger.warning("⚠️ Featurization marker exists but output files are missing - re-running featurization")
895
- featurization_complete_marker.unlink()
896
- logger.info("🔄 No valid featurization checkpoint found - starting featurization...")
 
 
 
 
 
 
 
897
 
898
- if not teacher_model_name:
899
- logger.warning("⚠️ Could not determine teacher model name, using fallback")
900
- teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
 
 
 
 
 
 
 
 
 
 
 
901
 
902
- logger.info(f"📊 Using teacher model: {teacher_model_name}")
 
 
 
 
903
 
904
- try:
905
- # Use configured dataset for code specialization
906
- featurize_cmd = [
907
- "python",
908
- "-m",
909
- "tokenlearn.featurize",
910
- "--model-name",
911
- str(teacher_model_name),
912
- "--output-dir",
913
- str(features_dir),
914
- "--dataset-path",
915
- str(distillation_config.tokenlearn_dataset),
916
- "--dataset-name",
917
- str(distillation_config.tokenlearn_dataset_name),
918
- "--dataset-split",
919
- "train",
920
- "--key",
921
- str(distillation_config.tokenlearn_text_key), # Use configured text field
922
- "--batch-size",
923
- "1024", # Optimized batch size for A100-40G
924
- ]
925
-
926
- logger.info("🔄 Running tokenlearn featurization...")
927
- logger.info(
928
- f"📊 Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
929
- )
930
- logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
931
- logger.info(f"Command: {' '.join(featurize_cmd)}")
932
- print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
933
-
934
- result = subprocess.run( # noqa: S603
935
- featurize_cmd,
936
- text=True,
937
- timeout=distillation_config.tokenlearn_timeout_featurize,
938
- check=False,
939
- )
940
 
941
- if result.returncode != 0:
942
- logger.error(f"❌ Featurization failed with return code: {result.returncode}")
943
- logger.error("💥 Tokenlearn featurization is required for training - cannot proceed")
944
- msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
945
- raise RuntimeError(msg)
946
 
947
- logger.info("✅ Featurization completed successfully")
 
 
948
 
949
- # Create checkpoint marker to indicate featurization is complete
950
- featurization_complete_marker.touch()
951
- logger.info(f"💾 Created featurization checkpoint: {featurization_complete_marker}")
 
 
952
 
953
- # Generate token frequencies for post-training re-regularization
954
- logger.info("📊 Computing token frequencies for SIF weighting...")
955
- compute_token_frequencies_for_sif(teacher_model, features_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
 
957
- except Exception as e:
958
- logger.exception("💥 Tokenlearn featurization failed")
959
- logger.exception("💥 Tokenlearn featurization is required for training - cannot proceed")
960
- msg = f"Tokenlearn featurization failed: {e}"
961
- raise RuntimeError(msg) from e
962
-
963
- # Step 3: Train using tokenlearn-train
964
- logger.info("🎓 Step 3: Training using tokenlearn...")
965
-
966
- # Check if training already completed (checkpoint detection)
967
- training_complete_marker = trained_dir / ".training_complete"
968
- if training_complete_marker.exists() and verify_training_output(trained_dir):
969
- logger.info("✅ Found existing training checkpoint with valid model files")
970
- logger.info(f"📂 Using cached trained model from: {trained_dir}")
971
-
972
- # Show available model files
973
- model_files = []
974
- for pattern in ["*.json", "*.safetensors", "*.bin"]:
975
- model_files.extend(list(trained_dir.glob(pattern)))
976
- for subdir in ["model", "model_weighted"]:
977
- subdir_path = trained_dir / subdir
978
- if subdir_path.exists():
979
- model_files.extend(list(subdir_path.glob(pattern)))
980
- logger.info(f"📁 Found {len(model_files)} cached model files")
981
- else:
982
- if training_complete_marker.exists():
983
- logger.warning("⚠️ Training marker exists but model files are missing - re-running training")
984
- training_complete_marker.unlink()
985
- logger.info("🔄 No valid training checkpoint found - starting training...")
986
 
987
- try:
988
- train_cmd = [
989
- "python",
990
- "-m",
991
- "tokenlearn.train",
992
- "--model-name",
993
- str(teacher_model_name),
994
- "--data-path",
995
- str(features_dir),
996
- "--save-path",
997
- str(trained_dir),
998
- ]
999
-
1000
- logger.info("���� Running tokenlearn training...")
1001
- logger.info(f"Command: {' '.join(train_cmd)}")
1002
- print(f"\n🎓 Executing: {' '.join(train_cmd)}\n")
1003
-
1004
- result = subprocess.run( # noqa: S603
1005
- train_cmd,
1006
- text=True,
1007
- timeout=distillation_config.tokenlearn_timeout_train,
1008
- check=False,
1009
- )
1010
 
1011
- if result.returncode != 0:
1012
- logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
1013
- logger.error("💥 Tokenlearn training is required - cannot proceed")
1014
- msg = f"Tokenlearn training failed with return code: {result.returncode}"
1015
- raise RuntimeError(msg)
 
 
1016
 
1017
- logger.info("✅ Tokenlearn training completed successfully")
 
1018
 
1019
- # Create checkpoint marker to indicate training is complete
1020
- training_complete_marker.touch()
1021
- logger.info(f"💾 Created training checkpoint: {training_complete_marker}")
 
 
1022
 
1023
- except Exception as e:
1024
- logger.exception("💥 Tokenlearn training failed")
1025
- logger.exception("💥 Tokenlearn training is required - cannot proceed")
1026
- msg = f"Tokenlearn training failed: {e}"
1027
- raise RuntimeError(msg) from e
1028
 
1029
- # Step 4: Load the trained model and apply post-training re-regularization
1030
- logger.info("📦 Step 4: Loading trained model and applying post-training re-regularization...")
 
1031
 
1032
- try:
1033
- from model2vec.model import StaticModel
1034
-
1035
- # Load the trained model from tokenlearn
1036
- trained_model_path = trained_dir / "model"
1037
- if not trained_model_path.exists():
1038
- # Try alternative paths
1039
- possible_paths = [
1040
- trained_dir / "model_weighted",
1041
- trained_dir,
1042
- ]
1043
-
1044
- for path in possible_paths:
1045
- if path.exists() and any(path.glob("*.json")):
1046
- trained_model_path = path
1047
- break
1048
- else:
1049
- logger.error(f"❌ Could not find trained model in {trained_dir}")
1050
- msg = f"Tokenlearn training failed - no model found in {trained_dir}"
1051
  raise RuntimeError(msg)
 
 
 
 
1052
 
1053
- # Load the model before re-regularization
1054
- logger.info("🔄 Loading model from tokenlearn training...")
1055
- trained_model = StaticModel.from_pretrained(str(trained_model_path))
1056
 
1057
- # Apply post-training re-regularization (POTION Step 4)
1058
- logger.info("🔧 Applying post-training re-regularization (PCA + SIF weighting)...")
1059
- final_model = apply_post_training_regularization(
1060
- trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
1061
- )
1062
 
1063
- logger.info("✅ Tokenlearn training pipeline with post-training re-regularization completed successfully")
 
1064
 
1065
- return final_model
 
 
 
 
 
 
1066
 
1067
- except Exception as e:
1068
- logger.exception("💥 Failed to load tokenlearn trained model")
1069
- logger.exception("💥 Cannot load trained model - training failed")
1070
- msg = f"Failed to load tokenlearn trained model: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  raise RuntimeError(msg) from e
 
 
 
 
 
 
 
 
1072
 
1073
 
1074
  def distill_single_teacher(
@@ -1118,7 +1197,6 @@ def distill_single_teacher(
1118
 
1119
  # Initialize Beam utilities if requested
1120
  checkpoint_mgr = None
1121
- model_mgr = None
1122
  if use_beam_utilities:
1123
  try:
1124
  _, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
@@ -1197,44 +1275,65 @@ def distill_single_teacher(
1197
 
1198
  existing_base = str(base_dir)
1199
 
1200
- # Step 3: Handle final model creation
1201
- if enable_training and base_model is not None:
1202
- # Perform tokenlearn training (POTION approach)
1203
- logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
1204
-
1205
- # Load teacher model for training
1206
- device = "cuda" if torch.cuda.is_available() else "cpu"
1207
- teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1208
-
1209
- # Perform tokenlearn training (POTION approach)
1210
- final_model = tokenlearn_training(base_model, teacher_st_model, checkpoint_mgr)
1211
-
1212
- # Save final model
1213
- final_dir.mkdir(parents=True, exist_ok=True)
1214
- final_model.save_pretrained(str(final_dir))
1215
-
1216
- # Sync final model and training checkpoints to Beam
1217
- if use_beam_utilities:
1218
- sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
1219
- if checkpoint_mgr:
1220
- sync_checkpoints_to_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
1221
 
1222
- del teacher_st_model
1223
- if torch.cuda.is_available():
1224
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1225
 
1226
- else:
1227
- # Copy base to final (no training)
1228
- logger.info(f"📁 Copying base to final for {teacher_name}")
1229
- if not copy_base_to_final(teacher_name, enable_training):
1230
- return {
1231
- "teacher_model": teacher_model,
1232
- "teacher_name": teacher_name,
1233
- "status": "failed_copy_to_final",
1234
- "error": "Failed to copy base to final",
1235
- }
1236
 
1237
- total_time = time.time() - start_time
1238
 
1239
  return {
1240
  "teacher_model": teacher_model,
@@ -1318,6 +1417,9 @@ def run_local_distillation(
1318
 
1319
  if result["status"] == "success" or result["status"].startswith("skipped"):
1320
  successful_models.append(teacher_name)
 
 
 
1321
 
1322
  # Summary
1323
  logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!")
@@ -1349,16 +1451,13 @@ def run_local_distillation(
1349
  return results_summary
1350
 
1351
 
1352
- @function(**get_distillation_function_kwargs())
1353
- def _beam_distill_models(
1354
  teacher_models: list[str] | None = None,
1355
  enable_training: bool = False,
1356
  pca_dims: int | None = None,
1357
  clear_cache: bool = False,
1358
  ) -> dict[str, Any]:
1359
- """Internal Beam function for distillation."""
1360
- logger.info("☁️ Running distillation on Beam")
1361
-
1362
  # Apply patches
1363
  patch_success = apply_local_patches()
1364
  if patch_success:
@@ -1404,6 +1503,9 @@ def _beam_distill_models(
1404
 
1405
  if result["status"] == "success" or result["status"].startswith("skipped"):
1406
  successful_models.append(teacher_name)
 
 
 
1407
 
1408
  # Summary
1409
  logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!")
@@ -1429,6 +1531,30 @@ def _beam_distill_models(
1429
  return results_summary
1430
 
1431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1432
  def run_beam_distillation(
1433
  teacher_models: list[str] | None = None,
1434
  enable_training: bool = False,
@@ -1439,8 +1565,11 @@ def run_beam_distillation(
1439
  logger.info("☁️ Running distillation on Beam with local sync")
1440
 
1441
  try:
 
 
 
1442
  # Run distillation on Beam
1443
- results = _beam_distill_models.remote(teacher_models, enable_training, pca_dims, clear_cache)
1444
 
1445
  # Check if Beam execution was successful
1446
  if not results:
@@ -1529,27 +1658,23 @@ def main(
1529
 
1530
  # Clear tokenlearn checkpoints if requested (for training mode)
1531
  if clear_checkpoints and train:
1532
- import tempfile
1533
-
1534
  logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...")
1535
  for teacher_model in models_to_distill:
1536
- teacher_name = teacher_model.split("/")[-1].replace("-", "_")
1537
-
1538
- # Construct checkpoint paths using secure temporary directory
1539
- temp_dir = Path(tempfile.gettempdir()) / f"tokenlearn_{teacher_name}"
1540
- features_dir = temp_dir / "features"
1541
- trained_dir = temp_dir / "trained"
1542
-
1543
- # Also check local paths
1544
- local_temp = Path("temp") / f"tokenlearn_{teacher_name}"
1545
- local_features = local_temp / "features"
1546
- local_trained = local_temp / "trained"
1547
-
1548
- # Clear checkpoints for all possible paths
1549
- for feat_dir, train_dir in [(features_dir, trained_dir), (local_features, local_trained)]:
1550
- if feat_dir.exists() or train_dir.exists():
1551
- clear_tokenlearn_checkpoints(feat_dir, train_dir)
1552
- logger.info(f"🗑️ Cleared checkpoints for {teacher_model}")
1553
  elif clear_checkpoints and not train:
1554
  logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)")
1555
 
 
49
  directories,
50
  distillation_config,
51
  get_distillation_function_kwargs,
52
+ get_training_function_kwargs,
53
  get_volume_config,
54
  languages_config,
55
  )
 
359
 
360
  logger.info("✅ Core distillation completed successfully")
361
 
362
+ # Validate model before saving
363
+ if hasattr(model, "tokenizer") and hasattr(model, "embedding"):
364
+ vocab_size = len(model.tokenizer.get_vocab())
365
+ embedding_size = model.embedding.shape[0]
366
+
367
+ logger.info("📊 Model validation:")
368
+ logger.info(f" - Vocabulary size: {vocab_size}")
369
+ logger.info(f" - Embedding matrix size: {embedding_size}")
370
+
371
+ if vocab_size != embedding_size:
372
+ logger.warning(f"⚠️ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}")
373
+ logger.warning("⚠️ This may cause issues in downstream usage")
374
+ else:
375
+ logger.info("✅ Vocabulary and embedding sizes match")
376
+
377
  # Save the model
378
  model.save_pretrained(str(output_path))
379
  logger.info(f"💾 Model saved to {output_path}")
 
788
  logger.info(f"🔄 Applying PCA with {pca_dims} dimensions...")
789
 
790
  # Get current embeddings
791
+ # Handle both torch tensors and numpy arrays
792
+ if hasattr(model.embedding, "cpu"):
793
+ embeddings = model.embedding.cpu().numpy().astype(np.float64)
794
+ else:
795
+ embeddings = model.embedding.astype(np.float64)
796
  original_shape = embeddings.shape
797
  logger.info(f"Original embedding shape: {original_shape}")
798
 
 
866
  4. Post-training re-regularization (PCA + SIF weighting)
867
  """
868
  import subprocess
 
869
  from pathlib import Path
870
 
871
  logger.info("🧪 Starting tokenlearn training (POTION approach)...")
872
 
873
+ # Create persistent directories for tokenlearn workflow (for checkpoint preservation)
874
+ teacher_model_name = getattr(teacher_model, "model_name", None)
875
+ if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
876
+ # Try to extract from the first module if it's a SentenceTransformer
877
+ first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
878
+ if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
879
+ teacher_model_name = first_module.auto_model.name_or_path
880
+
881
+ if not teacher_model_name:
882
+ teacher_model_name = "unknown_teacher"
883
+
884
+ # Use persistent directory for tokenlearn checkpoints
885
+ teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_")
886
+ persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug
887
+
888
+ features_dir = persistent_tokenlearn_dir / "features"
889
+ model_dir = persistent_tokenlearn_dir / "base_model"
890
+ trained_dir = persistent_tokenlearn_dir / "trained_model"
891
+
892
+ features_dir.mkdir(parents=True, exist_ok=True)
893
+ model_dir.mkdir(parents=True, exist_ok=True)
894
+ trained_dir.mkdir(parents=True, exist_ok=True)
895
+
896
+ logger.info(f"📁 Using persistent tokenlearn directory: {persistent_tokenlearn_dir}")
897
+
898
+ # Save the base distilled model for tokenlearn
899
+ student_model.save_pretrained(str(model_dir))
900
+ logger.info(f"💾 Saved base model to {model_dir}")
901
+
902
+ # Step 2: Create features using sentence transformer
903
+ logger.info("🔍 Step 2: Creating features using sentence transformer...")
904
+
905
+ # Get teacher model name/path for tokenlearn
906
+ teacher_model_name = getattr(teacher_model, "model_name", None)
907
+ if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
908
+ # Try to extract from the first module if it's a SentenceTransformer
909
+ # _modules is a dict-like container, get the first module by iterating
910
+ first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
911
+ if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
912
+ teacher_model_name = first_module.auto_model.name_or_path
913
+
914
+ logger.info(f"📊 Using teacher model: {teacher_model_name}")
915
+
916
+ # Check if featurization already completed (checkpoint detection)
917
+ featurization_complete_marker = features_dir / ".featurization_complete"
918
+ if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
919
+ logger.info("✅ Found existing featurization checkpoint with valid output files")
920
+ logger.info(f"📂 Using cached features from: {features_dir}")
921
+
922
+ # Verify marker is still valid
923
+ output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
924
+ logger.info(f"📁 Found {len(output_files)} cached feature files")
925
+ else:
926
+ if featurization_complete_marker.exists():
927
+ logger.warning("⚠️ Featurization marker exists but output files are missing - re-running featurization")
928
+ featurization_complete_marker.unlink()
929
+ logger.info("🔄 No valid featurization checkpoint found - starting featurization...")
930
+
931
+ if not teacher_model_name:
932
+ logger.warning("⚠️ Could not determine teacher model name, using fallback")
933
+ teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
934
 
935
  logger.info(f"📊 Using teacher model: {teacher_model_name}")
936
 
937
+ try:
938
+ # Use configured dataset for code specialization
939
+ featurize_cmd = [
940
+ "python",
941
+ "-m",
942
+ "tokenlearn.featurize",
943
+ "--model-name",
944
+ str(teacher_model_name),
945
+ "--output-dir",
946
+ str(features_dir),
947
+ "--dataset-path",
948
+ str(distillation_config.tokenlearn_dataset),
949
+ "--dataset-name",
950
+ str(distillation_config.tokenlearn_dataset_name),
951
+ "--dataset-split",
952
+ "train",
953
+ "--key",
954
+ str(distillation_config.tokenlearn_text_key), # Use configured text field
955
+ "--batch-size",
956
+ "1024", # Optimized batch size for A100-40G
957
+ ]
958
 
959
+ logger.info("🔄 Running tokenlearn featurization...")
960
+ logger.info(
961
+ f"📊 Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
962
+ )
963
+ logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
964
+ logger.info(f"Command: {' '.join(featurize_cmd)}")
965
+ print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
966
+
967
+ result = subprocess.run( # noqa: S603
968
+ featurize_cmd,
969
+ text=True,
970
+ timeout=distillation_config.tokenlearn_timeout_featurize,
971
+ check=False,
972
+ )
973
 
974
+ if result.returncode != 0:
975
+ logger.error(f"❌ Featurization failed with return code: {result.returncode}")
976
+ logger.error("💥 Tokenlearn featurization is required for training - cannot proceed")
977
+ msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
978
+ raise RuntimeError(msg)
979
 
980
+ logger.info("✅ Featurization completed successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
 
982
+ # Create checkpoint marker to indicate featurization is complete
983
+ featurization_complete_marker.touch()
984
+ logger.info(f"💾 Created featurization checkpoint: {featurization_complete_marker}")
 
 
985
 
986
+ # Generate token frequencies for post-training re-regularization
987
+ logger.info("📊 Computing token frequencies for SIF weighting...")
988
+ compute_token_frequencies_for_sif(teacher_model, features_dir)
989
 
990
+ except Exception as e:
991
+ logger.exception("💥 Tokenlearn featurization failed")
992
+ logger.exception("💥 Tokenlearn featurization is required for training - cannot proceed")
993
+ msg = f"Tokenlearn featurization failed: {e}"
994
+ raise RuntimeError(msg) from e
995
 
996
+ # Step 3: Train using tokenlearn-train
997
+ logger.info("🎓 Step 3: Training using tokenlearn...")
998
+
999
+ # Check if training already completed (checkpoint detection)
1000
+ training_complete_marker = trained_dir / ".training_complete"
1001
+ training_fallback_marker = trained_dir / ".training_fallback"
1002
+
1003
+ if training_complete_marker.exists() and verify_training_output(trained_dir):
1004
+ logger.info("✅ Found existing training checkpoint with valid model files")
1005
+ logger.info(f"📂 Using cached trained model from: {trained_dir}")
1006
+
1007
+ # Show available model files
1008
+ model_files = []
1009
+ for pattern in ["*.json", "*.safetensors", "*.bin"]:
1010
+ model_files.extend(list(trained_dir.glob(pattern)))
1011
+ for subdir in ["model", "model_weighted"]:
1012
+ subdir_path = trained_dir / subdir
1013
+ if subdir_path.exists():
1014
+ model_files.extend(list(subdir_path.glob(pattern)))
1015
+ logger.info(f"📁 Found {len(model_files)} cached model files")
1016
+ elif training_fallback_marker.exists():
1017
+ logger.warning("⚠️ Training fallback marker found - tokenlearn failed previously")
1018
+ logger.info("🔄 Proceeding with fallback to base model (simple distillation)")
1019
+ # Skip training and proceed to model loading (will fallback to base model)
1020
+ else:
1021
+ if training_complete_marker.exists():
1022
+ logger.warning("⚠️ Training marker exists but model files are missing - re-running training")
1023
+ training_complete_marker.unlink()
1024
+ logger.info("🔄 No valid training checkpoint found - starting training...")
1025
 
1026
+ try:
1027
+ train_cmd = [
1028
+ "python",
1029
+ "-m",
1030
+ "tokenlearn.train",
1031
+ "--model-name",
1032
+ str(teacher_model_name),
1033
+ "--data-path",
1034
+ str(features_dir),
1035
+ "--save-path",
1036
+ str(trained_dir),
1037
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1038
 
1039
+ logger.info("🔄 Running tokenlearn training...")
1040
+ logger.info(f"Command: {' '.join(train_cmd)}")
1041
+ print(f"\n🎓 Executing: {' '.join(train_cmd)}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1042
 
1043
+ result = subprocess.run( # noqa: S603
1044
+ train_cmd,
1045
+ text=True,
1046
+ capture_output=True, # Capture stdout and stderr
1047
+ timeout=distillation_config.tokenlearn_timeout_train,
1048
+ check=False,
1049
+ )
1050
 
1051
+ if result.returncode != 0:
1052
+ logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
1053
 
1054
+ # Log the actual error output for debugging
1055
+ if result.stderr:
1056
+ logger.error(f"stderr: {result.stderr}")
1057
+ if result.stdout:
1058
+ logger.info(f"stdout: {result.stdout}")
1059
 
1060
+ # Check if it's the token-vector mismatch issue
1061
+ error_output = str(result.stderr) + str(result.stdout)
1062
+ if "Number of tokens" in error_output and "does not match number of vectors" in error_output:
1063
+ logger.error("🔧 Token-vector mismatch detected in tokenlearn")
1064
+ logger.error("💥 This is a known issue with tokenlearn/Model2Vec integration")
1065
 
1066
+ # Create training marker to indicate we tried but failed
1067
+ training_fallback_marker = trained_dir / ".training_fallback"
1068
+ training_fallback_marker.touch()
1069
 
1070
+ logger.error("❌ Tokenlearn training failed due to token-vector mismatch")
1071
+ msg = f"Tokenlearn training failed with token-vector mismatch: {error_output}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1072
  raise RuntimeError(msg)
1073
+ logger.error("💥 Tokenlearn training failed with different error")
1074
+ msg = f"Tokenlearn training failed with return code: {result.returncode}"
1075
+ raise RuntimeError(msg)
1076
+ logger.info("✅ Tokenlearn training completed successfully")
1077
 
1078
+ # Create checkpoint marker to indicate training is complete
1079
+ training_complete_marker.touch()
1080
+ logger.info(f"💾 Created training checkpoint: {training_complete_marker}")
1081
 
1082
+ except Exception as e:
1083
+ logger.exception("💥 Tokenlearn training failed")
1084
+ logger.exception("💥 Tokenlearn training is required - cannot proceed")
1085
+ msg = f"Tokenlearn training failed: {e}"
1086
+ raise RuntimeError(msg) from e
1087
 
1088
+ # Step 4: Load the trained model and apply post-training re-regularization
1089
+ logger.info("📦 Step 4: Loading trained model and applying post-training re-regularization...")
1090
 
1091
+ # Check if we need to use fallback due to tokenlearn failure
1092
+ training_fallback_marker = trained_dir / ".training_fallback"
1093
+ if training_fallback_marker.exists():
1094
+ logger.error("❌ Tokenlearn training failed previously - cannot return trained model")
1095
+ logger.error("💥 Training was requested but failed - this would be misleading to return base model")
1096
+ msg = "Tokenlearn training failed - cannot proceed with training pipeline"
1097
+ raise RuntimeError(msg)
1098
 
1099
+ try:
1100
+ from model2vec.model import StaticModel
1101
+
1102
+ # Load the trained model from tokenlearn
1103
+ trained_model_path = trained_dir / "model"
1104
+ if not trained_model_path.exists():
1105
+ # Try alternative paths
1106
+ possible_paths = [
1107
+ trained_dir / "model_weighted",
1108
+ trained_dir,
1109
+ ]
1110
+
1111
+ for path in possible_paths:
1112
+ if path.exists() and any(path.glob("*.json")):
1113
+ trained_model_path = path
1114
+ break
1115
+ else:
1116
+ logger.error(f"❌ Could not find trained model in {trained_dir}")
1117
+ logger.error("💥 Training was requested but no trained model found - cannot proceed")
1118
+ msg = f"Trained model not found in {trained_dir} - training pipeline failed"
1119
+ raise RuntimeError(msg)
1120
+
1121
+ # Load the model before re-regularization
1122
+ logger.info("🔄 Loading model from tokenlearn training...")
1123
+ trained_model = StaticModel.from_pretrained(str(trained_model_path))
1124
+
1125
+ # Apply post-training re-regularization (POTION Step 4)
1126
+ logger.info("🔧 Applying post-training re-regularization (PCA + SIF weighting)...")
1127
+ final_model = apply_post_training_regularization(
1128
+ trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
1129
+ )
1130
+
1131
+ logger.info("✅ Tokenlearn training pipeline with post-training re-regularization completed successfully")
1132
+
1133
+ return final_model
1134
+
1135
+ except ValueError as e:
1136
+ if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
1137
+ logger.exception("💥 Token-vector mismatch in tokenlearn training")
1138
+ logger.exception("Error details")
1139
+ logger.exception("🔧 This is a known issue with tokenlearn/Model2Vec integration")
1140
+ logger.exception("💥 Training was requested but failed due to token-vector mismatch")
1141
+ msg = f"Tokenlearn training failed due to token-vector mismatch: {e}"
1142
  raise RuntimeError(msg) from e
1143
+ logger.exception("💥 Failed to load tokenlearn trained model")
1144
+ msg = f"Failed to load tokenlearn trained model: {e}"
1145
+ raise RuntimeError(msg) from e
1146
+ except Exception as e:
1147
+ logger.exception("💥 Failed to load tokenlearn trained model")
1148
+ logger.exception("💥 Cannot load trained model - training failed")
1149
+ msg = f"Failed to load tokenlearn trained model: {e}"
1150
+ raise RuntimeError(msg) from e
1151
 
1152
 
1153
  def distill_single_teacher(
 
1197
 
1198
  # Initialize Beam utilities if requested
1199
  checkpoint_mgr = None
 
1200
  if use_beam_utilities:
1201
  try:
1202
  _, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
 
1275
 
1276
  existing_base = str(base_dir)
1277
 
1278
+ # Step 3: Handle final model creation
1279
+ if enable_training and base_model is not None:
1280
+ # Perform tokenlearn training (POTION approach)
1281
+ logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1282
 
1283
+ try:
1284
+ # Load teacher model for training
1285
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1286
+ teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1287
+
1288
+ # Perform tokenlearn training (POTION approach)
1289
+ final_model = tokenlearn_training(base_model, teacher_st_model, checkpoint_mgr)
1290
+
1291
+ # Save final model
1292
+ final_dir.mkdir(parents=True, exist_ok=True)
1293
+ final_model.save_pretrained(str(final_dir))
1294
+
1295
+ # Sync final model and training checkpoints to Beam
1296
+ if use_beam_utilities:
1297
+ sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
1298
+ if checkpoint_mgr:
1299
+ sync_checkpoints_to_beam(
1300
+ VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
1301
+ )
1302
+
1303
+ del teacher_st_model
1304
+ if torch.cuda.is_available():
1305
+ torch.cuda.empty_cache()
1306
+
1307
+ except RuntimeError as e:
1308
+ # Training failed - clean up and return failure
1309
+ logger.exception(f"❌ Training failed for {teacher_name}")
1310
+
1311
+ # Clean up teacher model if it was loaded
1312
+ if "teacher_st_model" in locals():
1313
+ del teacher_st_model
1314
+ if torch.cuda.is_available():
1315
+ torch.cuda.empty_cache()
1316
+
1317
+ return {
1318
+ "teacher_model": teacher_model,
1319
+ "teacher_name": teacher_name,
1320
+ "status": "failed_training",
1321
+ "error": f"Training failed: {e!s}",
1322
+ "base_path": existing_base, # Base model was created successfully
1323
+ }
1324
 
1325
+ else:
1326
+ # Copy base to final (no training)
1327
+ logger.info(f"📁 Copying base to final for {teacher_name}")
1328
+ if not copy_base_to_final(teacher_name, enable_training):
1329
+ return {
1330
+ "teacher_model": teacher_model,
1331
+ "teacher_name": teacher_name,
1332
+ "status": "failed_copy_to_final",
1333
+ "error": "Failed to copy base to final",
1334
+ }
1335
 
1336
+ total_time = time.time() - start_time
1337
 
1338
  return {
1339
  "teacher_model": teacher_model,
 
1417
 
1418
  if result["status"] == "success" or result["status"].startswith("skipped"):
1419
  successful_models.append(teacher_name)
1420
+ elif result["status"] == "failed_training":
1421
+ # Note: Training failed but base model may still be available
1422
+ logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded")
1423
 
1424
  # Summary
1425
  logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!")
 
1451
  return results_summary
1452
 
1453
 
1454
+ def _beam_distill_internal(
 
1455
  teacher_models: list[str] | None = None,
1456
  enable_training: bool = False,
1457
  pca_dims: int | None = None,
1458
  clear_cache: bool = False,
1459
  ) -> dict[str, Any]:
1460
+ """Shared internal implementation for beam distillation."""
 
 
1461
  # Apply patches
1462
  patch_success = apply_local_patches()
1463
  if patch_success:
 
1503
 
1504
  if result["status"] == "success" or result["status"].startswith("skipped"):
1505
  successful_models.append(teacher_name)
1506
+ elif result["status"] == "failed_training":
1507
+ # Note: Training failed but base model may still be available
1508
+ logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded")
1509
 
1510
  # Summary
1511
  logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!")
 
1531
  return results_summary
1532
 
1533
 
1534
+ @function(**get_training_function_kwargs())
1535
+ def _beam_train_models(
1536
+ teacher_models: list[str] | None = None,
1537
+ enable_training: bool = True,
1538
+ pca_dims: int | None = None,
1539
+ clear_cache: bool = False,
1540
+ ) -> dict[str, Any]:
1541
+ """Beam function for training (distillation + tokenlearn)."""
1542
+ logger.info("☁️ Running training on Beam")
1543
+ return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
1544
+
1545
+
1546
+ @function(**get_distillation_function_kwargs())
1547
+ def _beam_distill_models(
1548
+ teacher_models: list[str] | None = None,
1549
+ enable_training: bool = False,
1550
+ pca_dims: int | None = None,
1551
+ clear_cache: bool = False,
1552
+ ) -> dict[str, Any]:
1553
+ """Beam function for basic distillation only."""
1554
+ logger.info("☁️ Running distillation on Beam")
1555
+ return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
1556
+
1557
+
1558
  def run_beam_distillation(
1559
  teacher_models: list[str] | None = None,
1560
  enable_training: bool = False,
 
1565
  logger.info("☁️ Running distillation on Beam with local sync")
1566
 
1567
  try:
1568
+ # Choose appropriate beam function based on training flag
1569
+ beam_function = _beam_train_models if enable_training else _beam_distill_models
1570
+
1571
  # Run distillation on Beam
1572
+ results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache)
1573
 
1574
  # Check if Beam execution was successful
1575
  if not results:
 
1658
 
1659
  # Clear tokenlearn checkpoints if requested (for training mode)
1660
  if clear_checkpoints and train:
 
 
1661
  logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...")
1662
  for teacher_model in models_to_distill:
1663
+ teacher_model.split("/")[-1].replace("-", "_")
1664
+
1665
+ # Use the same persistent directory structure as the training function
1666
+ teacher_slug = teacher_model.replace("/", "_").replace("-", "_")
1667
+ persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug
1668
+
1669
+ features_dir = persistent_tokenlearn_dir / "features"
1670
+ trained_dir = persistent_tokenlearn_dir / "trained_model"
1671
+
1672
+ # Clear persistent tokenlearn checkpoints
1673
+ if features_dir.exists() or trained_dir.exists():
1674
+ clear_tokenlearn_checkpoints(features_dir, trained_dir)
1675
+ logger.info(f"🗑️ Cleared persistent tokenlearn checkpoints for {teacher_model}")
1676
+ else:
1677
+ logger.info(f"ℹ️ No tokenlearn checkpoints found for {teacher_model}")
 
 
1678
  elif clear_checkpoints and not train:
1679
  logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)")
1680