Sarthak commited on
Commit
8083c06
Β·
1 Parent(s): ff551a2

feat(distiller): add checkpointing and refactor analyze.py

Browse files

This change introduces checkpointing to the tokenlearn featurization and training processes. This allows the processes to resume from where they left off if they are interrupted or have already completed. It also adds a --clear-checkpoints flag to force fresh featurization and training.

Additionally, minor refactoring was done to use list comprehension in analyze.py

Files changed (2) hide show
  1. src/distiller/analyze.py +6 -5
  2. src/distiller/distill.py +211 -86
src/distiller/analyze.py CHANGED
@@ -496,10 +496,11 @@ class CodeSearchNetAnalyzer:
496
  return
497
 
498
  # Find all our model directories
499
- our_model_dirs = []
500
- for model_dir in final_models_dir.iterdir():
501
- if model_dir.is_dir() and "code_model2vec" in model_dir.name:
502
- our_model_dirs.append(model_dir)
 
503
 
504
  logger.info(f"πŸ“ Found {len(our_model_dirs)} distilled model directories")
505
 
@@ -1567,7 +1568,7 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
1567
  if self.model_specs:
1568
  successful_specs = {k: v for k, v in self.model_specs.items() if v.get("analysis_successful", False)}
1569
  if successful_specs:
1570
- report += f"""
1571
 
1572
  ### πŸ“Š Model Specifications Analysis
1573
 
 
496
  return
497
 
498
  # Find all our model directories
499
+ our_model_dirs = [
500
+ model_dir
501
+ for model_dir in final_models_dir.iterdir()
502
+ if model_dir.is_dir() and "code_model2vec" in model_dir.name
503
+ ]
504
 
505
  logger.info(f"πŸ“ Found {len(our_model_dirs)} distilled model directories")
506
 
 
1568
  if self.model_specs:
1569
  successful_specs = {k: v for k, v in self.model_specs.items() if v.get("analysis_successful", False)}
1570
  if successful_specs:
1571
+ report += """
1572
 
1573
  ### πŸ“Š Model Specifications Analysis
1574
 
