feat: enhance inference and logging capabilities with SWD tracing
Browse filesUpdated .env.example to include new environment variables for master and worker agents, and added support for SWD tracing in inference.py. The SwdTraceWriter class was introduced to log SWD snapshots to a specified file, improving the logging mechanism. Adjusted README.md to reflect changes in API key requirements and SWD tracing options.
- .env.example +48 -6
- .gitignore +1 -0
- README.md +3 -3
- inference.py +98 -12
- server/llm_env.py +88 -0
- server/reward.py +4 -7
- server/worker_client.py +4 -7
.env.example
CHANGED
|
@@ -1,17 +1,59 @@
|
|
| 1 |
-
#
|
|
|
|
|
|
|
| 2 |
HF_TOKEN=
|
| 3 |
OPENAI_API_KEY=
|
| 4 |
API_BASE_URL=https://router.huggingface.co/v1
|
| 5 |
MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 6 |
|
| 7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
CORP_TASK_ID=e1_launch_readiness
|
| 9 |
|
| 10 |
-
#
|
|
|
|
|
|
|
|
|
|
| 11 |
CORP_STUB_WORKERS=1
|
| 12 |
-
CORP_DISABLE_LLM_JUDGE=1
|
| 13 |
-
CORP_WORKER_MODEL=Qwen/Qwen2.5-7B-Instruct
|
| 14 |
-
CORP_JUDGE_MODEL=Qwen/Qwen2.5-7B-Instruct
|
| 15 |
|
| 16 |
# Hugging Face Space
|
| 17 |
PORT=7860
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------
|
| 2 |
+
# Global fallbacks (OpenAI-compatible). Used when role-specific vars are unset.
|
| 3 |
+
# ---------------------------------------------------------------------------
|
| 4 |
HF_TOKEN=
|
| 5 |
OPENAI_API_KEY=
|
| 6 |
API_BASE_URL=https://router.huggingface.co/v1
|
| 7 |
MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 8 |
|
| 9 |
+
# ---------------------------------------------------------------------------
|
| 10 |
+
# Master agent (inference.py) — planner that reads/writes the SWD
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
CORP_MASTER_API_KEY=
|
| 13 |
+
CORP_MASTER_BASE_URL=
|
| 14 |
+
CORP_MASTER_MODEL=
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Frozen workers (delegate) — optional per-agent key/router/model
|
| 18 |
+
# Naming: CORP_WORKER_<AGENT_ID_UPPER>_API_KEY / _BASE_URL / _MODEL
|
| 19 |
+
# Example for dev_agent -> CORP_WORKER_DEV_AGENT_API_KEY
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
CORP_WORKER_DEFAULT_API_KEY=
|
| 22 |
+
CORP_WORKER_DEFAULT_BASE_URL=
|
| 23 |
+
CORP_WORKER_DEFAULT_MODEL=
|
| 24 |
+
|
| 25 |
+
CORP_WORKER_DEV_AGENT_API_KEY=
|
| 26 |
+
CORP_WORKER_DEV_AGENT_BASE_URL=
|
| 27 |
+
CORP_WORKER_DEV_AGENT_MODEL=
|
| 28 |
+
|
| 29 |
+
CORP_WORKER_HR_AGENT_API_KEY=
|
| 30 |
+
CORP_WORKER_HR_AGENT_BASE_URL=
|
| 31 |
+
CORP_WORKER_HR_AGENT_MODEL=
|
| 32 |
+
|
| 33 |
+
CORP_WORKER_FINANCE_AGENT_API_KEY=
|
| 34 |
+
CORP_WORKER_FINANCE_AGENT_BASE_URL=
|
| 35 |
+
CORP_WORKER_FINANCE_AGENT_MODEL=
|
| 36 |
+
|
| 37 |
+
CORP_WORKER_MODEL=
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# LLM judge (server/reward.py) — separate endpoint from master/workers
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
CORP_JUDGE_API_KEY=
|
| 43 |
+
CORP_JUDGE_BASE_URL=
|
| 44 |
+
CORP_JUDGE_MODEL=Qwen/Qwen2.5-7B-Instruct
|
| 45 |
+
CORP_DISABLE_LLM_JUDGE=1
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Episode + inference logging
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
CORP_TASK_ID=e1_launch_readiness
|
| 51 |
|
| 52 |
+
# Append-only SWD evolution log (separate from console). Use .jsonl for one JSON
|
| 53 |
+
# object per line, or .txt for human-readable JSON blocks.
|
| 54 |
+
CORP_SWD_TRACE_FILE=logs/swd_trace.jsonl
|
| 55 |
+
|
| 56 |
CORP_STUB_WORKERS=1
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# Hugging Face Space
|
| 59 |
PORT=7860
|
.gitignore
CHANGED
|
@@ -51,3 +51,4 @@ task.md
|
|
| 51 |
walkthrough.md
|
| 52 |
72b_eval.txt
|
| 53 |
.cursor/plans/corp-env_rewrite_plan_952c3fcd.plan.md
|
|
|
|
|
|
| 51 |
walkthrough.md
|
| 52 |
72b_eval.txt
|
| 53 |
.cursor/plans/corp-env_rewrite_plan_952c3fcd.plan.md
|
| 54 |
+
_test_swd.jsonl
|
README.md
CHANGED
|
@@ -53,11 +53,11 @@ uv run server
|
|
| 53 |
|
| 54 |
## Baseline inference (master agent)
|
| 55 |
|
| 56 |
-
Requires `
|
| 57 |
|
| 58 |
```powershell
|
| 59 |
uv run python inference.py
|
| 60 |
-
uv run python inference.py --tasks e1_launch_readiness --max-steps 25
|
| 61 |
```
|
| 62 |
|
| 63 |
## OpenEnv validation
|
|
@@ -79,7 +79,7 @@ docker run -p 7860:7860 --env-file .env.example corp-env
|
|
| 79 |
|
| 80 |
## Configuration
|
| 81 |
|
| 82 |
-
See [`.env.example`](.env.example) for `CORP_TASK_ID`, `CORP_STUB_WORKERS`,
|
| 83 |
|
| 84 |
## License
|
| 85 |
|
|
|
|
| 53 |
|
| 54 |
## Baseline inference (master agent)
|
| 55 |
|
| 56 |
+
Requires a **master** API key (`CORP_MASTER_API_KEY`, or `HF_TOKEN` / `OPENAI_API_KEY` as fallback). Without it, `inference.py` runs a short **deterministic E1** smoke test using stub workers. Optional **per-worker** and **judge** keys/URLs are in [`.env.example`](.env.example). Set `CORP_SWD_TRACE_FILE` or pass `--swd-trace path.jsonl` to append SWD snapshots to a file separate from console logs.
|
| 57 |
|
| 58 |
```powershell
|
| 59 |
uv run python inference.py
|
| 60 |
+
uv run python inference.py --tasks e1_launch_readiness --max-steps 25 --swd-trace logs/run.jsonl
|
| 61 |
```
|
| 62 |
|
| 63 |
## OpenEnv validation
|
|
|
|
| 79 |
|
| 80 |
## Configuration
|
| 81 |
|
| 82 |
+
See [`.env.example`](.env.example) for master/worker/judge API routing, `CORP_TASK_ID`, `CORP_STUB_WORKERS`, and `CORP_SWD_TRACE_FILE`.
|
| 83 |
|
| 84 |
## License
|
| 85 |
|
inference.py
CHANGED
|
@@ -10,19 +10,22 @@ import os
|
|
| 10 |
import re
|
| 11 |
import textwrap
|
| 12 |
import time
|
| 13 |
-
from
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from openai import OpenAI
|
| 17 |
|
| 18 |
from corp_env.models import CorpAction, CorpObservation
|
| 19 |
from server.environment import CorpEnvironment
|
|
|
|
| 20 |
|
| 21 |
load_dotenv()
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
BENCHMARK = "corp-env"
|
| 28 |
MAX_HISTORY_MESSAGES = 40
|
|
@@ -78,6 +81,66 @@ def log_end(task: str, steps: int, score: float, rewards: List[float]) -> None:
|
|
| 78 |
print(f"[END] task={task} steps={steps} score={score:.3f} rewards={rs}", flush=True)
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def extract_json(raw_text: str) -> dict:
|
| 82 |
cleaned = raw_text.strip()
|
| 83 |
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
|
|
@@ -146,7 +209,12 @@ def trim_history(messages: list, max_messages: int = MAX_HISTORY_MESSAGES) -> No
|
|
| 146 |
messages.pop(1)
|
| 147 |
|
| 148 |
|
| 149 |
-
def run_episode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
os.environ["CORP_TASK_ID"] = task_id
|
| 151 |
os.environ.setdefault("CORP_STUB_WORKERS", "1")
|
| 152 |
|
|
@@ -157,6 +225,8 @@ def run_episode(client: OpenAI, task_id: str, max_steps: int) -> tuple[float, in
|
|
| 157 |
|
| 158 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 159 |
obs = env.reset(task_id=task_id)
|
|
|
|
|
|
|
| 160 |
messages = [
|
| 161 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 162 |
{"role": "user", "content": build_observation_message(0, obs)},
|
|
@@ -208,6 +278,8 @@ def run_episode(client: OpenAI, task_id: str, max_steps: int) -> tuple[float, in
|
|
| 208 |
total += float(obs.reward or 0.0)
|
| 209 |
steps = step
|
| 210 |
log_step(step, alog[:200], float(obs.reward or 0.0), obs.done, obs.error)
|
|
|
|
|
|
|
| 211 |
messages.append({"role": "user", "content": build_observation_message(step, obs)})
|
| 212 |
if obs.done:
|
| 213 |
break
|
|
@@ -216,12 +288,14 @@ def run_episode(client: OpenAI, task_id: str, max_steps: int) -> tuple[float, in
|
|
| 216 |
return total, steps, rewards
|
| 217 |
|
| 218 |
|
| 219 |
-
def deterministic_e1_smoke() -> None:
|
| 220 |
"""Offline smoke: E1 solved with stub workers (no master LLM)."""
|
| 221 |
os.environ["CORP_TASK_ID"] = "e1_launch_readiness"
|
| 222 |
os.environ["CORP_STUB_WORKERS"] = "1"
|
| 223 |
env = CorpEnvironment()
|
| 224 |
obs = env.reset(task_id="e1_launch_readiness")
|
|
|
|
|
|
|
| 225 |
seq = [
|
| 226 |
CorpAction(action_type="delegate", agent_id="dev_agent", payload="Assess launch readiness"),
|
| 227 |
CorpAction(action_type="delegate", agent_id="hr_agent", payload="Staffing sign-off"),
|
|
@@ -241,6 +315,8 @@ def deterministic_e1_smoke() -> None:
|
|
| 241 |
total += r
|
| 242 |
rlist.append(r)
|
| 243 |
log_step(i, act.action_type, r, obs.done, obs.error)
|
|
|
|
|
|
|
| 244 |
log_end("e1_launch_readiness", len(seq), total, rlist)
|
| 245 |
|
| 246 |
|
|
@@ -253,21 +329,31 @@ def main() -> None:
|
|
| 253 |
help="Comma-separated task ids",
|
| 254 |
)
|
| 255 |
parser.add_argument("--max-steps", type=int, default=30, help="Max steps per episode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
args = parser.parse_args()
|
| 257 |
|
| 258 |
-
|
|
|
|
|
|
|
| 259 |
print(
|
| 260 |
-
"No
|
| 261 |
-
"Set keys to run the LLM master on --tasks.",
|
| 262 |
flush=True,
|
| 263 |
)
|
| 264 |
-
|
|
|
|
| 265 |
return
|
| 266 |
|
| 267 |
-
client = OpenAI(
|
| 268 |
for tid in [t.strip() for t in args.tasks.split(",") if t.strip()]:
|
| 269 |
ms = args.max_steps * 2 if tid == "h1_acquisition_defence" else args.max_steps
|
| 270 |
-
|
|
|
|
| 271 |
|
| 272 |
|
| 273 |
if __name__ == "__main__":
|
|
|
|
| 10 |
import re
|
| 11 |
import textwrap
|
| 12 |
import time
|
| 13 |
+
from datetime import datetime, timezone
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any, Dict, List, Optional
|
| 16 |
|
| 17 |
from dotenv import load_dotenv
|
| 18 |
from openai import OpenAI
|
| 19 |
|
| 20 |
from corp_env.models import CorpAction, CorpObservation
|
| 21 |
from server.environment import CorpEnvironment
|
| 22 |
+
from server.llm_env import openai_client_kwargs_master
|
| 23 |
|
| 24 |
load_dotenv()
|
| 25 |
|
| 26 |
+
MASTER_KWARGS = openai_client_kwargs_master()
|
| 27 |
+
MASTER_API_KEY = MASTER_KWARGS.get("api_key")
|
| 28 |
+
MODEL_NAME = os.getenv("CORP_MASTER_MODEL") or os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 29 |
|
| 30 |
BENCHMARK = "corp-env"
|
| 31 |
MAX_HISTORY_MESSAGES = 40
|
|
|
|
| 81 |
print(f"[END] task={task} steps={steps} score={score:.3f} rewards={rs}", flush=True)
|
| 82 |
|
| 83 |
|
| 84 |
+
class SwdTraceWriter:
|
| 85 |
+
"""Append SWD snapshots to a dedicated file (not mixed with console logs)."""
|
| 86 |
+
|
| 87 |
+
def __init__(self, path: Optional[str], task_id: str) -> None:
|
| 88 |
+
self.path = path.strip() if path else None
|
| 89 |
+
self.task_id = task_id
|
| 90 |
+
self._jsonl = bool(self.path and self.path.lower().endswith(".jsonl"))
|
| 91 |
+
if not self.path:
|
| 92 |
+
return
|
| 93 |
+
p = Path(self.path)
|
| 94 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")
|
| 96 |
+
with p.open("a", encoding="utf-8") as f:
|
| 97 |
+
f.write(
|
| 98 |
+
f"\n{'=' * 72}\n"
|
| 99 |
+
f"# CORP-ENV SWD trace | task={task_id} | started_utc={ts}\n"
|
| 100 |
+
f"{'=' * 72}\n"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def write(
|
| 104 |
+
self,
|
| 105 |
+
*,
|
| 106 |
+
phase: str,
|
| 107 |
+
step_index: int,
|
| 108 |
+
action: Optional[CorpAction],
|
| 109 |
+
obs: CorpObservation,
|
| 110 |
+
) -> None:
|
| 111 |
+
if not self.path:
|
| 112 |
+
return
|
| 113 |
+
action_blob: Dict[str, Any]
|
| 114 |
+
if action is None:
|
| 115 |
+
action_blob = {"note": "initial observation after reset"}
|
| 116 |
+
else:
|
| 117 |
+
action_blob = action.model_dump(mode="json", exclude_none=True)
|
| 118 |
+
|
| 119 |
+
if self._jsonl:
|
| 120 |
+
record = {
|
| 121 |
+
"phase": phase,
|
| 122 |
+
"step_index": step_index,
|
| 123 |
+
"env_turn": obs.turn,
|
| 124 |
+
"reward": obs.reward,
|
| 125 |
+
"done": obs.done,
|
| 126 |
+
"error": obs.error,
|
| 127 |
+
"action": action_blob,
|
| 128 |
+
"swd": obs.swd,
|
| 129 |
+
}
|
| 130 |
+
line = json.dumps(record, ensure_ascii=False)
|
| 131 |
+
with Path(self.path).open("a", encoding="utf-8") as f:
|
| 132 |
+
f.write(line + "\n")
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
with Path(self.path).open("a", encoding="utf-8") as f:
|
| 136 |
+
f.write(
|
| 137 |
+
f"\n--- {phase} step_index={step_index} env_turn={obs.turn} "
|
| 138 |
+
f"reward={obs.reward} done={obs.done} ---\n"
|
| 139 |
+
)
|
| 140 |
+
f.write(f"action: {json.dumps(action_blob, indent=2, ensure_ascii=False)}\n")
|
| 141 |
+
f.write(f"swd:\n{json.dumps(obs.swd, indent=2, ensure_ascii=False)}\n")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
def extract_json(raw_text: str) -> dict:
|
| 145 |
cleaned = raw_text.strip()
|
| 146 |
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
|
|
|
|
| 209 |
messages.pop(1)
|
| 210 |
|
| 211 |
|
| 212 |
+
def run_episode(
|
| 213 |
+
client: OpenAI,
|
| 214 |
+
task_id: str,
|
| 215 |
+
max_steps: int,
|
| 216 |
+
swd_trace: Optional[SwdTraceWriter],
|
| 217 |
+
) -> tuple[float, int, List[float]]:
|
| 218 |
os.environ["CORP_TASK_ID"] = task_id
|
| 219 |
os.environ.setdefault("CORP_STUB_WORKERS", "1")
|
| 220 |
|
|
|
|
| 225 |
|
| 226 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 227 |
obs = env.reset(task_id=task_id)
|
| 228 |
+
if swd_trace:
|
| 229 |
+
swd_trace.write(phase="after_reset", step_index=0, action=None, obs=obs)
|
| 230 |
messages = [
|
| 231 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 232 |
{"role": "user", "content": build_observation_message(0, obs)},
|
|
|
|
| 278 |
total += float(obs.reward or 0.0)
|
| 279 |
steps = step
|
| 280 |
log_step(step, alog[:200], float(obs.reward or 0.0), obs.done, obs.error)
|
| 281 |
+
if swd_trace:
|
| 282 |
+
swd_trace.write(phase="after_step", step_index=step, action=action, obs=obs)
|
| 283 |
messages.append({"role": "user", "content": build_observation_message(step, obs)})
|
| 284 |
if obs.done:
|
| 285 |
break
|
|
|
|
| 288 |
return total, steps, rewards
|
| 289 |
|
| 290 |
|
| 291 |
+
def deterministic_e1_smoke(swd_trace: Optional[SwdTraceWriter] = None) -> None:
|
| 292 |
"""Offline smoke: E1 solved with stub workers (no master LLM)."""
|
| 293 |
os.environ["CORP_TASK_ID"] = "e1_launch_readiness"
|
| 294 |
os.environ["CORP_STUB_WORKERS"] = "1"
|
| 295 |
env = CorpEnvironment()
|
| 296 |
obs = env.reset(task_id="e1_launch_readiness")
|
| 297 |
+
if swd_trace:
|
| 298 |
+
swd_trace.write(phase="after_reset", step_index=0, action=None, obs=obs)
|
| 299 |
seq = [
|
| 300 |
CorpAction(action_type="delegate", agent_id="dev_agent", payload="Assess launch readiness"),
|
| 301 |
CorpAction(action_type="delegate", agent_id="hr_agent", payload="Staffing sign-off"),
|
|
|
|
| 315 |
total += r
|
| 316 |
rlist.append(r)
|
| 317 |
log_step(i, act.action_type, r, obs.done, obs.error)
|
| 318 |
+
if swd_trace:
|
| 319 |
+
swd_trace.write(phase="after_step", step_index=i, action=act, obs=obs)
|
| 320 |
log_end("e1_launch_readiness", len(seq), total, rlist)
|
| 321 |
|
| 322 |
|
|
|
|
| 329 |
help="Comma-separated task ids",
|
| 330 |
)
|
| 331 |
parser.add_argument("--max-steps", type=int, default=30, help="Max steps per episode")
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--swd-trace",
|
| 334 |
+
type=str,
|
| 335 |
+
default=os.getenv("CORP_SWD_TRACE_FILE", ""),
|
| 336 |
+
help="Append SWD evolution to this file (.jsonl recommended). Overrides CORP_SWD_TRACE_FILE.",
|
| 337 |
+
)
|
| 338 |
args = parser.parse_args()
|
| 339 |
|
| 340 |
+
trace_path = (args.swd_trace or "").strip() or None
|
| 341 |
+
|
| 342 |
+
if not MASTER_API_KEY:
|
| 343 |
print(
|
| 344 |
+
"No master API key (set CORP_MASTER_API_KEY or HF_TOKEN / OPENAI_API_KEY) - "
|
| 345 |
+
"running deterministic E1 smoke only. Set keys to run the LLM master on --tasks.",
|
| 346 |
flush=True,
|
| 347 |
)
|
| 348 |
+
tw = SwdTraceWriter(trace_path, "e1_launch_readiness") if trace_path else None
|
| 349 |
+
deterministic_e1_smoke(swd_trace=tw)
|
| 350 |
return
|
| 351 |
|
| 352 |
+
client = OpenAI(**MASTER_KWARGS)
|
| 353 |
for tid in [t.strip() for t in args.tasks.split(",") if t.strip()]:
|
| 354 |
ms = args.max_steps * 2 if tid == "h1_acquisition_defence" else args.max_steps
|
| 355 |
+
tw = SwdTraceWriter(trace_path, tid) if trace_path else None
|
| 356 |
+
run_episode(client, tid, max_steps=ms, swd_trace=tw)
|
| 357 |
|
| 358 |
|
| 359 |
if __name__ == "__main__":
|
server/llm_env.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Resolve OpenAI-compatible API key and base URL per role (master, worker, judge)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _first(*values: Optional[str]) -> Optional[str]:
|
| 10 |
+
for v in values:
|
| 11 |
+
if v is not None and str(v).strip() != "":
|
| 12 |
+
return str(v).strip()
|
| 13 |
+
return None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def openai_client_kwargs_master() -> Dict[str, Any]:
|
| 17 |
+
"""Credentials for the master model (inference loop)."""
|
| 18 |
+
api_key = _first(
|
| 19 |
+
os.getenv("CORP_MASTER_API_KEY"),
|
| 20 |
+
os.getenv("HF_TOKEN"),
|
| 21 |
+
os.getenv("OPENAI_API_KEY"),
|
| 22 |
+
os.getenv("API_KEY"),
|
| 23 |
+
)
|
| 24 |
+
base_url = _first(
|
| 25 |
+
os.getenv("CORP_MASTER_BASE_URL"),
|
| 26 |
+
os.getenv("API_BASE_URL"),
|
| 27 |
+
os.getenv("OPENAI_BASE_URL"),
|
| 28 |
+
)
|
| 29 |
+
return _kwargs(api_key, base_url)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def openai_client_kwargs_worker(canonical_agent_id: str) -> Dict[str, Any]:
|
| 33 |
+
"""
|
| 34 |
+
Credentials for a frozen worker (dev_agent, hr_agent, finance_agent).
|
| 35 |
+
|
| 36 |
+
Per-agent overrides use uppercase id with hyphens as underscores, e.g.:
|
| 37 |
+
CORP_WORKER_DEV_AGENT_API_KEY, CORP_WORKER_DEV_AGENT_BASE_URL
|
| 38 |
+
"""
|
| 39 |
+
suffix = canonical_agent_id.upper().replace("-", "_")
|
| 40 |
+
api_key = _first(
|
| 41 |
+
os.getenv(f"CORP_WORKER_{suffix}_API_KEY"),
|
| 42 |
+
os.getenv("CORP_WORKER_DEFAULT_API_KEY"),
|
| 43 |
+
os.getenv("OPENAI_API_KEY"),
|
| 44 |
+
os.getenv("HF_TOKEN"),
|
| 45 |
+
os.getenv("API_KEY"),
|
| 46 |
+
)
|
| 47 |
+
base_url = _first(
|
| 48 |
+
os.getenv(f"CORP_WORKER_{suffix}_BASE_URL"),
|
| 49 |
+
os.getenv("CORP_WORKER_DEFAULT_BASE_URL"),
|
| 50 |
+
os.getenv("API_BASE_URL"),
|
| 51 |
+
os.getenv("OPENAI_BASE_URL"),
|
| 52 |
+
)
|
| 53 |
+
return _kwargs(api_key, base_url)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def openai_client_kwargs_judge() -> Dict[str, Any]:
|
| 57 |
+
"""Credentials for the optional LLM judge (reward)."""
|
| 58 |
+
api_key = _first(
|
| 59 |
+
os.getenv("CORP_JUDGE_API_KEY"),
|
| 60 |
+
os.getenv("OPENAI_API_KEY"),
|
| 61 |
+
os.getenv("HF_TOKEN"),
|
| 62 |
+
os.getenv("API_KEY"),
|
| 63 |
+
)
|
| 64 |
+
base_url = _first(
|
| 65 |
+
os.getenv("CORP_JUDGE_BASE_URL"),
|
| 66 |
+
os.getenv("API_BASE_URL"),
|
| 67 |
+
os.getenv("OPENAI_BASE_URL"),
|
| 68 |
+
)
|
| 69 |
+
return _kwargs(api_key, base_url)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _kwargs(api_key: Optional[str], base_url: Optional[str]) -> Dict[str, Any]:
|
| 73 |
+
out: Dict[str, Any] = {}
|
| 74 |
+
if api_key:
|
| 75 |
+
out["api_key"] = api_key
|
| 76 |
+
if base_url:
|
| 77 |
+
out["base_url"] = base_url
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def worker_model_for(canonical_agent_id: str) -> str:
|
| 82 |
+
suffix = canonical_agent_id.upper().replace("-", "_")
|
| 83 |
+
return _first(
|
| 84 |
+
os.getenv(f"CORP_WORKER_{suffix}_MODEL"),
|
| 85 |
+
os.getenv("CORP_WORKER_DEFAULT_MODEL"),
|
| 86 |
+
os.getenv("CORP_WORKER_MODEL"),
|
| 87 |
+
os.getenv("MODEL_NAME"),
|
| 88 |
+
) or "Qwen/Qwen2.5-7B-Instruct"
|
server/reward.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional
|
|
| 9 |
|
| 10 |
from openai import OpenAI
|
| 11 |
|
|
|
|
| 12 |
from server.swd import (
|
| 13 |
REQUIRED_TOP_LEVEL,
|
| 14 |
VALID_PHASES,
|
|
@@ -51,16 +52,15 @@ def compute_swd_coherence(swd: Dict[str, Any]) -> float:
|
|
| 51 |
def call_llm_judge(swd: Dict[str, Any], task_goal: str) -> float:
|
| 52 |
"""
|
| 53 |
Fast LLM judge (optional). Returns score in [0, 1] from YES count / 3.
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
if os.getenv("CORP_DISABLE_LLM_JUDGE", "").lower() in ("1", "true", "yes"):
|
| 57 |
return 0.0
|
| 58 |
|
| 59 |
-
|
| 60 |
-
if not
|
| 61 |
return 0.0
|
| 62 |
|
| 63 |
-
base_url = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL")
|
| 64 |
model = os.getenv("CORP_JUDGE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
|
| 65 |
|
| 66 |
prompt = f"""
|
|
@@ -82,9 +82,6 @@ Q1: YES/NO
|
|
| 82 |
Q2: YES/NO
|
| 83 |
Q3: YES/NO
|
| 84 |
"""
|
| 85 |
-
kwargs: Dict[str, Any] = {"api_key": api_key}
|
| 86 |
-
if base_url:
|
| 87 |
-
kwargs["base_url"] = base_url
|
| 88 |
client = OpenAI(**kwargs)
|
| 89 |
resp = client.chat.completions.create(
|
| 90 |
model=model,
|
|
|
|
| 9 |
|
| 10 |
from openai import OpenAI
|
| 11 |
|
| 12 |
+
from server.llm_env import openai_client_kwargs_judge
|
| 13 |
from server.swd import (
|
| 14 |
REQUIRED_TOP_LEVEL,
|
| 15 |
VALID_PHASES,
|
|
|
|
| 52 |
def call_llm_judge(swd: Dict[str, Any], task_goal: str) -> float:
|
| 53 |
"""
|
| 54 |
Fast LLM judge (optional). Returns score in [0, 1] from YES count / 3.
|
| 55 |
+
Uses CORP_JUDGE_* then global API keys (see server/llm_env.py). No call without a key.
|
| 56 |
"""
|
| 57 |
if os.getenv("CORP_DISABLE_LLM_JUDGE", "").lower() in ("1", "true", "yes"):
|
| 58 |
return 0.0
|
| 59 |
|
| 60 |
+
kwargs = openai_client_kwargs_judge()
|
| 61 |
+
if not kwargs.get("api_key"):
|
| 62 |
return 0.0
|
| 63 |
|
|
|
|
| 64 |
model = os.getenv("CORP_JUDGE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
|
| 65 |
|
| 66 |
prompt = f"""
|
|
|
|
| 82 |
Q2: YES/NO
|
| 83 |
Q3: YES/NO
|
| 84 |
"""
|
|
|
|
|
|
|
|
|
|
| 85 |
client = OpenAI(**kwargs)
|
| 86 |
resp = client.chat.completions.create(
|
| 87 |
model=model,
|
server/worker_client.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Optional
|
|
| 8 |
from openai import OpenAI
|
| 9 |
|
| 10 |
from server.agents.prompts import WORKER_PROMPTS
|
|
|
|
| 11 |
|
| 12 |
STUB_OUTPUTS = {
|
| 13 |
"dev_agent": (
|
|
@@ -44,19 +45,15 @@ def call_worker_model(
|
|
| 44 |
if os.getenv("CORP_STUB_WORKERS", "").lower() in ("1", "true", "yes"):
|
| 45 |
return call_model_stub(canonical_agent_id, task_description)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
if not api_key:
|
| 49 |
return call_model_stub(canonical_agent_id, task_description)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
model = os.getenv("CORP_WORKER_MODEL") or os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-7B-Instruct"
|
| 53 |
system = WORKER_PROMPTS.get(
|
| 54 |
canonical_agent_id,
|
| 55 |
"You are a concise corporate advisor. Plain prose only.",
|
| 56 |
)
|
| 57 |
-
kwargs = {"api_key": api_key}
|
| 58 |
-
if base_url:
|
| 59 |
-
kwargs["base_url"] = base_url
|
| 60 |
client = OpenAI(**kwargs)
|
| 61 |
resp = client.chat.completions.create(
|
| 62 |
model=model,
|
|
|
|
| 8 |
from openai import OpenAI
|
| 9 |
|
| 10 |
from server.agents.prompts import WORKER_PROMPTS
|
| 11 |
+
from server.llm_env import openai_client_kwargs_worker, worker_model_for
|
| 12 |
|
| 13 |
STUB_OUTPUTS = {
|
| 14 |
"dev_agent": (
|
|
|
|
| 45 |
if os.getenv("CORP_STUB_WORKERS", "").lower() in ("1", "true", "yes"):
|
| 46 |
return call_model_stub(canonical_agent_id, task_description)
|
| 47 |
|
| 48 |
+
kwargs = openai_client_kwargs_worker(canonical_agent_id)
|
| 49 |
+
if not kwargs.get("api_key"):
|
| 50 |
return call_model_stub(canonical_agent_id, task_description)
|
| 51 |
|
| 52 |
+
model = worker_model_for(canonical_agent_id)
|
|
|
|
| 53 |
system = WORKER_PROMPTS.get(
|
| 54 |
canonical_agent_id,
|
| 55 |
"You are a concise corporate advisor. Plain prose only.",
|
| 56 |
)
|
|
|
|
|
|
|
|
|
|
| 57 |
client = OpenAI(**kwargs)
|
| 58 |
resp = client.chat.completions.create(
|
| 59 |
model=model,
|