Spaces:
Sleeping
Sleeping
| """Pure-NumPy MLP matching the Android TFLite model exactly. | |
| Shapes, dtypes, and tensor order are dictated by server/config.py and | |
| must match `gen_tflite/make_model.py` byte-for-byte. | |
| """ | |
| from __future__ import annotations | |
| from typing import List | |
| import numpy as np | |
| from flwr.common import Parameters | |
| from server.config import LAYER_SHAPES | |
| def sort_layer_keys(keys: List[str]) -> List[str]: | |
| """Sort ``["a2","a10","a0"]`` → ``["a0","a2","a10"]`` (numeric, not lex).""" | |
| return sorted(keys, key=lambda k: int(k[1:])) | |
| def init_weights() -> List[np.ndarray]: | |
| """Glorot-uniform init for weight matrices, zeros for biases.""" | |
| rng = np.random.default_rng(seed=0) | |
| out: List[np.ndarray] = [] | |
| for _name, shape in LAYER_SHAPES: | |
| if len(shape) == 2: | |
| fan_in, fan_out = shape | |
| limit = float(np.sqrt(6.0 / (fan_in + fan_out))) | |
| out.append(rng.uniform(-limit, limit, size=shape).astype(np.float32)) | |
| else: | |
| out.append(np.zeros(shape, dtype=np.float32)) | |
| return out | |
| def weights_to_parameters(weights: List[np.ndarray]) -> Parameters: | |
| """Pack NumPy weights into a Flower Parameters message. | |
| Each tensor is serialised as raw little-endian float32 C-contiguous | |
| bytes. tensor_type is the literal string "ND" — the Android client | |
| hard-codes this. | |
| """ | |
| tensors = [ | |
| np.ascontiguousarray(w, dtype=np.float32).tobytes() | |
| for w in weights | |
| ] | |
| return Parameters(tensors=tensors, tensor_type="ND") | |
| def parameters_to_weights(parameters: Parameters) -> List[np.ndarray]: | |
| """Inverse of weights_to_parameters; fails loud on size mismatch.""" | |
| if len(parameters.tensors) != len(LAYER_SHAPES): | |
| raise ValueError( | |
| f"expected {len(LAYER_SHAPES)} tensors, got {len(parameters.tensors)}" | |
| ) | |
| out: List[np.ndarray] = [] | |
| for tensor_bytes, (name, shape) in zip(parameters.tensors, LAYER_SHAPES): | |
| expected = int(np.prod(shape)) * 4 | |
| if len(tensor_bytes) != expected: | |
| raise ValueError( | |
| f"tensor {name}: expected {expected} bytes for shape {shape}, " | |
| f"got {len(tensor_bytes)}" | |
| ) | |
| out.append( | |
| np.frombuffer(tensor_bytes, dtype=np.float32).reshape(shape).copy() | |
| ) | |
| return out | |
| def forward(weights: List[np.ndarray], x: np.ndarray) -> np.ndarray: | |
| """Server-side inference helper — matches the TFLite graph exactly.""" | |
| w1, b1, w2, b2, w3, b3 = weights | |
| h1 = np.maximum(x @ w1 + b1, 0.0) | |
| h2 = np.maximum(h1 @ w2 + b2, 0.0) | |
| return (h2 @ w3 + b3).astype(np.float32) | |