"""Pure-NumPy MLP matching the Android TFLite model exactly. Shapes, dtypes, and tensor order are dictated by server/config.py and must match `gen_tflite/make_model.py` byte-for-byte. """ from __future__ import annotations from typing import List import numpy as np from flwr.common import Parameters from server.config import LAYER_SHAPES def sort_layer_keys(keys: List[str]) -> List[str]: """Sort ``["a2","a10","a0"]`` → ``["a0","a2","a10"]`` (numeric, not lex).""" return sorted(keys, key=lambda k: int(k[1:])) def init_weights() -> List[np.ndarray]: """Glorot-uniform init for weight matrices, zeros for biases.""" rng = np.random.default_rng(seed=0) out: List[np.ndarray] = [] for _name, shape in LAYER_SHAPES: if len(shape) == 2: fan_in, fan_out = shape limit = float(np.sqrt(6.0 / (fan_in + fan_out))) out.append(rng.uniform(-limit, limit, size=shape).astype(np.float32)) else: out.append(np.zeros(shape, dtype=np.float32)) return out def weights_to_parameters(weights: List[np.ndarray]) -> Parameters: """Pack NumPy weights into a Flower Parameters message. Each tensor is serialised as raw little-endian float32 C-contiguous bytes. tensor_type is the literal string "ND" — the Android client hard-codes this. """ tensors = [ np.ascontiguousarray(w, dtype=np.float32).tobytes() for w in weights ] return Parameters(tensors=tensors, tensor_type="ND") def parameters_to_weights(parameters: Parameters) -> List[np.ndarray]: """Inverse of weights_to_parameters; fails loud on size mismatch.""" if len(parameters.tensors) != len(LAYER_SHAPES): raise ValueError( f"expected {len(LAYER_SHAPES)} tensors, got {len(parameters.tensors)}" ) out: List[np.ndarray] = [] for tensor_bytes, (name, shape) in zip(parameters.tensors, LAYER_SHAPES): expected = int(np.prod(shape)) * 4 if len(tensor_bytes) != expected: raise ValueError( f"tensor {name}: expected {expected} bytes for shape {shape}, " f"got {len(tensor_bytes)}" ) out.append( np.frombuffer(tensor_bytes, dtype=np.float32).reshape(shape).copy() ) return out def forward(weights: List[np.ndarray], x: np.ndarray) -> np.ndarray: """Server-side inference helper — matches the TFLite graph exactly.""" w1, b1, w2, b2, w3, b3 = weights h1 = np.maximum(x @ w1 + b1, 0.0) h2 = np.maximum(h1 @ w2 + b2, 0.0) return (h2 @ w3 + b3).astype(np.float32)