File size: 828 Bytes
11ac7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import List


def sigmoid(z: float) -> float:
    if z >= 0:
        ez = math.exp(-z)
        return 1.0 / (1.0 + ez)
    ez = math.exp(z)
    return ez / (1.0 + ez)


def dot(a: List[float], b: List[float]) -> float:
    return sum(x * y for x, y in zip(a, b))


@dataclass
class LogisticRegression:
    w: List[float]
    b: float

    @staticmethod
    def init(n_features: int) -> "LogisticRegression":
        return LogisticRegression(w=[0.0 for _ in range(n_features)], b=0.0)

    def predict_proba_one(self, x: List[float]) -> float:
        return sigmoid(dot(self.w, x) + self.b)

    def predict_one(self, x: List[float], threshold: float = 0.5) -> int:
        return 1 if self.predict_proba_one(x) >= threshold else 0