Sarthak commited on
Commit Β·
72121b3
1
Parent(s): fba41e9
feat(distiller): add option to skip post-training re-regularization
Browse filesThis change introduces an option to skip the post-training re-regularization step in the tokenlearn training pipeline. This can be useful for debugging or experimentation, or when re-regularization is not desired.
- src/distiller/config.py +3 -0
- src/distiller/distill.py +26 -8
src/distiller/config.py
CHANGED
|
@@ -216,6 +216,9 @@ class DistillationConfig(BaseModel):
|
|
| 216 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
| 217 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
| 218 |
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
distillation_config = DistillationConfig()
|
| 221 |
|
|
|
|
| 216 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
| 217 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
| 218 |
|
| 219 |
+
# Post-training configuration
|
| 220 |
+
skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
|
| 221 |
+
|
| 222 |
|
| 223 |
distillation_config = DistillationConfig()
|
| 224 |
|
src/distiller/distill.py
CHANGED
|
@@ -855,6 +855,7 @@ def tokenlearn_training(
|
|
| 855 |
student_model: Any,
|
| 856 |
teacher_model: SentenceTransformer,
|
| 857 |
checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
|
|
|
|
| 858 |
) -> Any:
|
| 859 |
"""
|
| 860 |
Perform tokenlearn training following the official POTION approach.
|
|
@@ -1122,13 +1123,17 @@ def tokenlearn_training(
|
|
| 1122 |
logger.info("π Loading model from tokenlearn training...")
|
| 1123 |
trained_model = StaticModel.from_pretrained(str(trained_model_path))
|
| 1124 |
|
| 1125 |
-
# Apply post-training re-regularization (POTION Step 4)
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1132 |
|
| 1133 |
return final_model
|
| 1134 |
|
|
@@ -1286,7 +1291,12 @@ def distill_single_teacher(
|
|
| 1286 |
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
| 1287 |
|
| 1288 |
# Perform tokenlearn training (POTION approach)
|
| 1289 |
-
final_model = tokenlearn_training(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1290 |
|
| 1291 |
# Save final model
|
| 1292 |
final_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -1634,10 +1644,18 @@ def main(
|
|
| 1634 |
clear_checkpoints: Annotated[
|
| 1635 |
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
| 1636 |
] = False,
|
|
|
|
|
|
|
|
|
|
| 1637 |
) -> None:
|
| 1638 |
"""Unified distillation command with optional training."""
|
| 1639 |
logger.info("π Starting unified Model2Vec distillation workflow")
|
| 1640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1641 |
logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
| 1642 |
logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}")
|
| 1643 |
|
|
|
|
| 855 |
student_model: Any,
|
| 856 |
teacher_model: SentenceTransformer,
|
| 857 |
checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
|
| 858 |
+
skip_post_training_regularization: bool = False,
|
| 859 |
) -> Any:
|
| 860 |
"""
|
| 861 |
Perform tokenlearn training following the official POTION approach.
|
|
|
|
| 1123 |
logger.info("π Loading model from tokenlearn training...")
|
| 1124 |
trained_model = StaticModel.from_pretrained(str(trained_model_path))
|
| 1125 |
|
| 1126 |
+
# Apply post-training re-regularization (POTION Step 4) unless skipped
|
| 1127 |
+
if skip_post_training_regularization:
|
| 1128 |
+
logger.info("βοΈ Skipping post-training re-regularization (PCA + SIF weighting) as requested")
|
| 1129 |
+
final_model = trained_model
|
| 1130 |
+
logger.info("β
Tokenlearn training pipeline completed successfully (without re-regularization)")
|
| 1131 |
+
else:
|
| 1132 |
+
logger.info("π§ Applying post-training re-regularization (PCA + SIF weighting)...")
|
| 1133 |
+
final_model = apply_post_training_regularization(
|
| 1134 |
+
trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
|
| 1135 |
+
)
|
| 1136 |
+
logger.info("β
Tokenlearn training pipeline with post-training re-regularization completed successfully")
|
| 1137 |
|
| 1138 |
return final_model
|
| 1139 |
|
|
|
|
| 1291 |
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
| 1292 |
|
| 1293 |
# Perform tokenlearn training (POTION approach)
|
| 1294 |
+
final_model = tokenlearn_training(
|
| 1295 |
+
base_model,
|
| 1296 |
+
teacher_st_model,
|
| 1297 |
+
checkpoint_mgr,
|
| 1298 |
+
skip_post_training_regularization=distillation_config.skip_post_training_regularization,
|
| 1299 |
+
)
|
| 1300 |
|
| 1301 |
# Save final model
|
| 1302 |
final_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 1644 |
clear_checkpoints: Annotated[
|
| 1645 |
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
| 1646 |
] = False,
|
| 1647 |
+
skip_ptr: Annotated[
|
| 1648 |
+
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
| 1649 |
+
] = False,
|
| 1650 |
) -> None:
|
| 1651 |
"""Unified distillation command with optional training."""
|
| 1652 |
logger.info("π Starting unified Model2Vec distillation workflow")
|
| 1653 |
|
| 1654 |
+
# Set post-training regularization flag in config
|
| 1655 |
+
distillation_config.skip_post_training_regularization = skip_ptr
|
| 1656 |
+
if skip_ptr and train:
|
| 1657 |
+
logger.info("βοΈ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
|
| 1658 |
+
|
| 1659 |
logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
| 1660 |
logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}")
|
| 1661 |
|