|
|
""" |
|
|
Fallback simulation loop for auto-healing validation failures |
|
|
""" |
|
|
|
|
|
import re |
|
|
import torch |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from rich.console import Console |
|
|
from rich.table import Table |
|
|
|
|
|
from .dryrun import dry_run, DryRunResult |
|
|
from .matrix import get_gpu_info, precision_supported, has_bitsandbytes |
|
|
from training.autodetect import suggested_lora_targets |
|
|
|
|
|
console = Console() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ConfigCandidate: |
|
|
"""Represents a configuration candidate for testing""" |
|
|
model: str |
|
|
precision: str |
|
|
seq_len: int |
|
|
batch_size: int |
|
|
lora: bool |
|
|
lora_targets: Optional[List[str]] = None |
|
|
gradient_checkpointing: bool = False |
|
|
dataset: str = "wikitext" |
|
|
text_field: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FallbackAttempt: |
|
|
"""Represents a single fallback attempt""" |
|
|
attempt_num: int |
|
|
config: ConfigCandidate |
|
|
result: DryRunResult |
|
|
strategy: str |
|
|
notes: str |
|
|
|
|
|
|
|
|
class FallbackSimulator: |
|
|
"""Handles fallback simulation and auto-healing""" |
|
|
|
|
|
def __init__(self): |
|
|
try: |
|
|
self.gpu = get_gpu_info() |
|
|
except Exception: |
|
|
|
|
|
self.gpu = type('GpuInfo', (), { |
|
|
'available': True, |
|
|
'name': 'Unknown GPU', |
|
|
'total_bytes': 0, |
|
|
'free_bytes': 0, |
|
|
'cc_major': 7, |
|
|
'cc_minor': 0, |
|
|
'bf16_supported': True |
|
|
})() |
|
|
self.attempts: List[FallbackAttempt] = [] |
|
|
|
|
|
def reset_gpu_state(self): |
|
|
"""Reset GPU state to clear any CUDA errors""" |
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def classify_error(self, error: str) -> str: |
|
|
"""Classify error type from error message""" |
|
|
error_lower = error.lower() |
|
|
|
|
|
if "out of memory" in error_lower or "oom" in error_lower: |
|
|
return "oom" |
|
|
elif "bf16" in error_lower and "not supported" in error_lower: |
|
|
return "precision" |
|
|
elif "fp16" in error_lower and "not supported" in error_lower: |
|
|
return "precision" |
|
|
elif "4-bit" in error_lower and "not supported" in error_lower: |
|
|
return "precision" |
|
|
elif "bitsandbytes" in error_lower: |
|
|
return "precision" |
|
|
elif "seq_len" in error_lower and "model limit" in error_lower: |
|
|
return "seq_len" |
|
|
elif "position" in error_lower and "embedding" in error_lower: |
|
|
return "seq_len" |
|
|
elif "lora" in error_lower and "target" in error_lower: |
|
|
return "lora" |
|
|
elif "cuda error" in error_lower and "assert" in error_lower: |
|
|
return "seq_len" |
|
|
else: |
|
|
return "unknown" |
|
|
|
|
|
def apply_fallback_strategy(self, config: ConfigCandidate, error_type: str) -> Optional[ConfigCandidate]: |
|
|
"""Apply fallback strategy based on error type""" |
|
|
new_config = ConfigCandidate( |
|
|
model=config.model, |
|
|
precision=config.precision, |
|
|
seq_len=config.seq_len, |
|
|
batch_size=config.batch_size, |
|
|
lora=config.lora, |
|
|
lora_targets=config.lora_targets, |
|
|
gradient_checkpointing=config.gradient_checkpointing, |
|
|
dataset=config.dataset, |
|
|
text_field=config.text_field |
|
|
) |
|
|
|
|
|
if error_type == "precision": |
|
|
|
|
|
if config.precision == "bf16" and not self.gpu.bf16_supported: |
|
|
new_config.precision = "fp16" |
|
|
return new_config |
|
|
elif config.precision == "qlora4bit" and not has_bitsandbytes(): |
|
|
new_config.precision = "fp16" |
|
|
return new_config |
|
|
elif config.precision == "fp16" and not self.gpu.available: |
|
|
new_config.precision = "fp32" |
|
|
return new_config |
|
|
elif config.precision in ["bf16", "fp16"] and not self.gpu.available: |
|
|
new_config.precision = "fp32" |
|
|
return new_config |
|
|
|
|
|
elif error_type == "oom": |
|
|
|
|
|
if config.batch_size > 1: |
|
|
new_config.batch_size = max(1, config.batch_size // 2) |
|
|
return new_config |
|
|
elif not config.gradient_checkpointing: |
|
|
new_config.gradient_checkpointing = True |
|
|
return new_config |
|
|
elif config.seq_len > 512: |
|
|
new_config.seq_len = max(512, config.seq_len // 2) |
|
|
return new_config |
|
|
elif config.precision in ["bf16", "fp32"]: |
|
|
new_config.precision = "fp16" |
|
|
return new_config |
|
|
elif config.precision == "fp16" and has_bitsandbytes() and self.gpu.available: |
|
|
new_config.precision = "qlora4bit" |
|
|
return new_config |
|
|
|
|
|
elif error_type == "seq_len": |
|
|
|
|
|
if config.seq_len > 1024: |
|
|
new_config.seq_len = 1024 |
|
|
return new_config |
|
|
elif config.seq_len > 512: |
|
|
new_config.seq_len = 512 |
|
|
return new_config |
|
|
|
|
|
elif error_type == "lora": |
|
|
|
|
|
if config.lora and config.lora_targets: |
|
|
new_config.lora_targets = ["q_proj", "v_proj"] |
|
|
return new_config |
|
|
|
|
|
return None |
|
|
|
|
|
def simulate_fallbacks(self, initial_config: ConfigCandidate, max_attempts: int = 10) -> Tuple[bool, Optional[ConfigCandidate]]: |
|
|
"""Simulate fallback attempts until success or max attempts reached""" |
|
|
current_config = initial_config |
|
|
attempt_num = 0 |
|
|
|
|
|
console.print(f"\n[bold blue]π Starting Auto-Heal Simulation Loop[/bold blue]") |
|
|
console.print(f"[dim]Max attempts: {max_attempts}[/dim]\n") |
|
|
|
|
|
|
|
|
attempts_table = Table(title="Fallback Simulation Attempts") |
|
|
attempts_table.add_column("Attempt", style="cyan", width=8) |
|
|
attempts_table.add_column("Precision", style="white", width=10) |
|
|
attempts_table.add_column("Seq Len", style="white", width=8) |
|
|
attempts_table.add_column("Batch", style="white", width=6) |
|
|
attempts_table.add_column("LoRA", style="white", width=6) |
|
|
attempts_table.add_column("Grad Check", style="white", width=10) |
|
|
attempts_table.add_column("Result", style="white", width=8) |
|
|
attempts_table.add_column("Strategy", style="yellow", width=20) |
|
|
|
|
|
while attempt_num < max_attempts: |
|
|
attempt_num += 1 |
|
|
|
|
|
|
|
|
self.reset_gpu_state() |
|
|
|
|
|
|
|
|
result = dry_run( |
|
|
model_id_or_path=current_config.model, |
|
|
precision=current_config.precision, |
|
|
seq_len=current_config.seq_len, |
|
|
batch_size=current_config.batch_size, |
|
|
lora=current_config.lora, |
|
|
lora_targets=current_config.lora_targets, |
|
|
) |
|
|
|
|
|
|
|
|
if attempt_num == 1: |
|
|
strategy = "Initial attempt" |
|
|
else: |
|
|
strategy = f"Fallback #{attempt_num-1}" |
|
|
|
|
|
|
|
|
attempt = FallbackAttempt( |
|
|
attempt_num=attempt_num, |
|
|
config=current_config, |
|
|
result=result, |
|
|
strategy=strategy, |
|
|
notes="" |
|
|
) |
|
|
self.attempts.append(attempt) |
|
|
|
|
|
|
|
|
result_text = "β
PASS" if result.ok else "β FAIL" |
|
|
attempts_table.add_row( |
|
|
str(attempt_num), |
|
|
current_config.precision, |
|
|
str(current_config.seq_len), |
|
|
str(current_config.batch_size), |
|
|
"Yes" if current_config.lora else "No", |
|
|
"Yes" if current_config.gradient_checkpointing else "No", |
|
|
result_text, |
|
|
strategy |
|
|
) |
|
|
|
|
|
if result.ok: |
|
|
console.print(attempts_table) |
|
|
console.print(f"\n[bold green]β
SUCCESS![/bold green] Auto-healing found working configuration at attempt {attempt_num}") |
|
|
return True, current_config |
|
|
|
|
|
|
|
|
error_type = self.classify_error(result.error or "unknown") |
|
|
next_config = self.apply_fallback_strategy(current_config, error_type) |
|
|
|
|
|
if next_config is None: |
|
|
console.print(attempts_table) |
|
|
console.print(f"\n[bold red]β FAILED[/bold red] No more fallback strategies available") |
|
|
return False, None |
|
|
|
|
|
|
|
|
if error_type == "oom": |
|
|
attempt.notes = f"OOM detected, reducing batch size to {next_config.batch_size}" |
|
|
elif error_type == "precision": |
|
|
attempt.notes = f"Precision {current_config.precision} not supported, switching to {next_config.precision}" |
|
|
elif error_type == "seq_len": |
|
|
attempt.notes = f"Sequence length {current_config.seq_len} too long, reducing to {next_config.seq_len}" |
|
|
elif error_type == "lora": |
|
|
attempt.notes = f"LoRA target modules not found, using defaults" |
|
|
|
|
|
current_config = next_config |
|
|
|
|
|
console.print(attempts_table) |
|
|
console.print(f"\n[bold red]β FAILED[/bold red] Max attempts ({max_attempts}) reached") |
|
|
return False, None |
|
|
|
|
|
def generate_yaml_config(self, config: ConfigCandidate) -> str: |
|
|
"""Generate YAML-style config block for the working configuration""" |
|
|
yaml_lines = [ |
|
|
"# AUTO-HEALED CONFIG PATCH", |
|
|
f"model: {config.model}", |
|
|
f"precision: {config.precision}", |
|
|
f"seq_len: {config.seq_len}", |
|
|
f"batch_size: {config.batch_size}", |
|
|
f"lora: {str(config.lora).lower()}", |
|
|
f"gradient_checkpointing: {str(config.gradient_checkpointing).lower()}", |
|
|
f"dataset: {config.dataset}", |
|
|
] |
|
|
|
|
|
if config.lora_targets: |
|
|
yaml_lines.append(f"lora_targets: {config.lora_targets}") |
|
|
|
|
|
if config.text_field: |
|
|
yaml_lines.append(f"text_field: {config.text_field}") |
|
|
|
|
|
return "\n".join(yaml_lines) |
|
|
|