slipstream-webgpu / src /agent_forecaster.py
ashaibani's picture
Slipstream WebGPU (in-browser agent)
c658ad5 verified
"""Agentic forecasting: a strict tool-calling loop that orchestrates TimesFM 2.5 + EVM tools.
Uses the OpenAI-compatible Fireworks API (Kimi K2.6 Turbo by default) - so the captured traces
are ALREADY in the messages+tool_calls shape used to SFT MiniCPM5-1B (no lossy conversion).
STRICT tool-calling: the result comes ONLY from a `finish` tool call. No text/JSON fallbacks -
a model that cannot reliably call tools is out of scope. If the model replies without a tool
call it is nudged to use one. With return_trace=True the full OpenAI `messages` list is returned
for distillation.
Auth: FIREWORKS_API_KEY (env, or a gitignored .env in the repo root).
"""
from __future__ import annotations
import json
import os
import numpy as np
from . import evm, forecasting
BASE_URL = "https://api.fireworks.ai/inference/v1" # OpenAI-compatible endpoint
MODEL = "accounts/fireworks/routers/kimi-k2p6-turbo"
MAX_TOKENS = 32000
def _load_env():
if os.environ.get("FIREWORKS_API_KEY"):
return
p = os.path.join(os.getcwd(), ".env")
if os.path.exists(p):
for line in open(p):
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, val = line.partition("=")
os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'"))
def make_client(base_url=BASE_URL):
_load_env()
from openai import OpenAI
key = os.environ.get("FIREWORKS_API_KEY")
if not key:
raise RuntimeError("FIREWORKS_API_KEY not set (export it or put it in a .env file)")
return OpenAI(base_url=base_url, api_key=key, timeout=900, max_retries=8) # backoff on 429/503
# OpenAI function-calling tool specs (the exact shape used for MiniCPM5 SFT via apply_chat_template).
TOOLS = [
{"type": "function", "function": {
"name": "forecast_series",
"description": ("Forecast the continuation of the project's cumulative EV or AC with the TimesFM "
"2.5 time-series model. Returns P10/P50/P90 cumulative paths over `horizon` future "
"periods, and for EV the period each path reaches BAC (completion). Use a horizon "
"large enough for the path to reach BAC."),
"parameters": {"type": "object", "properties": {
"which": {"type": "string", "enum": ["ev", "ac"], "description": "which cumulative series to forecast"},
"horizon": {"type": "integer", "description": "number of future periods to forecast"}},
"required": ["which", "horizon"]}}},
{"type": "function", "function": {
"name": "evm_metrics",
"description": ("Classical Earned Value Management metrics from the observed data: SPI, CPI, the "
"three standard EAC formulas, and an earned-schedule finish-period estimate."),
"parameters": {"type": "object", "properties": {}, "required": []}}},
{"type": "function", "function": {
"name": "finish",
"description": ("Submit your FINAL forecast and END your turn. Call once you have the TimesFM "
"forecast AND the EVM metrics and have reconciled them. Earned value cannot exceed "
"BAC; if TimesFM under-projects on a stalled/intermittent series, weigh the EVM EAC "
"formulas (e.g. BAC/CPI) instead."),
"parameters": {"type": "object", "properties": {
"finish_period": {"type": "number", "description": "projected completion period (where cumulative EV reaches BAC)"},
"eac": {"type": "number", "description": "estimate at completion = projected final total cost (currency)"},
"p_overrun": {"type": "number", "description": "probability between 0 and 1 that final cost exceeds 110% of BAC"},
"schedule_reasoning": {"type": "string", "description": "1-2 sentences: schedule outlook and the evidence for it"},
"cost_reasoning": {"type": "string", "description": "1-2 sentences: cost outlook and the evidence for it"}},
"required": ["finish_period", "eac", "p_overrun", "schedule_reasoning", "cost_reasoning"]}}},
]
_SYSTEM = (
"You are a senior project-controls forecaster. Workflow: (1) call evm_metrics for SPI/CPI and the "
"EAC formulas; (2) call forecast_series for EV (and AC if useful) with a horizon large enough to reach "
"BAC; (3) reconcile the TimesFM forecast with the EVM formulas, deciding which to trust; (4) call the "
"finish tool with your final numbers. Earned value cannot exceed BAC. Act ONLY by calling tools, and "
"ALWAYS end by calling finish."
)
_NUDGE = ("Do not reply in plain text - act only via tools. If you still need evidence, call forecast_series "
"or evm_metrics; otherwise call finish now with finish_period, eac, p_overrun, schedule_reasoning "
"and cost_reasoning.")
def _tool_forecast_series(project, k, which, horizon):
cum = project.ev[:k] if which == "ev" else project.ac[:k]
f = forecasting.timesfm_forecast(evm.to_increments(cum), int(horizon),
device="cpu", forecast_context_len=128, bac=project.bac)
last = float(cum[-1])
paths, out = {}, {"horizon": int(horizon)}
for q in ("q10", "q50", "q90"):
c = last + np.cumsum(f[q])
if which == "ev":
c = np.minimum(c, project.bac)
paths[q] = [round(float(x)) for x in c]
out["cumulative_paths"] = paths
if which == "ev":
out["reaches_bac_period"] = {
q: (k + int(np.argmax(np.array(paths[q]) >= 0.999 * project.bac)) + 1)
if (np.array(paths[q]) >= 0.999 * project.bac).any() else None
for q in ("q10", "q50", "q90")}
return out
def _tool_evm(project, k):
pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k]
s = evm.latest(pv, ev, ac, project.bac)
return {
"SPI": round(s["spi"], 3), "CPI": round(s["cpi"], 3),
"EAC_formulas": {m: round(v) for m, v in evm.all_eacs(pv, ev, ac, project.bac).items() if np.isfinite(v)},
"earned_schedule_finish_period": round(float(evm.forecast_finish(pv, ev, project.planned_finish)), 1),
}
def _user_prompt(project, k):
pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k]
return (
f"Project '{project.name}'. BAC (budget at completion) = {round(project.bac)}. "
f"Baseline finish period = {project.planned_finish}. Observed through period {k}.\n"
f"Cumulative PV: {[round(float(x)) for x in pv]}\n"
f"Cumulative EV: {[round(float(x)) for x in ev]}\n"
f"Cumulative AC: {[round(float(x)) for x in ac]}\n"
"Forecast the remaining work and finish with your final numbers."
)
def agent_forecast(project, k, *, client=None, model=MODEL, max_iters=8, max_tokens=MAX_TOKENS,
temperature=0.3, return_trace=False):
"""Run the strict tool-calling loop. Returns {forecast, tool_calls, n_api_calls[, messages]}.
`forecast` is the parsed `finish` arguments dict, or None if the model never called finish."""
client = client or make_client()
messages = [{"role": "system", "content": _SYSTEM},
{"role": "user", "content": _user_prompt(project, k)}]
result, calls, n_api = None, [], 0
usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0}
for _ in range(max_iters):
resp = client.chat.completions.create(model=model, messages=messages, tools=TOOLS,
max_tokens=max_tokens, temperature=temperature)
n_api += 1
u = resp.usage
if u: # track input/output/cache-read tokens
usage["prompt_tokens"] += getattr(u, "prompt_tokens", 0) or 0
usage["completion_tokens"] += getattr(u, "completion_tokens", 0) or 0
ptd = getattr(u, "prompt_tokens_details", None)
if ptd:
usage["cached_tokens"] += getattr(ptd, "cached_tokens", 0) or 0
msg = resp.choices[0].message
tcs = msg.tool_calls or []
assistant = {"role": "assistant", "content": msg.content or ""}
if tcs:
assistant["tool_calls"] = [
{"id": tc.id, "type": "function",
"function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in tcs]
messages.append(assistant)
if not tcs: # no escape hatch: force a tool call
messages.append({"role": "user", "content": _NUDGE})
continue
for tc in tcs:
name = tc.function.name
calls.append(name)
try:
args = json.loads(tc.function.arguments or "{}")
except Exception:
args = {}
try:
if name == "forecast_series":
out = _tool_forecast_series(project, k, **args)
elif name == "evm_metrics":
out = _tool_evm(project, k)
elif name == "finish":
result, out = args, {"ok": True}
else:
out = {"error": "unknown tool"}
except Exception as e:
out = {"error": f"{type(e).__name__}: {e}"}
messages.append({"role": "tool", "tool_call_id": tc.id, "content": json.dumps(out)})
if result is not None:
break
ret = {"forecast": result, "tool_calls": calls, "n_api_calls": n_api, "usage": usage}
if return_trace:
ret["messages"] = messages
return ret