codemalt / src /distiller /beam_utils.py
Sarthak
refactor(beam-utils): use direct file operations for beam volumes
12d70ca
"""
Beam Cloud Utilities for Model Distillation and Evaluation.
This module provides comprehensive utilities for managing Beam volumes, checkpoints,
and file operations across distillation, evaluation, and analysis workflows.
Features:
- Volume management (direct file operations when mounted)
- Checkpoint operations (save, load, cleanup, resume)
- File transfer utilities (copy, move, sync)
- Evaluation result management
- Model artifact handling
- Distributed storage optimization
"""
# ruff: noqa: S603, S607, PLW1510
import json
import logging
import shutil
import subprocess
import time
from pathlib import Path
from typing import Any
# Configure logging
logger = logging.getLogger(__name__)
def _is_running_on_beam() -> bool:
"""
Detect if we're running on Beam platform or locally.
On Beam, volumes are mounted as directories. Locally, we need to use beam CLI.
"""
import os
# Check for Beam environment variables
beam_env_vars = [
"BEAM_TASK_ID",
"BEAM_FUNCTION_ID",
"BEAM_RUN_ID",
"BEAM_JOB_ID",
"BEAM_CONTAINER_ID",
]
for env_var in beam_env_vars:
if os.environ.get(env_var):
return True
# Check for common Beam mount paths
beam_mount_paths = [
"/volumes", # Common Beam volume mount
"/mnt/beam",
"/var/beam",
"/beam",
]
return any(Path(mount_path).exists() for mount_path in beam_mount_paths)
def _check_beam_cli_available() -> bool:
"""
Check if beam CLI is available for local file operations.
Returns:
True if beam CLI is available, False otherwise
"""
try:
result = subprocess.run(["beam", "--version"], capture_output=True, text=True, timeout=10)
return result.returncode == 0
except (FileNotFoundError, subprocess.TimeoutExpired):
return False
class BeamVolumeManager:
"""Manager for Beam distributed storage volumes using direct file operations."""
def __init__(self, volume_name: str, mount_path: str = "./data") -> None:
"""
Initialize Beam Volume Manager.
Args:
volume_name: Name of the Beam volume
mount_path: Local mount path for the volume (should match Beam function mount path)
"""
self.volume_name = volume_name
self.mount_path = Path(mount_path)
self.mount_path.mkdir(parents=True, exist_ok=True)
def exists(self) -> bool:
"""Check if the volume mount path exists."""
return self.mount_path.exists()
def list_contents(self, subpath: str = "") -> list[dict[str, Any]]:
"""List contents of the volume directory."""
try:
target_path = self.mount_path / subpath if subpath else self.mount_path
if not target_path.exists():
logger.warning(f"⚠️ Path does not exist: {target_path}")
return []
contents: list[dict[str, Any]] = []
for item in target_path.iterdir():
stat = item.stat()
contents.append(
{
"name": item.name,
"size": f"{stat.st_size / (1024 * 1024):.2f}MB" if item.is_file() else "0MB",
"modified": time.ctime(stat.st_mtime),
"is_dir": item.is_dir(),
"path": str(item.relative_to(self.mount_path)),
}
)
return sorted(contents, key=lambda x: (not x["is_dir"], x["name"]))
except Exception:
logger.exception("❌ Error listing contents")
return []
def copy_file(self, src: str | Path, dst: str | Path) -> bool:
"""Copy a file within the volume."""
try:
src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dst_path)
logger.info(f"📄 Copied {src_path.name} to {dst_path}")
return True
except Exception:
logger.exception("❌ Error copying file")
return False
def copy_directory(self, src: str | Path, dst: str | Path) -> bool:
"""Copy a directory within the volume."""
try:
src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
if dst_path.exists():
shutil.rmtree(dst_path)
shutil.copytree(src_path, dst_path)
logger.info(f"📁 Copied directory {src_path.name} to {dst_path}")
return True
except Exception:
logger.exception("❌ Error copying directory")
return False
def move_file(self, src: str | Path, dst: str | Path) -> bool:
"""Move a file within the volume."""
try:
src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(src_path), str(dst_path))
logger.info(f"➡️ Moved {src_path.name} to {dst_path}")
return True
except Exception:
logger.exception("❌ Error moving file")
return False
def remove_file(self, file_path: str | Path) -> bool:
"""Remove a file from the volume."""
try:
target_path = self.mount_path / file_path if not Path(file_path).is_absolute() else Path(file_path)
if target_path.exists():
if target_path.is_file():
target_path.unlink()
logger.info(f"🗑️ Removed file: {target_path.name}")
else:
logger.warning(f"⚠️ Path is not a file: {target_path}")
return False
return True
logger.warning(f"⚠️ File does not exist: {target_path}")
return False
except Exception:
logger.exception("❌ Error removing file")
return False
def remove_directory(self, dir_path: str | Path) -> bool:
"""Remove a directory from the volume."""
try:
target_path = self.mount_path / dir_path if not Path(dir_path).is_absolute() else Path(dir_path)
if target_path.exists() and target_path.is_dir():
shutil.rmtree(target_path)
logger.info(f"🗑️ Removed directory: {target_path.name}")
return True
logger.warning(f"⚠️ Directory does not exist: {target_path}")
return False
except Exception:
logger.exception("❌ Error removing directory")
return False
def cleanup_old_files(self, pattern: str = "*", older_than_days: int = 7, subpath: str = "") -> list[str]:
"""Clean up old files in the volume based on age and pattern."""
try:
target_path = self.mount_path / subpath if subpath else self.mount_path
if not target_path.exists():
return []
cutoff_time = time.time() - (older_than_days * 24 * 3600)
removed_files: list[str] = []
for item in target_path.rglob(pattern):
if item.is_file() and item.stat().st_mtime < cutoff_time:
try:
item.unlink()
removed_files.append(str(item.relative_to(self.mount_path)))
logger.info(f"🧹 Removed old file: {item.name}")
except Exception as e:
logger.warning(f"⚠️ Could not remove {item.name}: {e}")
if removed_files:
logger.info(f"🧹 Cleaned up {len(removed_files)} old files")
return removed_files
except Exception:
logger.exception("❌ Error during cleanup")
return []
def get_size(self, subpath: str = "") -> int:
"""Get total size of volume or subpath in bytes."""
try:
target_path = self.mount_path / subpath if subpath else self.mount_path
if not target_path.exists():
return 0
total_size = 0
for item in target_path.rglob("*"):
if item.is_file():
total_size += item.stat().st_size
return total_size
except Exception:
logger.exception("❌ Error calculating size")
return 0
class BeamCheckpointManager:
"""Manager for checkpoint operations on Beam volumes with stage-based organization."""
def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
"""
Initialize Checkpoint Manager.
Args:
volume_manager: BeamVolumeManager instance
checkpoint_prefix: Prefix for checkpoint files
"""
self.volume = volume_manager
self.checkpoint_prefix = checkpoint_prefix
self.checkpoint_base_dir = self.volume.mount_path / checkpoint_prefix
self.checkpoint_base_dir.mkdir(parents=True, exist_ok=True)
def _get_stage_dir(self, stage: str) -> Path:
"""Get stage-specific checkpoint directory."""
stage_dir = self.checkpoint_base_dir / stage
stage_dir.mkdir(parents=True, exist_ok=True)
return stage_dir
def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
"""Save checkpoint to volume in stage-specific directory."""
try:
stage_dir = self._get_stage_dir(stage)
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
checkpoint_path = stage_dir / checkpoint_filename
with checkpoint_path.open("w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(f"💾 Saved checkpoint: {stage} step {step}")
return True
except Exception:
logger.exception("❌ Error saving checkpoint")
return False
def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
"""Load checkpoint from volume stage-specific directory."""
try:
stage_dir = self._get_stage_dir(stage)
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
checkpoint_path = stage_dir / checkpoint_filename
if checkpoint_path.exists():
with checkpoint_path.open("r") as f:
data = json.load(f)
logger.info(f"📂 Loaded checkpoint: {stage} step {step}")
return data
logger.info(f"Info: No checkpoint found: {stage} step {step}")
return None
except Exception:
logger.exception("❌ Error loading checkpoint")
return None
def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
"""Get the latest checkpoint for a stage."""
try:
stage_dir = self._get_stage_dir(stage)
# Find checkpoint files for this stage
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
stage_checkpoints: list[tuple[int, Path]] = []
for checkpoint_file in stage_dir.glob(pattern):
try:
# Extract step number from filename
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
step = int(step_str)
stage_checkpoints.append((step, checkpoint_file))
except ValueError:
continue
if not stage_checkpoints:
logger.info(f"Info: No checkpoints found for stage: {stage}")
return None
# Get the latest checkpoint
latest_step, latest_file = max(stage_checkpoints, key=lambda x: x[0])
# Load the latest checkpoint
with latest_file.open("r") as f:
data = json.load(f)
logger.info(f"📂 Found latest checkpoint: {stage} step {latest_step}")
return latest_step, data
except Exception:
logger.exception("❌ Error finding latest checkpoint")
return None
def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
"""Clean up old checkpoints, keeping only the latest N."""
try:
stage_dir = self._get_stage_dir(stage)
# Find checkpoint files for this stage
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
stage_checkpoints: list[tuple[int, Path]] = []
for checkpoint_file in stage_dir.glob(pattern):
try:
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
step = int(step_str)
stage_checkpoints.append((step, checkpoint_file))
except ValueError:
continue
# Sort by step number (descending)
stage_checkpoints.sort(key=lambda x: x[0], reverse=True)
# Remove old checkpoints
removed_files: list[str] = []
if len(stage_checkpoints) > keep_latest:
for _step, checkpoint_file in stage_checkpoints[keep_latest:]:
try:
checkpoint_file.unlink()
removed_files.append(checkpoint_file.name)
logger.info(f"🧹 Removed old checkpoint: {checkpoint_file.name}")
except Exception as e:
logger.warning(f"⚠️ Could not remove {checkpoint_file.name}: {e}")
if removed_files:
logger.info(f"🧹 Cleaned up {len(removed_files)} old checkpoints for {stage}")
return removed_files
except Exception:
logger.exception("❌ Error cleaning up checkpoints")
return []
def list_checkpoints(self, stage: str | None = None) -> list[dict[str, Any]]:
"""List all checkpoints, optionally filtered by stage."""
try:
checkpoints: list[dict[str, Any]] = []
if stage:
# List checkpoints for specific stage
stage_dir = self._get_stage_dir(stage)
pattern = f"{self.checkpoint_prefix}_{stage}_*.json"
for checkpoint_file in stage_dir.glob(pattern):
name_parts = checkpoint_file.stem.split("_")
if len(name_parts) >= 4:
try:
step = int(name_parts[3])
except ValueError:
step = 0
stat = checkpoint_file.stat()
checkpoints.append(
{
"stage": stage,
"step": step,
"filename": checkpoint_file.name,
"size": f"{stat.st_size / 1024:.1f}KB",
"modified": time.ctime(stat.st_mtime),
}
)
else:
# List checkpoints for all stages
for stage_dir in self.checkpoint_base_dir.iterdir():
if stage_dir.is_dir():
stage_name = stage_dir.name
pattern = f"{self.checkpoint_prefix}_{stage_name}_*.json"
for checkpoint_file in stage_dir.glob(pattern):
name_parts = checkpoint_file.stem.split("_")
if len(name_parts) >= 4:
try:
step = int(name_parts[3])
except ValueError:
step = 0
stat = checkpoint_file.stat()
checkpoints.append(
{
"stage": stage_name,
"step": step,
"filename": checkpoint_file.name,
"size": f"{stat.st_size / 1024:.1f}KB",
"modified": time.ctime(stat.st_mtime),
}
)
return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
except Exception:
logger.exception("❌ Error listing checkpoints")
return []
class BeamModelManager:
"""Manager for model artifacts on Beam volumes."""
def __init__(self, volume_manager: BeamVolumeManager, model_prefix: str = "models") -> None:
"""
Initialize Model Manager.
Args:
volume_manager: BeamVolumeManager instance
model_prefix: Prefix for model files
"""
self.volume = volume_manager
self.model_prefix = model_prefix
self.model_dir = self.volume.mount_path / model_prefix
self.model_dir.mkdir(parents=True, exist_ok=True)
def save_model(self, model_name: str, local_model_path: str | Path) -> bool:
"""Save model to Beam volume."""
try:
local_path = Path(local_model_path)
if not local_path.exists():
logger.error(f"❌ Model path does not exist: {local_path}")
return False
model_dest = self.model_dir / model_name
if local_path.is_dir():
# Copy entire directory
if model_dest.exists():
shutil.rmtree(model_dest)
shutil.copytree(local_path, model_dest)
logger.info(f"💾 Saved model directory {model_name}")
else:
# Copy single file
model_dest.mkdir(exist_ok=True)
shutil.copy2(local_path, model_dest / local_path.name)
logger.info(f"💾 Saved model file {model_name}")
return True
except Exception:
logger.exception("❌ Error saving model")
return False
def load_model(self, model_name: str, local_model_path: str | Path = "./models") -> bool:
"""Load model from Beam volume."""
try:
local_path = Path(local_model_path)
local_path.mkdir(parents=True, exist_ok=True)
model_source = self.model_dir / model_name
model_dest = local_path / model_name
if not model_source.exists():
logger.error(f"❌ Model does not exist: {model_name}")
return False
if model_dest.exists():
if model_dest.is_dir():
shutil.rmtree(model_dest)
else:
model_dest.unlink()
if model_source.is_dir():
shutil.copytree(model_source, model_dest)
else:
shutil.copy2(model_source, model_dest)
logger.info(f"📥 Loaded model {model_name}")
return True
except Exception:
logger.exception("❌ Error loading model")
return False
def list_models(self) -> list[dict[str, str]]:
"""List all models in the volume."""
try:
models: list[dict[str, str]] = []
if not self.model_dir.exists():
return models
for item in self.model_dir.iterdir():
if item.is_dir():
stat = item.stat()
# Calculate directory size
total_size = sum(f.stat().st_size for f in item.rglob("*") if f.is_file())
models.append(
{
"name": item.name,
"size": f"{total_size / (1024 * 1024):.1f}MB",
"modified": time.ctime(stat.st_mtime),
}
)
return sorted(models, key=lambda x: x["name"])
except Exception:
logger.exception("❌ Error listing models")
return []
def remove_model(self, model_name: str) -> bool:
"""Remove a model from the volume."""
try:
model_path = self.model_dir / model_name
if model_path.exists():
if model_path.is_dir():
shutil.rmtree(model_path)
else:
model_path.unlink()
logger.info(f"🗑️ Removed model: {model_name}")
return True
logger.warning(f"⚠️ Model does not exist: {model_name}")
return False
except Exception:
logger.exception("❌ Error removing model")
return False
class BeamEvaluationManager:
"""Manager for evaluation results on Beam volumes."""
def __init__(
self,
volume_manager: BeamVolumeManager,
results_prefix: str = "evaluation_results",
) -> None:
"""
Initialize Evaluation Manager.
Args:
volume_manager: BeamVolumeManager instance
results_prefix: Prefix for evaluation result files
"""
self.volume = volume_manager
self.results_prefix = results_prefix
self.results_dir = self.volume.mount_path / results_prefix
self.results_dir.mkdir(parents=True, exist_ok=True)
def save_evaluation_results(
self, model_name: str, results: dict[str, Any], eval_type: str = "codesearchnet"
) -> bool:
"""Save evaluation results to Beam volume."""
try:
results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
results_path = self.results_dir / results_filename
with results_path.open("w") as f:
json.dump(results, f, indent=2, default=str)
logger.info(f"💾 Saved evaluation results for {model_name}")
return True
except Exception:
logger.exception("❌ Error saving evaluation results")
return False
def load_evaluation_results(self, model_name: str, eval_type: str = "codesearchnet") -> dict[str, Any] | None:
"""Load evaluation results from Beam volume."""
try:
results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
results_path = self.results_dir / results_filename
if results_path.exists():
with results_path.open("r") as f:
results = json.load(f)
logger.info(f"📂 Loaded evaluation results for {model_name}")
return results
logger.info(f"Info: No evaluation results found for {model_name}")
return None
except Exception:
logger.exception("❌ Error loading evaluation results")
return None
def list_evaluation_results(self, eval_type: str | None = None) -> list[dict[str, str]]:
"""List all evaluation results."""
try:
results: list[dict[str, str]] = []
if not self.results_dir.exists():
return results
for item in self.results_dir.glob("*.json"):
# Parse evaluation info
if eval_type is None or item.name.startswith(f"{eval_type}_eval_"):
# Extract model name from filename
model_name = item.name.replace("_eval_", "_").replace(".json", "")
if eval_type:
model_name = model_name.replace(f"{eval_type}_", "")
stat = item.stat()
results.append(
{
"model_name": model_name,
"filename": item.name,
"size": f"{stat.st_size / 1024:.1f}KB",
"modified": time.ctime(stat.st_mtime),
}
)
return sorted(results, key=lambda x: x["model_name"])
except Exception:
logger.exception("❌ Error listing evaluation results")
return []
def remove_evaluation_results(self, model_name: str, eval_type: str = "codesearchnet") -> bool:
"""Remove evaluation results from volume."""
try:
results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
results_path = self.results_dir / results_filename
if results_path.exists():
results_path.unlink()
logger.info(f"🗑️ Removed evaluation results for {model_name}")
return True
logger.warning(f"⚠️ Evaluation results do not exist for {model_name}")
return False
except Exception:
logger.exception("❌ Error removing evaluation results")
return False
def create_beam_utilities(
volume_name: str, mount_path: str = "./data"
) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
"""
Create a complete set of Beam utilities.
Args:
volume_name: Name of the Beam volume
mount_path: Local mount path for the volume
Returns:
Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager)
"""
volume_manager = BeamVolumeManager(volume_name, mount_path)
checkpoint_manager = BeamCheckpointManager(volume_manager)
model_manager = BeamModelManager(volume_manager)
evaluation_manager = BeamEvaluationManager(volume_manager)
return volume_manager, checkpoint_manager, model_manager, evaluation_manager
def cleanup_beam_workspace(volume_name: str, mount_path: str = "./data", confirm: bool = False) -> bool:
"""
Clean up entire Beam workspace including all data in the mounted volume.
Args:
volume_name: Name of the volume to clean up
mount_path: Mount path of the volume
confirm: If True, skip confirmation prompt
Returns:
True if cleanup successful, False otherwise
"""
if not confirm:
response = input(f"⚠️ This will delete all data in volume '{volume_name}' at '{mount_path}'. Continue? (y/N): ")
if response.lower() != "y":
logger.info("Cleanup cancelled")
return False
try:
volume_manager = BeamVolumeManager(volume_name, mount_path)
if not volume_manager.exists():
logger.info(f"Volume mount path does not exist: {mount_path}")
return True
# List what will be deleted
contents = volume_manager.list_contents()
logger.info(f"🗑️ Will delete {len(contents)} items from volume '{volume_name}'")
# Delete all contents in the mount path
for item in volume_manager.mount_path.iterdir():
try:
if item.is_dir():
shutil.rmtree(item)
logger.info(f"🗑️ Removed directory: {item.name}")
else:
item.unlink()
logger.info(f"🗑️ Removed file: {item.name}")
except Exception as e:
logger.warning(f"⚠️ Could not remove {item.name}: {e}")
logger.info(f"✅ Successfully cleaned up Beam workspace: {volume_name}")
return True
except Exception:
logger.exception("❌ Error during cleanup")
return False
def get_workspace_info(volume_name: str, mount_path: str = "./data") -> dict[str, Any]:
"""
Get information about the Beam workspace.
Args:
volume_name: Name of the volume
mount_path: Mount path of the volume
Returns:
Dictionary with workspace information
"""
try:
volume_manager = BeamVolumeManager(volume_name, mount_path)
if not volume_manager.exists():
return {
"volume_name": volume_name,
"mount_path": mount_path,
"exists": False,
"size": 0,
"contents": [],
}
contents = volume_manager.list_contents()
total_size = volume_manager.get_size()
return {
"volume_name": volume_name,
"mount_path": str(volume_manager.mount_path),
"exists": True,
"size": total_size,
"size_mb": f"{total_size / (1024 * 1024):.1f}MB",
"num_items": len(contents),
"contents": contents[:10], # First 10 items
}
except Exception:
logger.exception("❌ Error getting workspace info")
return {
"volume_name": volume_name,
"mount_path": mount_path,
"error": "Error occurred",
}
# Example usage functions
def example_distillation_workflow() -> None:
"""Example of using Beam utilities for distillation workflow."""
volume_name = "gte_qwen2_m2v_code"
mount_path = "./gte_qwen2_m2v_code" # Should match Beam function mount path
# Create utilities
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
# Check if volume exists
if volume_mgr.exists():
logger.info(f"Volume {volume_name} is mounted at {mount_path}")
else:
logger.warning(f"Volume {volume_name} not found at {mount_path}")
return
# Save a checkpoint
checkpoint_data = {
"epoch": 1,
"loss": 0.25,
"model_state": "dummy_state",
"timestamp": time.time(),
}
checkpoint_mgr.save_checkpoint("training", checkpoint_data, step=1000)
# List checkpoints
checkpoints = checkpoint_mgr.list_checkpoints("training")
logger.info(f"Found {len(checkpoints)} training checkpoints")
# Save evaluation results
eval_results = {
"model_name": "gte_qwen2_m2v_code",
"overall": {"ndcg@10": 0.35, "mrr": 0.42},
"timestamp": time.time(),
}
eval_mgr.save_evaluation_results("gte_qwen2_m2v_code", eval_results)
# Get workspace info
info = get_workspace_info(volume_name, mount_path)
logger.info(f"Workspace info: {info}")
def download_evaluation_results_from_beam(
volume_name: str,
remote_results_dir: str = "evaluation_results",
local_results_dir: str = "code_model2vec/evaluation_results",
) -> bool:
"""
Download evaluation result files from Beam volume to local directory.
Args:
volume_name: Name of the Beam volume
remote_results_dir: Directory path in the Beam volume containing results
local_results_dir: Local directory to download results to
Returns:
True if download successful, False otherwise
"""
try:
local_path = Path(local_results_dir)
local_path.mkdir(parents=True, exist_ok=True)
if _is_running_on_beam():
# Direct file operations when running on Beam
remote_path = Path(volume_name) / remote_results_dir
if not remote_path.exists():
logger.info("ℹ️ No evaluation results directory found on Beam")
return True
# Find and copy JSON result files
remote_files = list(remote_path.glob("*.json"))
downloaded_files = []
for result_file in remote_files:
local_file_path = local_path / result_file.name
try:
shutil.copy2(result_file, local_file_path)
downloaded_files.append(result_file.name)
logger.info(f"📥 Downloaded: {result_file.name}")
# Delete the file from Beam volume after successful download
result_file.unlink()
logger.info(f"🗑️ Deleted from volume: {result_file.name}")
except Exception as e:
logger.warning(f"⚠️ Failed to download {result_file.name}: {e}")
if downloaded_files:
logger.info(f"✅ Downloaded {len(downloaded_files)} evaluation result files")
return True
logger.info("ℹ️ No new evaluation files to download")
return True
# When running locally, we cannot access Beam volumes directly
# This would require a proper Beam storage API or CLI tool
logger.info("ℹ️ Evaluation results download from local environment not supported")
logger.info("ℹ️ Evaluation results are only accessible when running on Beam platform")
return True
except Exception:
logger.exception("❌ Error downloading evaluation results from Beam")
return False
def download_specific_evaluation_file(
volume_name: str,
model_name: str,
remote_results_dir: str = "evaluation_results",
local_results_dir: str = "code_model2vec/evaluation_results",
file_prefix: str = "codesearchnet_eval",
) -> bool:
"""
Download a specific evaluation or benchmark result file from Beam volume using direct file operations.
Args:
volume_name: Name of the Beam volume
model_name: Name of the model whose results to download
remote_results_dir: Directory path in the Beam volume containing results
local_results_dir: Local directory to download results to
file_prefix: Prefix for the file (e.g., 'codesearchnet_eval', 'benchmark')
Returns:
True if download successful, False otherwise
"""
try:
local_path = Path(local_results_dir)
local_path.mkdir(parents=True, exist_ok=True)
# Generate filename following the pattern
safe_model_name = model_name.replace("/", "_")
filename = f"{file_prefix}_{safe_model_name}.json"
# When running on Beam, the volume is mounted as a directory
remote_file_path = Path(volume_name) / remote_results_dir / filename
local_file_path = local_path / filename
if not remote_file_path.exists():
logger.warning(f"⚠️ No {file_prefix} results found for {model_name} on Beam")
return False
# Copy the specific file
import shutil
shutil.copy2(remote_file_path, local_file_path)
logger.info(f"📥 Downloaded {file_prefix} results for {model_name}")
# Delete the file from Beam volume after successful download
remote_file_path.unlink()
logger.info(f"🗑️ Deleted {file_prefix} results for {model_name} from volume")
return True
except Exception:
logger.exception(f"❌ Error downloading {file_prefix} results for {model_name}")
return False
def download_model_from_beam(
volume_name: str,
model_name: str,
local_dir: str,
) -> bool:
"""
Download a model from Beam volume to local directory using direct file operations.
Args:
volume_name: Name of the Beam volume
model_name: Name of the model to download
local_dir: Local directory to download model to
Returns:
True if download successful, False otherwise
"""
try:
local_path = Path(local_dir)
local_path.mkdir(parents=True, exist_ok=True)
# When running on Beam, the volume is mounted as a directory
remote_model_path = Path(volume_name) / "models" / model_name
local_model_path = local_path / model_name
if not remote_model_path.exists():
logger.warning(f"⚠️ Model {model_name} not found in Beam volume at {remote_model_path}")
return False
# Copy the model directory
import shutil
if local_model_path.exists():
shutil.rmtree(local_model_path)
shutil.copytree(remote_model_path, local_model_path)
logger.info(f"📥 Downloaded model {model_name} from Beam to {local_dir}")
return True
except Exception as e:
logger.warning(f"⚠️ Failed to download model {model_name} from Beam: {e}")
return False
def upload_model_to_beam(
volume_name: str,
model_name: str,
local_dir: str,
) -> bool:
"""
Upload a model from local directory to Beam volume using direct file operations.
Args:
volume_name: Name of the Beam volume
model_name: Name for the model on Beam
local_dir: Local directory containing the model
Returns:
True if upload successful, False otherwise
"""
try:
local_path = Path(local_dir)
if not local_path.exists():
logger.error(f"❌ Local model directory does not exist: {local_dir}")
return False
# When running on Beam, the volume is mounted as a directory
remote_models_dir = Path(volume_name) / "models"
remote_models_dir.mkdir(parents=True, exist_ok=True)
remote_model_path = remote_models_dir / model_name
# Copy the model directory
import shutil
if remote_model_path.exists():
shutil.rmtree(remote_model_path)
shutil.copytree(local_path, remote_model_path)
logger.info(f"📤 Uploaded model {model_name} to Beam from {local_dir}")
return True
except Exception as e:
logger.warning(f"⚠️ Failed to upload model {model_name} to Beam: {e}")
return False
def download_checkpoints_from_beam(
volume_name: str,
stage: str | None = None,
remote_checkpoints_dir: str = "checkpoints",
local_checkpoints_dir: str = "code_model2vec/checkpoints",
) -> bool:
"""
Download checkpoint files from Beam volume to local directory using direct file operations.
Args:
volume_name: Name of the Beam volume
stage: Specific stage to download (e.g., 'distillation', 'training'), or None for all
remote_checkpoints_dir: Directory path in the Beam volume containing checkpoints
local_checkpoints_dir: Local directory to download checkpoints to
Returns:
True if download successful, False otherwise
"""
try:
local_path = Path(local_checkpoints_dir)
local_path.mkdir(parents=True, exist_ok=True)
# When running on Beam, the volume is mounted as a directory
remote_base_path = Path(volume_name) / remote_checkpoints_dir
# If the remote path doesn't exist, there are no checkpoints to download
if not remote_base_path.exists():
logger.info(f"ℹ️ No checkpoint directory found at {remote_base_path}")
return True
# Build the pattern for files to download
if stage:
local_stage_dir = local_path / stage
local_stage_dir.mkdir(parents=True, exist_ok=True)
# Look for files in stage-specific directory
remote_stage_dir = remote_base_path / stage
if remote_stage_dir.exists():
remote_files = list(remote_stage_dir.glob(f"checkpoints_{stage}_*.json"))
else:
remote_files = []
else:
# Look for all checkpoint files in all stage subdirectories
remote_files = []
for stage_dir in remote_base_path.iterdir():
if stage_dir.is_dir():
remote_files.extend(stage_dir.glob("checkpoints_*.json"))
# Copy each checkpoint file
downloaded_files = []
for checkpoint_file in remote_files:
# Determine local subdirectory based on checkpoint stage
file_stage = checkpoint_file.name.split("_")[1] if "_" in checkpoint_file.name else "unknown"
local_stage_dir = local_path / file_stage
local_stage_dir.mkdir(parents=True, exist_ok=True)
local_file_path = local_stage_dir / checkpoint_file.name
try:
import shutil
shutil.copy2(checkpoint_file, local_file_path)
downloaded_files.append(checkpoint_file.name)
logger.info(f"📥 Downloaded checkpoint: {checkpoint_file.name}")
except Exception as e:
logger.warning(f"⚠️ Failed to download checkpoint {checkpoint_file.name}: {e}")
if downloaded_files:
logger.info(f"✅ Downloaded {len(downloaded_files)} checkpoint files")
return True
logger.info("ℹ️ No new checkpoint files to download")
return True
except Exception:
logger.exception("❌ Error downloading checkpoints from Beam")
return False
def upload_checkpoints_to_beam(
volume_name: str,
stage: str | None = None,
local_checkpoints_dir: str = "code_model2vec/checkpoints",
remote_checkpoints_dir: str = "checkpoints",
) -> bool:
"""
Upload checkpoint files from local directory to Beam volume using direct file operations.
Args:
volume_name: Name of the Beam volume
stage: Specific stage to upload (e.g., 'distillation', 'training'), or None for all
local_checkpoints_dir: Local directory containing checkpoints
remote_checkpoints_dir: Directory path in the Beam volume to store checkpoints
Returns:
True if upload successful, False otherwise
"""
try:
local_path = Path(local_checkpoints_dir)
if not local_path.exists():
logger.warning(f"⚠️ Local checkpoints directory does not exist: {local_checkpoints_dir}")
return True # Not an error - no checkpoints to upload
# When running on Beam, the volume is mounted as a directory
remote_base_path = Path(volume_name) / remote_checkpoints_dir
remote_base_path.mkdir(parents=True, exist_ok=True)
# Find checkpoint files to upload
if stage:
# Look in the stage subdirectory
stage_dir = local_path / stage
checkpoint_files = list(stage_dir.glob(f"checkpoints_{stage}_*.json")) if stage_dir.exists() else []
else:
# Look for all checkpoint files in all subdirectories
checkpoint_files = []
for subdir in local_path.iterdir():
if subdir.is_dir():
checkpoint_files.extend(subdir.glob("checkpoints_*.json"))
if not checkpoint_files:
logger.info(f"ℹ️ No checkpoint files found to upload for stage: {stage or 'all'}")
return True
# Copy each checkpoint file
uploaded_files = []
for checkpoint_file in checkpoint_files:
# Determine remote subdirectory based on checkpoint stage
file_stage = checkpoint_file.name.split("_")[1] if "_" in checkpoint_file.name else "unknown"
remote_stage_dir = remote_base_path / file_stage
remote_stage_dir.mkdir(parents=True, exist_ok=True)
remote_file_path = remote_stage_dir / checkpoint_file.name
try:
import shutil
shutil.copy2(checkpoint_file, remote_file_path)
uploaded_files.append(checkpoint_file.name)
logger.info(f"📤 Uploaded checkpoint: {checkpoint_file.name}")
except Exception as e:
logger.warning(f"⚠️ Failed to upload checkpoint {checkpoint_file.name}: {e}")
if uploaded_files:
logger.info(f"✅ Uploaded {len(uploaded_files)} checkpoint files")
return True
return False
except Exception:
logger.exception("❌ Error uploading checkpoints to Beam")
return False
def sync_checkpoints_from_beam(
volume_name: str,
stage: str,
local_checkpoints_dir: str = "code_model2vec/checkpoints",
) -> bool:
"""
Sync specific stage checkpoints from Beam to local directory.
Args:
volume_name: Name of the Beam volume
stage: Stage to sync (e.g., 'distillation', 'training')
local_checkpoints_dir: Local directory for checkpoints
Returns:
True if sync successful, False otherwise
"""
logger.info(f"🔄 Syncing {stage} checkpoints from Beam...")
return download_checkpoints_from_beam(volume_name, stage, "checkpoints", local_checkpoints_dir)
def sync_checkpoints_to_beam(
volume_name: str,
stage: str,
local_checkpoints_dir: str = "code_model2vec/checkpoints",
) -> bool:
"""
Sync specific stage checkpoints from local directory to Beam.
Args:
volume_name: Name of the Beam volume
stage: Stage to sync (e.g., 'distillation', 'training')
local_checkpoints_dir: Local directory containing checkpoints
Returns:
True if sync successful, False otherwise
"""
logger.info(f"🔄 Syncing {stage} checkpoints to Beam...")
return upload_checkpoints_to_beam(volume_name, stage, local_checkpoints_dir, "checkpoints")
if __name__ == "__main__":
# Example usage
logging.basicConfig(level=logging.INFO)
example_distillation_workflow()