import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def d_latch(e, d, q_prev, weights): """D-Latch: E=1 transparent (Q=D), E=0 hold (Q=Q_prev).""" inp = torch.tensor([float(e), float(d), float(q_prev)]) e_and_d = int((inp @ weights['e_and_d.weight'].T + weights['e_and_d.bias'] >= 0).item()) e_and_notd = int((inp @ weights['e_and_notd.weight'].T + weights['e_and_notd.bias'] >= 0).item()) note_and_qprev = int((inp @ weights['note_and_qprev.weight'].T + weights['note_and_qprev.bias'] >= 0).item()) note_and_notqprev = int((inp @ weights['note_and_notqprev.weight'].T + weights['note_and_notqprev.bias'] >= 0).item()) l1 = torch.tensor([float(e_and_d), float(e_and_notd), float(note_and_qprev), float(note_and_notqprev)]) q = int((l1 @ weights['q.weight'].T + weights['q.bias'] >= 0).item()) qn = int((l1 @ weights['qn.weight'].T + weights['qn.bias'] >= 0).item()) return q, qn if __name__ == '__main__': w = load_model() print('D-Latch:') print('E D Q_prev | Q Qn') for e in range(2): for d in range(2): for q_prev in range(2): q, qn = d_latch(e, d, q_prev, w) print(f'{e} {d} {q_prev} | {q} {qn}')