#include using namespace std; constexpr long long P = 998244353; constexpr int MAXSTATES = 31 * 31; int d; int g1_arr[30], gy_arr[30]; int ret_instr[MAXSTATES]; long long cost_arr[MAXSTATES]; int8_t status_arr[MAXSTATES]; struct Frame { int state; int8_t phase; int dep1, dep2; }; Frame stk[MAXSTATES * 2]; int stk_top; inline int encode(int i, int x) { return i * (d + 1) + x; } long long eval_fast() { int ns = (d + 1) * (d + 1); memset(status_arr, 0, ns); stk_top = 0; stk[stk_top++] = {encode(0, 0), 0, -1, -1}; while (stk_top > 0) { Frame& f = stk[stk_top - 1]; int s = f.state, i = s / (d + 1), x = s % (d + 1); if (f.phase == 0) { if (status_arr[s] == 2) { stk_top--; continue; } if (status_arr[s] == 1) return -1; status_arr[s] = 1; if (i < d) { if (x == i + 1) { ret_instr[s] = g1_arr[i]; cost_arr[s] = 1; status_arr[s] = 2; stk_top--; continue; } f.dep1 = encode(gy_arr[i], i + 1); f.phase = 1; stk[stk_top++] = {f.dep1, 0, -1, -1}; continue; } else { if (x == 0) { ret_instr[s] = -1; cost_arr[s] = 1; status_arr[s] = 2; stk_top--; continue; } f.dep1 = encode(d, 1); f.phase = 1; stk[stk_top++] = {f.dep1, 0, -1, -1}; continue; } } else if (f.phase == 1) { if (status_arr[f.dep1] != 2) return -1; int j = ret_instr[f.dep1]; if (j < 0 || j > d) return -1; f.dep2 = encode(j, x); f.phase = 2; stk[stk_top++] = {f.dep2, 0, -1, -1}; continue; } else { if (status_arr[f.dep2] != 2) return -1; ret_instr[s] = ret_instr[f.dep2]; cost_arr[s] = (cost_arr[f.dep1] + cost_arr[f.dep2] + 1) % P; status_arr[s] = 2; stk_top--; } } return status_arr[encode(0, 0)] == 2 ? cost_arr[encode(0, 0)] : -1; } int main() { d = 26; mt19937 rng(42); auto t0 = chrono::steady_clock::now(); int total = 0, valid = 0; for (int i = 0; i < 100000; i++) { for (int j = 0; j < d; j++) { g1_arr[j] = j + 1; gy_arr[j] = rng() % (j + 1); } int nc = rng() % 4; for (int c = 0; c < nc; c++) { g1_arr[rng() % d] = rng() % (d + 1); } total++; long long T = eval_fast(); if (T >= 0) valid++; } auto t1 = chrono::steady_clock::now(); double ms = chrono::duration_cast(t1 - t0).count() / 1000.0; printf("d=%d: %d evals in %.1f ms (%.1f us/eval), %d valid (%.1f%%)\n", d, total, ms, ms*1000/total, valid, 100.0*valid/total); return 0; }