Sammi1211's picture
Initial push
d7f62d0
"""Pure-NumPy MLP matching the Android TFLite model exactly.
Shapes, dtypes, and tensor order are dictated by server/config.py and
must match `gen_tflite/make_model.py` byte-for-byte.
"""
from __future__ import annotations
from typing import List
import numpy as np
from flwr.common import Parameters
from server.config import LAYER_SHAPES
def sort_layer_keys(keys: List[str]) -> List[str]:
"""Sort ``["a2","a10","a0"]`` → ``["a0","a2","a10"]`` (numeric, not lex)."""
return sorted(keys, key=lambda k: int(k[1:]))
def init_weights() -> List[np.ndarray]:
"""Glorot-uniform init for weight matrices, zeros for biases."""
rng = np.random.default_rng(seed=0)
out: List[np.ndarray] = []
for _name, shape in LAYER_SHAPES:
if len(shape) == 2:
fan_in, fan_out = shape
limit = float(np.sqrt(6.0 / (fan_in + fan_out)))
out.append(rng.uniform(-limit, limit, size=shape).astype(np.float32))
else:
out.append(np.zeros(shape, dtype=np.float32))
return out
def weights_to_parameters(weights: List[np.ndarray]) -> Parameters:
"""Pack NumPy weights into a Flower Parameters message.
Each tensor is serialised as raw little-endian float32 C-contiguous
bytes. tensor_type is the literal string "ND" — the Android client
hard-codes this.
"""
tensors = [
np.ascontiguousarray(w, dtype=np.float32).tobytes()
for w in weights
]
return Parameters(tensors=tensors, tensor_type="ND")
def parameters_to_weights(parameters: Parameters) -> List[np.ndarray]:
"""Inverse of weights_to_parameters; fails loud on size mismatch."""
if len(parameters.tensors) != len(LAYER_SHAPES):
raise ValueError(
f"expected {len(LAYER_SHAPES)} tensors, got {len(parameters.tensors)}"
)
out: List[np.ndarray] = []
for tensor_bytes, (name, shape) in zip(parameters.tensors, LAYER_SHAPES):
expected = int(np.prod(shape)) * 4
if len(tensor_bytes) != expected:
raise ValueError(
f"tensor {name}: expected {expected} bytes for shape {shape}, "
f"got {len(tensor_bytes)}"
)
out.append(
np.frombuffer(tensor_bytes, dtype=np.float32).reshape(shape).copy()
)
return out
def forward(weights: List[np.ndarray], x: np.ndarray) -> np.ndarray:
"""Server-side inference helper — matches the TFLite graph exactly."""
w1, b1, w2, b2, w3, b3 = weights
h1 = np.maximum(x @ w1 + b1, 0.0)
h2 = np.maximum(h1 @ w2 + b2, 0.0)
return (h2 @ w3 + b3).astype(np.float32)