#include using namespace std; struct TestCase { int n; long long k; vector> A; long long answer; }; TestCase gen_matrix(int n, long long k, function valfn) { TestCase tc; tc.n = n; tc.k = k; tc.A.assign(n+1, vector(n+1, 0)); vector all; for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { tc.A[i][j] = valfn(i, j); all.push_back(tc.A[i][j]); } sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc; } struct Solver { const TestCase& tc; int query_count; vector memo; int n; // Diagnostics int num_iters; vector walk_costs; vector sample_costs; vector cand_sizes; vector split_ratios; Solver(const TestCase& t) : tc(t), query_count(0), n(t.n), num_iters(0) { memo.assign(2002 * 2002, -1); } long long do_query(int r, int c) { int key = r * 2001 + c; if (memo[key] != -1) return memo[key]; query_count++; memo[key] = tc.A[r][c]; return memo[key]; } long long solve() { long long k = tc.k; long long N2 = (long long)n * n; if (n == 1) return do_query(1, 1); long long heap_k = min(k, N2 - k + 1); if (heap_k + n <= 24000) { if (k <= N2 - k + 1) { priority_queue, vector>, greater<>> pq; vector> vis(n + 1, vector(n + 1, false)); pq.emplace(do_query(1, 1), 1, 1); vis[1][1] = true; long long result = -1; for (long long i = 0; i < k; i++) { auto [v, r, c] = pq.top(); pq.pop(); result = v; if (r + 1 <= n && !vis[r + 1][c]) { vis[r + 1][c] = true; pq.emplace(do_query(r + 1, c), r + 1, c); } if (c + 1 <= n && !vis[r][c + 1]) { vis[r][c + 1] = true; pq.emplace(do_query(r, c + 1), r, c + 1); } } return result; } else { long long kk = N2 - k + 1; priority_queue> pq; vector> vis(n + 1, vector(n + 1, false)); pq.emplace(do_query(n, n), n, n); vis[n][n] = true; long long result = -1; for (long long i = 0; i < kk; i++) { auto [v, r, c] = pq.top(); pq.pop(); result = v; if (r - 1 >= 1 && !vis[r - 1][c]) { vis[r - 1][c] = true; pq.emplace(do_query(r - 1, c), r - 1, c); } if (c - 1 >= 1 && !vis[r][c - 1]) { vis[r][c - 1] = true; pq.emplace(do_query(r, c - 1), r, c - 1); } } return result; } } vector L(n + 1, 1), R(n + 1, n); long long k_rem = k; for (int iter = 0; iter < 100; iter++) { vector active; long long total_cand = 0; for (int i = 1; i <= n; i++) { if (L[i] <= R[i]) { active.push_back(i); total_cand += R[i] - L[i] + 1; } } int na = active.size(); if (total_cand == 0) break; if (total_cand == 1) { for (int i : active) return do_query(i, L[i]); break; } long long budget = 49500 - query_count; if (k_rem + na <= budget) { num_iters = iter; priority_queue, vector>, greater<>> pq; for (int i : active) pq.emplace(do_query(i, L[i]), i, L[i]); for (long long t = 1; t < k_rem; t++) { auto [v, r, c] = pq.top(); pq.pop(); if (c + 1 <= R[r]) pq.emplace(do_query(r, c + 1), r, c + 1); } return get<0>(pq.top()); } long long rev_k = total_cand - k_rem + 1; if (rev_k + na <= budget) { num_iters = iter; priority_queue> pq; for (int i : active) pq.emplace(do_query(i, R[i]), i, R[i]); for (long long t = 1; t < rev_k; t++) { auto [v, r, c] = pq.top(); pq.pop(); if (c - 1 >= L[r]) pq.emplace(do_query(r, c - 1), r, c - 1); } return get<0>(pq.top()); } int qc_before = query_count; // Pivot selection (original) vector pvals; double target_frac = (double)(k_rem - 0.5) / total_cand; int sample_n = max(1, min(na, (int)ceil(sqrt((double)na) * 4))); int step = max(1, na / sample_n); for (int idx = 0; idx < na; idx += step) { int i = active[idx]; int width = R[i] - L[i] + 1; int col = L[i] + (int)(target_frac * width); col = max(L[i], min(R[i], col)); pvals.push_back(do_query(i, col)); } sort(pvals.begin(), pvals.end()); long long pivot = pvals[pvals.size() / 2]; int qc_after_sample = query_count; vector p_le(n + 1, 0); { int j = 0; for (int idx = na - 1; idx >= 0; idx--) { int i = active[idx]; j = max(j, L[i]); while (j <= R[i] && do_query(i, j) <= pivot) j++; p_le[i] = j - 1; } } int qc_after_walk = query_count; long long cle = 0; for (int i : active) { int rl = min(p_le[i], R[i]); if (rl >= L[i]) cle += rl - L[i] + 1; } double ratio = (double)min(cle, total_cand - cle) / total_cand; sample_costs.push_back(qc_after_sample - qc_before); walk_costs.push_back(qc_after_walk - qc_after_sample); cand_sizes.push_back(total_cand); split_ratios.push_back(ratio); if (cle >= k_rem) { for (int i : active) R[i] = min(R[i], p_le[i]); } else { k_rem -= cle; for (int i : active) L[i] = max(L[i], p_le[i] + 1); } } return -1; } }; int main() { // Test multiplicative n=2000, k=2000000 auto tc = gen_matrix(2000, 2000000, [](int i, int j) -> long long { return (long long)i * j; }); Solver s(tc); long long result = s.solve(); bool correct = (result == tc.answer); printf("Result: %lld, Expected: %lld, Correct: %s, Queries: %d\n", result, tc.answer, correct ? "YES" : "NO", s.query_count); printf("Iterations before fallback: %d\n", s.num_iters); printf("\nPer-iteration breakdown:\n"); printf("%-5s %10s %8s %8s %10s\n", "Iter", "Candidates", "Sample", "Walk", "SplitRatio"); for (int i = 0; i < (int)s.walk_costs.size(); i++) { printf("%-5d %10lld %8d %8d %10.4f\n", i, s.cand_sizes[i], s.sample_costs[i], s.walk_costs[i], s.split_ratios[i]); } // Also test additive printf("\n--- Additive n=2000 k=2000000 ---\n"); auto tc2 = gen_matrix(2000, 2000000, [](int i, int j) -> long long { return i + j; }); Solver s2(tc2); long long result2 = s2.solve(); printf("Result: %lld, Expected: %lld, Correct: %s, Queries: %d\n", result2, tc2.answer, result2 == tc2.answer ? "YES" : "NO", s2.query_count); printf("Iterations before fallback: %d\n", s2.num_iters); printf("\nPer-iteration breakdown:\n"); printf("%-5s %10s %8s %8s %10s\n", "Iter", "Candidates", "Sample", "Walk", "SplitRatio"); for (int i = 0; i < (int)s2.walk_costs.size(); i++) { printf("%-5d %10lld %8d %8d %10.4f\n", i, s2.cand_sizes[i], s2.sample_costs[i], s2.walk_costs[i], s2.split_ratios[i]); } // shifted printf("\n--- Shifted n=2000 k=2000000 ---\n"); int nn = 2000; auto tc3 = gen_matrix(nn, 2000000, [nn](int i, int j) -> long long { return (long long)(i + nn) * (j + nn); }); Solver s3(tc3); long long result3 = s3.solve(); printf("Result: %lld, Expected: %lld, Correct: %s, Queries: %d\n", result3, tc3.answer, result3 == tc3.answer ? "YES" : "NO", s3.query_count); printf("Iterations before fallback: %d\n", s3.num_iters); printf("\nPer-iteration breakdown:\n"); printf("%-5s %10s %8s %8s %10s\n", "Iter", "Candidates", "Sample", "Walk", "SplitRatio"); for (int i = 0; i < (int)s3.walk_costs.size(); i++) { printf("%-5d %10lld %8d %8d %10.4f\n", i, s3.cand_sizes[i], s3.sample_costs[i], s3.walk_costs[i], s3.split_ratios[i]); } return 0; }