Spaces:
Running
Running
File size: 9,671 Bytes
c658ad5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """Agentic forecasting: a strict tool-calling loop that orchestrates TimesFM 2.5 + EVM tools.
Uses the OpenAI-compatible Fireworks API (Kimi K2.6 Turbo by default) - so the captured traces
are ALREADY in the messages+tool_calls shape used to SFT MiniCPM5-1B (no lossy conversion).
STRICT tool-calling: the result comes ONLY from a `finish` tool call. No text/JSON fallbacks -
a model that cannot reliably call tools is out of scope. If the model replies without a tool
call it is nudged to use one. With return_trace=True the full OpenAI `messages` list is returned
for distillation.
Auth: FIREWORKS_API_KEY (env, or a gitignored .env in the repo root).
"""
from __future__ import annotations
import json
import os
import numpy as np
from . import evm, forecasting
BASE_URL = "https://api.fireworks.ai/inference/v1" # OpenAI-compatible endpoint
MODEL = "accounts/fireworks/routers/kimi-k2p6-turbo"
MAX_TOKENS = 32000
def _load_env():
if os.environ.get("FIREWORKS_API_KEY"):
return
p = os.path.join(os.getcwd(), ".env")
if os.path.exists(p):
for line in open(p):
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, val = line.partition("=")
os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'"))
def make_client(base_url=BASE_URL):
_load_env()
from openai import OpenAI
key = os.environ.get("FIREWORKS_API_KEY")
if not key:
raise RuntimeError("FIREWORKS_API_KEY not set (export it or put it in a .env file)")
return OpenAI(base_url=base_url, api_key=key, timeout=900, max_retries=8) # backoff on 429/503
# OpenAI function-calling tool specs (the exact shape used for MiniCPM5 SFT via apply_chat_template).
TOOLS = [
{"type": "function", "function": {
"name": "forecast_series",
"description": ("Forecast the continuation of the project's cumulative EV or AC with the TimesFM "
"2.5 time-series model. Returns P10/P50/P90 cumulative paths over `horizon` future "
"periods, and for EV the period each path reaches BAC (completion). Use a horizon "
"large enough for the path to reach BAC."),
"parameters": {"type": "object", "properties": {
"which": {"type": "string", "enum": ["ev", "ac"], "description": "which cumulative series to forecast"},
"horizon": {"type": "integer", "description": "number of future periods to forecast"}},
"required": ["which", "horizon"]}}},
{"type": "function", "function": {
"name": "evm_metrics",
"description": ("Classical Earned Value Management metrics from the observed data: SPI, CPI, the "
"three standard EAC formulas, and an earned-schedule finish-period estimate."),
"parameters": {"type": "object", "properties": {}, "required": []}}},
{"type": "function", "function": {
"name": "finish",
"description": ("Submit your FINAL forecast and END your turn. Call once you have the TimesFM "
"forecast AND the EVM metrics and have reconciled them. Earned value cannot exceed "
"BAC; if TimesFM under-projects on a stalled/intermittent series, weigh the EVM EAC "
"formulas (e.g. BAC/CPI) instead."),
"parameters": {"type": "object", "properties": {
"finish_period": {"type": "number", "description": "projected completion period (where cumulative EV reaches BAC)"},
"eac": {"type": "number", "description": "estimate at completion = projected final total cost (currency)"},
"p_overrun": {"type": "number", "description": "probability between 0 and 1 that final cost exceeds 110% of BAC"},
"schedule_reasoning": {"type": "string", "description": "1-2 sentences: schedule outlook and the evidence for it"},
"cost_reasoning": {"type": "string", "description": "1-2 sentences: cost outlook and the evidence for it"}},
"required": ["finish_period", "eac", "p_overrun", "schedule_reasoning", "cost_reasoning"]}}},
]
_SYSTEM = (
"You are a senior project-controls forecaster. Workflow: (1) call evm_metrics for SPI/CPI and the "
"EAC formulas; (2) call forecast_series for EV (and AC if useful) with a horizon large enough to reach "
"BAC; (3) reconcile the TimesFM forecast with the EVM formulas, deciding which to trust; (4) call the "
"finish tool with your final numbers. Earned value cannot exceed BAC. Act ONLY by calling tools, and "
"ALWAYS end by calling finish."
)
_NUDGE = ("Do not reply in plain text - act only via tools. If you still need evidence, call forecast_series "
"or evm_metrics; otherwise call finish now with finish_period, eac, p_overrun, schedule_reasoning "
"and cost_reasoning.")
def _tool_forecast_series(project, k, which, horizon):
cum = project.ev[:k] if which == "ev" else project.ac[:k]
f = forecasting.timesfm_forecast(evm.to_increments(cum), int(horizon),
device="cpu", forecast_context_len=128, bac=project.bac)
last = float(cum[-1])
paths, out = {}, {"horizon": int(horizon)}
for q in ("q10", "q50", "q90"):
c = last + np.cumsum(f[q])
if which == "ev":
c = np.minimum(c, project.bac)
paths[q] = [round(float(x)) for x in c]
out["cumulative_paths"] = paths
if which == "ev":
out["reaches_bac_period"] = {
q: (k + int(np.argmax(np.array(paths[q]) >= 0.999 * project.bac)) + 1)
if (np.array(paths[q]) >= 0.999 * project.bac).any() else None
for q in ("q10", "q50", "q90")}
return out
def _tool_evm(project, k):
pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k]
s = evm.latest(pv, ev, ac, project.bac)
return {
"SPI": round(s["spi"], 3), "CPI": round(s["cpi"], 3),
"EAC_formulas": {m: round(v) for m, v in evm.all_eacs(pv, ev, ac, project.bac).items() if np.isfinite(v)},
"earned_schedule_finish_period": round(float(evm.forecast_finish(pv, ev, project.planned_finish)), 1),
}
def _user_prompt(project, k):
pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k]
return (
f"Project '{project.name}'. BAC (budget at completion) = {round(project.bac)}. "
f"Baseline finish period = {project.planned_finish}. Observed through period {k}.\n"
f"Cumulative PV: {[round(float(x)) for x in pv]}\n"
f"Cumulative EV: {[round(float(x)) for x in ev]}\n"
f"Cumulative AC: {[round(float(x)) for x in ac]}\n"
"Forecast the remaining work and finish with your final numbers."
)
def agent_forecast(project, k, *, client=None, model=MODEL, max_iters=8, max_tokens=MAX_TOKENS,
temperature=0.3, return_trace=False):
"""Run the strict tool-calling loop. Returns {forecast, tool_calls, n_api_calls[, messages]}.
`forecast` is the parsed `finish` arguments dict, or None if the model never called finish."""
client = client or make_client()
messages = [{"role": "system", "content": _SYSTEM},
{"role": "user", "content": _user_prompt(project, k)}]
result, calls, n_api = None, [], 0
usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0}
for _ in range(max_iters):
resp = client.chat.completions.create(model=model, messages=messages, tools=TOOLS,
max_tokens=max_tokens, temperature=temperature)
n_api += 1
u = resp.usage
if u: # track input/output/cache-read tokens
usage["prompt_tokens"] += getattr(u, "prompt_tokens", 0) or 0
usage["completion_tokens"] += getattr(u, "completion_tokens", 0) or 0
ptd = getattr(u, "prompt_tokens_details", None)
if ptd:
usage["cached_tokens"] += getattr(ptd, "cached_tokens", 0) or 0
msg = resp.choices[0].message
tcs = msg.tool_calls or []
assistant = {"role": "assistant", "content": msg.content or ""}
if tcs:
assistant["tool_calls"] = [
{"id": tc.id, "type": "function",
"function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in tcs]
messages.append(assistant)
if not tcs: # no escape hatch: force a tool call
messages.append({"role": "user", "content": _NUDGE})
continue
for tc in tcs:
name = tc.function.name
calls.append(name)
try:
args = json.loads(tc.function.arguments or "{}")
except Exception:
args = {}
try:
if name == "forecast_series":
out = _tool_forecast_series(project, k, **args)
elif name == "evm_metrics":
out = _tool_evm(project, k)
elif name == "finish":
result, out = args, {"ok": True}
else:
out = {"error": "unknown tool"}
except Exception as e:
out = {"error": f"{type(e).__name__}: {e}"}
messages.append({"role": "tool", "tool_call_id": tc.id, "content": json.dumps(out)})
if result is not None:
break
ret = {"forecast": result, "tool_calls": calls, "n_api_calls": n_api, "usage": usage}
if return_trace:
ret["messages"] = messages
return ret
|