Sarthak commited on
Commit
ef6935e
·
1 Parent(s): ee673cb

feat(distiller): configure beam functions with resource settings

Browse files

This commit introduces a configuration system for Beam functions, allowing for better resource management, execution settings, and environment configurations for different types of Beam jobs like distillation and evaluation. It also simplifies the function deployment process by providing pre-defined configurations and utilities for creating @function decorator kwargs.

Files changed (1) hide show
  1. src/distiller/config.py +140 -24
src/distiller/config.py CHANGED
@@ -37,6 +37,63 @@ def setup_logging(level: int = logging.INFO) -> None:
37
  GPU_NAME = GpuType.A100_40
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Volume configurations for different workflows
41
  class VolumeConfig(BaseModel):
42
  """Volume configuration container."""
@@ -64,13 +121,8 @@ VOLUMES: dict[str, VolumeConfig] = {
64
  # Default volume name for all workflows
65
  DEFAULT_VOLUME = "primary"
66
 
67
- # Beam environment settings
68
- BEAM_ENV_SETTINGS: dict[str, str] = {
69
- "TOKENIZERS_PARALLELISM": "false",
70
- "CUDA_LAUNCH_BLOCKING": "0",
71
- "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
72
- "TORCH_CUDNN_V8_API_ENABLED": "1",
73
- }
74
 
75
  # Common Python packages for Beam images
76
  COMMON_PACKAGES: list[str] = [
@@ -79,6 +131,7 @@ COMMON_PACKAGES: list[str] = [
79
  "datasets>=3.2.0",
80
  "sentence-transformers>=4.1.0",
81
  "model2vec[train]>=0.5.0",
 
82
  "numpy>=1.26.4",
83
  "scikit-learn>=1.6.1",
84
  "pandas>=2.0.0",
@@ -86,9 +139,12 @@ COMMON_PACKAGES: list[str] = [
86
  "plotly>=5.0.0",
87
  "matplotlib>=3.7.0",
88
  "seaborn>=0.12.0",
 
 
 
89
  ]
90
 
91
- # Create common Beam image
92
  IMAGE = Image(python_version="python3.12").add_python_packages(COMMON_PACKAGES)
93
 
94
  # =============================================================================
@@ -109,8 +165,7 @@ TEACHER_MODELS: list[str] = [
109
  "sentence-transformers/all-MiniLM-L6-v2",
110
  "sentence-transformers/all-mpnet-base-v2",
111
  "sentence-transformers/paraphrase-MiniLM-L6-v2",
112
- "nomic-ai/nomic-embed-code",
113
- "nomic-ai/CodeRankEmbed",
114
  ]
115
 
116
  # Default evaluation models for comparison
@@ -125,6 +180,7 @@ DEFAULT_EVALUATION_MODELS: list[str] = [
125
  "microsoft/graphcodebert-base",
126
  "minishlab/potion-base-8M",
127
  "minishlab/potion-retrieval-32M",
 
128
  "nomic-ai/nomic-embed-text-v2-moe",
129
  "Qodo/Qodo-Embed-1-1.5B",
130
  "Salesforce/codet5-base",
@@ -132,9 +188,7 @@ DEFAULT_EVALUATION_MODELS: list[str] = [
132
  "sentence-transformers/all-MiniLM-L6-v2",
133
  "sentence-transformers/all-mpnet-base-v2",
134
  "sentence-transformers/paraphrase-MiniLM-L6-v2",
135
- "nvidia/NV-Embed-v2",
136
- "nomic-ai/nomic-embed-code",
137
- "nomic-ai/CodeRankEmbed",
138
  ]
139
 
140
 
@@ -150,12 +204,12 @@ class DistillationConfig(BaseModel):
150
  sif_coefficient: float = 1e-3
151
  apply_zipf: bool = True
152
 
153
- # Training parameters (used when --train flag is enabled)
154
- training_epochs: int = 2
155
- learning_rate: float = 1e-4
156
- batch_size: int = 32
157
- max_training_samples: int = 50000
158
- teacher_model_config: dict[str, Any] = {}
159
 
160
 
161
  distillation_config = DistillationConfig()
@@ -196,11 +250,8 @@ class CodeSearchNetConfig(BaseModel):
196
 
197
  codesearchnet_config = CodeSearchNetConfig()
198
 
199
- # Training dataset configurations
200
- TRAINING_DATASETS: dict[str, str] = {
201
- "codesearchnet": "sentence-transformers/codesearchnet",
202
- "code_search_net": "code_search_net",
203
- }
204
 
205
  # =============================================================================
206
  # OUTPUT DIRECTORY CONFIGURATION
@@ -337,3 +388,68 @@ def format_filename(pattern_key: str, **kwargs: Any) -> str:
337
  def get_safe_model_name(model_name: str) -> str:
338
  """Convert model name to filesystem-safe name."""
339
  return "".join(c for c in model_name if c.isalnum() or c in ("-", "_", ".")).replace("/", "_")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  GPU_NAME = GpuType.A100_40
38
 
39
 
40
+ # Comprehensive Beam function configuration
41
+ class BeamFunctionConfig(BaseModel):
42
+ """Complete configuration for Beam @function decorator parameters."""
43
+
44
+ # Resource allocation
45
+ cpu: float = 2.0 # Number of CPU cores
46
+ memory: int = 8192 # Memory in MiB (8GB)
47
+ gpu: str = "A100_40" # GPU type
48
+
49
+ # Execution settings
50
+ timeout: int = 3600 * 12 # 12 hours timeout for long distillation jobs
51
+ retries: int = 2 # Retry failed tasks up to 2 times
52
+ headless: bool = False # Keep connected during execution
53
+
54
+ # Optional settings
55
+ callback_url: str | None = None # Webhook URL for task completion
56
+ name: str | None = None # Function name for deployment
57
+ task_policy: Any | None = None # Task lifecycle policy
58
+ retry_for: list[str] | None = None # Specific exceptions to retry on
59
+
60
+ # Environment and dependencies
61
+ secrets: list[str] = ["HF_ACCESS_TOKEN"] # Required secrets
62
+ env_vars: dict[str, str] = {
63
+ "TOKENIZERS_PARALLELISM": "false",
64
+ "CUDA_LAUNCH_BLOCKING": "0",
65
+ "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
66
+ "TORCH_CUDNN_V8_API_ENABLED": "1",
67
+ # Flash attention environment variables
68
+ "FLASH_ATTENTION_FORCE_USE": "1",
69
+ "TORCH_COMPILE_DISABLE": "1",
70
+ }
71
+
72
+
73
+ # Configuration for different types of Beam jobs
74
+ BEAM_CONFIGS: dict[str, BeamFunctionConfig] = {
75
+ "distillation": BeamFunctionConfig(
76
+ cpu=4.0,
77
+ memory=16384, # 8GB for distillation
78
+ gpu="A100_40",
79
+ timeout=3600 * 12, # 12 hours
80
+ retries=2,
81
+ secrets=["HF_ACCESS_TOKEN"],
82
+ ),
83
+ "evaluation": BeamFunctionConfig(
84
+ cpu=2.0,
85
+ memory=8192, # 8GB for evaluation
86
+ gpu="A100_40", # Smaller GPU for evaluation
87
+ timeout=3600 * 4, # 4 hours
88
+ retries=3,
89
+ secrets=["HF_ACCESS_TOKEN"],
90
+ ),
91
+ }
92
+
93
+ # Default beam configuration
94
+ DEFAULT_BEAM_CONFIG = BEAM_CONFIGS["distillation"]
95
+
96
+
97
  # Volume configurations for different workflows
98
  class VolumeConfig(BaseModel):
99
  """Volume configuration container."""
 
121
  # Default volume name for all workflows
122
  DEFAULT_VOLUME = "primary"
123
 
124
+ # Legacy environment settings (now part of BeamFunctionConfig)
125
+ BEAM_ENV_SETTINGS: dict[str, str] = DEFAULT_BEAM_CONFIG.env_vars
 
 
 
 
 
126
 
127
  # Common Python packages for Beam images
128
  COMMON_PACKAGES: list[str] = [
 
131
  "datasets>=3.2.0",
132
  "sentence-transformers>=4.1.0",
133
  "model2vec[train]>=0.5.0",
134
+ "tokenlearn>=0.2.0",
135
  "numpy>=1.26.4",
136
  "scikit-learn>=1.6.1",
137
  "pandas>=2.0.0",
 
139
  "plotly>=5.0.0",
140
  "matplotlib>=3.7.0",
141
  "seaborn>=0.12.0",
142
+ "typer>=0.16.0",
143
+ "pydantic>=2.11.5",
144
+ "hatchling>=1.27.0",
145
  ]
146
 
147
+ # Create common Beam image without flash-attn due to PyTorch version conflicts
148
  IMAGE = Image(python_version="python3.12").add_python_packages(COMMON_PACKAGES)
149
 
150
  # =============================================================================
 
165
  "sentence-transformers/all-MiniLM-L6-v2",
166
  "sentence-transformers/all-mpnet-base-v2",
167
  "sentence-transformers/paraphrase-MiniLM-L6-v2",
168
+ "jinaai/jina-embeddings-v2-base-code",
 
169
  ]
170
 
171
  # Default evaluation models for comparison
 
180
  "microsoft/graphcodebert-base",
181
  "minishlab/potion-base-8M",
182
  "minishlab/potion-retrieval-32M",
183
+ "minishlab/potion-multilingual-128M",
184
  "nomic-ai/nomic-embed-text-v2-moe",
185
  "Qodo/Qodo-Embed-1-1.5B",
186
  "Salesforce/codet5-base",
 
188
  "sentence-transformers/all-MiniLM-L6-v2",
189
  "sentence-transformers/all-mpnet-base-v2",
190
  "sentence-transformers/paraphrase-MiniLM-L6-v2",
191
+ "jinaai/jina-embeddings-v2-base-code",
 
 
192
  ]
193
 
194
 
 
204
  sif_coefficient: float = 1e-3
205
  apply_zipf: bool = True
206
 
207
+ # Tokenlearn-specific parameters (POTION approach)
208
+ tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
209
+ tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
210
+ tokenlearn_text_key: str = "code" # Text field to use from the dataset ('code' or 'comment')
211
+ tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
212
+ tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
213
 
214
 
215
  distillation_config = DistillationConfig()
 
250
 
251
  codesearchnet_config = CodeSearchNetConfig()
252
 
253
+ # Training dataset configuration
254
+ TRAINING_DATASET: str = "sentence-transformers/codesearchnet"
 
 
 
255
 
256
  # =============================================================================
257
  # OUTPUT DIRECTORY CONFIGURATION
 
388
  def get_safe_model_name(model_name: str) -> str:
389
  """Convert model name to filesystem-safe name."""
390
  return "".join(c for c in model_name if c.isalnum() or c in ("-", "_", ".")).replace("/", "_")
391
+
392
+
393
+ def get_beam_config(job_type: str = "distillation") -> BeamFunctionConfig:
394
+ """Get Beam configuration for a specific job type."""
395
+ if job_type in BEAM_CONFIGS:
396
+ return BEAM_CONFIGS[job_type]
397
+ return DEFAULT_BEAM_CONFIG
398
+
399
+
400
+ def create_beam_function_kwargs(
401
+ job_type: str = "distillation", volume_config: VolumeConfig | None = None
402
+ ) -> dict[str, Any]:
403
+ """Create kwargs dictionary for @function decorator."""
404
+ from beam import Volume
405
+
406
+ config = get_beam_config(job_type)
407
+ volume_cfg = volume_config or get_volume_config()
408
+
409
+ # Convert GPU string to proper type if needed
410
+ gpu_type = config.gpu
411
+ if isinstance(gpu_type, str):
412
+ # Map string to GpuType if it's a known type
413
+ gpu_mapping = {
414
+ "A100_40": GpuType.A100_40,
415
+ "A100_80": GpuType.A100_80,
416
+ "T4": GpuType.T4,
417
+ "A10G": GpuType.A10G,
418
+ "NoGPU": GpuType.NoGPU,
419
+ }
420
+ gpu_type = gpu_mapping.get(config.gpu, config.gpu)
421
+
422
+ kwargs = {
423
+ "cpu": config.cpu,
424
+ "memory": config.memory,
425
+ "gpu": gpu_type,
426
+ "image": IMAGE,
427
+ "timeout": config.timeout,
428
+ "retries": config.retries,
429
+ "headless": config.headless,
430
+ "volumes": [Volume(name=volume_cfg.name, mount_path=volume_cfg.mount_path)],
431
+ "secrets": config.secrets,
432
+ "env": config.env_vars,
433
+ }
434
+
435
+ # Add optional parameters if they're set
436
+ if config.callback_url:
437
+ kwargs["callback_url"] = config.callback_url
438
+ if config.name:
439
+ kwargs["name"] = config.name
440
+ if config.task_policy:
441
+ kwargs["task_policy"] = config.task_policy
442
+ if config.retry_for:
443
+ kwargs["retry_for"] = config.retry_for
444
+
445
+ return kwargs
446
+
447
+
448
+ def get_distillation_function_kwargs() -> dict[str, Any]:
449
+ """Get function kwargs specifically for distillation jobs."""
450
+ return create_beam_function_kwargs("distillation")
451
+
452
+
453
+ def get_evaluation_function_kwargs() -> dict[str, Any]:
454
+ """Get function kwargs specifically for evaluation jobs."""
455
+ return create_beam_function_kwargs("evaluation")