Sarthak
feat: overhaul distiller package with unified CLI, enhanced evaluation, and modular structure
454e47c
| """ | |
| Common utilities for the distiller package. | |
| This module provides shared functionality used across multiple components | |
| including model discovery, result management, and initialization helpers. | |
| """ | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from types import TracebackType | |
| from typing import Any | |
| from .beam_utils import ( | |
| BeamCheckpointManager, | |
| BeamEvaluationManager, | |
| BeamModelManager, | |
| BeamVolumeManager, | |
| create_beam_utilities, | |
| ) | |
| from .config import VolumeConfig, get_safe_model_name, get_volume_config, setup_logging | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # BEAM UTILITIES MANAGEMENT | |
| # ============================================================================= | |
| class BeamContext: | |
| """Context manager for Beam utilities with consistent initialization.""" | |
| def __init__(self, workflow: str, volume_config: VolumeConfig | None = None) -> None: | |
| """ | |
| Initialize Beam context. | |
| Args: | |
| workflow: Workflow type (distill, evaluate, benchmark, etc.) | |
| volume_config: Optional custom volume config, otherwise inferred from workflow | |
| """ | |
| self.workflow = workflow | |
| self.volume_config = volume_config or get_volume_config() | |
| self.volume_manager: BeamVolumeManager | None = None | |
| self.checkpoint_manager: BeamCheckpointManager | None = None | |
| self.model_manager: BeamModelManager | None = None | |
| self.evaluation_manager: BeamEvaluationManager | None = None | |
| def __enter__(self) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]: | |
| """Enter context and initialize utilities.""" | |
| logger.info(f"🚀 Initializing Beam utilities for {self.workflow}") | |
| logger.info(f"📁 Volume: {self.volume_config.name} at {self.volume_config.mount_path}") | |
| self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager = ( | |
| create_beam_utilities(self.volume_config.name, self.volume_config.mount_path) | |
| ) | |
| return self.volume_manager, self.checkpoint_manager, self.model_manager, self.evaluation_manager | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc_val: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| """Exit context with cleanup if needed.""" | |
| if exc_type: | |
| logger.error(f"❌ Error in Beam context for {self.workflow}: {exc_val}") | |
| else: | |
| logger.info(f"✅ Beam context for {self.workflow} completed successfully") | |
| def get_beam_utilities() -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]: | |
| """ | |
| Get Beam utilities for a specific workflow. | |
| Returns: | |
| Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager) | |
| """ | |
| volume_config = get_volume_config() | |
| return create_beam_utilities(volume_config.name, volume_config.mount_path) | |
| # ============================================================================= | |
| # MODEL DISCOVERY | |
| # ============================================================================= | |
| def discover_simplified_models(base_path: str | Path = ".") -> list[str]: | |
| """ | |
| Discover simplified distillation models in the specified directory. | |
| Args: | |
| base_path: Base path to search for models | |
| Returns: | |
| List of model paths sorted alphabetically | |
| """ | |
| base = Path(base_path) | |
| # Look for models in common locations | |
| search_patterns = [ | |
| "code_model2vec/final/**/", | |
| "final/**/", | |
| "code_model2vec_*/", | |
| "*/config.json", | |
| "*.safetensors", | |
| ] | |
| discovered_models = [] | |
| for pattern in search_patterns: | |
| matches = list(base.glob(pattern)) | |
| for match in matches: | |
| if match.is_dir(): | |
| # Check if it's a valid model directory | |
| if (match / "config.json").exists() or (match / "model.safetensors").exists(): | |
| discovered_models.append(str(match)) | |
| elif match.name == "config.json": | |
| # Add parent directory if config.json found | |
| discovered_models.append(str(match.parent)) | |
| # Remove duplicates and sort | |
| unique_models = sorted(set(discovered_models)) | |
| logger.info(f"🔍 Discovered {len(unique_models)} models in {base_path}") | |
| for model in unique_models: | |
| logger.info(f" 📁 {model}") | |
| return unique_models | |
| def validate_model_path(model_path: str | Path, volume_manager: BeamVolumeManager | None = None) -> str | None: | |
| """ | |
| Validate and resolve model path, checking local filesystem and Beam volumes. | |
| Args: | |
| model_path: Path to model (can be local path or HuggingFace model name) | |
| volume_manager: Optional volume manager for Beam volume checks | |
| Returns: | |
| Resolved model path or None if not found | |
| """ | |
| path = Path(model_path) | |
| # Check if it's a HuggingFace model name | |
| if "/" in str(model_path) and not path.exists() and not str(model_path).startswith("/"): | |
| logger.info(f"📥 Treating as HuggingFace model: {model_path}") | |
| return str(model_path) | |
| # Check local filesystem | |
| if path.exists(): | |
| logger.info(f"✅ Found local model: {model_path}") | |
| return str(path) | |
| # Check Beam volume if available | |
| if volume_manager: | |
| volume_path = Path(volume_manager.mount_path) / path.name | |
| if volume_path.exists(): | |
| logger.info(f"✅ Found model in Beam volume: {volume_path}") | |
| return str(volume_path) | |
| # Check volume root | |
| root_path = Path(volume_manager.mount_path) | |
| if (root_path / "config.json").exists(): | |
| logger.info(f"✅ Found model in Beam volume root: {root_path}") | |
| return str(root_path) | |
| logger.warning(f"⚠️ Model not found: {model_path}") | |
| return None | |
| # ============================================================================= | |
| # RESULT MANAGEMENT | |
| # ============================================================================= | |
| def save_results_with_backup( | |
| results: dict[str, Any], | |
| primary_path: str | Path, | |
| model_name: str, | |
| result_type: str = "evaluation", | |
| volume_manager: BeamVolumeManager | None = None, | |
| evaluation_manager: BeamEvaluationManager | None = None, | |
| ) -> bool: | |
| """ | |
| Save results with multiple backup strategies. | |
| Args: | |
| results: Results dictionary to save | |
| primary_path: Primary save location | |
| model_name: Model name for filename generation | |
| result_type: Type of results (evaluation, benchmark, etc.) | |
| volume_manager: Optional volume manager for Beam storage | |
| evaluation_manager: Optional evaluation manager for specialized storage | |
| Returns: | |
| True if saved successfully to at least one location | |
| """ | |
| success_count = 0 | |
| safe_name = get_safe_model_name(model_name) | |
| # Save to primary location | |
| try: | |
| primary = Path(primary_path) | |
| primary.mkdir(parents=True, exist_ok=True) | |
| filename = f"{result_type}_{safe_name}.json" | |
| filepath = primary / filename | |
| with filepath.open("w") as f: | |
| json.dump(results, f, indent=2, default=str) | |
| logger.info(f"💾 Saved {result_type} results to: {filepath}") | |
| success_count += 1 | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to save to primary location: {e}") | |
| # Save to Beam volume if available | |
| if volume_manager: | |
| try: | |
| volume_path = Path(volume_manager.mount_path) / f"{result_type}_results" | |
| volume_path.mkdir(parents=True, exist_ok=True) | |
| filename = f"{result_type}_{safe_name}.json" | |
| filepath = volume_path / filename | |
| with filepath.open("w") as f: | |
| json.dump(results, f, indent=2, default=str) | |
| logger.info(f"💾 Saved {result_type} results to Beam volume: {filepath}") | |
| success_count += 1 | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to save to Beam volume: {e}") | |
| # Save via evaluation manager if available and appropriate | |
| if evaluation_manager and result_type == "evaluation": | |
| try: | |
| success = evaluation_manager.save_evaluation_results(model_name, results) | |
| if success: | |
| logger.info(f"💾 Saved via evaluation manager for {model_name}") | |
| success_count += 1 | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to save via evaluation manager: {e}") | |
| return success_count > 0 | |
| def load_existing_results( | |
| model_name: str, | |
| result_type: str = "evaluation", | |
| search_paths: list[str | Path] | None = None, | |
| volume_manager: BeamVolumeManager | None = None, | |
| evaluation_manager: BeamEvaluationManager | None = None, | |
| ) -> dict[str, Any] | None: | |
| """ | |
| Load existing results from multiple possible locations. | |
| Args: | |
| model_name: Model name to search for | |
| result_type: Type of results to load | |
| search_paths: Additional paths to search | |
| volume_manager: Optional volume manager | |
| evaluation_manager: Optional evaluation manager | |
| Returns: | |
| Results dictionary if found, None otherwise | |
| """ | |
| safe_name = get_safe_model_name(model_name) | |
| filename = f"{result_type}_{safe_name}.json" | |
| # Search in provided paths | |
| if search_paths: | |
| for search_path in search_paths: | |
| filepath = Path(search_path) / filename | |
| if filepath.exists(): | |
| try: | |
| with filepath.open("r") as f: | |
| results = json.load(f) | |
| logger.info(f"📂 Loaded existing {result_type} results from: {filepath}") | |
| return results | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to load from {filepath}: {e}") | |
| # Search in Beam volume | |
| if volume_manager: | |
| volume_path = Path(volume_manager.mount_path) / f"{result_type}_results" / filename | |
| if volume_path.exists(): | |
| try: | |
| with volume_path.open("r") as f: | |
| results = json.load(f) | |
| logger.info(f"📂 Loaded existing {result_type} results from Beam volume: {volume_path}") | |
| return results | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to load from Beam volume: {e}") | |
| # Try evaluation manager | |
| if evaluation_manager and result_type == "evaluation": | |
| try: | |
| results = evaluation_manager.load_evaluation_results(model_name) | |
| if results: | |
| logger.info(f"📂 Loaded existing {result_type} results via evaluation manager") | |
| return results | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to load via evaluation manager: {e}") | |
| logger.info(f"ℹ️ No existing {result_type} results found for {model_name}") | |
| return None | |
| # ============================================================================= | |
| # WORKFLOW HELPERS | |
| # ============================================================================= | |
| def print_workflow_summary( | |
| workflow_name: str, | |
| total_items: int, | |
| processed_items: int, | |
| skipped_items: int, | |
| execution_time: float | None = None, | |
| ) -> None: | |
| """Print a standardized workflow summary.""" | |
| logger.info(f"\n✅ {workflow_name} complete!") | |
| logger.info(f"📊 Total items: {total_items}") | |
| logger.info(f"✨ Newly processed: {processed_items}") | |
| logger.info(f"⏭️ Skipped (already done): {skipped_items}") | |
| if execution_time: | |
| logger.info(f"⏱️ Execution time: {execution_time:.2f} seconds") | |
| def check_existing_results( | |
| items: list[str], | |
| result_type: str, | |
| search_paths: list[str | Path] | None = None, | |
| volume_manager: BeamVolumeManager | None = None, | |
| ) -> tuple[list[str], list[str]]: | |
| """ | |
| Check which items already have results and which need processing. | |
| Args: | |
| items: List of items (model names, etc.) to check | |
| result_type: Type of results to check for | |
| search_paths: Paths to search for existing results | |
| volume_manager: Optional volume manager | |
| Returns: | |
| Tuple of (items_to_process, items_to_skip) | |
| """ | |
| to_process = [] | |
| to_skip = [] | |
| for item in items: | |
| existing = load_existing_results(item, result_type, search_paths, volume_manager) | |
| if existing: | |
| to_skip.append(item) | |
| else: | |
| to_process.append(item) | |
| return to_process, to_skip | |
| # ============================================================================= | |
| # INITIALIZATION | |
| # ============================================================================= | |
| def initialize_distiller_logging(level: int = logging.INFO) -> None: | |
| """Initialize logging for distiller package.""" | |
| setup_logging(level) | |
| logger.info("🚀 Distiller package initialized") | |
| # Ensure logging is set up when module is imported | |
| initialize_distiller_logging() | |