Sarthak commited on
Commit
d820ac9
·
1 Parent(s): c9e9334

feat(config): allow multiple GPU types for training and simplify GPU handling

Browse files

This change allows training jobs to be configured with either a single GPU type or a list of GPU types, providing more flexibility in resource allocation. It also adds a function to get function kwargs specifically for training jobs.

The `create_beam_function_kwargs` function was updated to accept both single and list of GPU types without converting from string, simplifying the GPU type handling.

Files changed (1) hide show
  1. src/distiller/config.py +18 -18
src/distiller/config.py CHANGED
@@ -33,9 +33,6 @@ def setup_logging(level: int = logging.INFO) -> None:
33
  # BEAM CLOUD CONFIGURATION
34
  # =============================================================================
35
 
36
- # Beam execution settings
37
- GPU_NAME = GpuType.A100_40
38
-
39
 
40
  # Comprehensive Beam function configuration
41
  class BeamFunctionConfig(BaseModel):
@@ -44,7 +41,7 @@ class BeamFunctionConfig(BaseModel):
44
  # Resource allocation
45
  cpu: float = 2.0 # Number of CPU cores
46
  memory: int = 8192 # Memory in MiB (8GB)
47
- gpu: str = "A100_40" # GPU type
48
 
49
  # Execution settings
50
  timeout: int = 3600 * 12 # 12 hours timeout for long distillation jobs
@@ -75,7 +72,15 @@ BEAM_CONFIGS: dict[str, BeamFunctionConfig] = {
75
  "distillation": BeamFunctionConfig(
76
  cpu=4.0,
77
  memory=16384, # 8GB for distillation
78
- gpu="A100_40",
 
 
 
 
 
 
 
 
79
  timeout=3600 * 12, # 12 hours
80
  retries=2,
81
  secrets=["HF_ACCESS_TOKEN"],
@@ -83,7 +88,7 @@ BEAM_CONFIGS: dict[str, BeamFunctionConfig] = {
83
  "evaluation": BeamFunctionConfig(
84
  cpu=2.0,
85
  memory=8192, # 8GB for evaluation
86
- gpu="A100_40", # Smaller GPU for evaluation
87
  timeout=3600 * 4, # 4 hours
88
  retries=3,
89
  secrets=["HF_ACCESS_TOKEN"],
@@ -408,18 +413,8 @@ def create_beam_function_kwargs(
408
 
409
  # Convert GPU string to proper type if needed
410
  gpu_type = config.gpu
411
- if isinstance(gpu_type, str):
412
- # Map string to GpuType if it's a known type
413
- gpu_mapping = {
414
- "A100_40": GpuType.A100_40,
415
- "A100_80": GpuType.A100_80,
416
- "T4": GpuType.T4,
417
- "A10G": GpuType.A10G,
418
- "NoGPU": GpuType.NoGPU,
419
- }
420
- gpu_type = gpu_mapping.get(config.gpu, config.gpu)
421
-
422
- kwargs = {
423
  "cpu": config.cpu,
424
  "memory": config.memory,
425
  "gpu": gpu_type,
@@ -450,6 +445,11 @@ def get_distillation_function_kwargs() -> dict[str, Any]:
450
  return create_beam_function_kwargs("distillation")
451
 
452
 
 
 
 
 
 
453
  def get_evaluation_function_kwargs() -> dict[str, Any]:
454
  """Get function kwargs specifically for evaluation jobs."""
455
  return create_beam_function_kwargs("evaluation")
 
33
  # BEAM CLOUD CONFIGURATION
34
  # =============================================================================
35
 
 
 
 
36
 
37
  # Comprehensive Beam function configuration
38
  class BeamFunctionConfig(BaseModel):
 
41
  # Resource allocation
42
  cpu: float = 2.0 # Number of CPU cores
43
  memory: int = 8192 # Memory in MiB (8GB)
44
+ gpu: GpuType | list[GpuType] = GpuType.A100_40 # GPU type
45
 
46
  # Execution settings
47
  timeout: int = 3600 * 12 # 12 hours timeout for long distillation jobs
 
72
  "distillation": BeamFunctionConfig(
73
  cpu=4.0,
74
  memory=16384, # 8GB for distillation
75
+ gpu=GpuType.A100_40,
76
+ timeout=3600 * 12, # 12 hours
77
+ retries=2,
78
+ secrets=["HF_ACCESS_TOKEN"],
79
+ ),
80
+ "training": BeamFunctionConfig(
81
+ cpu=4.0,
82
+ memory=16384, # 8GB for distillation
83
+ gpu=[GpuType.H100, GpuType.A100_40],
84
  timeout=3600 * 12, # 12 hours
85
  retries=2,
86
  secrets=["HF_ACCESS_TOKEN"],
 
88
  "evaluation": BeamFunctionConfig(
89
  cpu=2.0,
90
  memory=8192, # 8GB for evaluation
91
+ gpu=GpuType.A100_40, # Smaller GPU for evaluation
92
  timeout=3600 * 4, # 4 hours
93
  retries=3,
94
  secrets=["HF_ACCESS_TOKEN"],
 
413
 
414
  # Convert GPU string to proper type if needed
415
  gpu_type = config.gpu
416
+
417
+ kwargs: dict[str, Any] = {
 
 
 
 
 
 
 
 
 
 
418
  "cpu": config.cpu,
419
  "memory": config.memory,
420
  "gpu": gpu_type,
 
445
  return create_beam_function_kwargs("distillation")
446
 
447
 
448
+ def get_training_function_kwargs() -> dict[str, Any]:
449
+ """Get function kwargs specifically for training jobs."""
450
+ return create_beam_function_kwargs("training")
451
+
452
+
453
  def get_evaluation_function_kwargs() -> dict[str, Any]:
454
  """Get function kwargs specifically for evaluation jobs."""
455
  return create_beam_function_kwargs("evaluation")