| from __future__ import annotations
|
|
|
| import json |
| import os |
| import signal |
| import socket |
| import subprocess |
| import sys |
| import threading |
| import time |
| from http.server import BaseHTTPRequestHandler, HTTPServer |
| from pathlib import Path |
|
|
| import httpx
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent
|
| TASKS = [ |
| "refund_triage_easy", |
| "cross_function_brief_medium", |
| "executive_escalation_hard", |
| ] |
|
|
|
|
| def print_check(name: str, passed: bool, detail: str = "") -> None:
|
| status = "PASS" if passed else "FAIL"
|
| suffix = f" - {detail}" if detail else ""
|
| print(f"{status}: {name}{suffix}")
|
|
|
|
|
| def find_free_port() -> int:
|
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| sock.bind(("127.0.0.1", 0))
|
| sock.listen(1)
|
| return int(sock.getsockname()[1])
|
|
|
|
|
| def wait_for_server(base_url: str, timeout: float = 20.0) -> bool:
|
| deadline = time.time() + timeout
|
| with httpx.Client(timeout=2.0) as client:
|
| while time.time() < deadline:
|
| try:
|
| response = client.get(f"{base_url}/health")
|
| if response.status_code == 200:
|
| return True
|
| except Exception:
|
| time.sleep(0.5)
|
| return False
|
|
|
|
|
| def greedy_action(observation: dict) -> dict:
|
| query_terms = set(observation["query"].lower().split())
|
| selected = set(observation.get("selected_chunks", []))
|
| available = [chunk for chunk in observation["available_chunks"] if chunk["chunk_id"] not in selected]
|
| remaining_budget = observation["token_budget"] - observation["total_tokens_used"]
|
|
|
| def overlap(chunk: dict) -> tuple[float, int, str]:
|
| keyword_terms = set(" ".join(chunk["keywords"]).lower().split())
|
| union = query_terms | keyword_terms
|
| score = (len(query_terms & keyword_terms) / len(union)) if union else 0.0
|
| return (-score, chunk["tokens"], chunk["chunk_id"])
|
|
|
| if selected and ( |
| observation["step_number"] >= 3 |
| or observation["total_tokens_used"] >= int(observation["token_budget"] * 0.7) |
| ): |
| if not observation.get("plan_draft"): |
| return {"action_type": "set_resolution_plan", "plan": "Verify evidence, protect customers, and publish only grounded actions."} |
| return {"action_type": "submit_report", "answer": "A concise grounded incident operations brief using the prioritized artifacts."} |
|
|
| if selected: |
| heavy = sorted( |
| [chunk for chunk in available + observation["available_chunks"] if chunk["chunk_id"] in selected], |
| key=lambda chunk: (-chunk["tokens"], chunk["chunk_id"]), |
| ) |
| if heavy and heavy[0]["tokens"] > max(120, observation["token_budget"] // 3): |
| return { |
| "action_type": "summarize_artifact", |
| "artifact_id": heavy[0]["chunk_id"], |
| "compression_ratio": 0.5, |
| } |
|
|
| for chunk in sorted(available, key=overlap): |
| return {"action_type": "inspect_artifact", "artifact_id": chunk["chunk_id"]} |
|
|
| return {"action_type": "submit_report", "answer": "A concise grounded incident operations brief using the prioritized artifacts."} |
|
|
|
|
| def planner_action(client: httpx.Client, base_url: str, fallback_observation: dict) -> dict:
|
| try:
|
| response = client.post(f"{base_url}/optimize-step")
|
| if response.status_code == 200:
|
| return response.json()
|
| except Exception:
|
| pass
|
| return greedy_action(fallback_observation)
|
|
|
|
|
| def run_task(client: httpx.Client, base_url: str, task_name: str) -> tuple[bool, float]:
|
| reset = client.post(f"{base_url}/reset", json={"task_name": task_name})
|
| if reset.status_code != 200:
|
| print_check(f"reset {task_name}", False, reset.text)
|
| return False, 0.0
|
|
|
| observation = reset.json()["observation"]
|
| done = False
|
| final_score = 0.0
|
| while not done:
|
| action = planner_action(client, base_url, observation)
|
| step = client.post(f"{base_url}/step", json=action)
|
| if step.status_code != 200:
|
| print_check(f"step {task_name}", False, step.text)
|
| return False, 0.0
|
| body = step.json()
|
| observation = body["observation"]
|
| done = body["done"]
|
| final_score = float(body["reward"])
|
| in_range = 0.0 <= final_score <= 1.0
|
| print_check(f"task {task_name} score range", in_range, f"score={final_score:.4f}")
|
| return in_range, final_score
|
|
|
|
|
| def run_inference_script(base_url: str) -> bool: |
| proxy_port = find_free_port() |
| requests_seen: list[dict[str, str | None]] = [] |
|
|
| class ProxyHandler(BaseHTTPRequestHandler): |
| def do_POST(self): |
| length = int(self.headers.get("Content-Length", "0")) |
| body = self.rfile.read(length).decode("utf-8") |
| requests_seen.append( |
| { |
| "path": self.path, |
| "authorization": self.headers.get("Authorization"), |
| "body": body, |
| } |
| ) |
| payload = { |
| "id": "chatcmpl-validate", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": "validator-proxy", |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": json.dumps( |
| { |
| "action_type": "submit_report", |
| "answer": "Validated via proxy [support_003]", |
| } |
| ), |
| }, |
| "finish_reason": "stop", |
| } |
| ], |
| } |
| encoded = json.dumps(payload).encode("utf-8") |
| self.send_response(200) |
| self.send_header("Content-Type", "application/json") |
| self.send_header("Content-Length", str(len(encoded))) |
| self.end_headers() |
| self.wfile.write(encoded) |
|
|
| def log_message(self, format: str, *args): |
| return |
|
|
| proxy_server = HTTPServer(("127.0.0.1", proxy_port), ProxyHandler) |
| proxy_thread = threading.Thread(target=proxy_server.serve_forever, daemon=True) |
| proxy_thread.start() |
|
|
| try: |
| env = os.environ.copy() |
| env["RAG_ENV_URL"] = base_url |
| env.pop("ALLOW_BASELINE_FALLBACK", None) |
| env["API_BASE_URL"] = f"http://127.0.0.1:{proxy_port}/v1" |
| env["API_KEY"] = "offline-validation-token" |
| env["HF_TOKEN"] = "legacy-should-not-win" |
| process = subprocess.run( |
| [sys.executable, "inference.py"], |
| cwd=PROJECT_ROOT, |
| capture_output=True, |
| text=True, |
| timeout=120, |
| env=env, |
| ) |
| stdout = process.stdout or "" |
| has_start = "[START]" in stdout |
| has_end = "[END]" in stdout |
| end_has_score = " score=" in stdout |
| proxy_called = any(request["path"] == "/v1/chat/completions" for request in requests_seen) |
| auth_ok = any(request["authorization"] == "Bearer offline-validation-token" for request in requests_seen) |
| return process.returncode == 0 and has_start and has_end and end_has_score and proxy_called and auth_ok |
| finally: |
| proxy_server.shutdown() |
| proxy_server.server_close() |
|
|
|
|
| def main() -> int:
|
| port = find_free_port()
|
| base_url = f"http://127.0.0.1:{port}"
|
| command = [sys.executable, "-m", "uvicorn", "app:app", "--host", "127.0.0.1", "--port", str(port)]
|
| process = subprocess.Popen(command, cwd=PROJECT_ROOT)
|
|
|
| try:
|
| if not wait_for_server(base_url):
|
| print_check("server startup", False, "Timed out waiting for /health")
|
| return 1
|
| print_check("server startup", True)
|
|
|
| all_passed = True
|
| with httpx.Client(timeout=10.0) as client:
|
| health = client.get(f"{base_url}/health")
|
| health_ok = health.status_code == 200 and health.json().get("status") == "ok"
|
| print_check("GET /health", health_ok)
|
| all_passed &= health_ok
|
|
|
| reset = client.post(f"{base_url}/reset", json={"task_name": "refund_triage_easy"}) |
| reset_ok = reset.status_code == 200 and "observation" in reset.json()
|
| print_check("POST /reset", reset_ok)
|
| all_passed &= reset_ok
|
|
|
| initial_observation = reset.json().get("observation", {})
|
| first_chunk_id = None
|
| for chunk in initial_observation.get("available_chunks", []):
|
| if chunk.get("chunk_id"):
|
| first_chunk_id = chunk["chunk_id"]
|
| break
|
| step_payload = {"action_type": "inspect_artifact", "artifact_id": first_chunk_id} if first_chunk_id else { |
| "action_type": "submit_report", |
| "answer": "No chunk available for validation.", |
| } |
| step = client.post(f"{base_url}/step", json=step_payload)
|
| step_ok = step.status_code == 200 and isinstance(step.json().get("reward"), float)
|
| print_check("POST /step", step_ok)
|
| all_passed &= step_ok
|
|
|
| state = client.get(f"{base_url}/state")
|
| state_ok = state.status_code == 200 and "selected_chunks" in state.json()
|
| print_check("GET /state", state_ok)
|
| all_passed &= state_ok
|
|
|
| optimize_prompt = client.post(
|
| f"{base_url}/optimize-prompt",
|
| json={
|
| "prompt": "Draft a customer-safe admin compromise update with rollback safeguards and cite evidence.",
|
| "corpus_family": "enterprise_v2",
|
| "compression_mode": "grounded",
|
| },
|
| )
|
| optimize_body = optimize_prompt.json() if optimize_prompt.status_code == 200 else {}
|
| optimize_ok = (
|
| optimize_prompt.status_code == 200
|
| and "optimized_prompt" in optimize_body
|
| and "context_tuning" in optimize_body
|
| and "grounding" in optimize_body
|
| and optimize_body.get("optimization_mode") == "grounded"
|
| and bool(optimize_body.get("grounding", {}).get("citation_ready"))
|
| )
|
| print_check("POST /optimize-prompt", optimize_ok)
|
| all_passed &= optimize_ok
|
|
|
| inference_ok = run_inference_script(base_url)
|
| print_check("python inference.py", inference_ok)
|
| all_passed &= inference_ok
|
|
|
| for task_name in TASKS:
|
| passed, _ = run_task(client, base_url, task_name)
|
| all_passed &= passed
|
|
|
| return 0 if all_passed else 1
|
| finally:
|
| if process.poll() is None:
|
| process.terminate()
|
| try:
|
| process.wait(timeout=5)
|
| except subprocess.TimeoutExpired:
|
| if os.name == "nt":
|
| process.kill()
|
| else:
|
| process.send_signal(signal.SIGKILL)
|
|
|
|
|
| if __name__ == "__main__":
|
| raise SystemExit(main())
|
|
|