Sarthak commited on
Commit ·
ef6935e
1
Parent(s): ee673cb
feat(distiller): configure beam functions with resource settings
Browse filesThis commit introduces a configuration system for Beam functions, allowing for better resource management, execution settings, and environment configurations for different types of Beam jobs like distillation and evaluation. It also simplifies the function deployment process by providing pre-defined configurations and utilities for creating @function decorator kwargs.
- src/distiller/config.py +140 -24
src/distiller/config.py
CHANGED
|
@@ -37,6 +37,63 @@ def setup_logging(level: int = logging.INFO) -> None:
|
|
| 37 |
GPU_NAME = GpuType.A100_40
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Volume configurations for different workflows
|
| 41 |
class VolumeConfig(BaseModel):
|
| 42 |
"""Volume configuration container."""
|
|
@@ -64,13 +121,8 @@ VOLUMES: dict[str, VolumeConfig] = {
|
|
| 64 |
# Default volume name for all workflows
|
| 65 |
DEFAULT_VOLUME = "primary"
|
| 66 |
|
| 67 |
-
#
|
| 68 |
-
BEAM_ENV_SETTINGS: dict[str, str] =
|
| 69 |
-
"TOKENIZERS_PARALLELISM": "false",
|
| 70 |
-
"CUDA_LAUNCH_BLOCKING": "0",
|
| 71 |
-
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
|
| 72 |
-
"TORCH_CUDNN_V8_API_ENABLED": "1",
|
| 73 |
-
}
|
| 74 |
|
| 75 |
# Common Python packages for Beam images
|
| 76 |
COMMON_PACKAGES: list[str] = [
|
|
@@ -79,6 +131,7 @@ COMMON_PACKAGES: list[str] = [
|
|
| 79 |
"datasets>=3.2.0",
|
| 80 |
"sentence-transformers>=4.1.0",
|
| 81 |
"model2vec[train]>=0.5.0",
|
|
|
|
| 82 |
"numpy>=1.26.4",
|
| 83 |
"scikit-learn>=1.6.1",
|
| 84 |
"pandas>=2.0.0",
|
|
@@ -86,9 +139,12 @@ COMMON_PACKAGES: list[str] = [
|
|
| 86 |
"plotly>=5.0.0",
|
| 87 |
"matplotlib>=3.7.0",
|
| 88 |
"seaborn>=0.12.0",
|
|
|
|
|
|
|
|
|
|
| 89 |
]
|
| 90 |
|
| 91 |
-
# Create common Beam image
|
| 92 |
IMAGE = Image(python_version="python3.12").add_python_packages(COMMON_PACKAGES)
|
| 93 |
|
| 94 |
# =============================================================================
|
|
@@ -109,8 +165,7 @@ TEACHER_MODELS: list[str] = [
|
|
| 109 |
"sentence-transformers/all-MiniLM-L6-v2",
|
| 110 |
"sentence-transformers/all-mpnet-base-v2",
|
| 111 |
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
| 112 |
-
"
|
| 113 |
-
"nomic-ai/CodeRankEmbed",
|
| 114 |
]
|
| 115 |
|
| 116 |
# Default evaluation models for comparison
|
|
@@ -125,6 +180,7 @@ DEFAULT_EVALUATION_MODELS: list[str] = [
|
|
| 125 |
"microsoft/graphcodebert-base",
|
| 126 |
"minishlab/potion-base-8M",
|
| 127 |
"minishlab/potion-retrieval-32M",
|
|
|
|
| 128 |
"nomic-ai/nomic-embed-text-v2-moe",
|
| 129 |
"Qodo/Qodo-Embed-1-1.5B",
|
| 130 |
"Salesforce/codet5-base",
|
|
@@ -132,9 +188,7 @@ DEFAULT_EVALUATION_MODELS: list[str] = [
|
|
| 132 |
"sentence-transformers/all-MiniLM-L6-v2",
|
| 133 |
"sentence-transformers/all-mpnet-base-v2",
|
| 134 |
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
| 135 |
-
"
|
| 136 |
-
"nomic-ai/nomic-embed-code",
|
| 137 |
-
"nomic-ai/CodeRankEmbed",
|
| 138 |
]
|
| 139 |
|
| 140 |
|
|
@@ -150,12 +204,12 @@ class DistillationConfig(BaseModel):
|
|
| 150 |
sif_coefficient: float = 1e-3
|
| 151 |
apply_zipf: bool = True
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
|
| 160 |
|
| 161 |
distillation_config = DistillationConfig()
|
|
@@ -196,11 +250,8 @@ class CodeSearchNetConfig(BaseModel):
|
|
| 196 |
|
| 197 |
codesearchnet_config = CodeSearchNetConfig()
|
| 198 |
|
| 199 |
-
# Training dataset
|
| 200 |
-
|
| 201 |
-
"codesearchnet": "sentence-transformers/codesearchnet",
|
| 202 |
-
"code_search_net": "code_search_net",
|
| 203 |
-
}
|
| 204 |
|
| 205 |
# =============================================================================
|
| 206 |
# OUTPUT DIRECTORY CONFIGURATION
|
|
@@ -337,3 +388,68 @@ def format_filename(pattern_key: str, **kwargs: Any) -> str:
|
|
| 337 |
def get_safe_model_name(model_name: str) -> str:
|
| 338 |
"""Convert model name to filesystem-safe name."""
|
| 339 |
return "".join(c for c in model_name if c.isalnum() or c in ("-", "_", ".")).replace("/", "_")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
GPU_NAME = GpuType.A100_40
|
| 38 |
|
| 39 |
|
| 40 |
+
# Comprehensive Beam function configuration
|
| 41 |
+
class BeamFunctionConfig(BaseModel):
|
| 42 |
+
"""Complete configuration for Beam @function decorator parameters."""
|
| 43 |
+
|
| 44 |
+
# Resource allocation
|
| 45 |
+
cpu: float = 2.0 # Number of CPU cores
|
| 46 |
+
memory: int = 8192 # Memory in MiB (8GB)
|
| 47 |
+
gpu: str = "A100_40" # GPU type
|
| 48 |
+
|
| 49 |
+
# Execution settings
|
| 50 |
+
timeout: int = 3600 * 12 # 12 hours timeout for long distillation jobs
|
| 51 |
+
retries: int = 2 # Retry failed tasks up to 2 times
|
| 52 |
+
headless: bool = False # Keep connected during execution
|
| 53 |
+
|
| 54 |
+
# Optional settings
|
| 55 |
+
callback_url: str | None = None # Webhook URL for task completion
|
| 56 |
+
name: str | None = None # Function name for deployment
|
| 57 |
+
task_policy: Any | None = None # Task lifecycle policy
|
| 58 |
+
retry_for: list[str] | None = None # Specific exceptions to retry on
|
| 59 |
+
|
| 60 |
+
# Environment and dependencies
|
| 61 |
+
secrets: list[str] = ["HF_ACCESS_TOKEN"] # Required secrets
|
| 62 |
+
env_vars: dict[str, str] = {
|
| 63 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 64 |
+
"CUDA_LAUNCH_BLOCKING": "0",
|
| 65 |
+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
|
| 66 |
+
"TORCH_CUDNN_V8_API_ENABLED": "1",
|
| 67 |
+
# Flash attention environment variables
|
| 68 |
+
"FLASH_ATTENTION_FORCE_USE": "1",
|
| 69 |
+
"TORCH_COMPILE_DISABLE": "1",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Configuration for different types of Beam jobs
|
| 74 |
+
BEAM_CONFIGS: dict[str, BeamFunctionConfig] = {
|
| 75 |
+
"distillation": BeamFunctionConfig(
|
| 76 |
+
cpu=4.0,
|
| 77 |
+
memory=16384, # 8GB for distillation
|
| 78 |
+
gpu="A100_40",
|
| 79 |
+
timeout=3600 * 12, # 12 hours
|
| 80 |
+
retries=2,
|
| 81 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 82 |
+
),
|
| 83 |
+
"evaluation": BeamFunctionConfig(
|
| 84 |
+
cpu=2.0,
|
| 85 |
+
memory=8192, # 8GB for evaluation
|
| 86 |
+
gpu="A100_40", # Smaller GPU for evaluation
|
| 87 |
+
timeout=3600 * 4, # 4 hours
|
| 88 |
+
retries=3,
|
| 89 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 90 |
+
),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Default beam configuration
|
| 94 |
+
DEFAULT_BEAM_CONFIG = BEAM_CONFIGS["distillation"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
# Volume configurations for different workflows
|
| 98 |
class VolumeConfig(BaseModel):
|
| 99 |
"""Volume configuration container."""
|
|
|
|
| 121 |
# Default volume name for all workflows
|
| 122 |
DEFAULT_VOLUME = "primary"
|
| 123 |
|
| 124 |
+
# Legacy environment settings (now part of BeamFunctionConfig)
|
| 125 |
+
BEAM_ENV_SETTINGS: dict[str, str] = DEFAULT_BEAM_CONFIG.env_vars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Common Python packages for Beam images
|
| 128 |
COMMON_PACKAGES: list[str] = [
|
|
|
|
| 131 |
"datasets>=3.2.0",
|
| 132 |
"sentence-transformers>=4.1.0",
|
| 133 |
"model2vec[train]>=0.5.0",
|
| 134 |
+
"tokenlearn>=0.2.0",
|
| 135 |
"numpy>=1.26.4",
|
| 136 |
"scikit-learn>=1.6.1",
|
| 137 |
"pandas>=2.0.0",
|
|
|
|
| 139 |
"plotly>=5.0.0",
|
| 140 |
"matplotlib>=3.7.0",
|
| 141 |
"seaborn>=0.12.0",
|
| 142 |
+
"typer>=0.16.0",
|
| 143 |
+
"pydantic>=2.11.5",
|
| 144 |
+
"hatchling>=1.27.0",
|
| 145 |
]
|
| 146 |
|
| 147 |
+
# Create common Beam image without flash-attn due to PyTorch version conflicts
|
| 148 |
IMAGE = Image(python_version="python3.12").add_python_packages(COMMON_PACKAGES)
|
| 149 |
|
| 150 |
# =============================================================================
|
|
|
|
| 165 |
"sentence-transformers/all-MiniLM-L6-v2",
|
| 166 |
"sentence-transformers/all-mpnet-base-v2",
|
| 167 |
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
| 168 |
+
"jinaai/jina-embeddings-v2-base-code",
|
|
|
|
| 169 |
]
|
| 170 |
|
| 171 |
# Default evaluation models for comparison
|
|
|
|
| 180 |
"microsoft/graphcodebert-base",
|
| 181 |
"minishlab/potion-base-8M",
|
| 182 |
"minishlab/potion-retrieval-32M",
|
| 183 |
+
"minishlab/potion-multilingual-128M",
|
| 184 |
"nomic-ai/nomic-embed-text-v2-moe",
|
| 185 |
"Qodo/Qodo-Embed-1-1.5B",
|
| 186 |
"Salesforce/codet5-base",
|
|
|
|
| 188 |
"sentence-transformers/all-MiniLM-L6-v2",
|
| 189 |
"sentence-transformers/all-mpnet-base-v2",
|
| 190 |
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
| 191 |
+
"jinaai/jina-embeddings-v2-base-code",
|
|
|
|
|
|
|
| 192 |
]
|
| 193 |
|
| 194 |
|
|
|
|
| 204 |
sif_coefficient: float = 1e-3
|
| 205 |
apply_zipf: bool = True
|
| 206 |
|
| 207 |
+
# Tokenlearn-specific parameters (POTION approach)
|
| 208 |
+
tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
|
| 209 |
+
tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
|
| 210 |
+
tokenlearn_text_key: str = "code" # Text field to use from the dataset ('code' or 'comment')
|
| 211 |
+
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
| 212 |
+
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
| 213 |
|
| 214 |
|
| 215 |
distillation_config = DistillationConfig()
|
|
|
|
| 250 |
|
| 251 |
codesearchnet_config = CodeSearchNetConfig()
|
| 252 |
|
| 253 |
+
# Training dataset configuration
|
| 254 |
+
TRAINING_DATASET: str = "sentence-transformers/codesearchnet"
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
# =============================================================================
|
| 257 |
# OUTPUT DIRECTORY CONFIGURATION
|
|
|
|
| 388 |
def get_safe_model_name(model_name: str) -> str:
|
| 389 |
"""Convert model name to filesystem-safe name."""
|
| 390 |
return "".join(c for c in model_name if c.isalnum() or c in ("-", "_", ".")).replace("/", "_")
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def get_beam_config(job_type: str = "distillation") -> BeamFunctionConfig:
|
| 394 |
+
"""Get Beam configuration for a specific job type."""
|
| 395 |
+
if job_type in BEAM_CONFIGS:
|
| 396 |
+
return BEAM_CONFIGS[job_type]
|
| 397 |
+
return DEFAULT_BEAM_CONFIG
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def create_beam_function_kwargs(
|
| 401 |
+
job_type: str = "distillation", volume_config: VolumeConfig | None = None
|
| 402 |
+
) -> dict[str, Any]:
|
| 403 |
+
"""Create kwargs dictionary for @function decorator."""
|
| 404 |
+
from beam import Volume
|
| 405 |
+
|
| 406 |
+
config = get_beam_config(job_type)
|
| 407 |
+
volume_cfg = volume_config or get_volume_config()
|
| 408 |
+
|
| 409 |
+
# Convert GPU string to proper type if needed
|
| 410 |
+
gpu_type = config.gpu
|
| 411 |
+
if isinstance(gpu_type, str):
|
| 412 |
+
# Map string to GpuType if it's a known type
|
| 413 |
+
gpu_mapping = {
|
| 414 |
+
"A100_40": GpuType.A100_40,
|
| 415 |
+
"A100_80": GpuType.A100_80,
|
| 416 |
+
"T4": GpuType.T4,
|
| 417 |
+
"A10G": GpuType.A10G,
|
| 418 |
+
"NoGPU": GpuType.NoGPU,
|
| 419 |
+
}
|
| 420 |
+
gpu_type = gpu_mapping.get(config.gpu, config.gpu)
|
| 421 |
+
|
| 422 |
+
kwargs = {
|
| 423 |
+
"cpu": config.cpu,
|
| 424 |
+
"memory": config.memory,
|
| 425 |
+
"gpu": gpu_type,
|
| 426 |
+
"image": IMAGE,
|
| 427 |
+
"timeout": config.timeout,
|
| 428 |
+
"retries": config.retries,
|
| 429 |
+
"headless": config.headless,
|
| 430 |
+
"volumes": [Volume(name=volume_cfg.name, mount_path=volume_cfg.mount_path)],
|
| 431 |
+
"secrets": config.secrets,
|
| 432 |
+
"env": config.env_vars,
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# Add optional parameters if they're set
|
| 436 |
+
if config.callback_url:
|
| 437 |
+
kwargs["callback_url"] = config.callback_url
|
| 438 |
+
if config.name:
|
| 439 |
+
kwargs["name"] = config.name
|
| 440 |
+
if config.task_policy:
|
| 441 |
+
kwargs["task_policy"] = config.task_policy
|
| 442 |
+
if config.retry_for:
|
| 443 |
+
kwargs["retry_for"] = config.retry_for
|
| 444 |
+
|
| 445 |
+
return kwargs
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def get_distillation_function_kwargs() -> dict[str, Any]:
|
| 449 |
+
"""Get function kwargs specifically for distillation jobs."""
|
| 450 |
+
return create_beam_function_kwargs("distillation")
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def get_evaluation_function_kwargs() -> dict[str, Any]:
|
| 454 |
+
"""Get function kwargs specifically for evaluation jobs."""
|
| 455 |
+
return create_beam_function_kwargs("evaluation")
|