src/distiller/distill.py CHANGED
@@ -866,7 +866,7 @@ def tokenlearn_training(
866
  student_model.save_pretrained(str(model_dir))
867
  logger.info(f"πŸ’Ύ Saved base model to {model_dir}")
868
 
869
- # Step 2: Create features using tokenlearn-featurize
870
  logger.info("πŸ” Step 2: Creating features using sentence transformer...")
871
 
872
  # Get teacher model name/path for tokenlearn
@@ -878,107 +878,153 @@ def tokenlearn_training(
878
  if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
879
  teacher_model_name = first_module.auto_model.name_or_path
880
 
881
- if not teacher_model_name:
882
- logger.warning("⚠️ Could not determine teacher model name, using fallback")
883
- teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
884
-
885
  logger.info(f"πŸ“Š Using teacher model: {teacher_model_name}")
886
 
887
- try:
888
- # Use configured dataset for code specialization
889
- featurize_cmd = [
890
- "python",
891
- "-m",
892
- "tokenlearn.featurize",
893
- "--model-name",
894
- str(teacher_model_name),
895
- "--output-dir",
896
- str(features_dir),
897
- "--dataset-path",
898
- str(distillation_config.tokenlearn_dataset),
899
- "--dataset-name",
900
- str(distillation_config.tokenlearn_dataset_name),
901
- "--dataset-split",
902
- "train",
903
- "--key",
904
- str(distillation_config.tokenlearn_text_key), # Use configured text field
905
- "--batch-size",
906
- "1024", # Optimized batch size for A100-40G
907
- ]
908
 
909
- logger.info("πŸ”„ Running tokenlearn featurization...")
910
- logger.info(
911
- f"πŸ“Š Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
912
- )
913
- logger.info(f"πŸ“ Text field: {distillation_config.tokenlearn_text_key}")
914
- logger.info(f"Command: {' '.join(featurize_cmd)}")
915
- print(f"\nπŸ”„ Executing: {' '.join(featurize_cmd)}\n")
916
-
917
- result = subprocess.run( # noqa: S603
918
- featurize_cmd,
919
- text=True,
920
- timeout=distillation_config.tokenlearn_timeout_featurize,
921
- check=False,
922
- )
923
 
924
- if result.returncode != 0:
925
- logger.error(f"❌ Featurization failed with return code: {result.returncode}")
926
- logger.error("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
927
- msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
928
- raise RuntimeError(msg)
929
 
930
- logger.info("βœ… Featurization completed successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
- # Generate token frequencies for post-training re-regularization
933
- logger.info("πŸ“Š Computing token frequencies for SIF weighting...")
934
- compute_token_frequencies_for_sif(teacher_model, features_dir)
 
 
 
 
 
 
 
 
 
 
 
935
 
936
- except Exception as e:
937
- logger.exception("πŸ’₯ Tokenlearn featurization failed")
938
- logger.exception("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
939
- msg = f"Tokenlearn featurization failed: {e}"
940
- raise RuntimeError(msg) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
 
942
  # Step 3: Train using tokenlearn-train
943
  logger.info("πŸŽ“ Step 3: Training using tokenlearn...")
944
 
945
- try:
946
- train_cmd = [
947
- "python",
948
- "-m",
949
- "tokenlearn.train",
950
- "--model-name",
951
- str(teacher_model_name),
952
- "--data-path",
953
- str(features_dir),
954
- "--save-path",
955
- str(trained_dir),
956
- ]
 
 
 
 
 
 
 
 
957
 
958
- logger.info("πŸ”„ Running tokenlearn training...")
959
- logger.info(f"Command: {' '.join(train_cmd)}")
960
- print(f"\nπŸŽ“ Executing: {' '.join(train_cmd)}\n")
 
 
 
 
 
 
 
 
 
961
 
962
- result = subprocess.run( # noqa: S603
963
- train_cmd,
964
- text=True,
965
- timeout=distillation_config.tokenlearn_timeout_train,
966
- check=False,
967
- )
968
 
969
- if result.returncode != 0:
970
- logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
971
- logger.error("πŸ’₯ Tokenlearn training is required - cannot proceed")
972
- msg = f"Tokenlearn training failed with return code: {result.returncode}"
973
- raise RuntimeError(msg)
 
974
 
975
- logger.info("βœ… Tokenlearn training completed successfully")
 
 
 
 
976
 
977
- except Exception as e:
978
- logger.exception("πŸ’₯ Tokenlearn training failed")
979
- logger.exception("πŸ’₯ Tokenlearn training is required - cannot proceed")
980
- msg = f"Tokenlearn training failed: {e}"
981
- raise RuntimeError(msg) from e
 
 
 
 
 
 
982
 
983
  # Step 4: Load the trained model and apply post-training re-regularization
984
  logger.info("πŸ“¦ Step 4: Loading trained model and applying post-training re-regularization...")
@@ -1256,6 +1302,9 @@ def run_local_distillation(
1256
  if model in models_to_distill:
1257
  clear_model_cache(model)
1258
 
 
 
 
1259
  for teacher_model in models_to_distill:
1260
  result = distill_single_teacher(
1261
  teacher_model=teacher_model,
@@ -1453,6 +1502,9 @@ def main(
1453
  clear_cache: Annotated[
1454
  bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
1455
  ] = False,
 
 
 
1456
  ) -> None:
1457
  """Unified distillation command with optional training."""
1458
  logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
@@ -1475,6 +1527,32 @@ def main(
1475
  if model in models_to_distill:
1476
  clear_model_cache(model)
1477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1478
  # Run distillation workflow
1479
  if use_beam:
1480
  results = run_beam_distillation(
@@ -1822,5 +1900,52 @@ def baai_bge_model_distillation(
1822
  return None
1823
 
1824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1825
  if __name__ == "__main__":
1826
  typer.run(main)
 
866
  student_model.save_pretrained(str(model_dir))
867
  logger.info(f"πŸ’Ύ Saved base model to {model_dir}")
868
 
869
+ # Step 2: Create features using sentence transformer
870
  logger.info("πŸ” Step 2: Creating features using sentence transformer...")
871
 
872
  # Get teacher model name/path for tokenlearn
 
878
  if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
879
  teacher_model_name = first_module.auto_model.name_or_path
880
 
 
 
 
 
881
  logger.info(f"πŸ“Š Using teacher model: {teacher_model_name}")
882
 
883
+ # Check if featurization already completed (checkpoint detection)
884
+ featurization_complete_marker = features_dir / ".featurization_complete"
885
+ if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
886
+ logger.info("βœ… Found existing featurization checkpoint with valid output files")
887
+ logger.info(f"πŸ“‚ Using cached features from: {features_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
 
889
+ # Verify marker is still valid
890
+ output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
891
+ logger.info(f"πŸ“ Found {len(output_files)} cached feature files")
892
+ else:
893
+ if featurization_complete_marker.exists():
894
+ logger.warning("⚠️ Featurization marker exists but output files are missing - re-running featurization")
895
+ featurization_complete_marker.unlink()
896
+ logger.info("πŸ”„ No valid featurization checkpoint found - starting featurization...")
897
+
898
+ if not teacher_model_name:
899
+ logger.warning("⚠️ Could not determine teacher model name, using fallback")
900
+ teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
 
 
901
 
902
+ logger.info(f"πŸ“Š Using teacher model: {teacher_model_name}")
 
 
 
 
903
 
904
+ try:
905
+ # Use configured dataset for code specialization
906
+ featurize_cmd = [
907
+ "python",
908
+ "-m",
909
+ "tokenlearn.featurize",
910
+ "--model-name",
911
+ str(teacher_model_name),
912
+ "--output-dir",
913
+ str(features_dir),
914
+ "--dataset-path",
915
+ str(distillation_config.tokenlearn_dataset),
916
+ "--dataset-name",
917
+ str(distillation_config.tokenlearn_dataset_name),
918
+ "--dataset-split",
919
+ "train",
920
+ "--key",
921
+ str(distillation_config.tokenlearn_text_key), # Use configured text field
922
+ "--batch-size",
923
+ "1024", # Optimized batch size for A100-40G
924
+ ]
925
 
926
+ logger.info("πŸ”„ Running tokenlearn featurization...")
927
+ logger.info(
928
+ f"πŸ“Š Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
929
+ )
930
+ logger.info(f"πŸ“ Text field: {distillation_config.tokenlearn_text_key}")
931
+ logger.info(f"Command: {' '.join(featurize_cmd)}")
932
+ print(f"\nπŸ”„ Executing: {' '.join(featurize_cmd)}\n")
933
+
934
+ result = subprocess.run( # noqa: S603
935
+ featurize_cmd,
936
+ text=True,
937
+ timeout=distillation_config.tokenlearn_timeout_featurize,
938
+ check=False,
939
+ )
940
 
941
+ if result.returncode != 0:
942
+ logger.error(f"❌ Featurization failed with return code: {result.returncode}")
943
+ logger.error("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
944
+ msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
945
+ raise RuntimeError(msg)
946
+
947
+ logger.info("βœ… Featurization completed successfully")
948
+
949
+ # Create checkpoint marker to indicate featurization is complete
950
+ featurization_complete_marker.touch()
951
+ logger.info(f"πŸ’Ύ Created featurization checkpoint: {featurization_complete_marker}")
952
+
953
+ # Generate token frequencies for post-training re-regularization
954
+ logger.info("πŸ“Š Computing token frequencies for SIF weighting...")
955
+ compute_token_frequencies_for_sif(teacher_model, features_dir)
956
+
957
+ except Exception as e:
958
+ logger.exception("πŸ’₯ Tokenlearn featurization failed")
959
+ logger.exception("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
960
+ msg = f"Tokenlearn featurization failed: {e}"
961
+ raise RuntimeError(msg) from e
962
 
963
  # Step 3: Train using tokenlearn-train
964
  logger.info("πŸŽ“ Step 3: Training using tokenlearn...")
965
 
966
+ # Check if training already completed (checkpoint detection)
967
+ training_complete_marker = trained_dir / ".training_complete"
968
+ if training_complete_marker.exists() and verify_training_output(trained_dir):
969
+ logger.info("βœ… Found existing training checkpoint with valid model files")
970
+ logger.info(f"πŸ“‚ Using cached trained model from: {trained_dir}")
971
+
972
+ # Show available model files
973
+ model_files = []
974
+ for pattern in ["*.json", "*.safetensors", "*.bin"]:
975
+ model_files.extend(list(trained_dir.glob(pattern)))
976
+ for subdir in ["model", "model_weighted"]:
977
+ subdir_path = trained_dir / subdir
978
+ if subdir_path.exists():
979
+ model_files.extend(list(subdir_path.glob(pattern)))
980
+ logger.info(f"πŸ“ Found {len(model_files)} cached model files")
981
+ else:
982
+ if training_complete_marker.exists():
983
+ logger.warning("⚠️ Training marker exists but model files are missing - re-running training")
984
+ training_complete_marker.unlink()
985
+ logger.info("πŸ”„ No valid training checkpoint found - starting training...")
986
 
987
+ try:
988
+ train_cmd = [
989
+ "python",
990
+ "-m",
991
+ "tokenlearn.train",
992
+ "--model-name",
993
+ str(teacher_model_name),
994
+ "--data-path",
995
+ str(features_dir),
996
+ "--save-path",
997
+ str(trained_dir),
998
+ ]
999
 
1000
+ logger.info("πŸ”„ Running tokenlearn training...")
1001
+ logger.info(f"Command: {' '.join(train_cmd)}")
1002
+ print(f"\nπŸŽ“ Executing: {' '.join(train_cmd)}\n")
 
 
 
1003
 
1004
+ result = subprocess.run( # noqa: S603
1005
+ train_cmd,
1006
+ text=True,
1007
+ timeout=distillation_config.tokenlearn_timeout_train,
1008
+ check=False,
1009
+ )
1010
 
1011
+ if result.returncode != 0:
1012
+ logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
1013
+ logger.error("πŸ’₯ Tokenlearn training is required - cannot proceed")
1014
+ msg = f"Tokenlearn training failed with return code: {result.returncode}"
1015
+ raise RuntimeError(msg)
1016
 
1017
+ logger.info("βœ… Tokenlearn training completed successfully")
1018
+
1019
+ # Create checkpoint marker to indicate training is complete
1020
+ training_complete_marker.touch()
1021
+ logger.info(f"πŸ’Ύ Created training checkpoint: {training_complete_marker}")
1022
+
1023
+ except Exception as e:
1024
+ logger.exception("πŸ’₯ Tokenlearn training failed")
1025
+ logger.exception("πŸ’₯ Tokenlearn training is required - cannot proceed")
1026
+ msg = f"Tokenlearn training failed: {e}"
1027
+ raise RuntimeError(msg) from e
1028
 
1029
  # Step 4: Load the trained model and apply post-training re-regularization
1030
  logger.info("πŸ“¦ Step 4: Loading trained model and applying post-training re-regularization...")
 
1302
  if model in models_to_distill:
1303
  clear_model_cache(model)
1304
 
1305
+ # Clear tokenlearn checkpoints if requested (for training mode)
1306
+ # Note: Checkpoint clearing is handled at the main function level
1307
+ # Run distillation workflow
1308
  for teacher_model in models_to_distill:
1309
  result = distill_single_teacher(
1310
  teacher_model=teacher_model,
 
1502
  clear_cache: Annotated[
1503
  bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
1504
  ] = False,
1505
+ clear_checkpoints: Annotated[
1506
+ bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
1507
+ ] = False,
1508
  ) -> None:
1509
  """Unified distillation command with optional training."""
1510
  logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
 
1527
  if model in models_to_distill:
1528
  clear_model_cache(model)
1529
 
1530
+ # Clear tokenlearn checkpoints if requested (for training mode)
1531
+ if clear_checkpoints and train:
1532
+ import tempfile
1533
+
1534
+ logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...")
1535
+ for teacher_model in models_to_distill:
1536
+ teacher_name = teacher_model.split("/")[-1].replace("-", "_")
1537
+
1538
+ # Construct checkpoint paths using secure temporary directory
1539
+ temp_dir = Path(tempfile.gettempdir()) / f"tokenlearn_{teacher_name}"
1540
+ features_dir = temp_dir / "features"
1541
+ trained_dir = temp_dir / "trained"
1542
+
1543
+ # Also check local paths
1544
+ local_temp = Path("temp") / f"tokenlearn_{teacher_name}"
1545
+ local_features = local_temp / "features"
1546
+ local_trained = local_temp / "trained"
1547
+
1548
+ # Clear checkpoints for all possible paths
1549
+ for feat_dir, train_dir in [(features_dir, trained_dir), (local_features, local_trained)]:
1550
+ if feat_dir.exists() or train_dir.exists():
1551
+ clear_tokenlearn_checkpoints(feat_dir, train_dir)
1552
+ logger.info(f"πŸ—‘οΈ Cleared checkpoints for {teacher_model}")
1553
+ elif clear_checkpoints and not train:
1554
+ logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)")
1555
+
1556
  # Run distillation workflow
1557
  if use_beam:
1558
  results = run_beam_distillation(
 
1900
  return None
1901
 
1902
 
1903
+ def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None:
1904
+ """Clear tokenlearn checkpoint markers to force re-execution of steps."""
1905
+ featurization_marker = features_dir / ".featurization_complete"
1906
+ training_marker = trained_dir / ".training_complete"
1907
+
1908
+ if featurization_marker.exists():
1909
+ featurization_marker.unlink()
1910
+ logger.info(f"πŸ—‘οΈ Cleared featurization checkpoint: {featurization_marker}")
1911
+
1912
+ if training_marker.exists():
1913
+ training_marker.unlink()
1914
+ logger.info(f"πŸ—‘οΈ Cleared training checkpoint: {training_marker}")
1915
+
1916
+
1917
+ def verify_featurization_output(features_dir: Path) -> bool:
1918
+ """Verify that featurization output files actually exist and are valid."""
1919
+ if not features_dir.exists():
1920
+ return False
1921
+
1922
+ # Check for expected tokenlearn output files
1923
+
1924
+ # Check if any expected files exist
1925
+ return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"])
1926
+
1927
+
1928
+ def verify_training_output(trained_dir: Path) -> bool:
1929
+ """Verify that training output files actually exist and are valid."""
1930
+ if not trained_dir.exists():
1931
+ return False
1932
+
1933
+ # Check for model files
1934
+ model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"]
1935
+ for model_file in model_files:
1936
+ if (trained_dir / model_file).exists():
1937
+ return True
1938
+
1939
+ # Check for alternative model directory structure
1940
+ for subdir in ["model", "model_weighted"]:
1941
+ subdir_path = trained_dir / subdir
1942
+ if subdir_path.exists():
1943
+ for model_file in model_files:
1944
+ if (subdir_path / model_file).exists():
1945
+ return True
1946
+
1947
+ return False
1948
+
1949
+
1950
  if __name__ == "__main__":
1951
  typer.run(main)