kfoughali commited on
Commit
36a5fc5
·
verified ·
1 Parent(s): 0713715

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +18 -124
config.py CHANGED
@@ -1,85 +1,26 @@
1
  """
2
- Configuration, constants, and data classes for Enhanced SPG compression.
3
- RESEARCH-GRADE: All parameters configurable, no hardcoding.
 
4
  """
5
 
6
  import json
7
  import hashlib
 
 
 
 
8
  from dataclasses import dataclass, field, asdict
 
9
  from enum import Enum
10
- from typing import List, Optional, NamedTuple, Dict, Any
11
  from datetime import datetime
12
  import torch
13
  import transformers
14
- import logging
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
20
- # Model configurations - NO HARDCODING
21
- SUPPORTED_MODELS: Dict[str, Dict[str, Any]] = {
22
- "gpt2": {
23
- "name": "gpt2",
24
- "requires_auth": False,
25
- "max_context": 1024,
26
- "default_dtype": "float16"
27
- },
28
- "llama2-7b": {
29
- "name": "meta-llama/Llama-2-7b-hf",
30
- "requires_auth": True,
31
- "max_context": 4096,
32
- "default_dtype": "float16"
33
- },
34
- "mistral-7b": {
35
- "name": "mistralai/Mistral-7B-v0.1",
36
- "requires_auth": False,
37
- "max_context": 8192,
38
- "default_dtype": "float16"
39
- },
40
- "opt-1.3b": {
41
- "name": "facebook/opt-1.3b",
42
- "requires_auth": False,
43
- "max_context": 2048,
44
- "default_dtype": "float16"
45
- }
46
- }
47
-
48
- # Benchmark configurations - NO HARDCODING
49
- # FIXED: Changed "perplexity" to "wikitext" for consistency
50
- BENCHMARK_CONFIGS: Dict[str, Dict[str, Any]] = {
51
- "wikitext": { # CHANGED from "perplexity" to "wikitext"
52
- "type": "wikitext", # CHANGED
53
- "default_samples": 50,
54
- "default_prefill": 512,
55
- "default_generation": 64
56
- },
57
- "niah": {
58
- "type": "needle_in_haystack",
59
- "depths": [10, 25, 50, 75, 90], # Percentage depths
60
- "needle": "The secret password is BANANA",
61
- "default_samples": 10,
62
- "default_context": 4096
63
- },
64
- "ruler": {
65
- "type": "ruler",
66
- "max_seq_lengths": [1024, 2048, 4096, 8192],
67
- "default_samples": 10,
68
- "default_n_facts": 10
69
- },
70
- "scbench": {
71
- "type": "shared_context",
72
- "num_turns": [5, 10, 20],
73
- "default_samples": 10,
74
- "default_context": 2048
75
- },
76
- "longbench": {
77
- "type": "longbench",
78
- "subsets": ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa"],
79
- "default_samples": 20,
80
- "max_context": 8192
81
- }
82
- }
83
 
84
  class CompressionType(Enum):
85
  """RocketKV-enhanced SPG methods with explicit validation."""
@@ -89,12 +30,14 @@ class CompressionType(Enum):
89
  ENHANCED_SPG = "enhanced_spg"
90
  PROGRESSIVE_SPG = "progressive_spg"
91
 
 
92
  class PrecisionLevel(NamedTuple):
93
  """Precision level configuration with validation."""
94
  threshold: float
95
  bits: Optional[int]
96
  name: str
97
 
 
98
  @dataclass
99
  class ResearchConstants:
100
  """All constants/thresholds from validated research - NO HARDCODING."""
@@ -173,6 +116,7 @@ class ResearchConstants:
173
  MIN_COMPRESSION_RATIO: float = 1.0
174
  MAX_COMPRESSION_RATIO: float = 1000.0
175
 
 
176
  @dataclass
177
  class EnhancedSPGConfig:
178
  """Research-grade configuration with RocketKV-style 450x compression support."""
@@ -248,9 +192,6 @@ class EnhancedSPGConfig:
248
  stage_compression_min: float = 2.0 # Minimum stage compression ratio
249
  stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
250
 
251
- # Flash Attention support
252
- use_flash_attention: bool = False # Try to use Flash Attention if available
253
-
254
  def __post_init__(self):
255
  """Validate all parameters - fail fast on invalid config."""
256
  constants = ResearchConstants()
@@ -340,6 +281,7 @@ class EnhancedSPGConfig:
340
  else:
341
  return self.kernel_size_xlarge_seq
