JustinTX's picture
Add files using upload-large-folder tool
1fd0050 verified
#include <bits/stdc++.h>
using namespace std;
struct TestCase {
int n; long long k;
vector<vector<long long>> A; long long answer;
};
mt19937_64 rng_gen(42);
TestCase gen_matrix(int n, long long k, function<long long(int,int)> valfn) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { tc.A[i][j] = valfn(i, j); all.push_back(tc.A[i][j]); }
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
TestCase gen_multiplicative(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return (long long)i * j; }); }
TestCase gen_shifted(int n, long long k) { return gen_matrix(n, k, [n](int i, int j) -> long long { return (long long)(i + n) * (j + n); }); }
TestCase gen_additive(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return i + j; }); }
TestCase gen_random_sorted(int n, long long k) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) tc.A[i][j] = (long long)i * 1000000 + (long long)j * 1000 + (rng_gen() % 500);
for (int i = 1; i <= n; i++) for (int j = 2; j <= n; j++) tc.A[i][j] = max(tc.A[i][j], tc.A[i][j-1]);
for (int j = 1; j <= n; j++) for (int i = 2; i <= n; i++) tc.A[i][j] = max(tc.A[i][j], tc.A[i-1][j]);
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) all.push_back(tc.A[i][j]);
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
// Simulating gpt5.2_2 approach exactly
struct Solver {
const TestCase& tc;
int query_count;
vector<long long> memo;
int n;
Solver(const TestCase& t) : tc(t), query_count(0), n(t.n) { memo.assign(2002 * 2002, -1); }
long long get(int x, int y) {
int key = x * 2001 + y;
if (memo[key] != -1) return memo[key];
query_count++;
memo[key] = tc.A[x][y];
return memo[key];
}
pair<long long, vector<int>> count_leq(long long x) {
vector<int> pos(n + 1, 0);
long long cnt = 0;
int j = n;
for (int i = 1; i <= n; i++) {
while (j >= 1) {
if (get(i, j) <= x) break;
j--;
}
pos[i] = j;
cnt += j;
if (j == 0) break;
}
return {cnt, pos};
}
pair<long long, vector<int>> count_lt(long long x) {
vector<int> pos(n + 1, 0);
long long cnt = 0;
int j = n;
for (int i = 1; i <= n; i++) {
while (j >= 1) {
if (get(i, j) < x) break;
j--;
}
pos[i] = j;
cnt += j;
if (j == 0) break;
}
return {cnt, pos};
}
long long solve() {
long long k = tc.k;
long long N2 = (long long)n * n;
if (n == 1) return get(1, 1);
if (k <= 1) return get(1, 1);
if (k >= N2) return get(n, n);
// Heap for small k
long long tSmall = min(k, N2 - k + 1);
if (tSmall + n <= 49000) {
if (k <= N2 - k + 1) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) pq.emplace(get(i, 1), i, 1);
long long ans = 0;
for (long long s = 1; s <= k; s++) {
auto [v, r, c] = pq.top(); pq.pop();
ans = v;
if (s == k) break;
if (c < n) pq.emplace(get(r, c + 1), r, c + 1);
}
return ans;
} else {
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) pq.emplace(get(i, n), i, n);
long long ans = 0;
long long t = N2 - k + 1;
for (long long s = 1; s <= t; s++) {
auto [v, r, c] = pq.top(); pq.pop();
ans = v;
if (s == t) break;
if (c > 1) pq.emplace(get(r, c - 1), r, c - 1);
}
return ans;
}
}
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
// Initial upper bound
int rBound = max(1, min(n, (int)((k + n - 1) / n)));
long long highTh = get(rBound, n);
auto [cntHigh, posHigh] = count_leq(highTh);
if (cntHigh < k) {
highTh = get(n, n);
tie(cntHigh, posHigh) = count_leq(highTh);
}
vector<int> posLow(n + 1, 0);
long long cntLow = 0;
long long lowTh = LLONG_MIN;
const int SAMPLES = 9;
while (true) {
if (cntHigh < k) {
highTh = get(n, n);
tie(cntHigh, posHigh) = count_leq(highTh);
if (cntHigh < k) break;
}
long long candTotal = cntHigh - cntLow;
if (candTotal <= 0) return highTh;
long long needSmall = k - cntLow;
long long needLarge = cntHigh - k + 1;
vector<long long> pref(n + 1, 0);
int nonempty = 0;
for (int i = 1; i <= n; i++) {
int len = posHigh[i] - posLow[i];
if (len > 0) nonempty++;
pref[i] = pref[i - 1] + max(0, len);
}
candTotal = pref[n];
if (candTotal <= 0) return highTh;
long long remaining = 49500 - query_count;
long long enumCost = min(needSmall, needLarge) + nonempty + 10;
if (enumCost <= remaining) {
// Enumerate
if (needSmall <= needLarge) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) {
int L = posLow[i] + 1, R = posHigh[i];
if (L >= 1 && L <= n && L <= R) pq.emplace(get(i, L), i, L);
}
long long ans = 0;
for (long long t = 1; t <= needSmall; t++) {
auto [v, r, c] = pq.top(); pq.pop();
ans = v;
if (t == needSmall) break;
if (c < posHigh[r]) pq.emplace(get(r, c + 1), r, c + 1);
}
return ans;
} else {
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) {
int L = posLow[i] + 1, R = posHigh[i];
if (R >= 1 && R <= n && L <= R) pq.emplace(get(i, R), i, R);
}
long long ans = 0;
for (long long t = 1; t <= needLarge; t++) {
auto [v, r, c] = pq.top(); pq.pop();
ans = v;
if (t == needLarge) break;
if (c > posLow[r] + 1) pq.emplace(get(r, c - 1), r, c - 1);
}
return ans;
}
}
if (remaining < 2 * n + SAMPLES + 200) {
if (nonempty + min(needSmall, needLarge) + 10 <= remaining) {
// same enum
if (needSmall <= needLarge) {
priority_queue<tuple<long long, int, int>, vector<tuple<long long, int, int>>, greater<>> pq;
for (int i = 1; i <= n; i++) {
int L = posLow[i] + 1, R = posHigh[i];
if (L >= 1 && L <= n && L <= R) pq.emplace(get(i, L), i, L);
}
long long ans = 0;
for (long long t = 1; t <= needSmall; t++) {
auto [v, r, c] = pq.top(); pq.pop(); ans = v;
if (t == needSmall) break;
if (c < posHigh[r]) pq.emplace(get(r, c + 1), r, c + 1);
}
return ans;
} else {
priority_queue<tuple<long long, int, int>> pq;
for (int i = 1; i <= n; i++) {
int L = posLow[i] + 1, R = posHigh[i];
if (R >= 1 && R <= n && L <= R) pq.emplace(get(i, R), i, R);
}
long long ans = 0;
for (long long t = 1; t <= needLarge; t++) {
auto [v, r, c] = pq.top(); pq.pop(); ans = v;
if (t == needLarge) break;
if (c > posLow[r] + 1) pq.emplace(get(r, c - 1), r, c - 1);
}
return ans;
}
}
return highTh;
}
// Sample pivots
vector<long long> samp;
for (int s = 0; s < SAMPLES; s++) {
long long pick = 1 + rng() % candTotal;
int row = (int)(lower_bound(pref.begin() + 1, pref.end(), pick) - pref.begin());
if (row < 1 || row > n) continue;
int len = posHigh[row] - posLow[row];
if (len <= 0) continue;
int col = posLow[row] + 1 + rng() % len;
col = max(1, min(n, col));
samp.push_back(get(row, col));
}
if (samp.empty()) return highTh;
sort(samp.begin(), samp.end());
double q = (double)needSmall / (double)candTotal;
int idx = (int)(q * (double)(samp.size() - 1));
idx = max(0, min((int)samp.size() - 1, idx));
long long pivot = samp[idx];
auto [cntPivot, posPivot] = count_leq(pivot);
if (cntPivot >= k) {
if (pivot == highTh && cntPivot == cntHigh) {
auto [cntLess, posLess] = count_lt(pivot);
if (cntLess < k) return pivot;
highTh = pivot - 1;
posHigh = posLess;
cntHigh = cntLess;
} else {
highTh = pivot;
posHigh = posPivot;
cntHigh = cntPivot;
}
} else {
lowTh = pivot;
posLow = posPivot;
cntLow = cntPivot;
}
}
return 0;
}
};
int main() {
struct TestDef { string name; function<TestCase()> gen; };
vector<TestDef> tests;
tests.push_back({"additive n=100 k=5000", []{ return gen_additive(100, 5000); }});
tests.push_back({"mult n=100 k=5000", []{ return gen_multiplicative(100, 5000); }});
tests.push_back({"additive n=500 k=125000", []{ return gen_additive(500, 125000); }});
tests.push_back({"mult n=500 k=125000", []{ return gen_multiplicative(500, 125000); }});
tests.push_back({"random n=500 k=125000", []{ return gen_random_sorted(500, 125000); }});
tests.push_back({"additive n=2000 k=2000000", []{ return gen_additive(2000, 2000000); }});
tests.push_back({"mult n=2000 k=2000000", []{ return gen_multiplicative(2000, 2000000); }});
tests.push_back({"shifted n=2000 k=2000000", []{ return gen_shifted(2000, 2000000); }});
tests.push_back({"random n=2000 k=2000000", []{ return gen_random_sorted(2000, 2000000); }});
tests.push_back({"mult n=2000 k=100000", []{ return gen_multiplicative(2000, 100000); }});
tests.push_back({"mult n=2000 k=3900000", []{ return gen_multiplicative(2000, 3900000); }});
for (auto& t : tests) {
auto tc = t.gen();
Solver s(tc);
long long result = s.solve();
bool correct = (result == tc.answer);
int used = s.query_count;
double score = !correct ? 0.0 : (used <= tc.n ? 1.0 : (used >= 50000 ? 0.0 : (50000.0 - used) / (50000.0 - tc.n)));
printf("%-45s q=%6d %s score=%.4f\n", t.name.c_str(), used, correct ? "OK" : "WRONG", score);
}
}