JustinTX's picture
Add files using upload-large-folder tool
1fd0050 verified
#include <bits/stdc++.h>
using namespace std;
struct TestCase {
int n; long long k;
vector<vector<long long>> A; long long answer;
};
mt19937_64 rng_gen(42);
TestCase gen_matrix(int n, long long k, function<long long(int,int)> valfn) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { tc.A[i][j] = valfn(i, j); all.push_back(tc.A[i][j]); }
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
TestCase gen_multiplicative(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return (long long)i * j; }); }
TestCase gen_shifted(int n, long long k) { return gen_matrix(n, k, [n](int i, int j) -> long long { return (long long)(i + n) * (j + n); }); }
TestCase gen_additive(int n, long long k) { return gen_matrix(n, k, [](int i, int j) -> long long { return i + j; }); }
TestCase gen_random_sorted(int n, long long k) {
TestCase tc; tc.n = n; tc.k = k;
tc.A.assign(n+1, vector<long long>(n+1, 0));
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) tc.A[i][j] = (long long)i * 1000000 + (long long)j * 1000 + (rng_gen() % 500);
for (int i = 1; i <= n; i++) for (int j = 2; j <= n; j++) tc.A[i][j] = max(tc.A[i][j], tc.A[i][j-1]);
for (int j = 1; j <= n; j++) for (int i = 2; i <= n; i++) tc.A[i][j] = max(tc.A[i][j], tc.A[i-1][j]);
vector<long long> all;
for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) all.push_back(tc.A[i][j]);
sort(all.begin(), all.end()); tc.answer = all[k-1]; return tc;
}
struct Solver {
const TestCase& tc;
int query_count;
vector<long long> memo;
int n;
Solver(const TestCase& t) : tc(t), query_count(0), n(t.n) { memo.assign(2002 * 2002, -1); }
long long do_query(int r, int c) {
int key = r * 2001 + c;
if (memo[key] != -1) return memo[key];
query_count++;
memo[key] = tc.A[r][c];
return memo[key];
}
// Staircase walk counting elements <= mid, respecting jLo/jHi bounds
// jLo[i] = number of elements in row i KNOWN to be <= some lower bound (0-based count from left)
// jHi[i] = number of elements in row i KNOWN to be <= some upper bound
// Walk from top, j starts at jHi[1], goes down
pair<long long, vector<int>> countLeq(long long mid, const vector<int>& jLo, const vector<int>& jHi) {
vector<int> cutoff(n + 1, 0);
long long cnt = 0;
int j = min(n, jHi[1]);
for (int i = 1; i <= n; i++) {
int lo = jLo[i];
int hi = min(n, jHi[i]);
if (hi <= lo) {
cutoff[i] = lo;
cnt += lo;
continue;
}
if (j > hi) j = hi;
while (j > lo) {
long long v = do_query(i, j);
if (v <= mid) {
cutoff[i] = j;
cnt += j;
goto next;
}
j--;
}
cutoff[i] = lo;
cnt += lo;
next:;
}
return {cnt, cutoff};
}
long long solve() {
long long k = tc.k;
long long NLL = (long long)n * n;
if (n == 1) return do_query(1, 1);
if (k == 1) return do_query(1, 1);
if (k == NLL) return do_query(n, n);
// Phase 1: Sample random values
int countsBudget = min(30, max(1, 45000 / max(1, 2 * n)));
int reserved = 100;
long long left = 50000 - (long long)countsBudget * 2 * n - reserved;
if (left < 0) left = 0;
int E = (int)min(5000LL, max(400LL, left / 3));
long long SBudget = left - E;
if (SBudget < 0) SBudget = 0;
long long Ssize = min(SBudget, min(6000LL, (long long)n * n));
if (Ssize < 0) Ssize = 0;
vector<long long> sampleVals;
{
mt19937_64 rng(1469598103934665603ULL ^ (uint64_t)n * 1181783497276652981ULL ^ ((uint64_t)k << 1));
set<long long> seen;
// Grid sampling
int g = (int)floor(sqrt((double)Ssize));
if (g > 0) {
for (int ri = 1; ri <= g && (long long)sampleVals.size() < Ssize; ri++) {
int r = max(1, min(n, (int)((ri * (long long)n) / (g + 1))));
for (int ci = 1; ci <= g && (long long)sampleVals.size() < Ssize; ci++) {
int c = max(1, min(n, (int)((ci * (long long)n) / (g + 1))));
long long key = (long long)r * 10000 + c;
if (seen.insert(key).second)
sampleVals.push_back(do_query(r, c));
}
}
}
// Random fill
while ((long long)sampleVals.size() < Ssize) {
int r = 1 + rng() % n;
int c = 1 + rng() % n;
long long key = (long long)r * 10000 + c;
if (seen.insert(key).second)
sampleVals.push_back(do_query(r, c));
}
sort(sampleVals.begin(), sampleVals.end());
sampleVals.erase(unique(sampleVals.begin(), sampleVals.end()), sampleVals.end());
}
// Phase 2: Binary search over sample values
vector<int> jLo(n + 1, 0), jHi(n + 1, n);
long long cLo = 0, cHi = NLL;
int li = -1, hiIndex = (int)sampleVals.size();
int usedCounts = 0;
while (usedCounts < countsBudget && hiIndex - li > 1 && cHi - cLo > E) {
int midIndex = li + (hiIndex - li) / 2;
long long pivot = sampleVals[midIndex];
auto [cnt, cutoff] = countLeq(pivot, jLo, jHi);
usedCounts++;
if (cnt >= k) {
hiIndex = midIndex;
jHi = cutoff;
cHi = cnt;
} else {
li = midIndex;
jLo = cutoff;
cLo = cnt;
}
}
// Phase 3: Numeric binary search if still too wide
long long loVal = (li >= 0 ? sampleVals[li] : (sampleVals.empty() ? do_query(1,1) : sampleVals.front()));
long long hiVal = (hiIndex < (int)sampleVals.size() ? sampleVals[hiIndex] : (sampleVals.empty() ? do_query(n,n) : sampleVals.back()));
while (usedCounts < countsBudget && cHi - cLo > E) {
if (loVal >= hiVal) break;
long long mid = loVal + (hiVal - loVal) / 2;
if (mid == loVal) mid++;
if (mid >= hiVal) break;
auto [cnt, cutoff] = countLeq(mid, jLo, jHi);
usedCounts++;
if (cnt >= k) {
jHi = cutoff;
cHi = cnt;
hiVal = mid;
} else {
jLo = cutoff;
cLo = cnt;
loVal = mid;
}
}
// Phase 4: Enumerate remaining candidates
long long W = cHi - cLo;
vector<long long> cand;
cand.reserve((size_t)W);
for (int i = 1; i <= n; i++) {
for (int j = jLo[i] + 1; j <= jHi[i]; j++) {
cand.push_back(do_query(i, j));
}
}
long long rank = k - cLo;
if (rank <= 0 || cand.empty()) return loVal;
if (rank > (long long)cand.size()) return hiVal;
nth_element(cand.begin(), cand.begin() + (rank - 1), cand.end());
return cand[rank - 1];
}
};
int main() {
struct TestDef { string name; function<TestCase()> gen; };
vector<TestDef> tests;
tests.push_back({"additive n=100 k=5000", []{ return gen_additive(100, 5000); }});
tests.push_back({"mult n=100 k=5000", []{ return gen_multiplicative(100, 5000); }});
tests.push_back({"additive n=500 k=125000", []{ return gen_additive(500, 125000); }});
tests.push_back({"mult n=500 k=125000", []{ return gen_multiplicative(500, 125000); }});
tests.push_back({"random n=500 k=125000", []{ return gen_random_sorted(500, 125000); }});
tests.push_back({"additive n=2000 k=2000000", []{ return gen_additive(2000, 2000000); }});
tests.push_back({"mult n=2000 k=2000000", []{ return gen_multiplicative(2000, 2000000); }});
tests.push_back({"shifted n=2000 k=2000000", []{ return gen_shifted(2000, 2000000); }});
tests.push_back({"random n=2000 k=2000000", []{ return gen_random_sorted(2000, 2000000); }});
tests.push_back({"mult n=2000 k=100000", []{ return gen_multiplicative(2000, 100000); }});
tests.push_back({"mult n=2000 k=3900000", []{ return gen_multiplicative(2000, 3900000); }});
for (auto& t : tests) {
auto tc = t.gen();
Solver s(tc);
long long result = s.solve();
bool correct = (result == tc.answer);
int used = s.query_count;
double score = !correct ? 0.0 : (used <= tc.n ? 1.0 : (used >= 50000 ? 0.0 : (50000.0 - used) / (50000.0 - tc.n)));
printf("%-45s q=%6d %s score=%.4f\n", t.name.c_str(), used, correct ? "OK" : "WRONG", score);
}
}