342
 
 
343
  @dataclass
344
  class ProvingConfig:
345
  """Configuration for attestable proof generation and verification - NO HARDCODING."""
@@ -364,6 +306,7 @@ class ProvingConfig:
364
  if not 0 < self.ppl_tolerance < 1:
365
  raise ValueError(f"ppl_tolerance must be in (0, 1), got {self.ppl_tolerance}")
366
 
 
367
  @dataclass
368
  class CompressionConfig:
369
  """Research-grade configuration for RocketKV-enhanced SPG methods."""
@@ -371,10 +314,6 @@ class CompressionConfig:
371
  compression_type: CompressionType = CompressionType.ENHANCED_SPG
372
  seed: int = 42
373
 
374
- # Model selection
375
- model_key: str = "gpt2" # Key into SUPPORTED_MODELS
376
- model_name: str = field(init=False) # Will be set in __post_init__
377
-
378
  # Enhanced SPG configuration
379
  enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
380
 
@@ -398,25 +337,10 @@ class CompressionConfig:
398
  dataset_config: str = "wikitext-2-raw-v1"
399
  dataset_split: str = "test"
400
 
401
- # Benchmark configuration
402
- benchmark_type: str = "wikitext" # wikitext, niah, ruler, scbench, longbench
403
- benchmark_subset: Optional[str] = None # For longbench subsets
404
-
405
- # NIAH-specific parameters
406
- niah_needle: str = field(default_factory=lambda: BENCHMARK_CONFIGS["niah"]["needle"])
407
- niah_depth_percent: float = 50.0
408
-
409
- # RULER-specific parameters
410
- ruler_max_seq_length: int = 4096
411
-
412
- # SCBench-specific parameters
413
- scbench_num_turns: int = 10
414
-
415
  # Memory and system settings
416
  clear_cache_between_runs: bool = True
417
  use_memory_snapshot: bool = True
418
  fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
419
- use_flash_attention: bool = False # Try to use Flash Attention if available
420
 
421
  # Output settings
422
  generate_latex: bool = True
@@ -430,25 +354,14 @@ class CompressionConfig:
430
  timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
431
 
432
  def __post_init__(self):
433
- """Comprehensive validation - FAIL FAST on any invalid parameter, NO SILENT DEFAULTS."""
434
  constants = ResearchConstants()
435
 
436
- # Set model name from key - FAIL FAST if invalid
437
- if self.model_key not in SUPPORTED_MODELS:
438
- raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
439
- self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
440
- logger.info(f"Model selected: {self.model_name} (key: {self.model_key})")
441
-
442
- # Validate benchmark type - FAIL FAST if invalid
443
- if self.benchmark_type not in BENCHMARK_CONFIGS:
444
- raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
445
- logger.info(f"Benchmark selected: {self.benchmark_type}")
446
-
447
- # Validate core parameters - NO MAGIC NUMBERS
448
  if not isinstance(self.seed, int) or self.seed < 0:
449
  raise ValueError(f"seed must be non-negative integer, got {self.seed}")
450
 
451
- # Validate evaluation parameters with explicit bounds
452
  if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
453
  logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
454
 
@@ -461,33 +374,14 @@ class CompressionConfig:
461
  if not 1 <= self.n_seeds <= 10:
462
  logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
463
 
464
- # Validate statistical parameters - EXPLICIT BOUNDS
465
  if not 0.5 <= self.confidence_level < 1.0:
466
  raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
467
 
468
  if not 100 <= self.n_bootstrap <= 10000:
469
  logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
470
 
471
- # Validate benchmark-specific parameters
472
- if self.benchmark_type == "longbench" and not self.benchmark_subset:
473
- logger.warning("LongBench selected but no subset specified")
474
-
475
- if self.benchmark_type == "niah" and not self.niah_needle:
476
- raise ValueError("NIAH benchmark requires niah_needle to be set")
477
-
478
- if self.benchmark_type == "ruler" and self.ruler_max_seq_length <= 0:
479
- raise ValueError(f"ruler_max_seq_length must be positive, got {self.ruler_max_seq_length}")
480
-
481
- if self.benchmark_type == "scbench" and self.scbench_num_turns <= 0:
482
- raise ValueError(f"scbench_num_turns must be positive, got {self.scbench_num_turns}")
483
-
484
- # Pass Flash Attention setting to EnhancedSPGConfig
485
- self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
486
-
487
- logger.info("Configuration validated successfully - STRICT COMPLIANCE")
488
- logger.info(f"Target compression: {self.enhanced_spg_config.target_compression_ratio}x")
489
- logger.info(f"Fail on CPU fallback: {self.fail_on_cpu_fallback}")
490
- logger.info(f"Proving enabled: {self.proving.enabled}")
491
 
