JustinTX's picture
Add files using upload-large-folder tool
1fd0050 verified
#include <bits/stdc++.h>
using namespace std;
struct TestCase {
int n; long long k;
vector<vector<long long>> A; long long answer;
};
mt19937_64 rng_gen(42);
TestCase gen_matrix(int n, long long k, function<long long(int,int)> valfn) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { tc.A[i][j] = valfn(i, j); all.push_back(tc.A[i][j]); }
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
TestCase gen_multiplicative(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return (long long)i * j; }); }
TestCase gen_shifted(int n, long long k) { return gen_matrix(n, k, [n](int i, int j) -> long long { return (long long)(i + n) * (j + n); }); }
TestCase gen_additive(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return i + j; }); }
TestCase gen_random_sorted(int n, long long k) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) tc.A[i][j] = (long long)i * 1000000 + (long long)j * 1000 + (rng_gen() % 500);
for (int i = 1; i <= n; i++) for (int j = 2; j <= n; j++) tc.A[i][j] = max(tc.A[i][j], tc.A[i][j-1]);
for (int j = 1; j <= n; j++) for (int i = 2; i <= n; i++) tc.A[i][j] = max(tc.A[i][j], tc.A[i-1][j]);
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) all.push_back(tc.A[i][j]);
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
struct Solver {
const TestCase& tc;
int query_count;
vector<long long> memo;
int n;
Solver(const TestCase& t) : tc(t), query_count(0), n(t.n) { memo.assign(2002 * 2002, -1); }
long long do_query(int r, int c) {
int key = r * 2001 + c;
if (memo[key] != -1) return memo[key];
query_count++;
memo[key] = tc.A[r][c];
return memo[key];
}
long long solve() {
long long k = tc.k;
long long total = (long long)n * n;
if (n == 1) return do_query(1, 1);
long long heap_k = min(k, total - k + 1);
if (heap_k + n <= 24000) {
if (k <= total - k + 1) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
vector<vector<bool>> vis(n + 1, vector<bool>(n + 1, false));
pq.emplace(do_query(1, 1), 1, 1); vis[1][1] = true;
long long result = -1;
for (long long i = 0; i < k; i++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (r+1<=n && !vis[r+1][c]) { vis[r+1][c]=true; pq.emplace(do_query(r+1,c),r+1,c); }
if (c+1<=n && !vis[r][c+1]) { vis[r][c+1]=true; pq.emplace(do_query(r,c+1),r,c+1); }
}
return result;
} else {
long long kk = total - k + 1;
priority_queue<tuple<long long, int, int>> pq;
vector<vector<bool>> vis(n + 1, vector<bool>(n + 1, false));
pq.emplace(do_query(n,n),n,n); vis[n][n]=true;
long long result = -1;
for (long long i = 0; i < kk; i++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (r-1>=1 && !vis[r-1][c]) { vis[r-1][c]=true; pq.emplace(do_query(r-1,c),r-1,c); }
if (c-1>=1 && !vis[r][c-1]) { vis[r][c-1]=true; pq.emplace(do_query(r,c-1),r,c-1); }
}
return result;
}
}
vector<int> L(n + 1, 1), R(n + 1, n);
long long k_rem = k;
// Initialize value bounds from matrix corners
long long val_lo = do_query(1, 1) - 1; // everything >= a[1][1]
long long val_hi = do_query(n, n); // everything <= a[n][n]
for (int iter = 0; iter < 100; iter++) {
vector<int> active;
long long total_cand = 0;
for (int i = 1; i <= n; i++) {
if (L[i] <= R[i]) { active.push_back(i); total_cand += R[i] - L[i] + 1; }
}
int na = active.size();
if (total_cand == 0) break;
if (total_cand == 1) { for (int i : active) return do_query(i, L[i]); break; }
long long budget = 49500 - query_count;
if (k_rem + na <= budget) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i : active) pq.emplace(do_query(i, L[i]), i, L[i]);
for (long long t = 1; t < k_rem; t++) {
auto [v, r, c] = pq.top(); pq.pop();
if (c + 1 <= R[r]) pq.emplace(do_query(r, c + 1), r, c + 1);
}
return get<0>(pq.top());
}
long long rev_k = total_cand - k_rem + 1;
if (rev_k + na <= budget) {
priority_queue<tuple<long long, int, int>> pq;
for (int i : active) pq.emplace(do_query(i, R[i]), i, R[i]);
for (long long t = 1; t < rev_k; t++) {
auto [v, r, c] = pq.top(); pq.pop();
if (c - 1 >= L[r]) pq.emplace(do_query(r, c - 1), r, c - 1);
}
return get<0>(pq.top());
}
double target_frac = (double)(k_rem - 0.5) / total_cand;
vector<long long> pvals;
int sample_n = max(1, min(na, (int)ceil(sqrt((double)na) * 4)));
int step = max(1, na / sample_n);
for (int idx = 0; idx < na; idx += step) {
int i = active[idx];
int width = R[i] - L[i] + 1;
int col = L[i] + (int)(target_frac * width);
col = max(L[i], min(R[i], col));
pvals.push_back(do_query(i, col));
}
sort(pvals.begin(), pvals.end());
long long pivot = pvals[pvals.size() / 2];
auto do_walk = [&](long long pv) -> pair<long long, vector<int>> {
vector<int> p_le(n + 1, 0);
int j = 0;
for (int idx = na - 1; idx >= 0; idx--) {
int i = active[idx];
j = max(j, L[i]);
while (j <= R[i] && do_query(i, j) <= pv) j++;
p_le[i] = j - 1;
}
long long cle = 0;
for (int i : active) {
int rl = min(p_le[i], R[i]);
if (rl >= L[i]) cle += rl - L[i] + 1;
}
return {cle, p_le};
};
auto [cle, p_le] = do_walk(pivot);
// Track count at value bounds for interpolation
// cnt_lo = count_leq(val_lo) within current L/R = 0 (val_lo = a[1][1]-1)
// cnt_hi = count_leq(val_hi) within current L/R = total_cand (val_hi = a[n][n])
// Update these as we learn
long long cnt_at_lo = 0, cnt_at_hi = total_cand; // approximate
if (cle >= k_rem) {
if (pivot < val_hi) { val_hi = pivot; cnt_at_hi = cle; }
} else {
if (pivot > val_lo) { val_lo = pivot; cnt_at_lo = cle; }
}
double split_ratio = (double)min(cle, total_cand - cle) / total_cand;
// Try value-based correction if split is poor
if (split_ratio < 0.20 && val_hi > val_lo) {
long long remaining_budget = 49500 - query_count;
// Do up to 3 binary search steps
int max_corrections = min(3, (int)((remaining_budget - 1000) / (2 * na)));
for (int c = 0; c < max_corrections && split_ratio < 0.20; c++) {
// Interpolate: want count = k_rem
// cnt_at_lo < k_rem, cnt_at_hi >= k_rem
// linear interpolation in value space
long long corr_val;
if (cnt_at_hi > cnt_at_lo) {
double frac = (double)(k_rem - cnt_at_lo) / (cnt_at_hi - cnt_at_lo);
frac = max(0.1, min(0.9, frac)); // avoid extreme extrapolation
corr_val = val_lo + (long long)(frac * (val_hi - val_lo));
} else {
corr_val = val_lo + (val_hi - val_lo) / 2;
}
if (corr_val <= val_lo) corr_val = val_lo + 1;
if (corr_val >= val_hi) corr_val = val_hi - 1;
if (corr_val <= val_lo || corr_val >= val_hi) break;
auto [cle2, p_le2] = do_walk(corr_val);
if (cle2 >= k_rem) {
val_hi = corr_val;
cnt_at_hi = cle2;
} else {
val_lo = corr_val;
cnt_at_lo = cle2;
}
double split2 = (double)min(cle2, total_cand - cle2) / total_cand;
if (split2 > split_ratio) {
cle = cle2; p_le = p_le2; split_ratio = split2;
}
}
}
if (cle >= k_rem) {
for (int i : active) R[i] = min(R[i], p_le[i]);
} else {
k_rem -= cle;
for (int i : active) L[i] = max(L[i], p_le[i] + 1);
}
}
return -1;
}
};
int main() {
struct TestDef { string name; function<TestCase()> gen; };
vector<TestDef> tests;
tests.push_back({"additive n=2000 k=2000000", []{ return gen_additive(2000, 2000000); }});
tests.push_back({"mult n=2000 k=2000000", []{ return gen_multiplicative(2000, 2000000); }});
tests.push_back({"shifted n=2000 k=2000000", []{ return gen_shifted(2000, 2000000); }});
tests.push_back({"random n=2000 k=2000000", []{ return gen_random_sorted(2000, 2000000); }});
tests.push_back({"mult n=2000 k=100000", []{ return gen_multiplicative(2000, 100000); }});
tests.push_back({"mult n=2000 k=3900000", []{ return gen_multiplicative(2000, 3900000); }});
for (auto& t : tests) {
auto tc = t.gen();
Solver s(tc);
long long result = s.solve();
bool correct = (result == tc.answer);
int used = s.query_count;
double score = !correct ? 0.0 : (used <= tc.n ? 1.0 : (used >= 50000 ? 0.0 : (50000.0 - used) / (50000.0 - tc.n)));
printf("%-45s q=%6d %s score=%.4f\n", t.name.c_str(), used, correct ? "OK" : "WRONG", score);
}
}