narcolepticchicken commited on
Commit
6e766eb
Β·
verified Β·
1 Parent(s): af936b6

Upload jobs/occ_debate_extended_baselines.py

Browse files
Files changed (1) hide show
  1. jobs/occ_debate_extended_baselines.py +238 -0
jobs/occ_debate_extended_baselines.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """OCC Debate β€” Iso-Round + Random-Drop Baselines. H200 run. 3 seeds Γ— 6 conditions.
3
+ FIXED: run_random_drop now returns 4 values (consistent with run_occ)."""
4
+
5
+ import json, os, sys, time, random, torch
6
+ from pathlib import Path
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ MODEL = os.environ.get("MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct")
10
+ SEEDS = [int(s) for s in os.environ.get("SEEDS", "42,123,456").split(",")]
11
+ OUT = Path(os.environ.get("OUT_DIR", "/app/results"))
12
+ OUT.mkdir(parents=True, exist_ok=True)
13
+ START = time.time()
14
+
15
+ def log(msg):
16
+ print(f"[+{time.time()-START:5.0f}s] {msg}", flush=True)
17
+
18
+ DEBATE_TOPICS = [
19
+ {"q": "Is Python faster than C for numerical computation?", "truth": "no"},
20
+ {"q": "Does water boil at 100Β°C at all altitudes?", "truth": "no"},
21
+ {"q": "Can quantum computers break RSA-2048 today (2026)?", "truth": "no"},
22
+ {"q": "Is the Earth's core hotter than the surface of the Sun?", "truth": "yes"},
23
+ {"q": "Does the Moon have an atmosphere?", "truth": "no"},
24
+ {"q": "Can sound travel through a vacuum?", "truth": "no"},
25
+ {"q": "Is JavaScript single-threaded by default in browsers?", "truth": "yes"},
26
+ {"q": "Does DNA replication occur in the nucleus of eukaryotic cells?", "truth": "yes"},
27
+ {"q": "Can a protein structure be determined with 100% certainty from X-ray?", "truth": "no"},
28
+ {"q": "Is gradient descent guaranteed to find global min for convex functions?", "truth": "yes"},
29
+ {"q": "Can GPT-4 reliably solve novel math proofs without supervision?", "truth": "no"},
30
+ {"q": "Is P vs NP solved as of 2026?", "truth": "no"},
31
+ {"q": "Do all metals expand when heated?", "truth": "no"},
32
+ {"q": "Is the speed of light constant in all reference frames?", "truth": "yes"},
33
+ {"q": "Can a program determine if an arbitrary program halts?", "truth": "no"},
34
+ {"q": "Is the Earth flat?", "truth": "no"},
35
+ {"q": "Does CO2 make up >1% of Earth's atmosphere?", "truth": "no"},
36
+ {"q": "Can classical computers efficiently simulate quantum?", "truth": "no"},
37
+ {"q": "Is the golden ratio exactly (1+√5)/2?", "truth": "yes"},
38
+ {"q": "Can 1-hidden-layer NN approximate any continuous function?", "truth": "yes"},
39
+ {"q": "Does entropy always increase in isolated systems?", "truth": "yes"},
40
+ {"q": "Is Python's GIL removed in CPython 3.13+?", "truth": "yes"},
41
+ {"q": "Do sharks get cancer?", "truth": "yes"},
42
+ {"q": "Is Antarctica a country?", "truth": "no"},
43
+ {"q": "Can humans survive without gut bacteria?", "truth": "yes"},
44
+ {"q": "Do all birds fly?", "truth": "no"},
45
+ {"q": "Is lightning hotter than the Sun's surface?", "truth": "yes"},
46
+ {"q": "Can finite-tape TM recognize all recursive languages?", "truth": "no"},
47
+ {"q": "Is the Riemann Hypothesis proved as of 2026?", "truth": "no"},
48
+ {"q": "Does gravitational lensing confirm GR?", "truth": "yes"},
49
+ ]
50
+
51
+ _model = None
52
+ _tok = None
53
+
54
+ def get_model():
55
+ global _model, _tok
56
+ if _model is None:
57
+ log(f"Loading {MODEL}...")
58
+ _tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
59
+ _tok.pad_token = _tok.eos_token
60
+ _model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
61
+ log(f"Loaded. Device: {_model.device}")
62
+ return _model, _tok
63
+
64
+ def generate(prompt, max_tokens=512):
65
+ model, tok = get_model()
66
+ inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
67
+ ilen = inputs.input_ids.shape[1]
68
+ with torch.no_grad():
69
+ out = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tok.eos_token_id)
70
+ ntok = out.shape[1] - ilen
71
+ return tok.decode(out[0][ilen:], skip_special_tokens=True), ntok
72
+
73
+ def position(text):
74
+ t = text.strip(); fl = t.split("\n")[0].lower()
75
+ if fl.startswith("yes:") or fl.startswith("yes "): return "yes"
76
+ if fl.startswith("no:") or fl.startswith("no "): return "no"
77
+ if any(w in fl for w in ["clearly yes","definitely yes","answer is yes"]): return "yes"
78
+ if any(w in fl for w in ["clearly no","definitely no","answer is no"]): return "no"
79
+ tc = t.lower()
80
+ return "yes" if tc.count(" yes ")>tc.count(" no ") else "no" if tc.count(" no ")>tc.count(" yes ") else "unclear"
81
+
82
+ def score_arg(text):
83
+ s=0.0; t=text.lower()
84
+ if any(w in t for w in ["yes","no","true","false"]): s+=0.25
85
+ if any(w in t for w in ["because","therefore","since","due to"]): s+=0.25
86
+ if any(w in t for w in ["evidence","proven","known","research","study"]): s+=0.25
87
+ if 20<len(text.split())<500: s+=0.25
88
+ return min(s,1.0)
89
+
90
+ def vote(positions, truth):
91
+ votes=[p for p in positions if p!="unclear"]
92
+ if not votes: return False,"unclear"
93
+ winner=max(set(votes),key=votes.count)
94
+ return winner==truth,winner
95
+
96
+ # ── CONDITIONS ──
97
+
98
+ def run_equal_1round():
99
+ correct,tokens=0,0
100
+ for t in DEBATE_TOPICS:
101
+ positions=[]
102
+ for a in range(4):
103
+ persona="honest" if a<3 else "adversarial"
104
+ text,ntok=generate(f"Question: {t['q']}\nYou are a {persona} debater. Start your answer with YES: or NO: followed by a brief explanation.\n",512)
105
+ positions.append(position(text)); tokens+=ntok
106
+ ok,_=vote(positions,t["truth"])
107
+ if ok: correct+=1
108
+ return correct,tokens
109
+
110
+ def run_equal_3round():
111
+ """All 4 agents speak 3 times per topic. Iso-round to OCC multi-turn."""
112
+ correct,tokens=0,0
113
+ for t in DEBATE_TOPICS:
114
+ all_positions=[]
115
+ for rnd in range(3):
116
+ for a in range(4):
117
+ persona="honest" if a<3 else "adversarial"
118
+ text,ntok=generate(f"Question: {t['q']}\nYou are a {persona} debater. Round {rnd+1}/3. Start your answer with YES: or NO: followed by a brief explanation.\n",512)
119
+ all_positions.append(position(text)); tokens+=ntok
120
+ ok,_=vote(all_positions,t["truth"])
121
+ if ok: correct+=1
122
+ return correct,tokens
123
+
124
+ def run_random_drop(drop_prob=0.25):
125
+ """Each agent has drop_prob chance of being SILENCED per topic. Not credit-based."""
126
+ correct,tokens,denied=0,0,0
127
+ for t in DEBATE_TOPICS:
128
+ positions=[]
129
+ for a in range(4):
130
+ if random.random() >= drop_prob:
131
+ persona="honest" if a<3 else "adversarial"
132
+ text,ntok=generate(f"Question: {t['q']}\nYou are a {persona} debater. Start your answer with YES: or NO: followed by a brief explanation.\n",512)
133
+ positions.append(position(text)); tokens+=ntok
134
+ else:
135
+ denied+=1
136
+ ok,_=vote(positions,t["truth"])
137
+ if ok: correct+=1
138
+ return correct,tokens,denied,30 # FIXED: return 4 values
139
+
140
+ def run_occ(pool_size, cost, max_earn=4):
141
+ correct,tokens,denied=0,0,0
142
+ credits=[pool_size//4]*4
143
+ for i,t in enumerate(DEBATE_TOPICS):
144
+ if i>0 and i%8==0:
145
+ credits=[max(0,c-1) for c in credits]
146
+ positions=[]
147
+ for a in range(4):
148
+ if credits[a]>=cost:
149
+ persona="honest" if a<3 else "adversarial"
150
+ text,ntok=generate(f"Question: {t['q']}\nYou are a {persona} debater. Start your answer with YES: or NO: followed by a brief explanation.\n",512)
151
+ p,q=position(text),score_arg(text); tokens+=ntok
152
+ earned=int(q*max_earn)
153
+ credits[a]=min(pool_size//4,credits[a]-cost+earned)
154
+ positions.append(p)
155
+ else:
156
+ denied+=1
157
+ ok,winner=vote(positions,t["truth"])
158
+ if ok: correct+=1
159
+ if sum(credits)<cost:
160
+ topics_ran=i+1; break
161
+ else:
162
+ topics_ran=len(DEBATE_TOPICS)
163
+ return correct,tokens,denied,topics_ran
164
+
165
+ # ── MAIN ──
166
+
167
+ CONDITIONS = [
168
+ ("equal_1round", run_equal_1round, False),
169
+ ("equal_3round", run_equal_3round, False),
170
+ ("random_drop", lambda: run_random_drop(0.25), True),
171
+ ("occ_240_5", lambda: run_occ(240,5), True),
172
+ ("occ_180_3", lambda: run_occ(180,3), True),
173
+ ("occ_120_3", lambda: run_occ(120,3), True),
174
+ ]
175
+
176
+ all_results={"model":MODEL,"seeds":{}}
177
+
178
+ for seed in SEEDS:
179
+ torch.manual_seed(seed)
180
+ random.seed(seed)
181
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
182
+ log(f"\n{'='*60}")
183
+ log(f"SEED {seed}")
184
+ log(f"{'='*60}")
185
+
186
+ get_model()
187
+
188
+ seed_results = {}
189
+ for name, fn, has_denied in CONDITIONS:
190
+ log(f"--- {name} ---")
191
+ t0=time.time()
192
+ if has_denied:
193
+ c, tk, d, r = fn()
194
+ acc = c/r
195
+ log(f" {c}/{r} ({acc:.3f}), {tk} tok, {d} denied ({time.time()-t0:.0f}s)")
196
+ seed_results[name] = {"correct":c,"total":r,"accuracy":acc,"tokens":tk,"denied":d}
197
+ elif name == "equal_3round":
198
+ c, tk = fn()
199
+ acc = c/30
200
+ log(f" {c}/30 ({acc:.3f}), {tk} tok ({time.time()-t0:.0f}s)")
201
+ seed_results[name] = {"correct":c,"total":30,"accuracy":acc,"tokens":tk}
202
+ else:
203
+ c, tk = fn()
204
+ acc = c/30
205
+ log(f" {c}/30 ({acc:.3f}), {tk} tok ({time.time()-t0:.0f}s)")
206
+ seed_results[name] = {"correct":c,"total":30,"accuracy":acc,"tokens":tk}
207
+
208
+ all_results["seeds"][str(seed)] = seed_results
209
+
210
+ # ── SUMMARY ──
211
+ log(f"\n{'='*60}")
212
+ log("MULTI-SEED SUMMARY (6 CONDITIONS)")
213
+ log(f"{'='*60}")
214
+ log(f"{'Condition':<18} {'Mean':>6} {'Min':>6} {'Max':>6} {'Range':>6} {'Tokens':>8}")
215
+ log(f"{'-'*18} {'-'*6} {'-'*6} {'-'*6} {'-'*6} {'-'*8}")
216
+
217
+ for name, fn, has_denied in CONDITIONS:
218
+ accs=[all_results["seeds"][str(s)][name]["accuracy"] for s in SEEDS]
219
+ toks=[all_results["seeds"][str(s)][name]["tokens"] for s in SEEDS]
220
+ mean=sum(accs)/len(accs); mn=min(accs); mx=max(accs)
221
+ meantok=sum(toks)/len(toks)
222
+ log(f" {name:<18} {mean:6.3f} {mn:6.3f} {mx:6.3f} {mx-mn:6.3f} {meantok:8.0f}")
223
+
224
+ # ── SAVE ──
225
+ all_results["summary"]={
226
+ name: {
227
+ "mean": sum(all_results["seeds"][str(s)][name]["accuracy"] for s in SEEDS)/len(SEEDS),
228
+ "min": min(all_results["seeds"][str(s)][name]["accuracy"] for s in SEEDS),
229
+ "max": max(all_results["seeds"][str(s)][name]["accuracy"] for s in SEEDS),
230
+ "mean_tokens": sum(all_results["seeds"][str(s)][name]["tokens"] for s in SEEDS)/len(SEEDS),
231
+ }
232
+ for name, _, _ in CONDITIONS
233
+ }
234
+
235
+ path=OUT/"debate_extended_baselines_results.json"
236
+ path.write_text(json.dumps(all_results,indent=2))
237
+ log(f"\nSaved -> {path}")
238
+ log(f"Total elapsed: {time.time()-START:.0f}s")