"""Agentic forecasting: a strict tool-calling loop that orchestrates TimesFM 2.5 + EVM tools. Uses the OpenAI-compatible Fireworks API (Kimi K2.6 Turbo by default) - so the captured traces are ALREADY in the messages+tool_calls shape used to SFT MiniCPM5-1B (no lossy conversion). STRICT tool-calling: the result comes ONLY from a `finish` tool call. No text/JSON fallbacks - a model that cannot reliably call tools is out of scope. If the model replies without a tool call it is nudged to use one. With return_trace=True the full OpenAI `messages` list is returned for distillation. Auth: FIREWORKS_API_KEY (env, or a gitignored .env in the repo root). """ from __future__ import annotations import json import os import numpy as np from . import evm, forecasting BASE_URL = "https://api.fireworks.ai/inference/v1" # OpenAI-compatible endpoint MODEL = "accounts/fireworks/routers/kimi-k2p6-turbo" MAX_TOKENS = 32000 def _load_env(): if os.environ.get("FIREWORKS_API_KEY"): return p = os.path.join(os.getcwd(), ".env") if os.path.exists(p): for line in open(p): line = line.strip() if line and not line.startswith("#") and "=" in line: key, _, val = line.partition("=") os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'")) def make_client(base_url=BASE_URL): _load_env() from openai import OpenAI key = os.environ.get("FIREWORKS_API_KEY") if not key: raise RuntimeError("FIREWORKS_API_KEY not set (export it or put it in a .env file)") return OpenAI(base_url=base_url, api_key=key, timeout=900, max_retries=8) # backoff on 429/503 # OpenAI function-calling tool specs (the exact shape used for MiniCPM5 SFT via apply_chat_template). TOOLS = [ {"type": "function", "function": { "name": "forecast_series", "description": ("Forecast the continuation of the project's cumulative EV or AC with the TimesFM " "2.5 time-series model. Returns P10/P50/P90 cumulative paths over `horizon` future " "periods, and for EV the period each path reaches BAC (completion). Use a horizon " "large enough for the path to reach BAC."), "parameters": {"type": "object", "properties": { "which": {"type": "string", "enum": ["ev", "ac"], "description": "which cumulative series to forecast"}, "horizon": {"type": "integer", "description": "number of future periods to forecast"}}, "required": ["which", "horizon"]}}}, {"type": "function", "function": { "name": "evm_metrics", "description": ("Classical Earned Value Management metrics from the observed data: SPI, CPI, the " "three standard EAC formulas, and an earned-schedule finish-period estimate."), "parameters": {"type": "object", "properties": {}, "required": []}}}, {"type": "function", "function": { "name": "finish", "description": ("Submit your FINAL forecast and END your turn. Call once you have the TimesFM " "forecast AND the EVM metrics and have reconciled them. Earned value cannot exceed " "BAC; if TimesFM under-projects on a stalled/intermittent series, weigh the EVM EAC " "formulas (e.g. BAC/CPI) instead."), "parameters": {"type": "object", "properties": { "finish_period": {"type": "number", "description": "projected completion period (where cumulative EV reaches BAC)"}, "eac": {"type": "number", "description": "estimate at completion = projected final total cost (currency)"}, "p_overrun": {"type": "number", "description": "probability between 0 and 1 that final cost exceeds 110% of BAC"}, "schedule_reasoning": {"type": "string", "description": "1-2 sentences: schedule outlook and the evidence for it"}, "cost_reasoning": {"type": "string", "description": "1-2 sentences: cost outlook and the evidence for it"}}, "required": ["finish_period", "eac", "p_overrun", "schedule_reasoning", "cost_reasoning"]}}}, ] _SYSTEM = ( "You are a senior project-controls forecaster. Workflow: (1) call evm_metrics for SPI/CPI and the " "EAC formulas; (2) call forecast_series for EV (and AC if useful) with a horizon large enough to reach " "BAC; (3) reconcile the TimesFM forecast with the EVM formulas, deciding which to trust; (4) call the " "finish tool with your final numbers. Earned value cannot exceed BAC. Act ONLY by calling tools, and " "ALWAYS end by calling finish." ) _NUDGE = ("Do not reply in plain text - act only via tools. If you still need evidence, call forecast_series " "or evm_metrics; otherwise call finish now with finish_period, eac, p_overrun, schedule_reasoning " "and cost_reasoning.") def _tool_forecast_series(project, k, which, horizon): cum = project.ev[:k] if which == "ev" else project.ac[:k] f = forecasting.timesfm_forecast(evm.to_increments(cum), int(horizon), device="cpu", forecast_context_len=128, bac=project.bac) last = float(cum[-1]) paths, out = {}, {"horizon": int(horizon)} for q in ("q10", "q50", "q90"): c = last + np.cumsum(f[q]) if which == "ev": c = np.minimum(c, project.bac) paths[q] = [round(float(x)) for x in c] out["cumulative_paths"] = paths if which == "ev": out["reaches_bac_period"] = { q: (k + int(np.argmax(np.array(paths[q]) >= 0.999 * project.bac)) + 1) if (np.array(paths[q]) >= 0.999 * project.bac).any() else None for q in ("q10", "q50", "q90")} return out def _tool_evm(project, k): pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k] s = evm.latest(pv, ev, ac, project.bac) return { "SPI": round(s["spi"], 3), "CPI": round(s["cpi"], 3), "EAC_formulas": {m: round(v) for m, v in evm.all_eacs(pv, ev, ac, project.bac).items() if np.isfinite(v)}, "earned_schedule_finish_period": round(float(evm.forecast_finish(pv, ev, project.planned_finish)), 1), } def _user_prompt(project, k): pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k] return ( f"Project '{project.name}'. BAC (budget at completion) = {round(project.bac)}. " f"Baseline finish period = {project.planned_finish}. Observed through period {k}.\n" f"Cumulative PV: {[round(float(x)) for x in pv]}\n" f"Cumulative EV: {[round(float(x)) for x in ev]}\n" f"Cumulative AC: {[round(float(x)) for x in ac]}\n" "Forecast the remaining work and finish with your final numbers." ) def agent_forecast(project, k, *, client=None, model=MODEL, max_iters=8, max_tokens=MAX_TOKENS, temperature=0.3, return_trace=False): """Run the strict tool-calling loop. Returns {forecast, tool_calls, n_api_calls[, messages]}. `forecast` is the parsed `finish` arguments dict, or None if the model never called finish.""" client = client or make_client() messages = [{"role": "system", "content": _SYSTEM}, {"role": "user", "content": _user_prompt(project, k)}] result, calls, n_api = None, [], 0 usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0} for _ in range(max_iters): resp = client.chat.completions.create(model=model, messages=messages, tools=TOOLS, max_tokens=max_tokens, temperature=temperature) n_api += 1 u = resp.usage if u: # track input/output/cache-read tokens usage["prompt_tokens"] += getattr(u, "prompt_tokens", 0) or 0 usage["completion_tokens"] += getattr(u, "completion_tokens", 0) or 0 ptd = getattr(u, "prompt_tokens_details", None) if ptd: usage["cached_tokens"] += getattr(ptd, "cached_tokens", 0) or 0 msg = resp.choices[0].message tcs = msg.tool_calls or [] assistant = {"role": "assistant", "content": msg.content or ""} if tcs: assistant["tool_calls"] = [ {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in tcs] messages.append(assistant) if not tcs: # no escape hatch: force a tool call messages.append({"role": "user", "content": _NUDGE}) continue for tc in tcs: name = tc.function.name calls.append(name) try: args = json.loads(tc.function.arguments or "{}") except Exception: args = {} try: if name == "forecast_series": out = _tool_forecast_series(project, k, **args) elif name == "evm_metrics": out = _tool_evm(project, k) elif name == "finish": result, out = args, {"ok": True} else: out = {"error": "unknown tool"} except Exception as e: out = {"error": f"{type(e).__name__}: {e}"} messages.append({"role": "tool", "tool_call_id": tc.id, "content": json.dumps(out)}) if result is not None: break ret = {"forecast": result, "tool_calls": calls, "n_api_calls": n_api, "usage": usage} if return_trace: ret["messages"] = messages return ret