Sarthak commited on
Commit Β·
8083c06
1
Parent(s): ff551a2
feat(distiller): add checkpointing and refactor analyze.py
Browse filesThis change introduces checkpointing to the tokenlearn featurization and training processes. This allows the processes to resume from where they left off if they are interrupted or have already completed. It also adds a --clear-checkpoints flag to force fresh featurization and training.
Additionally, minor refactoring was done to use list comprehension in analyze.py
- src/distiller/analyze.py +6 -5
- src/distiller/distill.py +211 -86
src/distiller/analyze.py
CHANGED
|
@@ -496,10 +496,11 @@ class CodeSearchNetAnalyzer:
|
|
| 496 |
return
|
| 497 |
|
| 498 |
# Find all our model directories
|
| 499 |
-
our_model_dirs = [
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
|
|
|
| 503 |
|
| 504 |
logger.info(f"π Found {len(our_model_dirs)} distilled model directories")
|
| 505 |
|
|
@@ -1567,7 +1568,7 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
| 1567 |
if self.model_specs:
|
| 1568 |
successful_specs = {k: v for k, v in self.model_specs.items() if v.get("analysis_successful", False)}
|
| 1569 |
if successful_specs:
|
| 1570 |
-
report +=
|
| 1571 |
|
| 1572 |
### π Model Specifications Analysis
|
| 1573 |
|
|
|
|
| 496 |
return
|
| 497 |
|
| 498 |
# Find all our model directories
|
| 499 |
+
our_model_dirs = [
|
| 500 |
+
model_dir
|
| 501 |
+
for model_dir in final_models_dir.iterdir()
|
| 502 |
+
if model_dir.is_dir() and "code_model2vec" in model_dir.name
|
| 503 |
+
]
|
| 504 |
|
| 505 |
logger.info(f"π Found {len(our_model_dirs)} distilled model directories")
|
| 506 |
|
|
|
|
| 1568 |
if self.model_specs:
|
| 1569 |
successful_specs = {k: v for k, v in self.model_specs.items() if v.get("analysis_successful", False)}
|
| 1570 |
if successful_specs:
|
| 1571 |
+
report += """
|
| 1572 |
|
| 1573 |
### π Model Specifications Analysis
|
| 1574 |
|
src/distiller/distill.py
CHANGED
|
@@ -866,7 +866,7 @@ def tokenlearn_training(
|
|
| 866 |
student_model.save_pretrained(str(model_dir))
|
| 867 |
logger.info(f"πΎ Saved base model to {model_dir}")
|
| 868 |
|
| 869 |
-
# Step 2: Create features using
|
| 870 |
logger.info("π Step 2: Creating features using sentence transformer...")
|
| 871 |
|
| 872 |
# Get teacher model name/path for tokenlearn
|
|
@@ -878,107 +878,153 @@ def tokenlearn_training(
|
|
| 878 |
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
|
| 879 |
teacher_model_name = first_module.auto_model.name_or_path
|
| 880 |
|
| 881 |
-
if not teacher_model_name:
|
| 882 |
-
logger.warning("β οΈ Could not determine teacher model name, using fallback")
|
| 883 |
-
teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
|
| 884 |
-
|
| 885 |
logger.info(f"π Using teacher model: {teacher_model_name}")
|
| 886 |
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
"tokenlearn.featurize",
|
| 893 |
-
"--model-name",
|
| 894 |
-
str(teacher_model_name),
|
| 895 |
-
"--output-dir",
|
| 896 |
-
str(features_dir),
|
| 897 |
-
"--dataset-path",
|
| 898 |
-
str(distillation_config.tokenlearn_dataset),
|
| 899 |
-
"--dataset-name",
|
| 900 |
-
str(distillation_config.tokenlearn_dataset_name),
|
| 901 |
-
"--dataset-split",
|
| 902 |
-
"train",
|
| 903 |
-
"--key",
|
| 904 |
-
str(distillation_config.tokenlearn_text_key), # Use configured text field
|
| 905 |
-
"--batch-size",
|
| 906 |
-
"1024", # Optimized batch size for A100-40G
|
| 907 |
-
]
|
| 908 |
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
check=False,
|
| 922 |
-
)
|
| 923 |
|
| 924 |
-
|
| 925 |
-
logger.error(f"β Featurization failed with return code: {result.returncode}")
|
| 926 |
-
logger.error("π₯ Tokenlearn featurization is required for training - cannot proceed")
|
| 927 |
-
msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
|
| 928 |
-
raise RuntimeError(msg)
|
| 929 |
|
| 930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
|
| 942 |
# Step 3: Train using tokenlearn-train
|
| 943 |
logger.info("π Step 3: Training using tokenlearn...")
|
| 944 |
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
|
| 962 |
-
|
| 963 |
-
train_cmd
|
| 964 |
-
|
| 965 |
-
timeout=distillation_config.tokenlearn_timeout_train,
|
| 966 |
-
check=False,
|
| 967 |
-
)
|
| 968 |
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
|
|
|
| 974 |
|
| 975 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 976 |
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
# Step 4: Load the trained model and apply post-training re-regularization
|
| 984 |
logger.info("π¦ Step 4: Loading trained model and applying post-training re-regularization...")
|
|
@@ -1256,6 +1302,9 @@ def run_local_distillation(
|
|
| 1256 |
if model in models_to_distill:
|
| 1257 |
clear_model_cache(model)
|
| 1258 |
|
|
|
|
|
|
|
|
|
|
| 1259 |
for teacher_model in models_to_distill:
|
| 1260 |
result = distill_single_teacher(
|
| 1261 |
teacher_model=teacher_model,
|
|
@@ -1453,6 +1502,9 @@ def main(
|
|
| 1453 |
clear_cache: Annotated[
|
| 1454 |
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
|
| 1455 |
] = False,
|
|
|
|
|
|
|
|
|
|
| 1456 |
) -> None:
|
| 1457 |
"""Unified distillation command with optional training."""
|
| 1458 |
logger.info("π Starting unified Model2Vec distillation workflow")
|
|
@@ -1475,6 +1527,32 @@ def main(
|
|
| 1475 |
if model in models_to_distill:
|
| 1476 |
clear_model_cache(model)
|
| 1477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1478 |
# Run distillation workflow
|
| 1479 |
if use_beam:
|
| 1480 |
results = run_beam_distillation(
|
|
@@ -1822,5 +1900,52 @@ def baai_bge_model_distillation(
|
|
| 1822 |
return None
|
| 1823 |
|
| 1824 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1825 |
if __name__ == "__main__":
|
| 1826 |
typer.run(main)
|
|
|
|
| 866 |
student_model.save_pretrained(str(model_dir))
|
| 867 |
logger.info(f"πΎ Saved base model to {model_dir}")
|
| 868 |
|
| 869 |
+
# Step 2: Create features using sentence transformer
|
| 870 |
logger.info("π Step 2: Creating features using sentence transformer...")
|
| 871 |
|
| 872 |
# Get teacher model name/path for tokenlearn
|
|
|
|
| 878 |
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"):
|
| 879 |
teacher_model_name = first_module.auto_model.name_or_path
|
| 880 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
logger.info(f"π Using teacher model: {teacher_model_name}")
|
| 882 |
|
| 883 |
+
# Check if featurization already completed (checkpoint detection)
|
| 884 |
+
featurization_complete_marker = features_dir / ".featurization_complete"
|
| 885 |
+
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
|
| 886 |
+
logger.info("β
Found existing featurization checkpoint with valid output files")
|
| 887 |
+
logger.info(f"π Using cached features from: {features_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
|
| 889 |
+
# Verify marker is still valid
|
| 890 |
+
output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json"))
|
| 891 |
+
logger.info(f"π Found {len(output_files)} cached feature files")
|
| 892 |
+
else:
|
| 893 |
+
if featurization_complete_marker.exists():
|
| 894 |
+
logger.warning("β οΈ Featurization marker exists but output files are missing - re-running featurization")
|
| 895 |
+
featurization_complete_marker.unlink()
|
| 896 |
+
logger.info("π No valid featurization checkpoint found - starting featurization...")
|
| 897 |
+
|
| 898 |
+
if not teacher_model_name:
|
| 899 |
+
logger.warning("β οΈ Could not determine teacher model name, using fallback")
|
| 900 |
+
teacher_model_name = "BAAI/bge-base-en-v1.5" # Fallback to a common model
|
|
|
|
|
|
|
| 901 |
|
| 902 |
+
logger.info(f"π Using teacher model: {teacher_model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
|
| 904 |
+
try:
|
| 905 |
+
# Use configured dataset for code specialization
|
| 906 |
+
featurize_cmd = [
|
| 907 |
+
"python",
|
| 908 |
+
"-m",
|
| 909 |
+
"tokenlearn.featurize",
|
| 910 |
+
"--model-name",
|
| 911 |
+
str(teacher_model_name),
|
| 912 |
+
"--output-dir",
|
| 913 |
+
str(features_dir),
|
| 914 |
+
"--dataset-path",
|
| 915 |
+
str(distillation_config.tokenlearn_dataset),
|
| 916 |
+
"--dataset-name",
|
| 917 |
+
str(distillation_config.tokenlearn_dataset_name),
|
| 918 |
+
"--dataset-split",
|
| 919 |
+
"train",
|
| 920 |
+
"--key",
|
| 921 |
+
str(distillation_config.tokenlearn_text_key), # Use configured text field
|
| 922 |
+
"--batch-size",
|
| 923 |
+
"1024", # Optimized batch size for A100-40G
|
| 924 |
+
]
|
| 925 |
|
| 926 |
+
logger.info("π Running tokenlearn featurization...")
|
| 927 |
+
logger.info(
|
| 928 |
+
f"π Dataset: {distillation_config.tokenlearn_dataset} (config: {distillation_config.tokenlearn_dataset_name})"
|
| 929 |
+
)
|
| 930 |
+
logger.info(f"π Text field: {distillation_config.tokenlearn_text_key}")
|
| 931 |
+
logger.info(f"Command: {' '.join(featurize_cmd)}")
|
| 932 |
+
print(f"\nπ Executing: {' '.join(featurize_cmd)}\n")
|
| 933 |
+
|
| 934 |
+
result = subprocess.run( # noqa: S603
|
| 935 |
+
featurize_cmd,
|
| 936 |
+
text=True,
|
| 937 |
+
timeout=distillation_config.tokenlearn_timeout_featurize,
|
| 938 |
+
check=False,
|
| 939 |
+
)
|
| 940 |
|
| 941 |
+
if result.returncode != 0:
|
| 942 |
+
logger.error(f"β Featurization failed with return code: {result.returncode}")
|
| 943 |
+
logger.error("π₯ Tokenlearn featurization is required for training - cannot proceed")
|
| 944 |
+
msg = f"Tokenlearn featurization failed with return code: {result.returncode}"
|
| 945 |
+
raise RuntimeError(msg)
|
| 946 |
+
|
| 947 |
+
logger.info("β
Featurization completed successfully")
|
| 948 |
+
|
| 949 |
+
# Create checkpoint marker to indicate featurization is complete
|
| 950 |
+
featurization_complete_marker.touch()
|
| 951 |
+
logger.info(f"πΎ Created featurization checkpoint: {featurization_complete_marker}")
|
| 952 |
+
|
| 953 |
+
# Generate token frequencies for post-training re-regularization
|
| 954 |
+
logger.info("π Computing token frequencies for SIF weighting...")
|
| 955 |
+
compute_token_frequencies_for_sif(teacher_model, features_dir)
|
| 956 |
+
|
| 957 |
+
except Exception as e:
|
| 958 |
+
logger.exception("π₯ Tokenlearn featurization failed")
|
| 959 |
+
logger.exception("π₯ Tokenlearn featurization is required for training - cannot proceed")
|
| 960 |
+
msg = f"Tokenlearn featurization failed: {e}"
|
| 961 |
+
raise RuntimeError(msg) from e
|
| 962 |
|
| 963 |
# Step 3: Train using tokenlearn-train
|
| 964 |
logger.info("π Step 3: Training using tokenlearn...")
|
| 965 |
|
| 966 |
+
# Check if training already completed (checkpoint detection)
|
| 967 |
+
training_complete_marker = trained_dir / ".training_complete"
|
| 968 |
+
if training_complete_marker.exists() and verify_training_output(trained_dir):
|
| 969 |
+
logger.info("β
Found existing training checkpoint with valid model files")
|
| 970 |
+
logger.info(f"π Using cached trained model from: {trained_dir}")
|
| 971 |
+
|
| 972 |
+
# Show available model files
|
| 973 |
+
model_files = []
|
| 974 |
+
for pattern in ["*.json", "*.safetensors", "*.bin"]:
|
| 975 |
+
model_files.extend(list(trained_dir.glob(pattern)))
|
| 976 |
+
for subdir in ["model", "model_weighted"]:
|
| 977 |
+
subdir_path = trained_dir / subdir
|
| 978 |
+
if subdir_path.exists():
|
| 979 |
+
model_files.extend(list(subdir_path.glob(pattern)))
|
| 980 |
+
logger.info(f"π Found {len(model_files)} cached model files")
|
| 981 |
+
else:
|
| 982 |
+
if training_complete_marker.exists():
|
| 983 |
+
logger.warning("β οΈ Training marker exists but model files are missing - re-running training")
|
| 984 |
+
training_complete_marker.unlink()
|
| 985 |
+
logger.info("π No valid training checkpoint found - starting training...")
|
| 986 |
|
| 987 |
+
try:
|
| 988 |
+
train_cmd = [
|
| 989 |
+
"python",
|
| 990 |
+
"-m",
|
| 991 |
+
"tokenlearn.train",
|
| 992 |
+
"--model-name",
|
| 993 |
+
str(teacher_model_name),
|
| 994 |
+
"--data-path",
|
| 995 |
+
str(features_dir),
|
| 996 |
+
"--save-path",
|
| 997 |
+
str(trained_dir),
|
| 998 |
+
]
|
| 999 |
|
| 1000 |
+
logger.info("π Running tokenlearn training...")
|
| 1001 |
+
logger.info(f"Command: {' '.join(train_cmd)}")
|
| 1002 |
+
print(f"\nπ Executing: {' '.join(train_cmd)}\n")
|
|
|
|
|
|
|
|
|
|
| 1003 |
|
| 1004 |
+
result = subprocess.run( # noqa: S603
|
| 1005 |
+
train_cmd,
|
| 1006 |
+
text=True,
|
| 1007 |
+
timeout=distillation_config.tokenlearn_timeout_train,
|
| 1008 |
+
check=False,
|
| 1009 |
+
)
|
| 1010 |
|
| 1011 |
+
if result.returncode != 0:
|
| 1012 |
+
logger.error(f"β Tokenlearn training failed with return code: {result.returncode}")
|
| 1013 |
+
logger.error("π₯ Tokenlearn training is required - cannot proceed")
|
| 1014 |
+
msg = f"Tokenlearn training failed with return code: {result.returncode}"
|
| 1015 |
+
raise RuntimeError(msg)
|
| 1016 |
|
| 1017 |
+
logger.info("β
Tokenlearn training completed successfully")
|
| 1018 |
+
|
| 1019 |
+
# Create checkpoint marker to indicate training is complete
|
| 1020 |
+
training_complete_marker.touch()
|
| 1021 |
+
logger.info(f"πΎ Created training checkpoint: {training_complete_marker}")
|
| 1022 |
+
|
| 1023 |
+
except Exception as e:
|
| 1024 |
+
logger.exception("π₯ Tokenlearn training failed")
|
| 1025 |
+
logger.exception("π₯ Tokenlearn training is required - cannot proceed")
|
| 1026 |
+
msg = f"Tokenlearn training failed: {e}"
|
| 1027 |
+
raise RuntimeError(msg) from e
|
| 1028 |
|
| 1029 |
# Step 4: Load the trained model and apply post-training re-regularization
|
| 1030 |
logger.info("π¦ Step 4: Loading trained model and applying post-training re-regularization...")
|
|
|
|
| 1302 |
if model in models_to_distill:
|
| 1303 |
clear_model_cache(model)
|
| 1304 |
|
| 1305 |
+
# Clear tokenlearn checkpoints if requested (for training mode)
|
| 1306 |
+
# Note: Checkpoint clearing is handled at the main function level
|
| 1307 |
+
# Run distillation workflow
|
| 1308 |
for teacher_model in models_to_distill:
|
| 1309 |
result = distill_single_teacher(
|
| 1310 |
teacher_model=teacher_model,
|
|
|
|
| 1502 |
clear_cache: Annotated[
|
| 1503 |
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
|
| 1504 |
] = False,
|
| 1505 |
+
clear_checkpoints: Annotated[
|
| 1506 |
+
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
| 1507 |
+
] = False,
|
| 1508 |
) -> None:
|
| 1509 |
"""Unified distillation command with optional training."""
|
| 1510 |
logger.info("π Starting unified Model2Vec distillation workflow")
|
|
|
|
| 1527 |
if model in models_to_distill:
|
| 1528 |
clear_model_cache(model)
|
| 1529 |
|
| 1530 |
+
# Clear tokenlearn checkpoints if requested (for training mode)
|
| 1531 |
+
if clear_checkpoints and train:
|
| 1532 |
+
import tempfile
|
| 1533 |
+
|
| 1534 |
+
logger.info("π§Ή Clearing tokenlearn checkpoints to force fresh featurization and training...")
|
| 1535 |
+
for teacher_model in models_to_distill:
|
| 1536 |
+
teacher_name = teacher_model.split("/")[-1].replace("-", "_")
|
| 1537 |
+
|
| 1538 |
+
# Construct checkpoint paths using secure temporary directory
|
| 1539 |
+
temp_dir = Path(tempfile.gettempdir()) / f"tokenlearn_{teacher_name}"
|
| 1540 |
+
features_dir = temp_dir / "features"
|
| 1541 |
+
trained_dir = temp_dir / "trained"
|
| 1542 |
+
|
| 1543 |
+
# Also check local paths
|
| 1544 |
+
local_temp = Path("temp") / f"tokenlearn_{teacher_name}"
|
| 1545 |
+
local_features = local_temp / "features"
|
| 1546 |
+
local_trained = local_temp / "trained"
|
| 1547 |
+
|
| 1548 |
+
# Clear checkpoints for all possible paths
|
| 1549 |
+
for feat_dir, train_dir in [(features_dir, trained_dir), (local_features, local_trained)]:
|
| 1550 |
+
if feat_dir.exists() or train_dir.exists():
|
| 1551 |
+
clear_tokenlearn_checkpoints(feat_dir, train_dir)
|
| 1552 |
+
logger.info(f"ποΈ Cleared checkpoints for {teacher_model}")
|
| 1553 |
+
elif clear_checkpoints and not train:
|
| 1554 |
+
logger.warning("β οΈ --clear-checkpoints flag is only relevant when training is enabled (--train)")
|
| 1555 |
+
|
| 1556 |
# Run distillation workflow
|
| 1557 |
if use_beam:
|
| 1558 |
results = run_beam_distillation(
|
|
|
|
| 1900 |
return None
|
| 1901 |
|
| 1902 |
|
| 1903 |
+
def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None:
|
| 1904 |
+
"""Clear tokenlearn checkpoint markers to force re-execution of steps."""
|
| 1905 |
+
featurization_marker = features_dir / ".featurization_complete"
|
| 1906 |
+
training_marker = trained_dir / ".training_complete"
|
| 1907 |
+
|
| 1908 |
+
if featurization_marker.exists():
|
| 1909 |
+
featurization_marker.unlink()
|
| 1910 |
+
logger.info(f"ποΈ Cleared featurization checkpoint: {featurization_marker}")
|
| 1911 |
+
|
| 1912 |
+
if training_marker.exists():
|
| 1913 |
+
training_marker.unlink()
|
| 1914 |
+
logger.info(f"ποΈ Cleared training checkpoint: {training_marker}")
|
| 1915 |
+
|
| 1916 |
+
|
| 1917 |
+
def verify_featurization_output(features_dir: Path) -> bool:
|
| 1918 |
+
"""Verify that featurization output files actually exist and are valid."""
|
| 1919 |
+
if not features_dir.exists():
|
| 1920 |
+
return False
|
| 1921 |
+
|
| 1922 |
+
# Check for expected tokenlearn output files
|
| 1923 |
+
|
| 1924 |
+
# Check if any expected files exist
|
| 1925 |
+
return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"])
|
| 1926 |
+
|
| 1927 |
+
|
| 1928 |
+
def verify_training_output(trained_dir: Path) -> bool:
|
| 1929 |
+
"""Verify that training output files actually exist and are valid."""
|
| 1930 |
+
if not trained_dir.exists():
|
| 1931 |
+
return False
|
| 1932 |
+
|
| 1933 |
+
# Check for model files
|
| 1934 |
+
model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"]
|
| 1935 |
+
for model_file in model_files:
|
| 1936 |
+
if (trained_dir / model_file).exists():
|
| 1937 |
+
return True
|
| 1938 |
+
|
| 1939 |
+
# Check for alternative model directory structure
|
| 1940 |
+
for subdir in ["model", "model_weighted"]:
|
| 1941 |
+
subdir_path = trained_dir / subdir
|
| 1942 |
+
if subdir_path.exists():
|
| 1943 |
+
for model_file in model_files:
|
| 1944 |
+
if (subdir_path / model_file).exists():
|
| 1945 |
+
return True
|
| 1946 |
+
|
| 1947 |
+
return False
|
| 1948 |
+
|
| 1949 |
+
|
| 1950 |
if __name__ == "__main__":
|
| 1951 |
typer.run(main)
|