JustinTX's picture
Add files using upload-large-folder tool
1fd0050 verified
#include <bits/stdc++.h>
using namespace std;
struct TestCase {
int n; long long k;
vector<vector<long long>> A; long long answer;
};
mt19937_64 rng_gen(42);
TestCase gen_matrix(int n, long long k, function<long long(int,int)> valfn) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { tc.A[i][j] = valfn(i, j); all.push_back(tc.A[i][j]); }
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
TestCase gen_multiplicative(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return (long long)i * j; }); }
TestCase gen_shifted(int n, long long k) { return gen_matrix(n, k, [n](int i, int j) -> long long { return (long long)(i + n) * (j + n); }); }
TestCase gen_additive(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return i + j; }); }
TestCase gen_random_sorted(int n, long long k) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) tc.A[i][j] = (long long)i * 1000000 + (long long)j * 1000 + (rng_gen() % 500);
for (int i = 1; i <= n; i++) for (int j = 2; j <= n; j++) tc.A[i][j] = max(tc.A[i][j], tc.A[i][j-1]);
for (int j = 1; j <= n; j++) for (int i = 2; i <= n; i++) tc.A[i][j] = max(tc.A[i][j], tc.A[i-1][j]);
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) all.push_back(tc.A[i][j]);
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
struct Solver {
const TestCase& tc;
int query_count;
vector<long long> memo;
int n;
Solver(const TestCase& t) : tc(t), query_count(0), n(t.n) { memo.assign(2002 * 2002, -1); }
long long do_query(int r, int c) {
int key = r * 2001 + c;
if (memo[key] != -1) return memo[key];
query_count++;
memo[key] = tc.A[r][c];
return memo[key];
}
long long solve() {
long long k = tc.k;
long long N2 = (long long)n * n;
if (n == 1) return do_query(1, 1);
long long heap_k = min(k, N2 - k + 1);
if (heap_k + n <= 24000) {
if (k <= N2 - k + 1) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
vector<vector<bool>> vis(n + 1, vector<bool>(n + 1, false));
pq.emplace(do_query(1, 1), 1, 1); vis[1][1] = true;
long long result = -1;
for (long long i = 0; i < k; i++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (r+1<=n && !vis[r+1][c]) { vis[r+1][c]=true; pq.emplace(do_query(r+1,c),r+1,c); }
if (c+1<=n && !vis[r][c+1]) { vis[r][c+1]=true; pq.emplace(do_query(r,c+1),r,c+1); }
}
return result;
} else {
long long kk = N2 - k + 1;
priority_queue<tuple<long long, int, int>> pq;
vector<vector<bool>> vis(n + 1, vector<bool>(n + 1, false));
pq.emplace(do_query(n,n),n,n); vis[n][n]=true;
long long result = -1;
for (long long i = 0; i < kk; i++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (r-1>=1 && !vis[r-1][c]) { vis[r-1][c]=true; pq.emplace(do_query(r-1,c),r-1,c); }
if (c-1>=1 && !vis[r][c-1]) { vis[r][c-1]=true; pq.emplace(do_query(r,c-1),r,c-1); }
}
return result;
}
}
// Strategy: query a strategic row entirely, use its sorted values as pivot candidates,
// then binary search within these values using staircase walks.
// Pick the row at position ceil(k/n) - this row's values span the likely answer range.
// Actually, for the first pass, query the row at index ceil(k/n).
// Its values range from a[r][1] to a[r][n], and the k-th element should be
// somewhere in or near this range.
vector<int> jLo(n + 1, 0), jHi(n + 1, n);
long long cLo = 0, cHi = N2;
// Query a strategic row
int pivot_row = max(1, min(n, (int)((k + n - 1) / n)));
// Query the entire row
vector<long long> row_vals(n + 1);
for (int j = 1; j <= n; j++) {
row_vals[j] = do_query(pivot_row, j);
}
// Row values are sorted (by matrix property)
// Binary search within this row's values for the right pivot
int lo_idx = 1, hi_idx = n;
while (lo_idx <= hi_idx && cHi - cLo > 0) {
int mid_idx = (lo_idx + hi_idx) / 2;
long long pivot = row_vals[mid_idx];
// Staircase walk to count <= pivot
vector<int> cutoff(n + 1, 0);
long long cnt = 0;
int j = jHi[1];
for (int i = 1; i <= n; i++) {
int lo_j = jLo[i];
int hi_j = jHi[i];
if (hi_j <= lo_j) { cutoff[i] = lo_j; cnt += lo_j; continue; }
if (j > hi_j) j = hi_j;
while (j > lo_j && do_query(i, j) > pivot) j--;
if (j > lo_j) { cutoff[i] = j; cnt += j; }
else { cutoff[i] = lo_j; cnt += lo_j; }
}
if (cnt >= k) {
jHi = cutoff;
cHi = cnt;
hi_idx = mid_idx - 1;
} else {
jLo = cutoff;
cLo = cnt;
lo_idx = mid_idx + 1;
}
// Check if we can enumerate
long long W = cHi - cLo;
long long budget = 49500 - query_count;
if (W <= budget) break;
}
// If still too large, query another row and refine further
// Find a row with the most remaining candidates
while (true) {
long long W = cHi - cLo;
long long budget = 49500 - query_count;
if (W <= budget) break;
if (budget < 2 * n + 100) break; // can't afford more walks
// Pick a new pivot row: the row with maximum width in current band
int best_row = -1, best_width = 0;
for (int i = 1; i <= n; i++) {
int w = jHi[i] - jLo[i];
if (w > best_width) { best_width = w; best_row = i; }
}
if (best_row == -1 || best_width == 0) break;
// Query the entire active segment of this row
for (int j = jLo[best_row] + 1; j <= jHi[best_row]; j++)
do_query(best_row, j);
// Binary search within this row's active segment
int lo_j = jLo[best_row] + 1, hi_j = jHi[best_row];
// Find the value at the right quantile
long long need_rank = k - cLo; // rank within current band
double frac = (double)need_rank / W;
int target_col = lo_j + (int)(frac * (hi_j - lo_j));
target_col = max(lo_j, min(hi_j, target_col));
long long pivot = do_query(best_row, target_col);
// Staircase walk
vector<int> cutoff(n + 1, 0);
long long cnt = 0;
int j = jHi[1];
for (int i = 1; i <= n; i++) {
int lo_jj = jLo[i], hi_jj = jHi[i];
if (hi_jj <= lo_jj) { cutoff[i] = lo_jj; cnt += lo_jj; continue; }
if (j > hi_jj) j = hi_jj;
while (j > lo_jj && do_query(i, j) > pivot) j--;
if (j > lo_jj) { cutoff[i] = j; cnt += j; }
else { cutoff[i] = lo_jj; cnt += lo_jj; }
}
if (cnt >= k) { jHi = cutoff; cHi = cnt; }
else { jLo = cutoff; cLo = cnt; }
}
// Enumerate remaining
long long W = cHi - cLo;
long long rank = k - cLo;
if (W <= 0) return do_query(1, 1); // shouldn't happen
vector<long long> cand;
cand.reserve((size_t)W);
for (int i = 1; i <= n; i++) {
for (int j = jLo[i] + 1; j <= jHi[i]; j++)
cand.push_back(do_query(i, j));
}
if (rank <= 0 || cand.empty()) return do_query(1, 1);
if (rank > (long long)cand.size()) return do_query(n, n);
nth_element(cand.begin(), cand.begin() + (rank - 1), cand.end());
return cand[rank - 1];
}
};
int main() {
struct TestDef { string name; function<TestCase()> gen; };
vector<TestDef> tests;
tests.push_back({"additive n=100 k=5000", []{ return gen_additive(100, 5000); }});
tests.push_back({"mult n=100 k=5000", []{ return gen_multiplicative(100, 5000); }});
tests.push_back({"additive n=500 k=125000", []{ return gen_additive(500, 125000); }});
tests.push_back({"mult n=500 k=125000", []{ return gen_multiplicative(500, 125000); }});
tests.push_back({"random n=500 k=125000", []{ return gen_random_sorted(500, 125000); }});
tests.push_back({"additive n=2000 k=2000000", []{ return gen_additive(2000, 2000000); }});
tests.push_back({"mult n=2000 k=2000000", []{ return gen_multiplicative(2000, 2000000); }});
tests.push_back({"shifted n=2000 k=2000000", []{ return gen_shifted(2000, 2000000); }});
tests.push_back({"random n=2000 k=2000000", []{ return gen_random_sorted(2000, 2000000); }});
tests.push_back({"mult n=2000 k=100000", []{ return gen_multiplicative(2000, 100000); }});
tests.push_back({"mult n=2000 k=3900000", []{ return gen_multiplicative(2000, 3900000); }});
for (auto& t : tests) {
auto tc = t.gen();
Solver s(tc);
long long result = s.solve();
bool correct = (result == tc.answer);
int used = s.query_count;
double score = !correct ? 0.0 : (used <= tc.n ? 1.0 : (used >= 50000 ? 0.0 : (50000.0 - used) / (50000.0 - tc.n)));
printf("%-45s q=%6d %s score=%.4f\n", t.name.c_str(), used, correct ? "OK" : "WRONG", score);
}
}