Sarthak commited on
Commit ·
729d700
1
Parent(s): 93151b9
refactor(distiller): improve beam distillation and tokenlearn integration
Browse filesThis commit introduces separate Beam functions for distillation and training, enabling more modular and controllable workflows. It also enhances tokenlearn integration by using persistent directories for caching and checkpointing, and improves error handling for training failures.
The changes also include validation of the model to check vocab and embedding sizes match, which can highlight issues in downstream usage.
- patches/tokenlearn.patch +25 -0
- src/distiller/distill.py +381 -256
patches/tokenlearn.patch
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--- a/tokenlearn/pretrain.py
|
| 2 |
+
+++ b/tokenlearn/pretrain.py
|
| 3 |
+
@@ -38,7 +38,10 @@ class FinetunableStaticModel(nn.Module):
|
| 4 |
+
"""Run the model using input IDs."""
|
| 5 |
+
input_ids = input_ids.view(-1)
|
| 6 |
+
input_ids = input_ids[input_ids != self.pad_token_id]
|
| 7 |
+
- w = self.w[input_ids]
|
| 8 |
+
+ # Fix for index out of bounds issue
|
| 9 |
+
+ # Clamp input_ids to valid range to prevent IndexError during training
|
| 10 |
+
+ valid_input_ids = torch.clamp(input_ids, 0, self.w.shape[0] - 1)
|
| 11 |
+
+ w = self.w[valid_input_ids]
|
| 12 |
+
return self.sub_forward(w)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
@@ -46,7 +49,10 @@ class FinetunableStaticModel(nn.Module):
|
| 16 |
+
# Add a small epsilon to avoid division by zero
|
| 17 |
+
length = zeros.sum(1) + 1e-16
|
| 18 |
+
- embedded = self.embeddings(input_ids)
|
| 19 |
+
+ # Fix for embedding index out of bounds issue
|
| 20 |
+
+ # Clamp input_ids to valid embedding range
|
| 21 |
+
+ valid_input_ids = torch.clamp(input_ids, 0, self.embeddings.num_embeddings - 1)
|
| 22 |
+
+ embedded = self.embeddings(valid_input_ids)
|
| 23 |
+
# Zero out the padding
|
| 24 |
+
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
|
| 25 |
+
# Simulate actual mean
|
src/distiller/distill.py
CHANGED
|
@@ -49,6 +49,7 @@ from .config import (
|
|
| 49 |
directories,
|
| 50 |
distillation_config,
|
| 51 |
get_distillation_function_kwargs,
|
|
|
|
| 52 |
get_volume_config,
|
| 53 |
languages_config,
|
| 54 |
)
|
|
@@ -358,6 +359,21 @@ def simple_distillation(
|
|
| 358 |
|
| 359 |
logger.info("✅ Core distillation completed successfully")
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
# Save the model
|
| 362 |
model.save_pretrained(str(output_path))
|
| 363 |
logger.info(f"💾 Model saved to {output_path}")
|
|
@@ -772,7 +788,11 @@ def apply_post_training_regularization(
|
|
| 772 |
logger.info(f"🔄 Applying PCA with {pca_dims} dimensions...")
|
| 773 |
|
| 774 |
# Get current embeddings
|
| 775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
original_shape = embeddings.shape
|
| 777 |
logger.info(f"Original embedding shape: {original_shape}")
|
| 778 |
|
|
@@ -846,229 +866,288 @@ def tokenlearn_training(
|
|
| 846 |
4. Post-training re-regularization (PCA + SIF weighting)
|
| 847 |
"""
|
| 848 |
import subprocess
|
| 849 |
-
import tempfile
|
| 850 |
from pathlib import Path
|
| 851 |
|
| 852 |
logger.info("🧪 Starting tokenlearn training (POTION approach)...")
|
| 853 |
|
| 854 |
-
# Create
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
|
| 881 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 882 |
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
|
| 902 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
|
| 904 |
-
|
| 905 |
-
# Use configured dataset for code specialization
|
| 906 |
-
featurize_cmd = [
|
| 907 |
-
"python",
|
| 908 |
-
"-m",
|
| 909 |
-
"tokenlearn.featurize",
|
| 910 |
-
"--model-name",
|
| 911 |
-
str(teacher_model_name),
|
| 912 |
-
"--output-dir",
|
| 913 |
-
str(features_dir),
|
| 914 |
-
"--dataset-path",
|
| 915 |
-
str(distillation_config.tokenlearn_dataset),
|
| 916 |
-
"--dataset-name",
|
| 917 |
-
str(distillation_config.tokenlearn_dataset_name),
|
| 918 |
-
"--dataset-split",
|
| 919 |
-
"train",
|
| 920 |
-
"--key",
|
| 921 |
-
str(distillation_config.tokenlearn_text_key), # Use configured text field
|
| 922 |
-
"--batch-size",
|
| 923 |
-
"1024", # Optimized batch size for A100-40G
|
| 924 |
-
]
|
| 925 |
-
|
| 926 |
-
logger.info("🔄 Running tokenlearn featurization...")
|
| 927 |
-
logger.info(
|
| 928 |
-
f"📊 Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
|
| 929 |
-
)
|
| 930 |
-
logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
|
| 931 |
-
logger.info(f"Command: {' '.join(featurize_cmd)}")
|
| 932 |
-
print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
|
| 933 |
-
|
| 934 |
-
result = subprocess.run( # noqa: S603
|
| 935 |
-
featurize_cmd,
|
| 936 |
-
text=True,
|
| 937 |
-
timeout=distillation_config.tokenlearn_timeout_featurize,
|
| 938 |
-
check=False,
|
| 939 |
-
)
|
| 940 |
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
|
| 945 |
-
raise RuntimeError(msg)
|
| 946 |
|
| 947 |
-
|
|
|
|
|
|
|
| 948 |
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
|
|
|
|
|
|
| 952 |
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
logger.info("✅ Found existing training checkpoint with valid model files")
|
| 970 |
-
logger.info(f"📂 Using cached trained model from: {trained_dir}")
|
| 971 |
-
|
| 972 |
-
# Show available model files
|
| 973 |
-
model_files = []
|
| 974 |
-
for pattern in ["*.json", "*.safetensors", "*.bin"]:
|
| 975 |
-
model_files.extend(list(trained_dir.glob(pattern)))
|
| 976 |
-
for subdir in ["model", "model_weighted"]:
|
| 977 |
-
subdir_path = trained_dir / subdir
|
| 978 |
-
if subdir_path.exists():
|
| 979 |
-
model_files.extend(list(subdir_path.glob(pattern)))
|
| 980 |
-
logger.info(f"📁 Found {len(model_files)} cached model files")
|
| 981 |
-
else:
|
| 982 |
-
if training_complete_marker.exists():
|
| 983 |
-
logger.warning("⚠️ Training marker exists but model files are missing - re-running training")
|
| 984 |
-
training_complete_marker.unlink()
|
| 985 |
-
logger.info("🔄 No valid training checkpoint found - starting training...")
|
| 986 |
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
"-m",
|
| 991 |
-
"tokenlearn.train",
|
| 992 |
-
"--model-name",
|
| 993 |
-
str(teacher_model_name),
|
| 994 |
-
"--data-path",
|
| 995 |
-
str(features_dir),
|
| 996 |
-
"--save-path",
|
| 997 |
-
str(trained_dir),
|
| 998 |
-
]
|
| 999 |
-
|
| 1000 |
-
logger.info("���� Running tokenlearn training...")
|
| 1001 |
-
logger.info(f"Command: {' '.join(train_cmd)}")
|
| 1002 |
-
print(f"\n🎓 Executing: {' '.join(train_cmd)}\n")
|
| 1003 |
-
|
| 1004 |
-
result = subprocess.run( # noqa: S603
|
| 1005 |
-
train_cmd,
|
| 1006 |
-
text=True,
|
| 1007 |
-
timeout=distillation_config.tokenlearn_timeout_train,
|
| 1008 |
-
check=False,
|
| 1009 |
-
)
|
| 1010 |
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
|
|
|
|
|
|
| 1016 |
|
| 1017 |
-
|
|
|
|
| 1018 |
|
| 1019 |
-
#
|
| 1020 |
-
|
| 1021 |
-
|
|
|
|
|
|
|
| 1022 |
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
|
| 1029 |
-
|
| 1030 |
-
|
|
|
|
| 1031 |
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
# Load the trained model from tokenlearn
|
| 1036 |
-
trained_model_path = trained_dir / "model"
|
| 1037 |
-
if not trained_model_path.exists():
|
| 1038 |
-
# Try alternative paths
|
| 1039 |
-
possible_paths = [
|
| 1040 |
-
trained_dir / "model_weighted",
|
| 1041 |
-
trained_dir,
|
| 1042 |
-
]
|
| 1043 |
-
|
| 1044 |
-
for path in possible_paths:
|
| 1045 |
-
if path.exists() and any(path.glob("*.json")):
|
| 1046 |
-
trained_model_path = path
|
| 1047 |
-
break
|
| 1048 |
-
else:
|
| 1049 |
-
logger.error(f"❌ Could not find trained model in {trained_dir}")
|
| 1050 |
-
msg = f"Tokenlearn training failed - no model found in {trained_dir}"
|
| 1051 |
raise RuntimeError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
|
| 1053 |
-
#
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
|
| 1057 |
-
|
| 1058 |
-
logger.
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
)
|
| 1062 |
|
| 1063 |
-
|
|
|
|
| 1064 |
|
| 1065 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1066 |
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
raise RuntimeError(msg) from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
|
| 1073 |
|
| 1074 |
def distill_single_teacher(
|
|
@@ -1118,7 +1197,6 @@ def distill_single_teacher(
|
|
| 1118 |
|
| 1119 |
# Initialize Beam utilities if requested
|
| 1120 |
checkpoint_mgr = None
|
| 1121 |
-
model_mgr = None
|
| 1122 |
if use_beam_utilities:
|
| 1123 |
try:
|
| 1124 |
_, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
|
|
@@ -1197,44 +1275,65 @@ def distill_single_teacher(
|
|
| 1197 |
|
| 1198 |
existing_base = str(base_dir)
|
| 1199 |
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
# Load teacher model for training
|
| 1206 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1207 |
-
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
| 1208 |
-
|
| 1209 |
-
# Perform tokenlearn training (POTION approach)
|
| 1210 |
-
final_model = tokenlearn_training(base_model, teacher_st_model, checkpoint_mgr)
|
| 1211 |
-
|
| 1212 |
-
# Save final model
|
| 1213 |
-
final_dir.mkdir(parents=True, exist_ok=True)
|
| 1214 |
-
final_model.save_pretrained(str(final_dir))
|
| 1215 |
-
|
| 1216 |
-
# Sync final model and training checkpoints to Beam
|
| 1217 |
-
if use_beam_utilities:
|
| 1218 |
-
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
|
| 1219 |
-
if checkpoint_mgr:
|
| 1220 |
-
sync_checkpoints_to_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints)
|
| 1221 |
|
| 1222 |
-
|
| 1223 |
-
|
| 1224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1225 |
|
| 1226 |
-
|
| 1227 |
-
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
|
| 1236 |
|
| 1237 |
-
|
| 1238 |
|
| 1239 |
return {
|
| 1240 |
"teacher_model": teacher_model,
|
|
@@ -1318,6 +1417,9 @@ def run_local_distillation(
|
|
| 1318 |
|
| 1319 |
if result["status"] == "success" or result["status"].startswith("skipped"):
|
| 1320 |
successful_models.append(teacher_name)
|
|
|
|
|
|
|
|
|
|
| 1321 |
|
| 1322 |
# Summary
|
| 1323 |
logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!")
|
|
@@ -1349,16 +1451,13 @@ def run_local_distillation(
|
|
| 1349 |
return results_summary
|
| 1350 |
|
| 1351 |
|
| 1352 |
-
|
| 1353 |
-
def _beam_distill_models(
|
| 1354 |
teacher_models: list[str] | None = None,
|
| 1355 |
enable_training: bool = False,
|
| 1356 |
pca_dims: int | None = None,
|
| 1357 |
clear_cache: bool = False,
|
| 1358 |
) -> dict[str, Any]:
|
| 1359 |
-
"""
|
| 1360 |
-
logger.info("☁️ Running distillation on Beam")
|
| 1361 |
-
|
| 1362 |
# Apply patches
|
| 1363 |
patch_success = apply_local_patches()
|
| 1364 |
if patch_success:
|
|
@@ -1404,6 +1503,9 @@ def _beam_distill_models(
|
|
| 1404 |
|
| 1405 |
if result["status"] == "success" or result["status"].startswith("skipped"):
|
| 1406 |
successful_models.append(teacher_name)
|
|
|
|
|
|
|
|
|
|
| 1407 |
|
| 1408 |
# Summary
|
| 1409 |
logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!")
|
|
@@ -1429,6 +1531,30 @@ def _beam_distill_models(
|
|
| 1429 |
return results_summary
|
| 1430 |
|
| 1431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1432 |
def run_beam_distillation(
|
| 1433 |
teacher_models: list[str] | None = None,
|
| 1434 |
enable_training: bool = False,
|
|
@@ -1439,8 +1565,11 @@ def run_beam_distillation(
|
|
| 1439 |
logger.info("☁️ Running distillation on Beam with local sync")
|
| 1440 |
|
| 1441 |
try:
|
|
|
|
|
|
|
|
|
|
| 1442 |
# Run distillation on Beam
|
| 1443 |
-
results =
|
| 1444 |
|
| 1445 |
# Check if Beam execution was successful
|
| 1446 |
if not results:
|
|
@@ -1529,27 +1658,23 @@ def main(
|
|
| 1529 |
|
| 1530 |
# Clear tokenlearn checkpoints if requested (for training mode)
|
| 1531 |
if clear_checkpoints and train:
|
| 1532 |
-
import tempfile
|
| 1533 |
-
|
| 1534 |
logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...")
|
| 1535 |
for teacher_model in models_to_distill:
|
| 1536 |
-
|
| 1537 |
-
|
| 1538 |
-
#
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
|
| 1546 |
-
|
| 1547 |
-
|
| 1548 |
-
|
| 1549 |
-
|
| 1550 |
-
|
| 1551 |
-
clear_tokenlearn_checkpoints(feat_dir, train_dir)
|
| 1552 |
-
logger.info(f"🗑️ Cleared checkpoints for {teacher_model}")
|
| 1553 |
elif clear_checkpoints and not train:
|
| 1554 |
logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)")
|
| 1555 |
|
|
|
|
| 49 |
directories,
|
| 50 |
distillation_config,
|
| 51 |
get_distillation_function_kwargs,
|
| 52 |
+
get_training_function_kwargs,
|
| 53 |
get_volume_config,
|
| 54 |
languages_config,
|
| 55 |
)
|
|
|
|
| 359 |
|
| 360 |
logger.info("✅ Core distillation completed successfully")
|
| 361 |
|
| 362 |
+
# Validate model before saving
|
| 363 |
+
if hasattr(model, "tokenizer") and hasattr(model, "embedding"):
|
| 364 |
+
vocab_size = len(model.tokenizer.get_vocab())
|
| 365 |
+
embedding_size = model.embedding.shape[0]
|
| 366 |
+
|
| 367 |
+
logger.info("📊 Model validation:")
|
| 368 |
+
logger.info(f" - Vocabulary size: {vocab_size}")
|
| 369 |
+
logger.info(f" - Embedding matrix size: {embedding_size}")
|
| 370 |
+
|
| 371 |
+
if vocab_size != embedding_size:
|
| 372 |
+
logger.warning(f"⚠️ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}")
|
| 373 |
+
logger.warning("⚠️ This may cause issues in downstream usage")
|
| 374 |
+
else:
|
| 375 |
+
logger.info("✅ Vocabulary and embedding sizes match")
|
| 376 |
+
|
| 377 |
# Save the model
|
| 378 |
model.save_pretrained(str(output_path))
|
| 379 |
logger.info(f"💾 Model saved to {output_path}")
|
|
|
|
| 788 |
logger.info(f"🔄 Applying PCA with {pca_dims} dimensions...")
|
| 789 |
|
| 790 |
# Get current embeddings
|
| 791 |
+
# Handle both torch tensors and numpy arrays
|
| 792 |
+
if hasattr(model.embedding, "cpu"):
|
| 793 |
+
embeddings = model.embedding.cpu().numpy().astype(np.float64)
|
| 794 |
+
else:
|
| 795 |
+
embeddings = model.embedding.astype(np.float64)
|
| 796 |
original_shape = embeddings.shape
|
| 797 |
logger.info(f"Original embedding shape: {original_shape}")
|
| 798 |
|
|
|
|
| 866 |
4. Post-training re-regularization (PCA + SIF weighting)
|
| 867 |
"""
|
| 868 |
import subprocess
|
|
|
|
| 869 |
from pathlib import Path
|
| 870 |
|
| 871 |
logger.info("🧪 Starting tokenlearn training (POTION approach)...")
|
| 872 |
|
| 873 |
+
# Create persistent directories for tokenlearn workflow (for checkpoint preservation)
|
| 874 |
+
teacher_model_name = getattr(teacher_model, "model_name", None)
|
| 875 |
+
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
|
| 876 |
+
# Try to extract from the first module if it's a SentenceTransformer
|
| 877 |
+
first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
|
| 878 |
+
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
|
| 879 |
+
teacher_model_name = first_module.auto_model.name_or_path
|
| 880 |
+
|
| 881 |
+
if not teacher_model_name:
|
| 882 |
+
teacher_model_name = "unknown_teacher"
|
| 883 |
+
|
| 884 |
+
# Use persistent directory for tokenlearn checkpoints
|
| 885 |
+
teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_")
|
| 886 |
+
persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug
|
| 887 |
+
|
| 888 |
+
features_dir = persistent_tokenlearn_dir / "features"
|
| 889 |
+
model_dir = persistent_tokenlearn_dir / "base_model"
|
| 890 |
+
trained_dir = persistent_tokenlearn_dir / "trained_model"
|
| 891 |
+
|
| 892 |
+
features_dir.mkdir(parents=True, exist_ok=True)
|
| 893 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 894 |
+
trained_dir.mkdir(parents=True, exist_ok=True)
|
| 895 |
+
|
| 896 |
+
logger.info(f"📁 Using persistent tokenlearn directory: {persistent_tokenlearn_dir}")
|
| 897 |
+
|
| 898 |
+
# Save the base distilled model for tokenlearn
|
| 899 |
+
student_model.save_pretrained(str(model_dir))
|
| 900 |
+
logger.info(f"💾 Saved base model to {model_dir}")
|
| 901 |
+
|
| 902 |
+
# Step 2: Create features using sentence transformer
|
| 903 |
+
logger.info("🔍 Step 2: Creating features using sentence transformer...")
|
| 904 |
+
|
| 905 |
+
# Get teacher model name/path for tokenlearn
|
| 906 |
+
teacher_model_name = getattr(teacher_model, "model_name", None)
|
| 907 |
+
if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: # noqa: SLF001
|
| 908 |
+
# Try to extract from the first module if it's a SentenceTransformer
|
| 909 |
+
# _modules is a dict-like container, get the first module by iterating
|
| 910 |
+
first_module = next(iter(teacher_model._modules.values())) # noqa: SLF001
|
| 911 |
+
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
|
| 912 |
+
teacher_model_name = first_module.auto_model.name_or_path
|
| 913 |
+
|
| 914 |
+
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 915 |
+
|
| 916 |
+
# Check if featurization already completed (checkpoint detection)
|
| 917 |
+
featurization_complete_marker = features_dir / ".featurization_complete"
|
| 918 |
+
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
|
| 919 |
+
logger.info("✅ Found existing featurization checkpoint with valid output files")
|
| 920 |
+
logger.info(f"📂 Using cached features from: {features_dir}")
|
| 921 |
+
|
| 922 |
+
# Verify marker is still valid
|
| 923 |
+
output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
|
| 924 |
+
logger.info(f"📁 Found {len(output_files)} cached feature files")
|
| 925 |
+
else:
|
| 926 |
+
if featurization_complete_marker.exists():
|
| 927 |
+
logger.warning("⚠️ Featurization marker exists but output files are missing - re-running featurization")
|
| 928 |
+
featurization_complete_marker.unlink()
|
| 929 |
+
logger.info("🔄 No valid featurization checkpoint found - starting featurization...")
|
| 930 |
+
|
| 931 |
+
if not teacher_model_name:
|
| 932 |
+
logger.warning("⚠️ Could not determine teacher model name, using fallback")
|
| 933 |
+
teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
|
| 934 |
|
| 935 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 936 |
|
| 937 |
+
try:
|
| 938 |
+
# Use configured dataset for code specialization
|
| 939 |
+
featurize_cmd = [
|
| 940 |
+
"python",
|
| 941 |
+
"-m",
|
| 942 |
+
"tokenlearn.featurize",
|
| 943 |
+
"--model-name",
|
| 944 |
+
str(teacher_model_name),
|
| 945 |
+
"--output-dir",
|
| 946 |
+
str(features_dir),
|
| 947 |
+
"--dataset-path",
|
| 948 |
+
str(distillation_config.tokenlearn_dataset),
|
| 949 |
+
"--dataset-name",
|
| 950 |
+
str(distillation_config.tokenlearn_dataset_name),
|
| 951 |
+
"--dataset-split",
|
| 952 |
+
"train",
|
| 953 |
+
"--key",
|
| 954 |
+
str(distillation_config.tokenlearn_text_key), # Use configured text field
|
| 955 |
+
"--batch-size",
|
| 956 |
+
"1024", # Optimized batch size for A100-40G
|
| 957 |
+
]
|
| 958 |
|
| 959 |
+
logger.info("🔄 Running tokenlearn featurization...")
|
| 960 |
+
logger.info(
|
| 961 |
+
f"📊 Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
|
| 962 |
+
)
|
| 963 |
+
logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
|
| 964 |
+
logger.info(f"Command: {' '.join(featurize_cmd)}")
|
| 965 |
+
print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
|
| 966 |
+
|
| 967 |
+
result = subprocess.run( # noqa: S603
|
| 968 |
+
featurize_cmd,
|
| 969 |
+
text=True,
|
| 970 |
+
timeout=distillation_config.tokenlearn_timeout_featurize,
|
| 971 |
+
check=False,
|
| 972 |
+
)
|
| 973 |
|
| 974 |
+
if result.returncode != 0:
|
| 975 |
+
logger.error(f"❌ Featurization failed with return code: {result.returncode}")
|
| 976 |
+
logger.error("💥 Tokenlearn featurization is required for training - cannot proceed")
|
| 977 |
+
msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
|
| 978 |
+
raise RuntimeError(msg)
|
| 979 |
|
| 980 |
+
logger.info("✅ Featurization completed successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
|
| 982 |
+
# Create checkpoint marker to indicate featurization is complete
|
| 983 |
+
featurization_complete_marker.touch()
|
| 984 |
+
logger.info(f"💾 Created featurization checkpoint: {featurization_complete_marker}")
|
|
|
|
|
|
|
| 985 |
|
| 986 |
+
# Generate token frequencies for post-training re-regularization
|
| 987 |
+
logger.info("📊 Computing token frequencies for SIF weighting...")
|
| 988 |
+
compute_token_frequencies_for_sif(teacher_model, features_dir)
|
| 989 |
|
| 990 |
+
except Exception as e:
|
| 991 |
+
logger.exception("💥 Tokenlearn featurization failed")
|
| 992 |
+
logger.exception("💥 Tokenlearn featurization is required for training - cannot proceed")
|
| 993 |
+
msg = f"Tokenlearn featurization failed: {e}"
|
| 994 |
+
raise RuntimeError(msg) from e
|
| 995 |
|
| 996 |
+
# Step 3: Train using tokenlearn-train
|
| 997 |
+
logger.info("🎓 Step 3: Training using tokenlearn...")
|
| 998 |
+
|
| 999 |
+
# Check if training already completed (checkpoint detection)
|
| 1000 |
+
training_complete_marker = trained_dir / ".training_complete"
|
| 1001 |
+
training_fallback_marker = trained_dir / ".training_fallback"
|
| 1002 |
+
|
| 1003 |
+
if training_complete_marker.exists() and verify_training_output(trained_dir):
|
| 1004 |
+
logger.info("✅ Found existing training checkpoint with valid model files")
|
| 1005 |
+
logger.info(f"📂 Using cached trained model from: {trained_dir}")
|
| 1006 |
+
|
| 1007 |
+
# Show available model files
|
| 1008 |
+
model_files = []
|
| 1009 |
+
for pattern in ["*.json", "*.safetensors", "*.bin"]:
|
| 1010 |
+
model_files.extend(list(trained_dir.glob(pattern)))
|
| 1011 |
+
for subdir in ["model", "model_weighted"]:
|
| 1012 |
+
subdir_path = trained_dir / subdir
|
| 1013 |
+
if subdir_path.exists():
|
| 1014 |
+
model_files.extend(list(subdir_path.glob(pattern)))
|
| 1015 |
+
logger.info(f"📁 Found {len(model_files)} cached model files")
|
| 1016 |
+
elif training_fallback_marker.exists():
|
| 1017 |
+
logger.warning("⚠️ Training fallback marker found - tokenlearn failed previously")
|
| 1018 |
+
logger.info("🔄 Proceeding with fallback to base model (simple distillation)")
|
| 1019 |
+
# Skip training and proceed to model loading (will fallback to base model)
|
| 1020 |
+
else:
|
| 1021 |
+
if training_complete_marker.exists():
|
| 1022 |
+
logger.warning("⚠️ Training marker exists but model files are missing - re-running training")
|
| 1023 |
+
training_complete_marker.unlink()
|
| 1024 |
+
logger.info("🔄 No valid training checkpoint found - starting training...")
|
| 1025 |
|
| 1026 |
+
try:
|
| 1027 |
+
train_cmd = [
|
| 1028 |
+
"python",
|
| 1029 |
+
"-m",
|
| 1030 |
+
"tokenlearn.train",
|
| 1031 |
+
"--model-name",
|
| 1032 |
+
str(teacher_model_name),
|
| 1033 |
+
"--data-path",
|
| 1034 |
+
str(features_dir),
|
| 1035 |
+
"--save-path",
|
| 1036 |
+
str(trained_dir),
|
| 1037 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
|
| 1039 |
+
logger.info("🔄 Running tokenlearn training...")
|
| 1040 |
+
logger.info(f"Command: {' '.join(train_cmd)}")
|
| 1041 |
+
print(f"\n🎓 Executing: {' '.join(train_cmd)}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
|
| 1043 |
+
result = subprocess.run( # noqa: S603
|
| 1044 |
+
train_cmd,
|
| 1045 |
+
text=True,
|
| 1046 |
+
capture_output=True, # Capture stdout and stderr
|
| 1047 |
+
timeout=distillation_config.tokenlearn_timeout_train,
|
| 1048 |
+
check=False,
|
| 1049 |
+
)
|
| 1050 |
|
| 1051 |
+
if result.returncode != 0:
|
| 1052 |
+
logger.error(f"❌ Tokenlearn training failed with return code: {result.returncode}")
|
| 1053 |
|
| 1054 |
+
# Log the actual error output for debugging
|
| 1055 |
+
if result.stderr:
|
| 1056 |
+
logger.error(f"stderr: {result.stderr}")
|
| 1057 |
+
if result.stdout:
|
| 1058 |
+
logger.info(f"stdout: {result.stdout}")
|
| 1059 |
|
| 1060 |
+
# Check if it's the token-vector mismatch issue
|
| 1061 |
+
error_output = str(result.stderr) + str(result.stdout)
|
| 1062 |
+
if "Number of tokens" in error_output and "does not match number of vectors" in error_output:
|
| 1063 |
+
logger.error("🔧 Token-vector mismatch detected in tokenlearn")
|
| 1064 |
+
logger.error("💥 This is a known issue with tokenlearn/Model2Vec integration")
|
| 1065 |
|
| 1066 |
+
# Create training marker to indicate we tried but failed
|
| 1067 |
+
training_fallback_marker = trained_dir / ".training_fallback"
|
| 1068 |
+
training_fallback_marker.touch()
|
| 1069 |
|
| 1070 |
+
logger.error("❌ Tokenlearn training failed due to token-vector mismatch")
|
| 1071 |
+
msg = f"Tokenlearn training failed with token-vector mismatch: {error_output}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
raise RuntimeError(msg)
|
| 1073 |
+
logger.error("💥 Tokenlearn training failed with different error")
|
| 1074 |
+
msg = f"Tokenlearn training failed with return code: {result.returncode}"
|
| 1075 |
+
raise RuntimeError(msg)
|
| 1076 |
+
logger.info("✅ Tokenlearn training completed successfully")
|
| 1077 |
|
| 1078 |
+
# Create checkpoint marker to indicate training is complete
|
| 1079 |
+
training_complete_marker.touch()
|
| 1080 |
+
logger.info(f"💾 Created training checkpoint: {training_complete_marker}")
|
| 1081 |
|
| 1082 |
+
except Exception as e:
|
| 1083 |
+
logger.exception("💥 Tokenlearn training failed")
|
| 1084 |
+
logger.exception("💥 Tokenlearn training is required - cannot proceed")
|
| 1085 |
+
msg = f"Tokenlearn training failed: {e}"
|
| 1086 |
+
raise RuntimeError(msg) from e
|
| 1087 |
|
| 1088 |
+
# Step 4: Load the trained model and apply post-training re-regularization
|
| 1089 |
+
logger.info("📦 Step 4: Loading trained model and applying post-training re-regularization...")
|
| 1090 |
|
| 1091 |
+
# Check if we need to use fallback due to tokenlearn failure
|
| 1092 |
+
training_fallback_marker = trained_dir / ".training_fallback"
|
| 1093 |
+
if training_fallback_marker.exists():
|
| 1094 |
+
logger.error("❌ Tokenlearn training failed previously - cannot return trained model")
|
| 1095 |
+
logger.error("💥 Training was requested but failed - this would be misleading to return base model")
|
| 1096 |
+
msg = "Tokenlearn training failed - cannot proceed with training pipeline"
|
| 1097 |
+
raise RuntimeError(msg)
|
| 1098 |
|
| 1099 |
+
try:
|
| 1100 |
+
from model2vec.model import StaticModel
|
| 1101 |
+
|
| 1102 |
+
# Load the trained model from tokenlearn
|
| 1103 |
+
trained_model_path = trained_dir / "model"
|
| 1104 |
+
if not trained_model_path.exists():
|
| 1105 |
+
# Try alternative paths
|
| 1106 |
+
possible_paths = [
|
| 1107 |
+
trained_dir / "model_weighted",
|
| 1108 |
+
trained_dir,
|
| 1109 |
+
]
|
| 1110 |
+
|
| 1111 |
+
for path in possible_paths:
|
| 1112 |
+
if path.exists() and any(path.glob("*.json")):
|
| 1113 |
+
trained_model_path = path
|
| 1114 |
+
break
|
| 1115 |
+
else:
|
| 1116 |
+
logger.error(f"❌ Could not find trained model in {trained_dir}")
|
| 1117 |
+
logger.error("💥 Training was requested but no trained model found - cannot proceed")
|
| 1118 |
+
msg = f"Trained model not found in {trained_dir} - training pipeline failed"
|
| 1119 |
+
raise RuntimeError(msg)
|
| 1120 |
+
|
| 1121 |
+
# Load the model before re-regularization
|
| 1122 |
+
logger.info("🔄 Loading model from tokenlearn training...")
|
| 1123 |
+
trained_model = StaticModel.from_pretrained(str(trained_model_path))
|
| 1124 |
+
|
| 1125 |
+
# Apply post-training re-regularization (POTION Step 4)
|
| 1126 |
+
logger.info("🔧 Applying post-training re-regularization (PCA + SIF weighting)...")
|
| 1127 |
+
final_model = apply_post_training_regularization(
|
| 1128 |
+
trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
logger.info("✅ Tokenlearn training pipeline with post-training re-regularization completed successfully")
|
| 1132 |
+
|
| 1133 |
+
return final_model
|
| 1134 |
+
|
| 1135 |
+
except ValueError as e:
|
| 1136 |
+
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
|
| 1137 |
+
logger.exception("💥 Token-vector mismatch in tokenlearn training")
|
| 1138 |
+
logger.exception("Error details")
|
| 1139 |
+
logger.exception("🔧 This is a known issue with tokenlearn/Model2Vec integration")
|
| 1140 |
+
logger.exception("💥 Training was requested but failed due to token-vector mismatch")
|
| 1141 |
+
msg = f"Tokenlearn training failed due to token-vector mismatch: {e}"
|
| 1142 |
raise RuntimeError(msg) from e
|
| 1143 |
+
logger.exception("💥 Failed to load tokenlearn trained model")
|
| 1144 |
+
msg = f"Failed to load tokenlearn trained model: {e}"
|
| 1145 |
+
raise RuntimeError(msg) from e
|
| 1146 |
+
except Exception as e:
|
| 1147 |
+
logger.exception("💥 Failed to load tokenlearn trained model")
|
| 1148 |
+
logger.exception("💥 Cannot load trained model - training failed")
|
| 1149 |
+
msg = f"Failed to load tokenlearn trained model: {e}"
|
| 1150 |
+
raise RuntimeError(msg) from e
|
| 1151 |
|
| 1152 |
|
| 1153 |
def distill_single_teacher(
|
|
|
|
| 1197 |
|
| 1198 |
# Initialize Beam utilities if requested
|
| 1199 |
checkpoint_mgr = None
|
|
|
|
| 1200 |
if use_beam_utilities:
|
| 1201 |
try:
|
| 1202 |
_, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path)
|
|
|
|
| 1275 |
|
| 1276 |
existing_base = str(base_dir)
|
| 1277 |
|
| 1278 |
+
# Step 3: Handle final model creation
|
| 1279 |
+
if enable_training and base_model is not None:
|
| 1280 |
+
# Perform tokenlearn training (POTION approach)
|
| 1281 |
+
logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
|
| 1283 |
+
try:
|
| 1284 |
+
# Load teacher model for training
|
| 1285 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1286 |
+
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
| 1287 |
+
|
| 1288 |
+
# Perform tokenlearn training (POTION approach)
|
| 1289 |
+
final_model = tokenlearn_training(base_model, teacher_st_model, checkpoint_mgr)
|
| 1290 |
+
|
| 1291 |
+
# Save final model
|
| 1292 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 1293 |
+
final_model.save_pretrained(str(final_dir))
|
| 1294 |
+
|
| 1295 |
+
# Sync final model and training checkpoints to Beam
|
| 1296 |
+
if use_beam_utilities:
|
| 1297 |
+
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
|
| 1298 |
+
if checkpoint_mgr:
|
| 1299 |
+
sync_checkpoints_to_beam(
|
| 1300 |
+
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
|
| 1301 |
+
)
|
| 1302 |
+
|
| 1303 |
+
del teacher_st_model
|
| 1304 |
+
if torch.cuda.is_available():
|
| 1305 |
+
torch.cuda.empty_cache()
|
| 1306 |
+
|
| 1307 |
+
except RuntimeError as e:
|
| 1308 |
+
# Training failed - clean up and return failure
|
| 1309 |
+
logger.exception(f"❌ Training failed for {teacher_name}")
|
| 1310 |
+
|
| 1311 |
+
# Clean up teacher model if it was loaded
|
| 1312 |
+
if "teacher_st_model" in locals():
|
| 1313 |
+
del teacher_st_model
|
| 1314 |
+
if torch.cuda.is_available():
|
| 1315 |
+
torch.cuda.empty_cache()
|
| 1316 |
+
|
| 1317 |
+
return {
|
| 1318 |
+
"teacher_model": teacher_model,
|
| 1319 |
+
"teacher_name": teacher_name,
|
| 1320 |
+
"status": "failed_training",
|
| 1321 |
+
"error": f"Training failed: {e!s}",
|
| 1322 |
+
"base_path": existing_base, # Base model was created successfully
|
| 1323 |
+
}
|
| 1324 |
|
| 1325 |
+
else:
|
| 1326 |
+
# Copy base to final (no training)
|
| 1327 |
+
logger.info(f"📁 Copying base to final for {teacher_name}")
|
| 1328 |
+
if not copy_base_to_final(teacher_name, enable_training):
|
| 1329 |
+
return {
|
| 1330 |
+
"teacher_model": teacher_model,
|
| 1331 |
+
"teacher_name": teacher_name,
|
| 1332 |
+
"status": "failed_copy_to_final",
|
| 1333 |
+
"error": "Failed to copy base to final",
|
| 1334 |
+
}
|
| 1335 |
|
| 1336 |
+
total_time = time.time() - start_time
|
| 1337 |
|
| 1338 |
return {
|
| 1339 |
"teacher_model": teacher_model,
|
|
|
|
| 1417 |
|
| 1418 |
if result["status"] == "success" or result["status"].startswith("skipped"):
|
| 1419 |
successful_models.append(teacher_name)
|
| 1420 |
+
elif result["status"] == "failed_training":
|
| 1421 |
+
# Note: Training failed but base model may still be available
|
| 1422 |
+
logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded")
|
| 1423 |
|
| 1424 |
# Summary
|
| 1425 |
logger.info("\n🏆 DISTILLATION WORKFLOW COMPLETE!")
|
|
|
|
| 1451 |
return results_summary
|
| 1452 |
|
| 1453 |
|
| 1454 |
+
def _beam_distill_internal(
|
|
|
|
| 1455 |
teacher_models: list[str] | None = None,
|
| 1456 |
enable_training: bool = False,
|
| 1457 |
pca_dims: int | None = None,
|
| 1458 |
clear_cache: bool = False,
|
| 1459 |
) -> dict[str, Any]:
|
| 1460 |
+
"""Shared internal implementation for beam distillation."""
|
|
|
|
|
|
|
| 1461 |
# Apply patches
|
| 1462 |
patch_success = apply_local_patches()
|
| 1463 |
if patch_success:
|
|
|
|
| 1503 |
|
| 1504 |
if result["status"] == "success" or result["status"].startswith("skipped"):
|
| 1505 |
successful_models.append(teacher_name)
|
| 1506 |
+
elif result["status"] == "failed_training":
|
| 1507 |
+
# Note: Training failed but base model may still be available
|
| 1508 |
+
logger.warning(f"⚠️ Training failed for {teacher_name}, but base distillation may have succeeded")
|
| 1509 |
|
| 1510 |
# Summary
|
| 1511 |
logger.info("\n🏆 BEAM DISTILLATION WORKFLOW COMPLETE!")
|
|
|
|
| 1531 |
return results_summary
|
| 1532 |
|
| 1533 |
|
| 1534 |
+
@function(**get_training_function_kwargs())
|
| 1535 |
+
def _beam_train_models(
|
| 1536 |
+
teacher_models: list[str] | None = None,
|
| 1537 |
+
enable_training: bool = True,
|
| 1538 |
+
pca_dims: int | None = None,
|
| 1539 |
+
clear_cache: bool = False,
|
| 1540 |
+
) -> dict[str, Any]:
|
| 1541 |
+
"""Beam function for training (distillation + tokenlearn)."""
|
| 1542 |
+
logger.info("☁️ Running training on Beam")
|
| 1543 |
+
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
|
| 1544 |
+
|
| 1545 |
+
|
| 1546 |
+
@function(**get_distillation_function_kwargs())
|
| 1547 |
+
def _beam_distill_models(
|
| 1548 |
+
teacher_models: list[str] | None = None,
|
| 1549 |
+
enable_training: bool = False,
|
| 1550 |
+
pca_dims: int | None = None,
|
| 1551 |
+
clear_cache: bool = False,
|
| 1552 |
+
) -> dict[str, Any]:
|
| 1553 |
+
"""Beam function for basic distillation only."""
|
| 1554 |
+
logger.info("☁️ Running distillation on Beam")
|
| 1555 |
+
return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache)
|
| 1556 |
+
|
| 1557 |
+
|
| 1558 |
def run_beam_distillation(
|
| 1559 |
teacher_models: list[str] | None = None,
|
| 1560 |
enable_training: bool = False,
|
|
|
|
| 1565 |
logger.info("☁️ Running distillation on Beam with local sync")
|
| 1566 |
|
| 1567 |
try:
|
| 1568 |
+
# Choose appropriate beam function based on training flag
|
| 1569 |
+
beam_function = _beam_train_models if enable_training else _beam_distill_models
|
| 1570 |
+
|
| 1571 |
# Run distillation on Beam
|
| 1572 |
+
results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache)
|
| 1573 |
|
| 1574 |
# Check if Beam execution was successful
|
| 1575 |
if not results:
|
|
|
|
| 1658 |
|
| 1659 |
# Clear tokenlearn checkpoints if requested (for training mode)
|
| 1660 |
if clear_checkpoints and train:
|
|
|
|
|
|
|
| 1661 |
logger.info("🧹 Clearing tokenlearn checkpoints to force fresh featurization and training...")
|
| 1662 |
for teacher_model in models_to_distill:
|
| 1663 |
+
teacher_model.split("/")[-1].replace("-", "_")
|
| 1664 |
+
|
| 1665 |
+
# Use the same persistent directory structure as the training function
|
| 1666 |
+
teacher_slug = teacher_model.replace("/", "_").replace("-", "_")
|
| 1667 |
+
persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug
|
| 1668 |
+
|
| 1669 |
+
features_dir = persistent_tokenlearn_dir / "features"
|
| 1670 |
+
trained_dir = persistent_tokenlearn_dir / "trained_model"
|
| 1671 |
+
|
| 1672 |
+
# Clear persistent tokenlearn checkpoints
|
| 1673 |
+
if features_dir.exists() or trained_dir.exists():
|
| 1674 |
+
clear_tokenlearn_checkpoints(features_dir, trained_dir)
|
| 1675 |
+
logger.info(f"🗑️ Cleared persistent tokenlearn checkpoints for {teacher_model}")
|
| 1676 |
+
else:
|
| 1677 |
+
logger.info(f"ℹ️ No tokenlearn checkpoints found for {teacher_model}")
|
|
|
|
|
|
|
| 1678 |
elif clear_checkpoints and not train:
|
| 1679 |
logger.warning("⚠️ --clear-checkpoints flag is only relevant when training is enabled (--train)")
|
| 1680 |
|