| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def d_latch(e, d, q_prev, weights): | |
| """D-Latch: E=1 transparent (Q=D), E=0 hold (Q=Q_prev).""" | |
| inp = torch.tensor([float(e), float(d), float(q_prev)]) | |
| e_and_d = int((inp @ weights['e_and_d.weight'].T + weights['e_and_d.bias'] >= 0).item()) | |
| e_and_notd = int((inp @ weights['e_and_notd.weight'].T + weights['e_and_notd.bias'] >= 0).item()) | |
| note_and_qprev = int((inp @ weights['note_and_qprev.weight'].T + weights['note_and_qprev.bias'] >= 0).item()) | |
| note_and_notqprev = int((inp @ weights['note_and_notqprev.weight'].T + weights['note_and_notqprev.bias'] >= 0).item()) | |
| l1 = torch.tensor([float(e_and_d), float(e_and_notd), float(note_and_qprev), float(note_and_notqprev)]) | |
| q = int((l1 @ weights['q.weight'].T + weights['q.bias'] >= 0).item()) | |
| qn = int((l1 @ weights['qn.weight'].T + weights['qn.bias'] >= 0).item()) | |
| return q, qn | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('D-Latch:') | |
| print('E D Q_prev | Q Qn') | |
| for e in range(2): | |
| for d in range(2): | |
| for q_prev in range(2): | |
| q, qn = d_latch(e, d, q_prev, w) | |
| print(f'{e} {d} {q_prev} | {q} {qn}') | |