File size: 1,797 Bytes
2c11783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Helpers for converting numpy arrays and other ML outputs to JSON-serializable
Python structures. FastAPI's JSONResponse cannot serialize numpy types directly.
"""

import numpy as np


def ndarray_to_list(arr: np.ndarray) -> list:
    """Recursively convert ndarray to nested Python lists."""
    return arr.tolist()


def float32(v) -> float:
    """Safely convert any numeric type to a Python float."""
    return float(v)


def safe_dict(d: dict) -> dict:
    """
    Walk a dict and convert any numpy scalars / arrays to Python native types.
    Safe for nested dicts and lists.
    """
    out = {}
    for k, v in d.items():
        if isinstance(v, np.ndarray):
            out[k] = v.tolist()
        elif isinstance(v, (np.integer,)):
            out[k] = int(v)
        elif isinstance(v, (np.floating,)):
            out[k] = float(v)
        elif isinstance(v, dict):
            out[k] = safe_dict(v)
        elif isinstance(v, list):
            out[k] = safe_list(v)
        else:
            out[k] = v
    return out


def safe_list(lst: list) -> list:
    out = []
    for v in lst:
        if isinstance(v, np.ndarray):
            out.append(v.tolist())
        elif isinstance(v, (np.integer,)):
            out.append(int(v))
        elif isinstance(v, (np.floating,)):
            out.append(float(v))
        elif isinstance(v, dict):
            out.append(safe_dict(v))
        elif isinstance(v, list):
            out.append(safe_list(v))
        else:
            out.append(v)
    return out


def confusion_matrix_to_dict(cm: np.ndarray, class_names: list[str]) -> dict:
    """Convert confusion matrix to a frontend-friendly format."""
    return {
        "matrix": cm.tolist(),
        "class_names": class_names,
        "n_classes": len(class_names),
    }