| | """ |
| | Unified Code-Specialized Model2Vec Distillation Script. |
| | |
| | This script provides a unified approach for creating code-specialized embeddings |
| | using Model2Vec distillation with optional code-specific training. |
| | |
| | Features: |
| | - Basic distillation (default): Simple Model2Vec distillation |
| | - Advanced training (--train flag): Additional CodeSearchNet fine-tuning |
| | - Checkpoint support with Beam sync utilities |
| | - Multi-teacher model processing |
| | - Smart resume capabilities |
| | - Hierarchical storage: base β final |
| | |
| | Directory Structure: |
| | - code_model2vec/base: Basic distilled models (first step) |
| | - code_model2vec/final: Final models (copied from base or after training) |
| | |
| | Usage: |
| | distiller distill [--use-beam] [--train] # Basic distillation or with training |
| | """ |
| |
|
| | import importlib.util |
| | import json |
| | import logging |
| | import os |
| | import time |
| | from pathlib import Path |
| | from typing import Annotated, Any |
| |
|
| | import torch |
| | import typer |
| | from beam import function |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | from distiller.model2vec.distill import distill |
| |
|
| | |
| | from .beam_utils import ( |
| | BeamCheckpointManager, |
| | create_beam_utilities, |
| | download_model_from_beam, |
| | sync_checkpoints_from_beam, |
| | sync_checkpoints_to_beam, |
| | upload_model_to_beam, |
| | ) |
| | from .config import ( |
| | codesearchnet_config, |
| | directories, |
| | distillation_config, |
| | get_distillation_function_kwargs, |
| | get_training_function_kwargs, |
| | get_volume_config, |
| | languages_config, |
| | ) |
| |
|
| | |
| | FLASH_ATTN_AVAILABLE = importlib.util.find_spec("flash_attn") is not None |
| |
|
| | |
| | |
| | |
| |
|
| | VOLUME_CONFIG = get_volume_config() |
| | LOCAL_BASE_DIR = directories.base |
| | LOCAL_FINAL_DIR = directories.final |
| |
|
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | DEFAULT_TEACHER_MODELS = list(distillation_config.code_teacher_models) |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def configure_flash_attention() -> dict[str, Any]: |
| | """Configure flash attention settings and return model kwargs.""" |
| | model_kwargs: dict[str, Any] = {} |
| |
|
| | if not FLASH_ATTN_AVAILABLE: |
| | logger.info("β οΈ Flash attention not available - using standard attention") |
| | return model_kwargs |
| |
|
| | |
| | os.environ["FLASH_ATTENTION_FORCE_USE"] = "1" |
| | |
| | os.environ["TORCH_COMPILE_DISABLE"] = "1" |
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | try: |
| | if torch.cuda.is_available(): |
| | device_capability = torch.cuda.get_device_capability() |
| | |
| | if device_capability[0] >= 7 and (device_capability[0] > 7 or device_capability[1] >= 5): |
| | logger.info("β
Flash attention enabled - compatible GPU detected") |
| | model_kwargs.update( |
| | { |
| | "model_kwargs": { |
| | "attn_implementation": "flash_attention_2", |
| | "torch_dtype": torch.float16, |
| | "use_flash_attention_2": True, |
| | "_attn_implementation": "flash_attention_2", |
| | } |
| | } |
| | ) |
| | else: |
| | logger.info(f"β οΈ GPU compute capability {device_capability} < 7.5 - flash attention disabled") |
| | else: |
| | logger.info("β οΈ No CUDA available - flash attention disabled") |
| | except Exception as e: |
| | logger.warning(f"β οΈ Failed to check GPU compatibility: {e} - flash attention disabled") |
| |
|
| | return model_kwargs |
| |
|
| |
|
| | def load_model_with_flash_attention(model_path: str, device: str = "auto") -> SentenceTransformer: |
| | """Load a SentenceTransformer model with flash attention if available.""" |
| | flash_kwargs = configure_flash_attention() |
| |
|
| | try: |
| | |
| | if flash_kwargs and "model_kwargs" in flash_kwargs: |
| | logger.info(f"π Loading model with flash attention: {Path(model_path).name}") |
| | model = SentenceTransformer(model_path, device=device, trust_remote_code=True, **flash_kwargs) |
| | logger.info("β
Model loaded successfully with flash attention") |
| | return model |
| | except Exception as e: |
| | logger.warning(f"β οΈ Failed to load with flash attention: {e}") |
| | logger.info("π Falling back to standard attention") |
| |
|
| | |
| | logger.info(f"π Loading model with standard attention: {Path(model_path).name}") |
| | model = SentenceTransformer(model_path, device=device, trust_remote_code=True) |
| | logger.info("β
Model loaded successfully with standard attention") |
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def get_current_config_hash(enable_training: bool) -> str: |
| | """Generate a hash of current configuration parameters for checkpoint validation.""" |
| | import hashlib |
| |
|
| | config_params = { |
| | "pca_dims": distillation_config.optimal_pca_dims, |
| | "sif_coefficient": distillation_config.sif_coefficient, |
| | "apply_zipf": distillation_config.apply_zipf, |
| | "enable_training": enable_training, |
| | } |
| |
|
| | if enable_training: |
| | |
| | tokenlearn_hash = hash( |
| | f"{distillation_config.tokenlearn_dataset}_{distillation_config.tokenlearn_dataset_name}_{distillation_config.tokenlearn_text_key}" |
| | ) |
| | config_params["tokenlearn_hash"] = float(abs(tokenlearn_hash) % 1000000) |
| |
|
| | config_str = str(sorted(config_params.items())) |
| | return hashlib.md5(config_str.encode()).hexdigest()[:12] |
| |
|
| |
|
| | def check_existing_base_model(teacher_name: str) -> str | None: |
| | """Check if base distilled model already exists locally.""" |
| | base_dir = Path(LOCAL_BASE_DIR) |
| | model_dir = base_dir / f"code_model2vec_{teacher_name}" |
| |
|
| | if model_dir.exists(): |
| | |
| | has_config = (model_dir / "config.json").exists() |
| | has_model_file = any( |
| | [ |
| | (model_dir / "model.safetensors").exists(), |
| | (model_dir / "model.bin").exists(), |
| | (model_dir / "pytorch_model.bin").exists(), |
| | ] |
| | ) |
| |
|
| | if has_config and has_model_file: |
| | logger.info(f"β
Found existing base model: {teacher_name}") |
| | return str(model_dir) |
| |
|
| | return None |
| |
|
| |
|
| | def check_existing_final_model(teacher_name: str, enable_training: bool = False) -> str | None: |
| | """Check if final model already exists locally.""" |
| | final_dir = Path(LOCAL_FINAL_DIR) |
| |
|
| | |
| | model_name = f"code_model2vec_{teacher_name}" |
| | if enable_training: |
| | model_name += "_fine_tuned" |
| | final_path = final_dir / model_name |
| |
|
| | if final_path.exists(): |
| | |
| | has_config = (final_path / "config.json").exists() |
| | has_model_file = any( |
| | [ |
| | (final_path / "model.safetensors").exists(), |
| | (final_path / "model.bin").exists(), |
| | (final_path / "pytorch_model.bin").exists(), |
| | ] |
| | ) |
| |
|
| | if has_config and has_model_file: |
| | logger.info(f"β
Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}") |
| | return str(final_path) |
| |
|
| | return None |
| |
|
| |
|
| | def copy_base_to_final(teacher_name: str, enable_training: bool = False) -> bool: |
| | """Copy base model to final directory.""" |
| | import shutil |
| |
|
| | base_path = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" |
| |
|
| | |
| | final_model_name = f"code_model2vec_{teacher_name}" |
| | if enable_training: |
| | final_model_name += "_fine_tuned" |
| | final_path = Path(LOCAL_FINAL_DIR) / final_model_name |
| |
|
| | try: |
| | final_path.parent.mkdir(parents=True, exist_ok=True) |
| | if final_path.exists(): |
| | shutil.rmtree(final_path) |
| | shutil.copytree(base_path, final_path) |
| | logger.info(f"π Copied {teacher_name} from base to final{'_fine_tuned' if enable_training else ''}") |
| | return True |
| | except Exception: |
| | logger.exception(f"β Failed to copy {teacher_name} to final{'_fine_tuned' if enable_training else ''}") |
| | return False |
| |
|
| |
|
| | def sync_model_from_beam( |
| | teacher_name: str, |
| | target_dir: str, |
| | use_beam_utilities: bool = False, |
| | ) -> bool: |
| | """Sync model from Beam volume to local directory.""" |
| | if not use_beam_utilities: |
| | return False |
| |
|
| | try: |
| | target_path = Path(target_dir) |
| | target_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | beam_model_name = f"{teacher_name}_model" |
| | success = download_model_from_beam(VOLUME_CONFIG.name, beam_model_name, str(target_path)) |
| |
|
| | if success: |
| | logger.info(f"π₯ Synced {teacher_name} from Beam to {target_dir}") |
| | return True |
| | logger.warning(f"β οΈ Failed to sync {teacher_name} from Beam") |
| | return False |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to sync {teacher_name} from Beam: {e}") |
| | return False |
| |
|
| |
|
| | def sync_model_to_beam( |
| | teacher_name: str, |
| | source_dir: str, |
| | use_beam_utilities: bool = False, |
| | ) -> bool: |
| | """Sync model from local directory to Beam volume.""" |
| | if not use_beam_utilities: |
| | return False |
| |
|
| | try: |
| | beam_model_name = f"{teacher_name}_model" |
| | success = upload_model_to_beam(VOLUME_CONFIG.name, beam_model_name, source_dir) |
| |
|
| | if success: |
| | logger.info(f"π€ Synced {teacher_name} to Beam from {source_dir}") |
| | return True |
| | logger.warning(f"β οΈ Failed to sync {teacher_name} to Beam") |
| | return False |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to sync {teacher_name} to Beam: {e}") |
| | return False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def simple_distillation( |
| | teacher_model: str, |
| | output_dir: str, |
| | pca_dims: int | None = None, |
| | retry_with_cache_clear: bool = False, |
| | ) -> Any: |
| | """ |
| | Perform simple Model2Vec distillation without additional training. |
| | |
| | Args: |
| | teacher_model: Name of teacher model |
| | output_dir: Output directory for the distilled model |
| | pca_dims: PCA dimensions (uses config default if None) |
| | retry_with_cache_clear: Whether this is a retry after clearing cache |
| | |
| | Returns: |
| | Distilled model or None if failed |
| | """ |
| | if pca_dims is None: |
| | pca_dims = int(distillation_config.optimal_pca_dims) |
| |
|
| | output_path = Path(output_dir) |
| | output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | retry_suffix = " (retry after cache clear)" if retry_with_cache_clear else "" |
| | logger.info(f"π Simple distillation{retry_suffix}: {teacher_model} β {output_dir}") |
| | logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") |
| |
|
| | start_time = time.time() |
| |
|
| | try: |
| | |
| | model = distill( |
| | model_name=teacher_model, |
| | pca_dims=int(pca_dims), |
| | apply_zipf=bool(distillation_config.apply_zipf), |
| | sif_coefficient=float(distillation_config.sif_coefficient), |
| | trust_remote_code=True, |
| | ) |
| |
|
| | logger.info("β
Core distillation completed successfully") |
| |
|
| | |
| | if hasattr(model, "tokenizer") and hasattr(model, "embedding"): |
| | vocab_size = len(model.tokenizer.get_vocab()) |
| | embedding_size = model.embedding.shape[0] |
| |
|
| | logger.info("π Model validation:") |
| | logger.info(f" - Vocabulary size: {vocab_size}") |
| | logger.info(f" - Embedding matrix size: {embedding_size}") |
| |
|
| | if vocab_size != embedding_size: |
| | logger.warning(f"β οΈ Vocabulary size mismatch: vocab={vocab_size}, embeddings={embedding_size}") |
| | logger.warning("β οΈ This may cause issues in downstream usage") |
| | else: |
| | logger.info("β
Vocabulary and embedding sizes match") |
| |
|
| | |
| | model.save_pretrained(str(output_path)) |
| | logger.info(f"πΎ Model saved to {output_path}") |
| |
|
| | |
| | logger.info(f"Model type: {type(model)}") |
| | if hasattr(model, "embedding"): |
| | logger.info(f"Embedding shape: {model.embedding.shape}") |
| | logger.info(f"Embedding dtype: {model.embedding.dtype}") |
| |
|
| | total_time = time.time() - start_time |
| | logger.info(f"π Simple distillation completed in {total_time:.2f} seconds") |
| | return model |
| |
|
| | except ValueError as e: |
| | if "Number of tokens" in str(e) and "does not match number of vectors" in str(e): |
| | logger.warning(f"β οΈ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue") |
| | logger.warning(f"Error details: {e}") |
| | logger.warning("π‘ This model has incompatible tokenization. Skipping...") |
| | return None |
| | if "weight is on the meta device" in str(e): |
| | logger.warning(f"β οΈ Device placement issue with {teacher_model} - model weights on meta device") |
| | logger.warning(f"Error details: {e}") |
| | logger.warning("π‘ This model has device placement issues. Skipping...") |
| | return None |
| | raise |
| | except AttributeError as e: |
| | if "backend_tokenizer" in str(e): |
| | logger.warning(f"β οΈ Tokenizer compatibility issue with {teacher_model}") |
| | logger.warning(f"Error details: {e}") |
| | logger.warning("π‘ This model's tokenizer is incompatible with Model2Vec. Skipping...") |
| | return None |
| | raise |
| | except FileNotFoundError as e: |
| | if "transformers_modules" in str(e) or "xlm_padding.py" in str(e): |
| | logger.warning(f"β οΈ Missing custom model files for {teacher_model}") |
| | logger.warning(f"Error details: {e}") |
| |
|
| | |
| | if not retry_with_cache_clear: |
| | logger.info("π§ Attempting to clear cache and retry...") |
| | if clear_model_cache(teacher_model): |
| | logger.info("π Retrying distillation after cache clear...") |
| | return simple_distillation(teacher_model, output_dir, pca_dims, retry_with_cache_clear=True) |
| |
|
| | logger.warning("π‘ This model has missing dependencies. Manual intervention may be required.") |
| | return None |
| | raise |
| | except Exception: |
| | logger.exception(f"β Simple distillation failed for {teacher_model}") |
| | return None |
| |
|
| |
|
| | def load_optimized_dataset( |
| | max_samples: int | None = None, |
| | checkpoint_manager: BeamCheckpointManager | None = None, |
| | dataset_path: str | None = None, |
| | ) -> list[str]: |
| | """Load our pre-created optimized dataset for tokenlearn training.""" |
| | from .dataset import DATASET_OUTPUT_DIR |
| | from .dataset import load_optimized_dataset as load_dataset_func |
| |
|
| | |
| | if dataset_path is None: |
| | dataset_path = distillation_config.custom_dataset_path |
| |
|
| | dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR |
| |
|
| | |
| | if max_samples is None: |
| | max_samples = distillation_config.tokenlearn_max_samples |
| |
|
| | logger.info(f"π― Loading optimized dataset from {dataset_dir}") |
| | logger.info(f"π Target samples: {max_samples}") |
| |
|
| | try: |
| | |
| | df = load_dataset_func(output_dir=dataset_dir, split="train") |
| |
|
| | |
| | texts = df["text"].tolist() |
| |
|
| | |
| | import random |
| |
|
| | random.seed(42) |
| | random.shuffle(texts) |
| |
|
| | |
| | if len(texts) > max_samples: |
| | texts = texts[:max_samples] |
| |
|
| | logger.info(f"β
Loaded {len(texts)} optimized training samples") |
| |
|
| | |
| | languages = df["language"].value_counts() |
| | logger.info("π Language distribution:") |
| | for lang, count in languages.items(): |
| | percentage = (count / len(df)) * 100 |
| | logger.info(f" {lang}: {count} samples ({percentage:.1f}%)") |
| |
|
| | return texts |
| |
|
| | except Exception as e: |
| | logger.warning(f"β οΈ Failed to load optimized dataset: {e}") |
| | logger.info("π Falling back to original CodeSearchNet loading...") |
| | return load_codesearchnet_dataset(max_samples, checkpoint_manager) |
| |
|
| |
|
| | def load_codesearchnet_dataset( |
| | max_samples: int | None = None, |
| | checkpoint_manager: BeamCheckpointManager | None = None, |
| | ) -> list[str]: |
| | """Load and format the CodeSearchNet dataset for token frequency computation.""" |
| | from datasets import load_dataset |
| |
|
| | |
| | if max_samples is None: |
| | max_samples = distillation_config.tokenlearn_max_samples |
| |
|
| | logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}") |
| | logger.info(f"Limiting to {max_samples} samples for training efficiency") |
| | logger.info(f"Languages: {', '.join(languages_config.all)}") |
| |
|
| | |
| | texts = [] |
| | start_from = 0 |
| |
|
| | if checkpoint_manager: |
| | checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0) |
| | if checkpoint_data: |
| | cached_texts = checkpoint_data.get("data", {}).get("texts", []) |
| | if len(cached_texts) >= max_samples: |
| | logger.info(f"β
Resumed dataset loading: {len(cached_texts)} texts from checkpoint") |
| | return cached_texts[:max_samples] |
| | logger.info(f"π Partial dataset found: {len(cached_texts)} texts, continuing...") |
| | texts = cached_texts |
| | start_from = len(texts) |
| |
|
| | try: |
| | |
| | num_languages = len(languages_config.all) |
| | samples_per_language = max_samples // num_languages |
| | remaining_samples = max_samples % num_languages |
| |
|
| | logger.info(f"π Target distribution: {samples_per_language} samples per language") |
| | if remaining_samples > 0: |
| | logger.info(f"π Extra {remaining_samples} samples will be distributed to first languages") |
| |
|
| | |
| | language_texts: dict[str, list[str]] = {} |
| | total_collected = len(texts) |
| |
|
| | for i, language in enumerate(languages_config.all): |
| | if total_collected >= max_samples: |
| | break |
| |
|
| | logger.info(f"π Loading {language} training data...") |
| |
|
| | |
| | target_for_lang = samples_per_language |
| | if i < remaining_samples: |
| | target_for_lang += 1 |
| |
|
| | |
| | if language in language_texts and len(language_texts[language]) >= target_for_lang: |
| | continue |
| |
|
| | try: |
| | |
| | from datasets import load_dataset |
| |
|
| | dataset = load_dataset( |
| | codesearchnet_config.dataset_name, |
| | language, |
| | split="train", |
| | trust_remote_code=True, |
| | ) |
| |
|
| | lang_texts: list[str] = [] |
| | processed_count = 0 |
| |
|
| | for processed_count, example in enumerate(dataset, 1): |
| | if len(lang_texts) >= target_for_lang: |
| | break |
| |
|
| | |
| | doc_string = example.get("func_documentation_string", "").strip() |
| | code_string = example.get("func_code_string", "").strip() |
| |
|
| | if doc_string and code_string and len(doc_string.split()) >= 3 and len(code_string) > 50: |
| | |
| | text = f"Documentation: {doc_string}\nCode:\n{code_string}" |
| |
|
| | |
| | if len(text) <= 2048: |
| | lang_texts.append(text) |
| |
|
| | if processed_count % 5000 == 0: |
| | logger.info(f" {language}: processed {processed_count}, collected {len(lang_texts)}") |
| |
|
| | language_texts[language] = lang_texts |
| | total_collected += len(lang_texts) |
| | logger.info(f"β
{language}: collected {len(lang_texts)} samples") |
| |
|
| | except Exception as e: |
| | logger.warning(f"β οΈ Failed to load {language} data: {e}") |
| | continue |
| |
|
| | |
| | combined_texts = [] |
| |
|
| | |
| | if start_from > 0: |
| | combined_texts = texts[:start_from] |
| |
|
| | |
| | max_lang_samples = max(len(lang_texts) for lang_texts in language_texts.values()) if language_texts else 0 |
| |
|
| | for sample_idx in range(max_lang_samples): |
| | for language in languages_config.all: |
| | if len(combined_texts) >= max_samples: |
| | break |
| |
|
| | if language in language_texts and sample_idx < len(language_texts[language]): |
| | combined_texts.append(language_texts[language][sample_idx]) |
| |
|
| | if len(combined_texts) >= max_samples: |
| | break |
| |
|
| | |
| | combined_texts = combined_texts[:max_samples] |
| |
|
| | |
| | logger.info("π Final dataset distribution:") |
| | lang_counts: dict[str, int] = {} |
| | for text in combined_texts: |
| | |
| | if "def " in text and ":" in text: |
| | lang_counts["python"] = lang_counts.get("python", 0) + 1 |
| | elif "function " in text and "{" in text: |
| | lang_counts["javascript"] = lang_counts.get("javascript", 0) + 1 |
| | elif "public " in text and "class " in text: |
| | lang_counts["java"] = lang_counts.get("java", 0) + 1 |
| | elif "<?php" in text or "$" in text: |
| | lang_counts["php"] = lang_counts.get("php", 0) + 1 |
| | elif "func " in text and "end" in text: |
| | lang_counts["ruby"] = lang_counts.get("ruby", 0) + 1 |
| | elif "func " in text and "}" in text: |
| | lang_counts["go"] = lang_counts.get("go", 0) + 1 |
| | else: |
| | lang_counts["other"] = lang_counts.get("other", 0) + 1 |
| |
|
| | for lang, count in lang_counts.items(): |
| | percentage = (count / len(combined_texts)) * 100 |
| | logger.info(f" {lang}: {count} samples ({percentage:.1f}%)") |
| |
|
| | |
| | if checkpoint_manager: |
| | checkpoint_data = { |
| | "config_hash": get_current_config_hash(enable_training=True), |
| | "stage": "dataset", |
| | "step": 0, |
| | "timestamp": time.time(), |
| | "data": {"texts": combined_texts}, |
| | } |
| | checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0) |
| |
|
| | logger.info(f"Successfully loaded {len(combined_texts)} balanced code-documentation pairs from CodeSearchNet") |
| | return combined_texts |
| |
|
| | except Exception: |
| | logger.exception("Error loading CodeSearchNet dataset") |
| | return texts |
| |
|
| |
|
| | def generate_teacher_embeddings( |
| | teacher_model: SentenceTransformer, |
| | texts: list[str], |
| | checkpoint_manager: BeamCheckpointManager | None = None, |
| | ) -> torch.Tensor: |
| | """Generate teacher embeddings for code training with checkpoint support.""" |
| | logger.info(f"Generating teacher embeddings for {len(texts)} texts...") |
| |
|
| | |
| | if checkpoint_manager: |
| | volume_path = Path(VOLUME_CONFIG.mount_path) |
| | embeddings_path = volume_path / "embeddings_cache.pt" |
| | config_path = volume_path / "embeddings_config.json" |
| |
|
| | if embeddings_path.exists() and config_path.exists(): |
| | try: |
| | |
| | with config_path.open("r") as f: |
| | config_data = json.load(f) |
| |
|
| | current_hash = get_current_config_hash(enable_training=True) |
| | if config_data.get("config_hash") == current_hash: |
| | |
| | final_embeddings = torch.load(embeddings_path, map_location="cpu") |
| | num_expected = config_data.get("num_texts", len(texts)) |
| |
|
| | if final_embeddings.shape[0] >= num_expected: |
| | logger.info(f"β
Loaded embeddings from cache ({final_embeddings.shape[0]} embeddings)") |
| | return final_embeddings[: len(texts)] |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to load embeddings cache: {e}, regenerating...") |
| |
|
| | |
| | logger.info("Generating fresh teacher embeddings...") |
| |
|
| | batch_size = 16 |
| | embeddings_list = [] |
| |
|
| | for i in range(0, len(texts), batch_size): |
| | batch_texts = texts[i : i + batch_size] |
| |
|
| | try: |
| | batch_embeddings = teacher_model.encode( |
| | batch_texts, |
| | convert_to_tensor=True, |
| | batch_size=batch_size, |
| | show_progress_bar=False, |
| | normalize_embeddings=True, |
| | ) |
| | embeddings_list.append(batch_embeddings) |
| |
|
| | if i % (batch_size * 10) == 0: |
| | logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts") |
| |
|
| | except torch.cuda.OutOfMemoryError: |
| | logger.warning(f"GPU OOM with batch size {batch_size}, reducing...") |
| | torch.cuda.empty_cache() |
| | batch_size = max(1, batch_size // 2) |
| |
|
| | |
| | batch_embeddings = teacher_model.encode( |
| | batch_texts, |
| | convert_to_tensor=True, |
| | batch_size=batch_size, |
| | show_progress_bar=False, |
| | normalize_embeddings=True, |
| | ) |
| | embeddings_list.append(batch_embeddings) |
| |
|
| | |
| | teacher_embeddings = torch.cat(embeddings_list, dim=0) |
| |
|
| | |
| | if teacher_embeddings.dtype != torch.float32: |
| | teacher_embeddings = teacher_embeddings.to(torch.float32) |
| |
|
| | logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}") |
| |
|
| | |
| | if checkpoint_manager: |
| | try: |
| | volume_path = Path(VOLUME_CONFIG.mount_path) |
| | embeddings_path = volume_path / "embeddings_cache.pt" |
| | config_path = volume_path / "embeddings_config.json" |
| |
|
| | |
| | torch.save(teacher_embeddings, embeddings_path) |
| |
|
| | |
| | config_data = { |
| | "config_hash": get_current_config_hash(enable_training=True), |
| | "num_texts": len(texts), |
| | "embedding_shape": list(teacher_embeddings.shape), |
| | "timestamp": time.time(), |
| | } |
| |
|
| | with config_path.open("w") as f: |
| | json.dump(config_data, f, indent=2) |
| |
|
| | logger.info("πΎ Saved embeddings cache for future runs") |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to save embeddings cache: {e}") |
| |
|
| | return teacher_embeddings |
| |
|
| |
|
| | def tokenlearn_training( |
| | student_model: Any, |
| | teacher_model: SentenceTransformer, |
| | checkpoint_manager: BeamCheckpointManager | None = None, |
| | ) -> Any: |
| | """ |
| | Perform tokenlearn training following the official POTION approach. |
| | |
| | This follows the 4-step process: |
| | 1. Model2Vec distillation (already done - student_model) |
| | 2. Sentence transformer inference (create features) |
| | 3. Tokenlearn training |
| | """ |
| | from pathlib import Path |
| |
|
| | logger.info("π§ͺ Starting tokenlearn training (POTION approach)...") |
| |
|
| | |
| | teacher_model_name = getattr(teacher_model, "model_name", None) |
| | if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: |
| | |
| | first_module = next(iter(teacher_model._modules.values())) |
| | if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"): |
| | teacher_model_name = first_module.auto_model.name_or_path |
| |
|
| | if not teacher_model_name: |
| | teacher_model_name = "unknown_teacher" |
| |
|
| | |
| | teacher_slug = teacher_model_name.replace("/", "_").replace("-", "_") |
| | persistent_tokenlearn_dir = Path(directories.base).parent / "tokenlearn_cache" / teacher_slug |
| |
|
| | features_dir = persistent_tokenlearn_dir / "features" |
| | model_dir = persistent_tokenlearn_dir / "base_model" |
| | trained_dir = persistent_tokenlearn_dir / "trained_model" |
| |
|
| | features_dir.mkdir(parents=True, exist_ok=True) |
| | model_dir.mkdir(parents=True, exist_ok=True) |
| | trained_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | logger.info(f"π Using persistent tokenlearn directory: {persistent_tokenlearn_dir}") |
| |
|
| | |
| | student_model.save_pretrained(str(model_dir)) |
| | logger.info(f"πΎ Saved base model to {model_dir}") |
| |
|
| | |
| | logger.info("π Step 2: Creating features using sentence transformer...") |
| |
|
| | |
| | teacher_model_name = getattr(teacher_model, "model_name", None) |
| | if not teacher_model_name and hasattr(teacher_model, "_modules") and len(teacher_model._modules) > 0: |
| | |
| | |
| | first_module = next(iter(teacher_model._modules.values())) |
| | if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "name_or_path"): |
| | teacher_model_name = first_module.auto_model.name_or_path |
| |
|
| | logger.info(f"π Using teacher model: {teacher_model_name}") |
| |
|
| | |
| | dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir) |
| |
|
| | |
| | featurization_complete_marker = features_dir / ".featurization_complete" |
| | if featurization_complete_marker.exists() and verify_featurization_output(features_dir): |
| | logger.info("β
Found existing featurization checkpoint with valid output files") |
| | logger.info(f"π Using cached features from: {features_dir}") |
| |
|
| | |
| | output_files = list(features_dir.glob("*.npy")) + list(features_dir.glob("*.json")) |
| | logger.info(f"π Found {len(output_files)} cached feature files") |
| | else: |
| | if featurization_complete_marker.exists(): |
| | logger.warning("β οΈ Featurization marker exists but output files are missing - re-running featurization") |
| | featurization_complete_marker.unlink() |
| | logger.info("π No valid featurization checkpoint found - starting featurization...") |
| |
|
| | if not teacher_model_name: |
| | logger.warning("β οΈ Could not determine teacher model name, using fallback") |
| | teacher_model_name = "BAAI/bge-base-en-v1.5" |
| |
|
| | logger.info(f"π Using teacher model: {teacher_model_name}") |
| |
|
| | try: |
| | |
| | from datasets import load_dataset |
| |
|
| | from distiller.tokenlearn.featurize import featurize |
| |
|
| | logger.info("π Running tokenlearn featurization...") |
| | logger.info(f"π Dataset: {dataset_path} (config: {dataset_name})") |
| | logger.info(f"π Text field: {text_key}") |
| |
|
| | |
| | if dataset_name is None: |
| | |
| | dataset = load_dataset( |
| | "json", |
| | data_files=dataset_path, |
| | split="train", |
| | streaming=True, |
| | ) |
| | else: |
| | |
| | dataset = load_dataset( |
| | dataset_path, |
| | name=dataset_name, |
| | split="train", |
| | streaming=True, |
| | ) |
| |
|
| | |
| | featurize( |
| | dataset=iter(dataset), |
| | model=teacher_model, |
| | output_dir=str(features_dir), |
| | max_means=50000, |
| | batch_size=512, |
| | text_key=text_key, |
| | ) |
| |
|
| | logger.info("β
Featurization completed successfully") |
| |
|
| | |
| | featurization_complete_marker.touch() |
| | logger.info(f"πΎ Created featurization checkpoint: {featurization_complete_marker}") |
| |
|
| | except Exception as e: |
| | logger.exception("π₯ Tokenlearn featurization failed") |
| | logger.exception("π₯ Tokenlearn featurization is required for training - cannot proceed") |
| | msg = f"Tokenlearn featurization failed: {e}" |
| | raise RuntimeError(msg) from e |
| |
|
| | |
| | logger.info("π Step 3: Training using tokenlearn...") |
| |
|
| | |
| | training_complete_marker = trained_dir / ".training_complete" |
| | training_fallback_marker = trained_dir / ".training_fallback" |
| |
|
| | if training_complete_marker.exists() and verify_training_output(trained_dir): |
| | logger.info("β
Found existing training checkpoint with valid model files") |
| | logger.info(f"π Using cached trained model from: {trained_dir}") |
| |
|
| | |
| | model_files = [] |
| | for pattern in ["*.json", "*.safetensors", "*.bin"]: |
| | model_files.extend(list(trained_dir.glob(pattern))) |
| | for subdir in ["model", "model_weighted"]: |
| | subdir_path = trained_dir / subdir |
| | if subdir_path.exists(): |
| | model_files.extend(list(subdir_path.glob(pattern))) |
| | logger.info(f"π Found {len(model_files)} cached model files") |
| | elif training_fallback_marker.exists(): |
| | logger.warning("β οΈ Training fallback marker found - tokenlearn failed previously") |
| | logger.info("π Proceeding with fallback to base model (simple distillation)") |
| | |
| | else: |
| | if training_complete_marker.exists(): |
| | logger.warning("β οΈ Training marker exists but model files are missing - re-running training") |
| | training_complete_marker.unlink() |
| | logger.info("π No valid training checkpoint found - starting training...") |
| |
|
| | try: |
| | |
| | from distiller.tokenlearn.train import train_model |
| | from distiller.tokenlearn.utils import collect_means_and_texts |
| |
|
| | |
| | logger.info("π Attempting IMPROVED tokenlearn training with optimized parameters...") |
| | logger.info("π Using smaller vocabulary and conservative PCA to prevent overfitting") |
| |
|
| | |
| | paths = sorted(features_dir.glob("*.json")) |
| | train_txt, train_vec = collect_means_and_texts(paths) |
| |
|
| | logger.info(f"π Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training") |
| |
|
| | try: |
| | |
| | trained_model = train_model( |
| | model_name=str(teacher_model_name), |
| | train_txt=train_txt, |
| | train_vec=train_vec, |
| | device="cuda" if torch.cuda.is_available() else "cpu", |
| | vocab_size=25000, |
| | pca_dims=256, |
| | ) |
| |
|
| | |
| | trained_model.save_pretrained(str(trained_dir)) |
| | logger.info("β
IMPROVED tokenlearn training completed successfully") |
| | training_complete_marker.touch() |
| | logger.info(f"πΎ Created improved training checkpoint: {training_complete_marker}") |
| |
|
| | except Exception as e: |
| | logger.warning(f"β οΈ Improved training failed: {e}") |
| | logger.info("π Falling back to CONSERVATIVE tokenlearn training...") |
| |
|
| | |
| | try: |
| | trained_model = train_model( |
| | model_name=str(teacher_model_name), |
| | train_txt=train_txt, |
| | train_vec=train_vec, |
| | device="cuda" if torch.cuda.is_available() else "cpu", |
| | vocab_size=15000, |
| | pca_dims=128, |
| | ) |
| |
|
| | |
| | trained_model.save_pretrained(str(trained_dir)) |
| | logger.info("β
Conservative tokenlearn training completed successfully") |
| | training_complete_marker.touch() |
| | logger.info(f"πΎ Created conservative training checkpoint: {training_complete_marker}") |
| |
|
| | except Exception as e2: |
| | logger.exception("β Conservative tokenlearn training also failed") |
| | logger.exception("π₯ All training approaches failed - check output above for details") |
| |
|
| | |
| | training_fallback_marker = trained_dir / ".training_fallback" |
| | training_fallback_marker.touch() |
| |
|
| | logger.exception("π₯ Tokenlearn training failed completely") |
| | msg = f"All tokenlearn training approaches failed: {e2}" |
| | raise RuntimeError(msg) from e2 |
| |
|
| | except Exception as e: |
| | logger.warning("π₯ All tokenlearn training approaches failed") |
| | logger.exception("π₯ All training approaches failed completely - cannot proceed") |
| | msg = f"All training approaches failed: {e}" |
| | raise RuntimeError(msg) from e |
| |
|
| | |
| | logger.info("π¦ Step 4: Loading trained model and applying post-training re-regularization...") |
| |
|
| | |
| | training_fallback_marker = trained_dir / ".training_fallback" |
| | if training_fallback_marker.exists(): |
| | logger.error("β Tokenlearn training failed previously - cannot return trained model") |
| | logger.error("π₯ Training was requested but failed - this would be misleading to return base model") |
| | msg = "Tokenlearn training failed - cannot proceed with training pipeline" |
| | raise RuntimeError(msg) |
| |
|
| | try: |
| | from distiller.model2vec.model import StaticModel |
| |
|
| | |
| | trained_model_path = trained_dir / "model" |
| | if not trained_model_path.exists(): |
| | |
| | possible_paths = [ |
| | trained_dir / "model_weighted", |
| | trained_dir, |
| | ] |
| |
|
| | for path in possible_paths: |
| | if path.exists() and any(path.glob("*.json")): |
| | trained_model_path = path |
| | break |
| | else: |
| | logger.error(f"β Could not find trained model in {trained_dir}") |
| | logger.error("π₯ Training was requested but no trained model found - cannot proceed") |
| | msg = f"Trained model not found in {trained_dir} - training pipeline failed" |
| | raise RuntimeError(msg) |
| |
|
| | |
| | logger.info("π Loading model from tokenlearn training...") |
| | trained_model = StaticModel.from_pretrained(str(trained_model_path)) |
| |
|
| | |
| | logger.info("β
Tokenlearn training pipeline completed successfully") |
| | return trained_model |
| |
|
| | except ValueError as e: |
| | if "Number of tokens" in str(e) and "does not match number of vectors" in str(e): |
| | logger.exception("π₯ Token-vector mismatch in tokenlearn training") |
| | logger.exception("Error details") |
| | logger.exception("π§ This is a known issue with tokenlearn/Model2Vec integration") |
| | logger.exception("π₯ Training was requested but failed due to token-vector mismatch") |
| | msg = f"Tokenlearn training failed due to token-vector mismatch: {e}" |
| | raise RuntimeError(msg) from e |
| | logger.exception("π₯ Failed to load tokenlearn trained model") |
| | msg = f"Failed to load tokenlearn trained model: {e}" |
| | raise RuntimeError(msg) from e |
| | except Exception as e: |
| | logger.exception("π₯ Failed to load tokenlearn trained model") |
| | logger.exception("π₯ Cannot load trained model - training failed") |
| | msg = f"Failed to load tokenlearn trained model: {e}" |
| | raise RuntimeError(msg) from e |
| |
|
| |
|
| | def distill_single_teacher( |
| | teacher_model: str, |
| | enable_training: bool = False, |
| | use_beam_utilities: bool = False, |
| | pca_dims: int | None = None, |
| | ) -> dict[str, Any]: |
| | """ |
| | Distill a single teacher model with optional training. |
| | |
| | Args: |
| | teacher_model: Name of teacher model |
| | enable_training: Whether to enable advanced training |
| | use_beam_utilities: Whether to use Beam utilities |
| | pca_dims: PCA dimensions |
| | |
| | Returns: |
| | Dictionary with distillation results |
| | """ |
| | teacher_name = teacher_model.split("/")[-1].replace("-", "_") |
| | base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" |
| |
|
| | |
| | final_model_name = f"code_model2vec_{teacher_name}" |
| | if enable_training: |
| | final_model_name += "_fine_tuned" |
| | final_dir = Path(LOCAL_FINAL_DIR) / final_model_name |
| |
|
| | logger.info(f"\n{'=' * 60}") |
| | logger.info(f"π Processing teacher model: {teacher_model}") |
| | logger.info(f"π Teacher name: {teacher_name}") |
| | logger.info(f"π Training enabled: {enable_training}") |
| | logger.info(f"{'=' * 60}") |
| |
|
| | |
| | is_compatible, warning_msg = check_model_compatibility(teacher_model) |
| | if not is_compatible: |
| | logger.warning(f"β οΈ Known compatibility issue: {warning_msg}") |
| | logger.info("π§ Attempting distillation anyway, but may fail...") |
| |
|
| | |
| | workaround_type = try_model_workarounds(teacher_model) |
| | |
| |
|
| | start_time = time.time() |
| |
|
| | |
| | checkpoint_mgr = None |
| | if use_beam_utilities: |
| | try: |
| | _, checkpoint_mgr, model_mgr, _ = create_beam_utilities(VOLUME_CONFIG.name, VOLUME_CONFIG.mount_path) |
| | except Exception as e: |
| | logger.warning(f"Failed to initialize Beam utilities: {e}") |
| |
|
| | try: |
| | |
| | existing_final = check_existing_final_model(teacher_name, enable_training) |
| | if existing_final: |
| | logger.info(f"β
Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}") |
| | total_time = time.time() - start_time |
| | return { |
| | "teacher_model": teacher_model, |
| | "teacher_name": teacher_name, |
| | "status": "skipped_existing_final", |
| | "final_path": existing_final, |
| | "distillation_time": total_time, |
| | } |
| |
|
| | |
| | if use_beam_utilities and checkpoint_mgr: |
| | logger.info(f"π Syncing existing checkpoints for {teacher_name}...") |
| | sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints) |
| | if enable_training: |
| | sync_checkpoints_from_beam(VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints) |
| |
|
| | |
| | existing_base = check_existing_base_model(teacher_name) |
| | base_model = None |
| |
|
| | if existing_base: |
| | logger.info(f"β
Found existing base model: {teacher_name}") |
| | if enable_training: |
| | |
| | from distiller.model2vec.model import StaticModel |
| |
|
| | base_model = StaticModel.from_pretrained(existing_base) |
| | elif use_beam_utilities: |
| | synced = sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities) |
| | if synced: |
| | existing_base = str(base_dir) |
| | if enable_training: |
| | from distiller.model2vec.model import StaticModel |
| |
|
| | base_model = StaticModel.from_pretrained(existing_base) |
| |
|
| | if not existing_base: |
| | |
| | logger.info(f"π Creating base model for {teacher_name}") |
| |
|
| | |
| | workaround_type = try_model_workarounds(teacher_model) |
| |
|
| | if workaround_type == "salesforce": |
| | base_model = salesforce_model_distillation(teacher_model, str(base_dir), pca_dims) |
| | elif workaround_type == "baai": |
| | base_model = baai_bge_model_distillation(teacher_model, str(base_dir), pca_dims) |
| | else: |
| | base_model = simple_distillation(teacher_model, str(base_dir), pca_dims) |
| |
|
| | if base_model is None: |
| | total_time = time.time() - start_time |
| | return { |
| | "teacher_model": teacher_model, |
| | "teacher_name": teacher_name, |
| | "status": "failed_base_distillation", |
| | "error": "Simple distillation failed", |
| | "distillation_time": total_time, |
| | } |
| |
|
| | |
| | if use_beam_utilities: |
| | sync_model_to_beam(teacher_name, str(base_dir), use_beam_utilities) |
| | if checkpoint_mgr: |
| | sync_checkpoints_to_beam( |
| | VOLUME_CONFIG.name, f"distillation_{teacher_name}", directories.checkpoints |
| | ) |
| |
|
| | existing_base = str(base_dir) |
| |
|
| | |
| | if enable_training and base_model is not None: |
| | |
| | logger.info(f"π§ͺ Starting tokenlearn training for {teacher_name}") |
| |
|
| | try: |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | teacher_st_model = load_model_with_flash_attention(teacher_model, device) |
| |
|
| | |
| | final_model = tokenlearn_training( |
| | base_model, |
| | teacher_st_model, |
| | checkpoint_mgr, |
| | ) |
| |
|
| | |
| | final_dir.mkdir(parents=True, exist_ok=True) |
| | final_model.save_pretrained(str(final_dir)) |
| |
|
| | |
| | if use_beam_utilities: |
| | sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities) |
| | if checkpoint_mgr: |
| | sync_checkpoints_to_beam( |
| | VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints |
| | ) |
| |
|
| | del teacher_st_model |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | except RuntimeError as e: |
| | |
| | logger.exception(f"β Training failed for {teacher_name}") |
| |
|
| | |
| | if "teacher_st_model" in locals(): |
| | del teacher_st_model |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | total_time = time.time() - start_time |
| | return { |
| | "teacher_model": teacher_model, |
| | "teacher_name": teacher_name, |
| | "status": "failed_training", |
| | "error": f"Training failed: {e!s}", |
| | "base_path": existing_base, |
| | "distillation_time": total_time, |
| | } |
| |
|
| | else: |
| | |
| | logger.info(f"π Copying base to final for {teacher_name}") |
| | if not copy_base_to_final(teacher_name, enable_training): |
| | total_time = time.time() - start_time |
| | return { |
| | "teacher_model": teacher_model, |
| | "teacher_name": teacher_name, |
| | "status": "failed_copy_to_final", |
| | "error": "Failed to copy base to final", |
| | "distillation_time": total_time, |
| | } |
| |
|
| | total_time = time.time() - start_time |
| | return { |
| | "teacher_model": teacher_model, |
| | "teacher_name": teacher_name, |
| | "status": "success", |
| | "enable_training": enable_training, |
| | "base_path": existing_base, |
| | "final_path": str(final_dir), |
| | "distillation_time": total_time, |
| | } |
| |
|
| | except Exception as e: |
| | logger.exception(f"β Failed to process {teacher_model}") |
| | total_time = time.time() - start_time |
| | return { |
| | "teacher_model": teacher_model, |
| | "teacher_name": teacher_name, |
| | "status": "failed", |
| | "error": str(e), |
| | "distillation_time": total_time, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def run_local_distillation( |
| | teacher_models: list[str] | None = None, |
| | enable_training: bool = False, |
| | pca_dims: int | None = None, |
| | clear_cache: bool = False, |
| | ) -> dict[str, Any]: |
| | """Run distillation locally.""" |
| | logger.info("π₯οΈ Running distillation locally") |
| |
|
| | if teacher_models is None: |
| | teacher_models = DEFAULT_TEACHER_MODELS |
| |
|
| | results = {} |
| | successful_models = [] |
| |
|
| | logger.info("π Starting distillation workflow") |
| | logger.info(f"π Processing {len(teacher_models)} teacher models") |
| | logger.info(f"π Training enabled: {enable_training}") |
| |
|
| | |
| | models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS |
| |
|
| | logger.info(f"π Teacher models to process: {len(models_to_distill)}") |
| | for i, model in enumerate(models_to_distill, 1): |
| | logger.info(f" {i}. {model}") |
| |
|
| | |
| | if clear_cache: |
| | logger.info("π§Ή Clearing cache for known problematic models...") |
| | problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] |
| | for model in problematic_models: |
| | if model in models_to_distill: |
| | clear_model_cache(model) |
| |
|
| | |
| | |
| | |
| | for teacher_model in models_to_distill: |
| | result = distill_single_teacher( |
| | teacher_model=teacher_model, |
| | enable_training=enable_training, |
| | use_beam_utilities=False, |
| | pca_dims=pca_dims, |
| | ) |
| |
|
| | teacher_name = result["teacher_name"] |
| | results[teacher_name] = result |
| |
|
| | if result["status"] == "success" or result["status"].startswith("skipped"): |
| | successful_models.append(teacher_name) |
| | elif result["status"] == "failed_training": |
| | |
| | logger.warning(f"β οΈ Training failed for {teacher_name}, but base distillation may have succeeded") |
| |
|
| | |
| | logger.info("\nπ DISTILLATION WORKFLOW COMPLETE!") |
| | logger.info(f"π Successful models: {len(successful_models)}") |
| | logger.info(f"π Training mode: {'Enabled' if enable_training else 'Basic distillation only'}") |
| |
|
| | for model_name in successful_models: |
| | result = results[model_name] |
| | logger.info(f"β
{model_name}: {result['teacher_model']}") |
| |
|
| | |
| | results_summary = { |
| | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| | "enable_training": enable_training, |
| | "successful_models": successful_models, |
| | "all_results": results, |
| | "total_successful": len(successful_models), |
| | "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
| | } |
| |
|
| | |
| | results_file = Path(LOCAL_BASE_DIR).parent / "distillation_results.json" |
| | results_file.parent.mkdir(parents=True, exist_ok=True) |
| | with results_file.open("w") as f: |
| | json.dump(results_summary, f, indent=2) |
| |
|
| | logger.info(f"π Results summary saved to: {results_file}") |
| |
|
| | return results_summary |
| |
|
| |
|
| | def _beam_distill_internal( |
| | teacher_models: list[str] | None = None, |
| | enable_training: bool = False, |
| | pca_dims: int | None = None, |
| | clear_cache: bool = False, |
| | ) -> dict[str, Any]: |
| | """Shared internal implementation for beam distillation.""" |
| | if teacher_models is None: |
| | teacher_models = DEFAULT_TEACHER_MODELS |
| |
|
| | |
| | if clear_cache: |
| | logger.info("π§Ή Clearing cache for known problematic models...") |
| | problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] |
| | for model in problematic_models: |
| | if model in teacher_models: |
| | clear_model_cache(model) |
| |
|
| | results = {} |
| | successful_models = [] |
| |
|
| | logger.info("π Starting Beam distillation workflow") |
| | logger.info(f"π Processing {len(teacher_models)} teacher models") |
| | logger.info(f"π Training enabled: {enable_training}") |
| |
|
| | |
| | models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS |
| |
|
| | logger.info(f"π Teacher models to process: {len(models_to_distill)}") |
| | for i, model in enumerate(models_to_distill, 1): |
| | logger.info(f" {i}. {model}") |
| |
|
| | for teacher_model in models_to_distill: |
| | result = distill_single_teacher( |
| | teacher_model=teacher_model, |
| | enable_training=enable_training, |
| | use_beam_utilities=True, |
| | pca_dims=pca_dims, |
| | ) |
| |
|
| | teacher_name = result["teacher_name"] |
| | results[teacher_name] = result |
| |
|
| | if result["status"] == "success" or result["status"].startswith("skipped"): |
| | successful_models.append(teacher_name) |
| | elif result["status"] == "failed_training": |
| | |
| | logger.warning(f"β οΈ Training failed for {teacher_name}, but base distillation may have succeeded") |
| |
|
| | |
| | logger.info("\nπ BEAM DISTILLATION WORKFLOW COMPLETE!") |
| | logger.info(f"π Successful models: {len(successful_models)}") |
| |
|
| | |
| | volume_path = Path(VOLUME_CONFIG.mount_path) |
| | results_summary = { |
| | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| | "enable_training": enable_training, |
| | "successful_models": successful_models, |
| | "all_results": results, |
| | "total_successful": len(successful_models), |
| | "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
| | } |
| |
|
| | results_file = volume_path / "distillation_results.json" |
| | with results_file.open("w") as f: |
| | json.dump(results_summary, f, indent=2) |
| |
|
| | logger.info(f"π Beam results saved to: {results_file}") |
| |
|
| | return results_summary |
| |
|
| |
|
| | @function(**get_training_function_kwargs()) |
| | def _beam_train_models( |
| | teacher_models: list[str] | None = None, |
| | enable_training: bool = True, |
| | pca_dims: int | None = None, |
| | clear_cache: bool = False, |
| | ) -> dict[str, Any]: |
| | """Beam function for training (distillation + tokenlearn).""" |
| | logger.info("βοΈ Running training on Beam") |
| | return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache) |
| |
|
| |
|
| | @function(**get_distillation_function_kwargs()) |
| | def _beam_distill_models( |
| | teacher_models: list[str] | None = None, |
| | enable_training: bool = False, |
| | pca_dims: int | None = None, |
| | clear_cache: bool = False, |
| | ) -> dict[str, Any]: |
| | """Beam function for basic distillation only.""" |
| | logger.info("βοΈ Running distillation on Beam") |
| | return _beam_distill_internal(teacher_models, enable_training, pca_dims, clear_cache) |
| |
|
| |
|
| | def run_beam_distillation( |
| | teacher_models: list[str] | None = None, |
| | enable_training: bool = False, |
| | pca_dims: int | None = None, |
| | clear_cache: bool = False, |
| | ) -> dict[str, Any]: |
| | """Run distillation on Beam and sync results.""" |
| | logger.info("βοΈ Running distillation on Beam with local sync") |
| |
|
| | try: |
| | |
| | beam_function = _beam_train_models if enable_training else _beam_distill_models |
| |
|
| | |
| | results = beam_function.remote(teacher_models, enable_training, pca_dims, clear_cache) |
| |
|
| | |
| | if not results: |
| | logger.error("β Beam execution failed or returned no results") |
| | return { |
| | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| | "enable_training": enable_training, |
| | "successful_models": [], |
| | "all_results": {}, |
| | "total_successful": 0, |
| | "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
| | "error": "Beam execution failed", |
| | } |
| |
|
| | |
| | if results.get("successful_models"): |
| | logger.info("π₯ Syncing models from Beam to local directories...") |
| |
|
| | for teacher_name in results["successful_models"]: |
| | |
| | base_dir = Path(LOCAL_BASE_DIR) / f"code_model2vec_{teacher_name}" |
| | sync_model_from_beam(teacher_name, str(base_dir), use_beam_utilities=True) |
| |
|
| | |
| | if enable_training: |
| | final_dir = Path(LOCAL_FINAL_DIR) / f"code_model2vec_{teacher_name}" |
| | sync_model_from_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities=True) |
| | else: |
| | |
| | copy_base_to_final(teacher_name, enable_training) |
| |
|
| | logger.info("β
All models synced from Beam") |
| |
|
| | return results |
| |
|
| | except Exception as e: |
| | logger.exception("β Beam distillation failed with exception") |
| | return { |
| | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| | "enable_training": enable_training, |
| | "successful_models": [], |
| | "all_results": {}, |
| | "total_successful": 0, |
| | "total_attempted": len(teacher_models or DEFAULT_TEACHER_MODELS), |
| | "error": str(e), |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def main( |
| | use_beam: Annotated[bool, typer.Option(help="Use Beam for distillation")] = False, |
| | train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False, |
| | teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None, |
| | pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None, |
| | clear_cache: Annotated[ |
| | bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation") |
| | ] = False, |
| | clear_checkpoints: Annotated[ |
| | bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training") |
| | ] = False, |
| | use_optimized_dataset: Annotated[ |
| | bool, |
| | typer.Option( |
| | "--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset" |
| | ), |
| | ] = False, |
| | dataset_path: Annotated[ |
| | str | None, |
| | typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"), |
| | ] = None, |
| | ) -> None: |
| | """Unified distillation command with optional training.""" |
| | logger.info("π Starting unified Model2Vec distillation workflow") |
| |
|
| | |
| | distillation_config.use_optimized_dataset = use_optimized_dataset |
| | distillation_config.custom_dataset_path = dataset_path |
| |
|
| | if use_optimized_dataset and train: |
| | dataset_source = dataset_path or "code_model2vec/dataset" |
| | logger.info(f"π― Using optimized dataset from: {dataset_source}") |
| | elif train: |
| | logger.info("π― Using C4 dataset for training (following POTION approach)") |
| |
|
| | logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}") |
| | logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}") |
| |
|
| | |
| | models_to_distill = teacher_models if teacher_models else DEFAULT_TEACHER_MODELS |
| |
|
| | logger.info(f"π Teacher models to process: {len(models_to_distill)}") |
| | for i, model in enumerate(models_to_distill, 1): |
| | logger.info(f" {i}. {model}") |
| |
|
| | |
| | if clear_cache: |
| | logger.info("π§Ή Clearing cache for known problematic models...") |
| | problematic_models = ["BAAI/bge-code-v1", "jinaai/jina-embeddings-v3", "Salesforce/SFR-Embedding-Code-2B_R"] |
| | for model in problematic_models: |
| | if model in models_to_distill: |
| | clear_model_cache(model) |
| |
|
| | |
| | if clear_checkpoints and train: |
| | logger.info("π§Ή Clearing tokenlearn checkpoints to force fresh featurization and training...") |
| | for teacher_model in models_to_distill: |
| | teacher_model.split("/")[-1].replace("-", "_") |
| |
|
| | |
| | teacher_slug = teacher_model.replace("/", "_").replace("-", "_") |
| | persistent_tokenlearn_dir = Path(LOCAL_BASE_DIR).parent / "tokenlearn_cache" / teacher_slug |
| |
|
| | features_dir = persistent_tokenlearn_dir / "features" |
| | trained_dir = persistent_tokenlearn_dir / "trained_model" |
| |
|
| | |
| | if features_dir.exists() or trained_dir.exists(): |
| | clear_tokenlearn_checkpoints(features_dir, trained_dir) |
| | logger.info(f"ποΈ Cleared persistent tokenlearn checkpoints for {teacher_model}") |
| | else: |
| | logger.info(f"βΉοΈ No tokenlearn checkpoints found for {teacher_model}") |
| | elif clear_checkpoints and not train: |
| | logger.warning("β οΈ --clear-checkpoints flag is only relevant when training is enabled (--train)") |
| |
|
| | |
| | if use_beam: |
| | results = run_beam_distillation( |
| | teacher_models=models_to_distill, |
| | enable_training=train, |
| | pca_dims=pca_dims, |
| | clear_cache=clear_cache, |
| | ) |
| | else: |
| | results = run_local_distillation( |
| | teacher_models=models_to_distill, |
| | enable_training=train, |
| | pca_dims=pca_dims, |
| | clear_cache=clear_cache, |
| | ) |
| |
|
| | |
| | if not results or not isinstance(results, dict): |
| | logger.error("β Distillation workflow failed - no valid results returned") |
| | results = { |
| | "total_successful": 0, |
| | "total_attempted": len(models_to_distill), |
| | "error": "Workflow failed", |
| | } |
| |
|
| | |
| | successful_count = results.get("total_successful", 0) |
| | total_attempted = results.get("total_attempted", 0) |
| |
|
| | logger.info("\nπ UNIFIED DISTILLATION WORKFLOW COMPLETED!") |
| | logger.info(f"π Successfully processed: {successful_count}/{total_attempted} models") |
| | logger.info(f"π Base models saved to: {LOCAL_BASE_DIR}") |
| | logger.info(f"π Final models saved to: {LOCAL_FINAL_DIR}") |
| |
|
| | if train: |
| | logger.info("π Advanced training was enabled - models include CodeSearchNet specialization") |
| | else: |
| | logger.info("π Basic distillation only - use --train flag to enable advanced training") |
| |
|
| |
|
| | def check_model_compatibility(teacher_model: str) -> tuple[bool, str | None]: |
| | """ |
| | Check if a model has known compatibility issues with Model2Vec. |
| | |
| | Returns: |
| | Tuple of (is_compatible, warning_message) |
| | """ |
| | known_incompatible = { |
| | "BAAI/bge-code-v1": "Qwen2Tokenizer lacks backend_tokenizer attribute", |
| | "jinaai/jina-embeddings-v3": "Missing custom transformers module dependencies", |
| | "Salesforce/SFR-Embedding-Code-2B_R": "Device placement issues with meta tensors", |
| | } |
| |
|
| | if teacher_model in known_incompatible: |
| | return False, known_incompatible[teacher_model] |
| |
|
| | |
| | if "qwen2" in teacher_model.lower() and "bge" in teacher_model.lower(): |
| | return False, "BGE models with Qwen2 tokenizers may have compatibility issues" |
| |
|
| | if "jina" in teacher_model.lower() and "embeddings-v3" in teacher_model.lower(): |
| | return False, "Jina embeddings v3 models may have missing dependencies" |
| |
|
| | if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower(): |
| | return False, "Salesforce SFR embedding models may have device placement issues" |
| |
|
| | return True, None |
| |
|
| |
|
| | def clear_model_cache(model_name: str) -> bool: |
| | """Clear HuggingFace cache for a specific model.""" |
| | try: |
| | import shutil |
| | from pathlib import Path |
| |
|
| | |
| | cache_dir = Path.home() / ".cache" / "huggingface" |
| |
|
| | |
| | model_slug = model_name.replace("/", "--") |
| |
|
| | |
| | transformers_cache = cache_dir / "transformers" / model_slug |
| | if transformers_cache.exists(): |
| | shutil.rmtree(transformers_cache) |
| | logger.info(f"ποΈ Cleared transformers cache for {model_name}") |
| |
|
| | |
| | hub_cache = cache_dir / "hub" / f"models--{model_slug}" |
| | if hub_cache.exists(): |
| | shutil.rmtree(hub_cache) |
| | logger.info(f"ποΈ Cleared hub cache for {model_name}") |
| |
|
| | |
| | modules_cache = cache_dir / "modules" / "transformers_modules" / model_name.split("/")[0] |
| | if modules_cache.exists(): |
| | shutil.rmtree(modules_cache) |
| | logger.info(f"ποΈ Cleared modules cache for {model_name}") |
| |
|
| | return True |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to clear cache for {model_name}: {e}") |
| | return False |
| |
|
| |
|
| | def try_model_workarounds(teacher_model: str) -> str | None: |
| | """ |
| | Try specific workarounds for problematic models. |
| | |
| | Returns: |
| | The type of workaround needed ("salesforce", "baai", etc.) or None if no workaround available |
| | """ |
| | if "salesforce" in teacher_model.lower() and "sfr-embedding" in teacher_model.lower(): |
| | logger.info("π§ Salesforce SFR model detected - will use specialized distillation") |
| | return "salesforce" |
| |
|
| | if "baai" in teacher_model.lower() and ("bge-code" in teacher_model.lower() or "bge-m3" in teacher_model.lower()): |
| | logger.info("π§ BAAI BGE model detected - will use specialized distillation") |
| | return "baai" |
| |
|
| | return None |
| |
|
| |
|
| | def salesforce_model_distillation( |
| | teacher_model: str, |
| | output_dir: str, |
| | pca_dims: int | None = None, |
| | ) -> Any: |
| | """Special distillation function for Salesforce SFR models that handles device placement issues.""" |
| | if pca_dims is None: |
| | pca_dims = int(distillation_config.optimal_pca_dims) |
| |
|
| | output_path = Path(output_dir) |
| | output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | logger.info(f"π Salesforce-specific distillation: {teacher_model} β {output_dir}") |
| | logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") |
| |
|
| | start_time = time.time() |
| |
|
| | try: |
| | import torch |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | |
| | logger.info("π§ Loading model with enhanced device settings...") |
| |
|
| | |
| | try: |
| | logger.info("π Attempting with to_empty() method...") |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True) |
| |
|
| | |
| | model = AutoModel.from_pretrained( |
| | teacher_model, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float16, |
| | device_map="meta", |
| | ) |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | |
| | model = model.to_empty(device=device) |
| | else: |
| | device = torch.device("cpu") |
| | model = model.to_empty(device=device) |
| |
|
| | |
| | model = model.to(torch.float16 if torch.cuda.is_available() else torch.float32) |
| |
|
| | logger.info("β
Successfully loaded with to_empty() method") |
| |
|
| | except Exception as e: |
| | logger.warning(f"to_empty() method failed: {e}") |
| |
|
| | |
| | logger.info("π Falling back to SentenceTransformer method...") |
| | sentence_model = load_model_with_flash_attention( |
| | teacher_model, |
| | device="cpu", |
| | ) |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | sentence_model = sentence_model.to("cuda") |
| |
|
| | |
| | model = sentence_model[0].auto_model |
| | tokenizer = sentence_model.tokenizer |
| |
|
| | logger.info("β
Successfully loaded with SentenceTransformer method") |
| |
|
| | |
| | from distiller.model2vec.distill.distillation import distill_from_model |
| |
|
| | distilled_model = distill_from_model( |
| | model=model, |
| | tokenizer=tokenizer, |
| | pca_dims=int(pca_dims), |
| | apply_zipf=bool(distillation_config.apply_zipf), |
| | sif_coefficient=float(distillation_config.sif_coefficient), |
| | ) |
| |
|
| | logger.info("β
Core distillation completed successfully") |
| |
|
| | |
| | distilled_model.save_pretrained(str(output_path)) |
| | logger.info(f"πΎ Model saved to {output_path}") |
| |
|
| | |
| | logger.info(f"Model type: {type(distilled_model)}") |
| | if hasattr(distilled_model, "embedding"): |
| | logger.info(f"Embedding shape: {distilled_model.embedding.shape}") |
| | logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}") |
| |
|
| | total_time = time.time() - start_time |
| | logger.info(f"π Salesforce distillation completed in {total_time:.2f} seconds") |
| |
|
| | |
| | if "sentence_model" in locals(): |
| | del sentence_model |
| | del model |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | return distilled_model |
| |
|
| | except Exception: |
| | logger.exception(f"β Salesforce-specific distillation failed for {teacher_model}") |
| | return None |
| |
|
| |
|
| | def baai_bge_model_distillation( |
| | teacher_model: str, |
| | output_dir: str, |
| | pca_dims: int | None = None, |
| | ) -> Any: |
| | """Special distillation function for BAAI BGE models that handles Qwen2Tokenizer compatibility issues.""" |
| | if pca_dims is None: |
| | pca_dims = int(distillation_config.optimal_pca_dims) |
| |
|
| | output_path = Path(output_dir) |
| | output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | logger.info(f"π BAAI BGE-specific distillation: {teacher_model} β {output_dir}") |
| | logger.info(f"π PCA dims: {pca_dims}, SIF: {distillation_config.sif_coefficient}") |
| |
|
| | start_time = time.time() |
| |
|
| | try: |
| | import torch |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | logger.info("π§ Loading BAAI model with tokenizer workaround...") |
| |
|
| | |
| | success = False |
| |
|
| | |
| | try: |
| | logger.info("π Attempting with SentenceTransformer wrapper...") |
| | sentence_model = load_model_with_flash_attention(teacher_model) |
| |
|
| | |
| | model = sentence_model[0].auto_model |
| | tokenizer = sentence_model.tokenizer |
| |
|
| | |
| | test_encoding = tokenizer.encode("test", return_tensors="pt") |
| | logger.info("β
SentenceTransformer method successful") |
| | success = True |
| |
|
| | except Exception as e: |
| | logger.warning(f"SentenceTransformer method failed: {e}") |
| |
|
| | |
| | try: |
| | logger.info("π Attempting with tokenizer replacement...") |
| | from transformers import BertTokenizerFast |
| |
|
| | |
| | model = AutoModel.from_pretrained(teacher_model, trust_remote_code=True) |
| |
|
| | |
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(teacher_model, trust_remote_code=True) |
| | except Exception: |
| | |
| | logger.info("π Falling back to BERT tokenizer...") |
| | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") |
| |
|
| | logger.info("β
Tokenizer replacement method successful") |
| | success = True |
| |
|
| | except Exception as e2: |
| | logger.warning(f"Tokenizer replacement method failed: {e2}") |
| |
|
| | if not success: |
| | logger.error("β All BAAI model loading methods failed") |
| | return None |
| |
|
| | |
| | from distiller.model2vec.distill.distillation import distill_from_model |
| |
|
| | distilled_model = distill_from_model( |
| | model=model, |
| | tokenizer=tokenizer, |
| | pca_dims=int(pca_dims), |
| | apply_zipf=bool(distillation_config.apply_zipf), |
| | sif_coefficient=float(distillation_config.sif_coefficient), |
| | ) |
| |
|
| | logger.info("β
Core distillation completed successfully") |
| |
|
| | |
| | distilled_model.save_pretrained(str(output_path)) |
| | logger.info(f"πΎ Model saved to {output_path}") |
| |
|
| | |
| | logger.info(f"Model type: {type(distilled_model)}") |
| | if hasattr(distilled_model, "embedding"): |
| | logger.info(f"Embedding shape: {distilled_model.embedding.shape}") |
| | logger.info(f"Embedding dtype: {distilled_model.embedding.dtype}") |
| |
|
| | total_time = time.time() - start_time |
| | logger.info(f"π BAAI BGE distillation completed in {total_time:.2f} seconds") |
| |
|
| | |
| | if "sentence_model" in locals(): |
| | del sentence_model |
| | del model |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | return distilled_model |
| |
|
| | except Exception: |
| | logger.exception(f"β BAAI BGE-specific distillation failed for {teacher_model}") |
| | return None |
| |
|
| |
|
| | def clear_tokenlearn_checkpoints(features_dir: Path, trained_dir: Path) -> None: |
| | """Clear tokenlearn checkpoint markers to force re-execution of steps.""" |
| | featurization_marker = features_dir / ".featurization_complete" |
| | training_marker = trained_dir / ".training_complete" |
| |
|
| | if featurization_marker.exists(): |
| | featurization_marker.unlink() |
| | logger.info(f"ποΈ Cleared featurization checkpoint: {featurization_marker}") |
| |
|
| | if training_marker.exists(): |
| | training_marker.unlink() |
| | logger.info(f"ποΈ Cleared training checkpoint: {training_marker}") |
| |
|
| |
|
| | def verify_featurization_output(features_dir: Path) -> bool: |
| | """Verify that featurization output files actually exist and are valid.""" |
| | if not features_dir.exists(): |
| | return False |
| |
|
| | |
| |
|
| | |
| | return any(list(features_dir.glob(file_pattern)) for file_pattern in ["*.npy", "*.json", "*.pt", "*.pkl"]) |
| |
|
| |
|
| | def verify_training_output(trained_dir: Path) -> bool: |
| | """Verify that training output files actually exist and are valid.""" |
| | if not trained_dir.exists(): |
| | return False |
| |
|
| | |
| | model_files = ["config.json", "model.safetensors", "modules.json", "tokenizer.json"] |
| | for model_file in model_files: |
| | if (trained_dir / model_file).exists(): |
| | return True |
| |
|
| | |
| | for subdir in ["model", "model_weighted"]: |
| | subdir_path = trained_dir / subdir |
| | if subdir_path.exists(): |
| | for model_file in model_files: |
| | if (subdir_path / model_file).exists(): |
| | return True |
| |
|
| | return False |
| |
|
| |
|
| | def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]: |
| | """ |
| | Prepare dataset for tokenlearn featurization. |
| | |
| | Returns: |
| | Tuple of (dataset_path, dataset_name, text_key) for tokenlearn |
| | """ |
| | if distillation_config.use_optimized_dataset: |
| | return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir) |
| | return _prepare_original_dataset_for_tokenlearn() |
| |
|
| |
|
| | def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]: |
| | """Prepare custom optimized dataset for tokenlearn featurization.""" |
| | logger.info("π― Preparing custom optimized dataset for tokenlearn...") |
| |
|
| | |
| | from .dataset import create_optimized_dataset, load_optimized_dataset |
| |
|
| | |
| | custom_dataset_dir = ( |
| | Path(distillation_config.custom_dataset_path) |
| | if distillation_config.custom_dataset_path |
| | else Path("code_model2vec/dataset") |
| | ) |
| | tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset" |
| |
|
| | |
| | if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists(): |
| | logger.info("π Custom dataset not found - creating optimized dataset...") |
| | create_optimized_dataset( |
| | max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, |
| | output_dir=custom_dataset_dir, |
| | create_multiple_formats=False, |
| | ) |
| |
|
| | |
| | logger.info(f"π Loading custom dataset from {custom_dataset_dir}") |
| | train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train") |
| |
|
| | |
| | tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | train_json_path = tokenlearn_dataset_dir / "train.json" |
| |
|
| | |
| | import json |
| |
|
| | with train_json_path.open("w") as f: |
| | for text in train_df["text"]: |
| | json.dump({"text": text}, f) |
| | f.write("\n") |
| |
|
| | logger.info(f"β
Prepared custom dataset with {len(train_df)} samples for tokenlearn") |
| | logger.info(f"πΎ Saved JSON dataset to {train_json_path}") |
| |
|
| | |
| | return str(train_json_path), None, "text" |
| |
|
| |
|
| | def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]: |
| | """Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach).""" |
| | logger.info("π Using C4 dataset for tokenlearn (following POTION approach)...") |
| | return ( |
| | str(distillation_config.tokenlearn_dataset), |
| | str(distillation_config.tokenlearn_dataset_name), |
| | str(distillation_config.tokenlearn_text_key), |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | typer.run(main) |
| |
|