JustinTX's picture
Add files using upload-large-folder tool
1fd0050 verified
#include <bits/stdc++.h>
using namespace std;
struct TestCase {
int n; long long k;
vector<vector<long long>> A; long long answer;
};
mt19937_64 rng_gen(42);
TestCase gen_matrix(int n, long long k, function<long long(int,int)> valfn) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { tc.A[i][j] = valfn(i, j); all.push_back(tc.A[i][j]); }
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
TestCase gen_multiplicative(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return (long long)i * j; }); }
TestCase gen_shifted(int n, long long k) { return gen_matrix(n, k, [n](int i, int j) -> long long { return (long long)(i + n) * (j + n); }); }
TestCase gen_additive(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return i + j; }); }
TestCase gen_random_sorted(int n, long long k) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) tc.A[i][j] = (long long)i * 1000000 + (long long)j * 1000 + (rng_gen() % 500);
for (int i = 1; i <= n; i++) for (int j = 2; j <= n; j++) tc.A[i][j] = max(tc.A[i][j], tc.A[i][j-1]);
for (int j = 1; j <= n; j++) for (int i = 2; i <= n; i++) tc.A[i][j] = max(tc.A[i][j], tc.A[i-1][j]);
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) all.push_back(tc.A[i][j]);
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
struct Solver {
const TestCase& tc;
int query_count;
vector<long long> memo;
int n;
int walk_count;
Solver(const TestCase& t) : tc(t), query_count(0), n(t.n), walk_count(0) { memo.assign(2002 * 2002, -1); }
long long do_query(int r, int c) {
int key = r * 2001 + c;
if (memo[key] != -1) return memo[key];
query_count++;
memo[key] = tc.A[r][c];
return memo[key];
}
pair<long long, vector<int>> countLeq(long long mid, const vector<int>& jLo, const vector<int>& jHi) {
walk_count++;
vector<int> cutoff(n + 1, 0);
long long cnt = 0;
int j = min(n, jHi[1]);
for (int i = 1; i <= n; i++) {
int lo = jLo[i], hi = min(n, jHi[i]);
if (hi <= lo) { cutoff[i] = lo; cnt += lo; continue; }
if (j > hi) j = hi;
while (j > lo && do_query(i, j) > mid) j--;
if (j > lo) { cutoff[i] = j; cnt += j; }
else { cutoff[i] = lo; cnt += lo; }
}
return {cnt, cutoff};
}
long long solve() {
long long k = tc.k;
long long NLL = (long long)n * n;
if (n == 1) return do_query(1, 1);
if (k == 1) return do_query(1, 1);
if (k == NLL) return do_query(n, n);
long long heap_k = min(k, NLL - k + 1);
if (heap_k + n <= 24000) {
if (k <= NLL - k + 1) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) pq.emplace(do_query(i, 1), i, 1);
long long result = -1;
for (long long t = 0; t < k; t++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (c + 1 <= n) pq.emplace(do_query(r, c + 1), r, c + 1);
}
return result;
} else {
long long kk = NLL - k + 1;
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) pq.emplace(do_query(i, n), i, n);
long long result = -1;
for (long long t = 0; t < kk; t++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (c - 1 >= 1) pq.emplace(do_query(r, c - 1), r, c - 1);
}
return result;
}
}
// Pure value binary search with bounded staircase walks
vector<int> jLo(n + 1, 0), jHi(n + 1, n);
long long cLo = 0, cHi = NLL;
long long loVal = do_query(1, 1) - 1;
long long hiVal = do_query(n, n);
// Get tighter initial upper bound
int rBound = max(1, min(n, (int)((k + n - 1) / n)));
long long initHi = do_query(rBound, n);
auto [ch, cutH] = countLeq(initHi, jLo, jHi);
if (ch >= k) {
jHi = cutH; cHi = ch; hiVal = initHi;
}
// Binary search on value
while (cHi - cLo > 0) {
long long budget = 49500 - query_count;
long long W = cHi - cLo;
long long needSmall = k - cLo;
long long needLarge = cHi - k + 1;
// Count non-empty segments
int nonempty = 0;
for (int i = 1; i <= n; i++) if (jHi[i] > jLo[i]) nonempty++;
// Can we enumerate?
if (min(needSmall + nonempty, min(needLarge + nonempty, W)) + 10 <= budget) {
if (W <= needSmall + nonempty && W <= needLarge + nonempty) {
// Enumerate all
vector<long long> cand;
cand.reserve((size_t)W);
for (int i = 1; i <= n; i++)
for (int j = jLo[i] + 1; j <= jHi[i]; j++)
cand.push_back(do_query(i, j));
long long rank = k - cLo;
nth_element(cand.begin(), cand.begin() + (rank - 1), cand.end());
return cand[rank - 1];
} else if (needSmall <= needLarge) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) {
int L = jLo[i] + 1, R = jHi[i];
if (L >= 1 && L <= n && L <= R) pq.emplace(do_query(i, L), i, L);
}
long long result = 0;
for (long long t = 0; t < needSmall; t++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (c + 1 <= jHi[r]) pq.emplace(do_query(r, c + 1), r, c + 1);
}
return result;
} else {
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) {
int L = jLo[i] + 1, R = jHi[i];
if (R >= 1 && R <= n && L <= R) pq.emplace(do_query(i, R), i, R);
}
long long result = 0;
for (long long t = 0; t < needLarge; t++) {
auto [v, r, c] = pq.top(); pq.pop(); result = v;
if (c - 1 >= jLo[r] + 1) pq.emplace(do_query(r, c - 1), r, c - 1);
}
return result;
}
}
if (budget < 2 * n + 200) break; // can't afford more
if (loVal >= hiVal) break;
long long midVal = loVal + (hiVal - loVal) / 2;
if (midVal <= loVal) midVal = loVal + 1;
if (midVal >= hiVal) break;
auto [cnt, cut] = countLeq(midVal, jLo, jHi);
if (cnt >= k) {
jHi = cut; cHi = cnt; hiVal = midVal;
} else {
jLo = cut; cLo = cnt; loVal = midVal;
}
}
return hiVal;
}
};
int main() {
struct TestDef { string name; function<TestCase()> gen; };
vector<TestDef> tests;
tests.push_back({"additive n=2000 k=2000000", []{ return gen_additive(2000, 2000000); }});
tests.push_back({"mult n=2000 k=2000000", []{ return gen_multiplicative(2000, 2000000); }});
tests.push_back({"shifted n=2000 k=2000000", []{ return gen_shifted(2000, 2000000); }});
tests.push_back({"random n=2000 k=2000000", []{ return gen_random_sorted(2000, 2000000); }});
tests.push_back({"mult n=2000 k=100000", []{ return gen_multiplicative(2000, 100000); }});
tests.push_back({"mult n=2000 k=3900000", []{ return gen_multiplicative(2000, 3900000); }});
// Test type 3, 4, 5 from interactor
tests.push_back({"type4 n=2000 k=2000000", []{ return gen_matrix(2000, 2000000, [](int i, int j)->long long { return i + 2LL*j; }); }});
tests.push_back({"type5 n=2000 k=2000000", []{ return gen_matrix(2000, 2000000, [](int i, int j)->long long { return 2LL*i + j; }); }});
for (auto& t : tests) {
auto tc = t.gen();
Solver s(tc);
long long result = s.solve();
bool correct = (result == tc.answer);
int used = s.query_count;
double score = !correct ? 0.0 : (used <= tc.n ? 1.0 : (used >= 50000 ? 0.0 : (50000.0 - used) / (50000.0 - tc.n)));
printf("%-45s q=%6d walks=%2d %s score=%.4f\n", t.name.c_str(), used, s.walk_count, correct ? "OK" : "WRONG", score);
}
}