Spaces:
Running
Running
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from random import Random | |
| from typing import Dict, List, Optional | |
| from .logistic import LogisticRegression | |
| from .metrics import evaluate | |
| class TrainResult: | |
| model: LogisticRegression | |
| history: List[Dict[str, float]] | |
| def _l2_norm(v: List[float]) -> float: | |
| return math.sqrt(sum(x * x for x in v)) | |
| def train_logreg_sgd( | |
| *, | |
| seed: int, | |
| xs_train: List[List[float]], | |
| ys_train: List[int], | |
| xs_val: List[List[float]], | |
| ys_val: List[int], | |
| epochs: int, | |
| lr: float, | |
| l2: float = 0.0, | |
| grad_clip: Optional[float] = None, | |
| loss_eps: float = 1e-12, | |
| ) -> TrainResult: | |
| r = Random(seed) | |
| n_features = len(xs_train[0]) if xs_train else 0 | |
| model = LogisticRegression.init(n_features) | |
| history: List[Dict[str, float]] = [] | |
| idx = list(range(len(xs_train))) | |
| for epoch in range(1, epochs + 1): | |
| r.shuffle(idx) | |
| for i in idx: | |
| x = xs_train[i] | |
| y = ys_train[i] | |
| p = model.predict_proba_one(x) | |
| g = (p - y) | |
| grad_w = [g * xi + (l2 * wi) for xi, wi in zip(x, model.w)] | |
| grad_b = g | |
| if grad_clip is not None and grad_clip > 0: | |
| norm = _l2_norm(grad_w) | |
| if norm > grad_clip: | |
| scale = grad_clip / (norm + 1e-12) | |
| grad_w = [gw * scale for gw in grad_w] | |
| grad_b *= scale | |
| model.w = [wi - lr * gw for wi, gw in zip(model.w, grad_w)] | |
| model.b = model.b - lr * grad_b | |
| train_loss, train_acc = evaluate(model, xs_train, ys_train, eps=loss_eps) | |
| val_loss, val_acc = evaluate(model, xs_val, ys_val, eps=loss_eps) | |
| history.append( | |
| { | |
| "epoch": float(epoch), | |
| "train_loss": float(train_loss), | |
| "train_acc": float(train_acc), | |
| "val_loss": float(val_loss), | |
| "val_acc": float(val_acc), | |
| } | |
| ) | |
| if not math.isfinite(train_loss) or not math.isfinite(val_loss): | |
| break | |
| return TrainResult(model=model, history=history) | |