JustinTX's picture
Add files using upload-large-folder tool
1fd0050 verified
#include <bits/stdc++.h>
using namespace std;
struct TestCase {
int n;
long long k;
vector<vector<long long>> A;
long long answer;
};
mt19937_64 rng_gen(42);
TestCase gen_matrix(int n, long long k, function<long long(int,int)> valfn) {
TestCase tc;
tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
vector<long long> all;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++) {
tc.A[i][j] = valfn(i, j);
all.push_back(tc.A[i][j]);
}
sort(all.begin(), all.end());
tc.answer = all[k-1];
return tc;
}
TestCase gen_multiplicative(int n, long long k) {
return gen_matrix(n, k, [](int i, int j) -> long long { return (long long)i * j; });
}
TestCase gen_shifted(int n, long long k) {
return gen_matrix(n, k, [n](int i, int j) -> long long { return (long long)(i + n) * (j + n); });
}
TestCase gen_additive(int n, long long k) {
return gen_matrix(n, k, [](int i, int j) -> long long { return i + j; });
}
TestCase gen_random_sorted(int n, long long k) {
TestCase tc;
tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
long long C = 1000000, D = 1000;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
tc.A[i][j] = (long long)i * C + (long long)j * D + (rng_gen() % 500);
for (int i = 1; i <= n; i++)
for (int j = 2; j <= n; j++)
tc.A[i][j] = max(tc.A[i][j], tc.A[i][j-1]);
for (int j = 1; j <= n; j++)
for (int i = 2; i <= n; i++)
tc.A[i][j] = max(tc.A[i][j], tc.A[i-1][j]);
vector<long long> all;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
all.push_back(tc.A[i][j]);
sort(all.begin(), all.end());
tc.answer = all[k-1];
return tc;
}
struct Solver {
const TestCase& tc;
int query_count;
vector<long long> memo;
int n;
Solver(const TestCase& t) : tc(t), query_count(0), n(t.n) {
memo.assign(2002 * 2002, -1);
}
long long do_query(int r, int c) {
int key = r * 2001 + c;
if (memo[key] != -1) return memo[key];
query_count++;
memo[key] = tc.A[r][c];
return memo[key];
}
// count_leq with staircase walk, returns count and boundary positions
long long count_leq(long long x, vector<int>& pos) {
pos.assign(n + 1, 0);
long long cnt = 0;
int j = n;
for (int i = 1; i <= n; i++) {
while (j >= 1 && do_query(i, j) > x) j--;
pos[i] = j;
cnt += j;
if (j == 0) break;
}
return cnt;
}
long long count_lt(long long x, vector<int>& pos) {
pos.assign(n + 1, 0);
long long cnt = 0;
int j = n;
for (int i = 1; i <= n; i++) {
while (j >= 1 && do_query(i, j) >= x) j--;
pos[i] = j;
cnt += j;
if (j == 0) break;
}
return cnt;
}
long long solve() {
long long k = tc.k;
long long N2 = (long long)n * n;
if (n == 1) return do_query(1, 1);
if (k == 1) return do_query(1, 1);
if (k == N2) return do_query(n, n);
long long heap_k = min(k, N2 - k + 1);
if (heap_k + n <= 24000) {
if (k <= N2 - k + 1) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) pq.emplace(do_query(i, 1), i, 1);
long long result = -1;
for (long long t = 0; t < k; t++) {
auto [v, r, c] = pq.top(); pq.pop();
result = v;
if (c + 1 <= n) pq.emplace(do_query(r, c + 1), r, c + 1);
}
return result;
} else {
long long kk = N2 - k + 1;
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) pq.emplace(do_query(i, n), i, n);
long long result = -1;
for (long long t = 0; t < kk; t++) {
auto [v, r, c] = pq.top(); pq.pop();
result = v;
if (c - 1 >= 1) pq.emplace(do_query(r, c - 1), r, c - 1);
}
return result;
}
}
// Strategy: query a strategic row, then use binary search on values from that row
// Row sqrt(k) should contain values near the k-th smallest
// Actually: query row n/2. The median row has n values spanning roughly
// from a[n/2][1] to a[n/2][n]. Binary search within this row to find approximate pivot.
// Better strategy: query the anti-diagonal elements to get initial value range,
// then do value-based binary search with staircase count_leq
// Phase 1: Get bounds.
// Row ceil(k/n) column n is an upper bound for the k-th element.
int rBound = min(n, (int)((k + n - 1) / n));
long long hi = do_query(rBound, n);
long long lo = do_query(1, 1);
// Verify: count_leq(hi) >= k
vector<int> posHi;
long long cntHi = count_leq(hi, posHi);
if (cntHi < k) {
hi = do_query(n, n);
cntHi = count_leq(hi, posHi);
}
vector<int> posLo;
long long cntLo = 0;
// posLo is all zeros initially
posLo.assign(n + 1, 0);
// Phase 2: Narrow via value-based binary search
// Each count_leq costs O(n) but caches queries
// Use the row-merge approach: query a strategic row and binary search within it
// Better: use the "fractional cascading" idea.
// Query row at index ceil(sqrt(n)) and use values from that row as candidates.
// Actually, let's use a simpler approach:
// Binary search: find min value v such that count_leq(v) >= k
// But we can't binary search on long long values directly (range too large)
// Instead, generate candidate values by querying strategic positions
// Strategy: Query entire row at index ~sqrt(k/n) * something
// Actually the simplest approach that works:
// Use the quickselect approach but with the staircase from TOP (i ascending, j descending)
// and use value from the staircase boundary as next pivot
// Let me implement: quickselect with staircase walk, and use boundary values
// as pivot candidates for next iteration
for (int iter = 0; iter < 50; iter++) {
long long candTotal = cntHi - cntLo;
if (candTotal <= 0) return hi;
long long needSmall = k - cntLo;
long long needLarge = cntHi - k + 1;
// Count non-empty segments
int nonempty = 0;
for (int i = 1; i <= n; i++) {
if (posHi[i] > posLo[i]) nonempty++;
}
long long budget = 49500 - query_count;
if (min(needSmall, needLarge) + nonempty + 10 <= budget) {
// Enumerate
if (needSmall <= needLarge) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) {
int L = posLo[i] + 1, R = posHi[i];
if (L >= 1 && L <= n && L <= R) pq.emplace(do_query(i, L), i, L);
}
long long result = 0;
for (long long t = 0; t < needSmall; t++) {
auto [v, r, c] = pq.top(); pq.pop();
result = v;
if (c + 1 <= posHi[r]) pq.emplace(do_query(r, c + 1), r, c + 1);
}
return result;
} else {
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) {
int L = posLo[i] + 1, R = posHi[i];
if (R >= 1 && R <= n && L <= R) pq.emplace(do_query(i, R), i, R);
}
long long result = 0;
for (long long t = 0; t < needLarge; t++) {
auto [v, r, c] = pq.top(); pq.pop();
result = v;
if (c - 1 >= posLo[r] + 1) pq.emplace(do_query(r, c - 1), r, c - 1);
}
return result;
}
}
if (budget < 2 * n + 200) return hi; // fallback
// Choose pivot: query a strategic position
// Use the median row's value at the fractional position
// The "median row" in terms of width-weighted is more complex
// Simple approach: pick the row with the most candidates and query at target_frac
// Build prefix sums for position-uniform sampling
vector<long long> pref(n + 1, 0);
for (int i = 1; i <= n; i++) {
int len = posHi[i] - posLo[i];
pref[i] = pref[i - 1] + max(0, len);
}
candTotal = pref[n];
if (candTotal <= 0) return hi;
// Sample multiple pivot candidates from the boundary values
// After a staircase walk for count_leq(v), the cells a[i][posHi[i]] are all <= hi
// and cells a[i][posHi[i]+1] (if exists) are > hi.
// The boundary values a[i][posHi[i]] and a[i][posLo[i]+1] are in the cache.
// Let's use the midpoint of each segment as pivot candidate.
// Actually, let me try: query the midpoint of each active segment in value terms.
// For row i, the values range from a[i, posLo[i]+1] to a[i, posHi[i]].
// The midpoint (in position) is (posLo[i]+1+posHi[i])/2.
// Sample from several rows, pick weighted quantile.
double target_q = (double)needSmall / (double)candTotal;
// Sample from ALL active rows at the target_q position
vector<pair<long long, long long>> vw; // (value, weight)
for (int i = 1; i <= n; i++) {
int L = posLo[i] + 1, R = posHi[i];
if (L > R || L < 1 || R > n) continue;
int width = R - L + 1;
int col = L + (int)(target_q * max(0, width - 1));
col = max(L, min(R, col));
vw.push_back({do_query(i, col), width});
}
// Weighted quantile
sort(vw.begin(), vw.end());
long long target_w = (long long)(target_q * candTotal);
long long cum = 0;
long long pivot = vw.back().first;
for (auto& [v, w] : vw) {
cum += w;
if (cum >= target_w) {
pivot = v;
break;
}
}
// Staircase walk for pivot
vector<int> posPivot;
long long cntPivot = count_leq(pivot, posPivot);
if (cntPivot >= k) {
if (cntPivot == cntHi && pivot >= hi) {
// Pivot didn't help - try count_lt
vector<int> posLess;
long long cntLess = count_lt(pivot, posLess);
if (cntLess < k) return pivot; // answer is pivot
hi = pivot - 1; // This is a hack for integer values
posHi = posLess;
cntHi = cntLess;
} else {
hi = pivot;
posHi = posPivot;
cntHi = cntPivot;
}
} else {
cntLo = cntPivot;
posLo = posPivot;
lo = pivot;
}
}
return -1;
}
};
int main() {
struct TestDef {
string name;
function<TestCase()> gen;
};
vector<TestDef> tests;
tests.push_back({"additive n=100 k=5000", []{ return gen_additive(100, 5000); }});
tests.push_back({"multiplicative n=100 k=5000", []{ return gen_multiplicative(100, 5000); }});
tests.push_back({"additive n=500 k=125000", []{ return gen_additive(500, 125000); }});
tests.push_back({"multiplicative n=500 k=125000", []{ return gen_multiplicative(500, 125000); }});
tests.push_back({"random n=500 k=125000", []{ return gen_random_sorted(500, 125000); }});
tests.push_back({"additive n=2000 k=2000000", []{ return gen_additive(2000, 2000000); }});
tests.push_back({"multiplicative n=2000 k=2000000", []{ return gen_multiplicative(2000, 2000000); }});
tests.push_back({"shifted n=2000 k=2000000", []{ return gen_shifted(2000, 2000000); }});
tests.push_back({"random n=2000 k=2000000", []{ return gen_random_sorted(2000, 2000000); }});
tests.push_back({"multiplicative n=2000 k=100000", []{ return gen_multiplicative(2000, 100000); }});
tests.push_back({"multiplicative n=2000 k=3900000", []{ return gen_multiplicative(2000, 3900000); }});
for (auto& t : tests) {
auto tc = t.gen();
Solver s(tc);
long long result = s.solve();
bool correct = (result == tc.answer);
double score;
int used = s.query_count;
if (!correct) score = 0.0;
else if (used <= tc.n) score = 1.0;
else if (used >= 50000) score = 0.0;
else score = (50000.0 - used) / (50000.0 - tc.n);
printf("%-45s n=%4d k=%8lld queries=%6d correct=%s score=%.4f\n",
t.name.c_str(), tc.n, tc.k, used, correct ? "YES" : "NO", score);
}
return 0;
}