File size: 670 Bytes
11ac7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb564a
 
 
11ac7be
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional


@dataclass(frozen=True)
class TrainConfig:
    seed: int = 42
    n_samples: int = 2000
    n_features: int = 16
    train_ratio: float = 0.8
    epochs: int = 20
    lr: float = 0.2
    l2: float = 0.0
    grad_clip: Optional[float] = None
    loss_eps: float = 1e-12
    report_top_k_features: int = 8
    
    csv_path: Optional[str] = None
    target_col: str = "target"

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    @staticmethod
    def from_dict(d: Dict[str, Any]) -> "TrainConfig":
        return TrainConfig(**d)