492
  def to_json(self) -> str:
493
  """Export config for reproducibility."""
 
1
  """
2
+ Configuration module for Enhanced SPG compression.
3
+ Contains all research constants, configuration classes, and validation logic.
4
+ STRICT COMPLIANCE: No hardcoding, all parameters from config.
5
  """
6
 
7
  import json
8
  import hashlib
9
+ import logging
10
+ import sys
11
+ import os
12
+ import platform
13
  from dataclasses import dataclass, field, asdict
14
+ from typing import List, Optional, NamedTuple, Any
15
  from enum import Enum
 
16
  from datetime import datetime
17
  import torch
18
  import transformers
 
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class CompressionType(Enum):
26
  """RocketKV-enhanced SPG methods with explicit validation."""
 
30
  ENHANCED_SPG = "enhanced_spg"
31
  PROGRESSIVE_SPG = "progressive_spg"
32
 
33
+
34
  class PrecisionLevel(NamedTuple):
35
  """Precision level configuration with validation."""
36
  threshold: float
37
  bits: Optional[int]
38
  name: str
39
 
40
+
41
  @dataclass
42
  class ResearchConstants:
43
  """All constants/thresholds from validated research - NO HARDCODING."""
 
116
  MIN_COMPRESSION_RATIO: float = 1.0
117
  MAX_COMPRESSION_RATIO: float = 1000.0
118
 
119
+
120
  @dataclass
121
  class EnhancedSPGConfig:
122
  """Research-grade configuration with RocketKV-style 450x compression support."""
 
192
  stage_compression_min: float = 2.0 # Minimum stage compression ratio
193
  stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
194
 
 
 
 
195
  def __post_init__(self):
196
  """Validate all parameters - fail fast on invalid config."""
197
  constants = ResearchConstants()
 
281
  else:
282
  return self.kernel_size_xlarge_seq
283
 
284
+
285
  @dataclass
286
  class ProvingConfig:
287
  """Configuration for attestable proof generation and verification - NO HARDCODING."""
 
306
  if not 0 < self.ppl_tolerance < 1:
307
  raise ValueError(f"ppl_tolerance must be in (0, 1), got {self.ppl_tolerance}")
308
 
309
+
310
  @dataclass
311
  class CompressionConfig:
312
  """Research-grade configuration for RocketKV-enhanced SPG methods."""
 
314
  compression_type: CompressionType = CompressionType.ENHANCED_SPG
315
  seed: int = 42
316
 
 
 
 
 
317
  # Enhanced SPG configuration
318
  enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
319
 
 
337
  dataset_config: str = "wikitext-2-raw-v1"
338
  dataset_split: str = "test"
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  # Memory and system settings
341
  clear_cache_between_runs: bool = True
342
  use_memory_snapshot: bool = True
343
  fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
 
344
 
345
  # Output settings
346
  generate_latex: bool = True
 
354
  timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
355
 
356
  def __post_init__(self):
357
+ """Comprehensive validation - fail fast on any invalid parameter."""
358
  constants = ResearchConstants()
359
 
360
+ # Validate core parameters
 
 
 
 
 
 
 
 
 
 
 
361
  if not isinstance(self.seed, int) or self.seed < 0:
362
  raise ValueError(f"seed must be non-negative integer, got {self.seed}")
363
 
364
+ # Validate evaluation parameters
365
  if not constants.MIN_EVAL_SAMPLES <= self.eval_samples <= constants.MAX_EVAL_SAMPLES:
366
  logger.warning(f"eval_samples {self.eval_samples} outside recommended range [{constants.MIN_EVAL_SAMPLES}, {constants.MAX_EVAL_SAMPLES}]")
367
 
 
374
  if not 1 <= self.n_seeds <= 10:
375
  logger.warning(f"n_seeds {self.n_seeds} outside recommended range [1, 10]")
376
 
377
+ # Validate statistical parameters
378
  if not 0.5 <= self.confidence_level < 1.0:
379
  raise ValueError(f"confidence_level must be in [0.5, 1.0), got {self.confidence_level}")
380
 
381
  if not 100 <= self.n_bootstrap <= 10000:
382
  logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
383
 
384
+ logger.info("RocketKV-enhanced SPG config validated successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  def to_json(self) -> str:
387
  """Export config for reproducibility."""