Spaces:
Sleeping
Sleeping
Merge pull request #10 from ping98k/codex/add-id-to-each-player
Browse files
main.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
-
load_dotenv("./local.env")
|
| 3 |
import os, json, re, ast, gradio as gr
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from tqdm import tqdm
|
|
@@ -30,7 +30,7 @@ class SimpleProgress:
|
|
| 30 |
return f"{self.prefix} {self.count}/{self.total} - ETA {eta}"
|
| 31 |
|
| 32 |
NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
|
| 33 |
-
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE",
|
| 34 |
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 100))
|
| 35 |
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
|
| 36 |
API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
|
|
@@ -40,9 +40,9 @@ PAIRWISE_FILTER_DEFAULT = os.getenv("ENABLE_PAIRWISE_FILTER", "true").lower() ==
|
|
| 40 |
GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
|
| 41 |
SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
|
| 42 |
PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
|
| 43 |
-
GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "
|
| 44 |
-
SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "0.
|
| 45 |
-
PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.
|
| 46 |
SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
|
| 47 |
PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
|
| 48 |
CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
|
|
@@ -131,10 +131,12 @@ def run_tournament(
|
|
| 131 |
f"Total tokens: {prompt_tokens + completion_tokens}"
|
| 132 |
)
|
| 133 |
|
| 134 |
-
def log_completion(prefix: str, text: str):
|
| 135 |
disp = text.replace("\n", " ")
|
| 136 |
if len(disp) > 100:
|
| 137 |
disp = disp[:100] + "…"
|
|
|
|
|
|
|
| 138 |
return log(f"{prefix}{disp}")
|
| 139 |
def log(msg):
|
| 140 |
process_log.append(msg)
|
|
@@ -153,12 +155,15 @@ def run_tournament(
|
|
| 153 |
add_usage(usage)
|
| 154 |
yield from log(f"{len(all_players)} players generated")
|
| 155 |
for i, p in enumerate(all_players, 1):
|
| 156 |
-
yield from log_completion(f"Completion {i}: ", p)
|
| 157 |
def criteria_block():
|
| 158 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
| 159 |
|
| 160 |
if enable_score_filter:
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
text, usage = prompt_score(
|
| 163 |
instruction,
|
| 164 |
criteria_list,
|
|
@@ -172,7 +177,7 @@ def run_tournament(
|
|
| 172 |
return_usage=True,
|
| 173 |
)
|
| 174 |
add_usage(usage)
|
| 175 |
-
score_outputs.append(text)
|
| 176 |
data = _clean_json(text)
|
| 177 |
if "scores" in data and isinstance(data["scores"], list):
|
| 178 |
vals = data["scores"]
|
|
@@ -183,7 +188,7 @@ def run_tournament(
|
|
| 183 |
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 184 |
prog = SimpleProgress(len(all_players), "Scoring")
|
| 185 |
scores = {}
|
| 186 |
-
for p, s in zip(
|
| 187 |
scores[p] = s
|
| 188 |
yield from log(prog.step())
|
| 189 |
hist_fig = plt.figure()
|
|
@@ -191,8 +196,8 @@ def run_tournament(
|
|
| 191 |
yield from log("Histogram generated")
|
| 192 |
top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
|
| 193 |
yield from log(f"Filtered to {len(top_players)} players with best scores")
|
| 194 |
-
for i, txt in enumerate(score_outputs, 1):
|
| 195 |
-
yield from log_completion(f"Score completion {i}: ", txt)
|
| 196 |
else:
|
| 197 |
top_players = all_players
|
| 198 |
if enable_pairwise_filter:
|
|
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
+
load_dotenv("./local.env",override=True)
|
| 3 |
import os, json, re, ast, gradio as gr
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from tqdm import tqdm
|
|
|
|
| 30 |
return f"{self.prefix} {self.count}/{self.total} - ETA {eta}"
|
| 31 |
|
| 32 |
NUM_TOP_PICKS_DEFAULT = int(os.getenv("NUM_TOP_PICKS", 3))
|
| 33 |
+
POOL_SIZE_DEFAULT = int(os.getenv("POOL_SIZE", 6))
|
| 34 |
MAX_WORKERS_DEFAULT = int(os.getenv("MAX_WORKERS", 100))
|
| 35 |
NUM_GENERATIONS_DEFAULT = int(os.getenv("NUM_GENERATIONS", 10))
|
| 36 |
API_BASE_DEFAULT = os.getenv("OPENAI_API_BASE", "")
|
|
|
|
| 40 |
GENERATE_MODEL_DEFAULT = os.getenv("GENERATE_MODEL", "gpt-4o-mini")
|
| 41 |
SCORE_MODEL_DEFAULT = os.getenv("SCORE_MODEL", "gpt-4o-mini")
|
| 42 |
PAIRWISE_MODEL_DEFAULT = os.getenv("PAIRWISE_MODEL", "gpt-4o-mini")
|
| 43 |
+
GENERATE_TEMPERATURE_DEFAULT = float(os.getenv("GENERATE_TEMPERATURE", "0.9"))
|
| 44 |
+
SCORE_TEMPERATURE_DEFAULT = float(os.getenv("SCORE_TEMPERATURE", "0.6"))
|
| 45 |
+
PAIRWISE_TEMPERATURE_DEFAULT = float(os.getenv("PAIRWISE_TEMPERATURE", "0.6"))
|
| 46 |
SCORE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_SCORE", "true").lower() == "true"
|
| 47 |
PAIRWISE_WITH_INSTRUCTION_DEFAULT = os.getenv("PASS_INSTRUCTION_TO_PAIRWISE", "true").lower() == "true"
|
| 48 |
CRITERIA_DEFAULT = "Factuality,Instruction Following,Precision"
|
|
|
|
| 131 |
f"Total tokens: {prompt_tokens + completion_tokens}"
|
| 132 |
)
|
| 133 |
|
| 134 |
+
def log_completion(prefix: str, text: str, player_id: int | None = None):
|
| 135 |
disp = text.replace("\n", " ")
|
| 136 |
if len(disp) > 100:
|
| 137 |
disp = disp[:100] + "…"
|
| 138 |
+
if player_id is not None:
|
| 139 |
+
prefix = f"{prefix}(ID {player_id}) "
|
| 140 |
return log(f"{prefix}{disp}")
|
| 141 |
def log(msg):
|
| 142 |
process_log.append(msg)
|
|
|
|
| 155 |
add_usage(usage)
|
| 156 |
yield from log(f"{len(all_players)} players generated")
|
| 157 |
for i, p in enumerate(all_players, 1):
|
| 158 |
+
yield from log_completion(f"Completion {i}: ", p, i)
|
| 159 |
def criteria_block():
|
| 160 |
return "\n".join(f"{i + 1}) {c}" for i, c in enumerate(criteria_list))
|
| 161 |
|
| 162 |
if enable_score_filter:
|
| 163 |
+
players_with_ids = list(enumerate(all_players, 1))
|
| 164 |
+
|
| 165 |
+
def score(item):
|
| 166 |
+
idx, player = item
|
| 167 |
text, usage = prompt_score(
|
| 168 |
instruction,
|
| 169 |
criteria_list,
|
|
|
|
| 177 |
return_usage=True,
|
| 178 |
)
|
| 179 |
add_usage(usage)
|
| 180 |
+
score_outputs.append((idx, text))
|
| 181 |
data = _clean_json(text)
|
| 182 |
if "scores" in data and isinstance(data["scores"], list):
|
| 183 |
vals = data["scores"]
|
|
|
|
| 188 |
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 189 |
prog = SimpleProgress(len(all_players), "Scoring")
|
| 190 |
scores = {}
|
| 191 |
+
for (idx, p), s in zip(players_with_ids, ex.map(score, players_with_ids)):
|
| 192 |
scores[p] = s
|
| 193 |
yield from log(prog.step())
|
| 194 |
hist_fig = plt.figure()
|
|
|
|
| 196 |
yield from log("Histogram generated")
|
| 197 |
top_players = sorted(all_players, key=scores.get, reverse=True)[:pool_size]
|
| 198 |
yield from log(f"Filtered to {len(top_players)} players with best scores")
|
| 199 |
+
for i, (idx, txt) in enumerate(score_outputs, 1):
|
| 200 |
+
yield from log_completion(f"Score completion {i}: ", txt, idx)
|
| 201 |
else:
|
| 202 |
top_players = all_players
|
| 203 |
if enable_pairwise_filter:
|