Sarthak commited on
Commit
72121b3
Β·
1 Parent(s): fba41e9

feat(distiller): add option to skip post-training re-regularization

Browse files

This change introduces an option to skip the post-training re-regularization step in the tokenlearn training pipeline. This can be useful for debugging or experimentation, or when re-regularization is not desired.

src/distiller/config.py CHANGED
@@ -216,6 +216,9 @@ class DistillationConfig(BaseModel):
216
  tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
217
  tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
218
 
 
 
 
219
 
220
  distillation_config = DistillationConfig()
221
 
 
216
  tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
217
  tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
218
 
219
+ # Post-training configuration
220
+ skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
221
+
222
 
223
  distillation_config = DistillationConfig()
224
 
src/distiller/distill.py CHANGED
@@ -855,6 +855,7 @@ def tokenlearn_training(
855
  student_model: Any,
856
  teacher_model: SentenceTransformer,
857
  checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
 
858
  ) -> Any:
859
  """
860
  Perform tokenlearn training following the official POTION approach.
@@ -1122,13 +1123,17 @@ def tokenlearn_training(
1122
  logger.info("πŸ”„ Loading model from tokenlearn training...")
1123
  trained_model = StaticModel.from_pretrained(str(trained_model_path))
1124
 
1125
- # Apply post-training re-regularization (POTION Step 4)
1126
- logger.info("πŸ”§ Applying post-training re-regularization (PCA + SIF weighting)...")
1127
- final_model = apply_post_training_regularization(
1128
- trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
1129
- )
1130
-
1131
- logger.info("βœ… Tokenlearn training pipeline with post-training re-regularization completed successfully")
 
 
 
 
1132
 
1133
  return final_model
1134
 
@@ -1286,7 +1291,12 @@ def distill_single_teacher(
1286
  teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1287
 
1288
  # Perform tokenlearn training (POTION approach)
1289
- final_model = tokenlearn_training(base_model, teacher_st_model, checkpoint_mgr)
 
 
 
 
 
1290
 
1291
  # Save final model
1292
  final_dir.mkdir(parents=True, exist_ok=True)
@@ -1634,10 +1644,18 @@ def main(
1634
  clear_checkpoints: Annotated[
1635
  bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
1636
  ] = False,
 
 
 
1637
  ) -> None:
1638
  """Unified distillation command with optional training."""
1639
  logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
1640
 
 
 
 
 
 
1641
  logger.info(f"πŸŽ“ Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1642
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1643
 
 
855
  student_model: Any,
856
  teacher_model: SentenceTransformer,
857
  checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
858
+ skip_post_training_regularization: bool = False,
859
  ) -> Any:
860
  """
861
  Perform tokenlearn training following the official POTION approach.
 
1123
  logger.info("πŸ”„ Loading model from tokenlearn training...")
1124
  trained_model = StaticModel.from_pretrained(str(trained_model_path))
1125
 
1126
+ # Apply post-training re-regularization (POTION Step 4) unless skipped
1127
+ if skip_post_training_regularization:
1128
+ logger.info("⏭️ Skipping post-training re-regularization (PCA + SIF weighting) as requested")
1129
+ final_model = trained_model
1130
+ logger.info("βœ… Tokenlearn training pipeline completed successfully (without re-regularization)")
1131
+ else:
1132
+ logger.info("πŸ”§ Applying post-training re-regularization (PCA + SIF weighting)...")
1133
+ final_model = apply_post_training_regularization(
1134
+ trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
1135
+ )
1136
+ logger.info("βœ… Tokenlearn training pipeline with post-training re-regularization completed successfully")
1137
 
1138
  return final_model
1139
 
 
1291
  teacher_st_model = load_model_with_flash_attention(teacher_model, device)
1292
 
1293
  # Perform tokenlearn training (POTION approach)
1294
+ final_model = tokenlearn_training(
1295
+ base_model,
1296
+ teacher_st_model,
1297
+ checkpoint_mgr,
1298
+ skip_post_training_regularization=distillation_config.skip_post_training_regularization,
1299
+ )
1300
 
1301
  # Save final model
1302
  final_dir.mkdir(parents=True, exist_ok=True)
 
1644
  clear_checkpoints: Annotated[
1645
  bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
1646
  ] = False,
1647
+ skip_ptr: Annotated[
1648
+ bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
1649
+ ] = False,
1650
  ) -> None:
1651
  """Unified distillation command with optional training."""
1652
  logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
1653
 
1654
+ # Set post-training regularization flag in config
1655
+ distillation_config.skip_post_training_regularization = skip_ptr
1656
+ if skip_ptr and train:
1657
+ logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
1658
+
1659
  logger.info(f"πŸŽ“ Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1660
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
1661