new-human
feat: 支持真实CSV数据训练与模型推理闭环
0bb564a
from __future__ import annotations
import csv
from dataclasses import dataclass
from random import Random
from typing import List, Tuple, Dict, Any
@dataclass(frozen=True)
class Dataset:
x_train: List[List[float]]
y_train: List[int]
x_val: List[List[float]]
y_val: List[int]
true_w: List[float]
true_b: float
feature_names: List[str] = None
def _dot(a: List[float], b: List[float]) -> float:
return sum(x * y for x, y in zip(a, b))
def load_csv_dataset(*, csv_path: str, target_col: str, seed: int, train_ratio: float) -> Dataset:
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader)
try:
target_idx = header.index(target_col)
except ValueError:
target_idx = -1 # default to last column if not found
target_col = header[-1]
feature_names = [col for i, col in enumerate(header) if i != target_idx]
xs: List[List[float]] = []
ys: List[int] = []
for row in reader:
if not row:
continue
# try parsing
try:
y = int(float(row[target_idx]))
x = [float(val) for i, val in enumerate(row) if i != target_idx]
xs.append(x)
ys.append(y)
except ValueError:
continue # skip invalid rows
n_samples = len(xs)
if n_samples == 0:
raise ValueError("No valid data found in CSV.")
idx = list(range(n_samples))
r = Random(seed)
r.shuffle(idx)
cut = max(1, min(n_samples - 1, int(n_samples * train_ratio)))
train_idx = idx[:cut]
val_idx = idx[cut:]
x_train = [xs[i] for i in train_idx]
y_train = [ys[i] for i in train_idx]
x_val = [xs[i] for i in val_idx]
y_val = [ys[i] for i in val_idx]
return Dataset(
x_train=x_train,
y_train=y_train,
x_val=x_val,
y_val=y_val,
true_w=[],
true_b=0.0,
feature_names=feature_names
)
def make_synthetic_binary_classification(
*, seed: int, n_samples: int, n_features: int, train_ratio: float
) -> Dataset:
r = Random(seed)
true_w = [r.uniform(-1.0, 1.0) for _ in range(n_features)]
true_b = r.uniform(-0.5, 0.5)
xs: List[List[float]] = []
ys: List[int] = []
for _ in range(n_samples):
x = [r.gauss(0.0, 1.0) for _ in range(n_features)]
margin = _dot(true_w, x) + true_b + r.gauss(0.0, 0.2)
y = 1 if margin > 0 else 0
xs.append(x)
ys.append(y)
idx = list(range(n_samples))
r.shuffle(idx)
cut = max(1, min(n_samples - 1, int(n_samples * train_ratio)))
train_idx = idx[:cut]
val_idx = idx[cut:]
x_train = [xs[i] for i in train_idx]
y_train = [ys[i] for i in train_idx]
x_val = [xs[i] for i in val_idx]
y_val = [ys[i] for i in val_idx]
feature_names = [f"feature_{i}" for i in range(n_features)]
return Dataset(
x_train=x_train,
y_train=y_train,
x_val=x_val,
y_val=y_val,
true_w=true_w,
true_b=true_b,
feature_names=feature_names
)