Spaces:
Running
Running
| """Agentic forecasting: a strict tool-calling loop that orchestrates TimesFM 2.5 + EVM tools. | |
| Uses the OpenAI-compatible Fireworks API (Kimi K2.6 Turbo by default) - so the captured traces | |
| are ALREADY in the messages+tool_calls shape used to SFT MiniCPM5-1B (no lossy conversion). | |
| STRICT tool-calling: the result comes ONLY from a `finish` tool call. No text/JSON fallbacks - | |
| a model that cannot reliably call tools is out of scope. If the model replies without a tool | |
| call it is nudged to use one. With return_trace=True the full OpenAI `messages` list is returned | |
| for distillation. | |
| Auth: FIREWORKS_API_KEY (env, or a gitignored .env in the repo root). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import numpy as np | |
| from . import evm, forecasting | |
| BASE_URL = "https://api.fireworks.ai/inference/v1" # OpenAI-compatible endpoint | |
| MODEL = "accounts/fireworks/routers/kimi-k2p6-turbo" | |
| MAX_TOKENS = 32000 | |
| def _load_env(): | |
| if os.environ.get("FIREWORKS_API_KEY"): | |
| return | |
| p = os.path.join(os.getcwd(), ".env") | |
| if os.path.exists(p): | |
| for line in open(p): | |
| line = line.strip() | |
| if line and not line.startswith("#") and "=" in line: | |
| key, _, val = line.partition("=") | |
| os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'")) | |
| def make_client(base_url=BASE_URL): | |
| _load_env() | |
| from openai import OpenAI | |
| key = os.environ.get("FIREWORKS_API_KEY") | |
| if not key: | |
| raise RuntimeError("FIREWORKS_API_KEY not set (export it or put it in a .env file)") | |
| return OpenAI(base_url=base_url, api_key=key, timeout=900, max_retries=8) # backoff on 429/503 | |
| # OpenAI function-calling tool specs (the exact shape used for MiniCPM5 SFT via apply_chat_template). | |
| TOOLS = [ | |
| {"type": "function", "function": { | |
| "name": "forecast_series", | |
| "description": ("Forecast the continuation of the project's cumulative EV or AC with the TimesFM " | |
| "2.5 time-series model. Returns P10/P50/P90 cumulative paths over `horizon` future " | |
| "periods, and for EV the period each path reaches BAC (completion). Use a horizon " | |
| "large enough for the path to reach BAC."), | |
| "parameters": {"type": "object", "properties": { | |
| "which": {"type": "string", "enum": ["ev", "ac"], "description": "which cumulative series to forecast"}, | |
| "horizon": {"type": "integer", "description": "number of future periods to forecast"}}, | |
| "required": ["which", "horizon"]}}}, | |
| {"type": "function", "function": { | |
| "name": "evm_metrics", | |
| "description": ("Classical Earned Value Management metrics from the observed data: SPI, CPI, the " | |
| "three standard EAC formulas, and an earned-schedule finish-period estimate."), | |
| "parameters": {"type": "object", "properties": {}, "required": []}}}, | |
| {"type": "function", "function": { | |
| "name": "finish", | |
| "description": ("Submit your FINAL forecast and END your turn. Call once you have the TimesFM " | |
| "forecast AND the EVM metrics and have reconciled them. Earned value cannot exceed " | |
| "BAC; if TimesFM under-projects on a stalled/intermittent series, weigh the EVM EAC " | |
| "formulas (e.g. BAC/CPI) instead."), | |
| "parameters": {"type": "object", "properties": { | |
| "finish_period": {"type": "number", "description": "projected completion period (where cumulative EV reaches BAC)"}, | |
| "eac": {"type": "number", "description": "estimate at completion = projected final total cost (currency)"}, | |
| "p_overrun": {"type": "number", "description": "probability between 0 and 1 that final cost exceeds 110% of BAC"}, | |
| "schedule_reasoning": {"type": "string", "description": "1-2 sentences: schedule outlook and the evidence for it"}, | |
| "cost_reasoning": {"type": "string", "description": "1-2 sentences: cost outlook and the evidence for it"}}, | |
| "required": ["finish_period", "eac", "p_overrun", "schedule_reasoning", "cost_reasoning"]}}}, | |
| ] | |
| _SYSTEM = ( | |
| "You are a senior project-controls forecaster. Workflow: (1) call evm_metrics for SPI/CPI and the " | |
| "EAC formulas; (2) call forecast_series for EV (and AC if useful) with a horizon large enough to reach " | |
| "BAC; (3) reconcile the TimesFM forecast with the EVM formulas, deciding which to trust; (4) call the " | |
| "finish tool with your final numbers. Earned value cannot exceed BAC. Act ONLY by calling tools, and " | |
| "ALWAYS end by calling finish." | |
| ) | |
| _NUDGE = ("Do not reply in plain text - act only via tools. If you still need evidence, call forecast_series " | |
| "or evm_metrics; otherwise call finish now with finish_period, eac, p_overrun, schedule_reasoning " | |
| "and cost_reasoning.") | |
| def _tool_forecast_series(project, k, which, horizon): | |
| cum = project.ev[:k] if which == "ev" else project.ac[:k] | |
| f = forecasting.timesfm_forecast(evm.to_increments(cum), int(horizon), | |
| device="cpu", forecast_context_len=128, bac=project.bac) | |
| last = float(cum[-1]) | |
| paths, out = {}, {"horizon": int(horizon)} | |
| for q in ("q10", "q50", "q90"): | |
| c = last + np.cumsum(f[q]) | |
| if which == "ev": | |
| c = np.minimum(c, project.bac) | |
| paths[q] = [round(float(x)) for x in c] | |
| out["cumulative_paths"] = paths | |
| if which == "ev": | |
| out["reaches_bac_period"] = { | |
| q: (k + int(np.argmax(np.array(paths[q]) >= 0.999 * project.bac)) + 1) | |
| if (np.array(paths[q]) >= 0.999 * project.bac).any() else None | |
| for q in ("q10", "q50", "q90")} | |
| return out | |
| def _tool_evm(project, k): | |
| pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k] | |
| s = evm.latest(pv, ev, ac, project.bac) | |
| return { | |
| "SPI": round(s["spi"], 3), "CPI": round(s["cpi"], 3), | |
| "EAC_formulas": {m: round(v) for m, v in evm.all_eacs(pv, ev, ac, project.bac).items() if np.isfinite(v)}, | |
| "earned_schedule_finish_period": round(float(evm.forecast_finish(pv, ev, project.planned_finish)), 1), | |
| } | |
| def _user_prompt(project, k): | |
| pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k] | |
| return ( | |
| f"Project '{project.name}'. BAC (budget at completion) = {round(project.bac)}. " | |
| f"Baseline finish period = {project.planned_finish}. Observed through period {k}.\n" | |
| f"Cumulative PV: {[round(float(x)) for x in pv]}\n" | |
| f"Cumulative EV: {[round(float(x)) for x in ev]}\n" | |
| f"Cumulative AC: {[round(float(x)) for x in ac]}\n" | |
| "Forecast the remaining work and finish with your final numbers." | |
| ) | |
| def agent_forecast(project, k, *, client=None, model=MODEL, max_iters=8, max_tokens=MAX_TOKENS, | |
| temperature=0.3, return_trace=False): | |
| """Run the strict tool-calling loop. Returns {forecast, tool_calls, n_api_calls[, messages]}. | |
| `forecast` is the parsed `finish` arguments dict, or None if the model never called finish.""" | |
| client = client or make_client() | |
| messages = [{"role": "system", "content": _SYSTEM}, | |
| {"role": "user", "content": _user_prompt(project, k)}] | |
| result, calls, n_api = None, [], 0 | |
| usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0} | |
| for _ in range(max_iters): | |
| resp = client.chat.completions.create(model=model, messages=messages, tools=TOOLS, | |
| max_tokens=max_tokens, temperature=temperature) | |
| n_api += 1 | |
| u = resp.usage | |
| if u: # track input/output/cache-read tokens | |
| usage["prompt_tokens"] += getattr(u, "prompt_tokens", 0) or 0 | |
| usage["completion_tokens"] += getattr(u, "completion_tokens", 0) or 0 | |
| ptd = getattr(u, "prompt_tokens_details", None) | |
| if ptd: | |
| usage["cached_tokens"] += getattr(ptd, "cached_tokens", 0) or 0 | |
| msg = resp.choices[0].message | |
| tcs = msg.tool_calls or [] | |
| assistant = {"role": "assistant", "content": msg.content or ""} | |
| if tcs: | |
| assistant["tool_calls"] = [ | |
| {"id": tc.id, "type": "function", | |
| "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in tcs] | |
| messages.append(assistant) | |
| if not tcs: # no escape hatch: force a tool call | |
| messages.append({"role": "user", "content": _NUDGE}) | |
| continue | |
| for tc in tcs: | |
| name = tc.function.name | |
| calls.append(name) | |
| try: | |
| args = json.loads(tc.function.arguments or "{}") | |
| except Exception: | |
| args = {} | |
| try: | |
| if name == "forecast_series": | |
| out = _tool_forecast_series(project, k, **args) | |
| elif name == "evm_metrics": | |
| out = _tool_evm(project, k) | |
| elif name == "finish": | |
| result, out = args, {"ok": True} | |
| else: | |
| out = {"error": "unknown tool"} | |
| except Exception as e: | |
| out = {"error": f"{type(e).__name__}: {e}"} | |
| messages.append({"role": "tool", "tool_call_id": tc.id, "content": json.dumps(out)}) | |
| if result is not None: | |
| break | |
| ret = {"forecast": result, "tool_calls": calls, "n_api_calls": n_api, "usage": usage} | |
| if return_trace: | |
| ret["messages"] = messages | |
| return ret | |