File size: 9,671 Bytes
c658ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Agentic forecasting: a strict tool-calling loop that orchestrates TimesFM 2.5 + EVM tools.

Uses the OpenAI-compatible Fireworks API (Kimi K2.6 Turbo by default) - so the captured traces
are ALREADY in the messages+tool_calls shape used to SFT MiniCPM5-1B (no lossy conversion).

STRICT tool-calling: the result comes ONLY from a `finish` tool call. No text/JSON fallbacks -
a model that cannot reliably call tools is out of scope. If the model replies without a tool
call it is nudged to use one. With return_trace=True the full OpenAI `messages` list is returned
for distillation.

Auth: FIREWORKS_API_KEY (env, or a gitignored .env in the repo root).
"""
from __future__ import annotations

import json
import os

import numpy as np

from . import evm, forecasting

BASE_URL = "https://api.fireworks.ai/inference/v1"        # OpenAI-compatible endpoint
MODEL = "accounts/fireworks/routers/kimi-k2p6-turbo"
MAX_TOKENS = 32000


def _load_env():
    if os.environ.get("FIREWORKS_API_KEY"):
        return
    p = os.path.join(os.getcwd(), ".env")
    if os.path.exists(p):
        for line in open(p):
            line = line.strip()
            if line and not line.startswith("#") and "=" in line:
                key, _, val = line.partition("=")
                os.environ.setdefault(key.strip(), val.strip().strip('"').strip("'"))


def make_client(base_url=BASE_URL):
    _load_env()
    from openai import OpenAI

    key = os.environ.get("FIREWORKS_API_KEY")
    if not key:
        raise RuntimeError("FIREWORKS_API_KEY not set (export it or put it in a .env file)")
    return OpenAI(base_url=base_url, api_key=key, timeout=900, max_retries=8)  # backoff on 429/503


# OpenAI function-calling tool specs (the exact shape used for MiniCPM5 SFT via apply_chat_template).
TOOLS = [
    {"type": "function", "function": {
        "name": "forecast_series",
        "description": ("Forecast the continuation of the project's cumulative EV or AC with the TimesFM "
                        "2.5 time-series model. Returns P10/P50/P90 cumulative paths over `horizon` future "
                        "periods, and for EV the period each path reaches BAC (completion). Use a horizon "
                        "large enough for the path to reach BAC."),
        "parameters": {"type": "object", "properties": {
            "which": {"type": "string", "enum": ["ev", "ac"], "description": "which cumulative series to forecast"},
            "horizon": {"type": "integer", "description": "number of future periods to forecast"}},
            "required": ["which", "horizon"]}}},
    {"type": "function", "function": {
        "name": "evm_metrics",
        "description": ("Classical Earned Value Management metrics from the observed data: SPI, CPI, the "
                        "three standard EAC formulas, and an earned-schedule finish-period estimate."),
        "parameters": {"type": "object", "properties": {}, "required": []}}},
    {"type": "function", "function": {
        "name": "finish",
        "description": ("Submit your FINAL forecast and END your turn. Call once you have the TimesFM "
                        "forecast AND the EVM metrics and have reconciled them. Earned value cannot exceed "
                        "BAC; if TimesFM under-projects on a stalled/intermittent series, weigh the EVM EAC "
                        "formulas (e.g. BAC/CPI) instead."),
        "parameters": {"type": "object", "properties": {
            "finish_period": {"type": "number", "description": "projected completion period (where cumulative EV reaches BAC)"},
            "eac": {"type": "number", "description": "estimate at completion = projected final total cost (currency)"},
            "p_overrun": {"type": "number", "description": "probability between 0 and 1 that final cost exceeds 110% of BAC"},
            "schedule_reasoning": {"type": "string", "description": "1-2 sentences: schedule outlook and the evidence for it"},
            "cost_reasoning": {"type": "string", "description": "1-2 sentences: cost outlook and the evidence for it"}},
            "required": ["finish_period", "eac", "p_overrun", "schedule_reasoning", "cost_reasoning"]}}},
]

_SYSTEM = (
    "You are a senior project-controls forecaster. Workflow: (1) call evm_metrics for SPI/CPI and the "
    "EAC formulas; (2) call forecast_series for EV (and AC if useful) with a horizon large enough to reach "
    "BAC; (3) reconcile the TimesFM forecast with the EVM formulas, deciding which to trust; (4) call the "
    "finish tool with your final numbers. Earned value cannot exceed BAC. Act ONLY by calling tools, and "
    "ALWAYS end by calling finish."
)

_NUDGE = ("Do not reply in plain text - act only via tools. If you still need evidence, call forecast_series "
          "or evm_metrics; otherwise call finish now with finish_period, eac, p_overrun, schedule_reasoning "
          "and cost_reasoning.")


def _tool_forecast_series(project, k, which, horizon):
    cum = project.ev[:k] if which == "ev" else project.ac[:k]
    f = forecasting.timesfm_forecast(evm.to_increments(cum), int(horizon),
                                     device="cpu", forecast_context_len=128, bac=project.bac)
    last = float(cum[-1])
    paths, out = {}, {"horizon": int(horizon)}
    for q in ("q10", "q50", "q90"):
        c = last + np.cumsum(f[q])
        if which == "ev":
            c = np.minimum(c, project.bac)
        paths[q] = [round(float(x)) for x in c]
    out["cumulative_paths"] = paths
    if which == "ev":
        out["reaches_bac_period"] = {
            q: (k + int(np.argmax(np.array(paths[q]) >= 0.999 * project.bac)) + 1)
            if (np.array(paths[q]) >= 0.999 * project.bac).any() else None
            for q in ("q10", "q50", "q90")}
    return out


def _tool_evm(project, k):
    pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k]
    s = evm.latest(pv, ev, ac, project.bac)
    return {
        "SPI": round(s["spi"], 3), "CPI": round(s["cpi"], 3),
        "EAC_formulas": {m: round(v) for m, v in evm.all_eacs(pv, ev, ac, project.bac).items() if np.isfinite(v)},
        "earned_schedule_finish_period": round(float(evm.forecast_finish(pv, ev, project.planned_finish)), 1),
    }


def _user_prompt(project, k):
    pv, ev, ac = project.pv[:k], project.ev[:k], project.ac[:k]
    return (
        f"Project '{project.name}'. BAC (budget at completion) = {round(project.bac)}. "
        f"Baseline finish period = {project.planned_finish}. Observed through period {k}.\n"
        f"Cumulative PV: {[round(float(x)) for x in pv]}\n"
        f"Cumulative EV: {[round(float(x)) for x in ev]}\n"
        f"Cumulative AC: {[round(float(x)) for x in ac]}\n"
        "Forecast the remaining work and finish with your final numbers."
    )


def agent_forecast(project, k, *, client=None, model=MODEL, max_iters=8, max_tokens=MAX_TOKENS,
                   temperature=0.3, return_trace=False):
    """Run the strict tool-calling loop. Returns {forecast, tool_calls, n_api_calls[, messages]}.
    `forecast` is the parsed `finish` arguments dict, or None if the model never called finish."""
    client = client or make_client()
    messages = [{"role": "system", "content": _SYSTEM},
                {"role": "user", "content": _user_prompt(project, k)}]
    result, calls, n_api = None, [], 0
    usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0}
    for _ in range(max_iters):
        resp = client.chat.completions.create(model=model, messages=messages, tools=TOOLS,
                                              max_tokens=max_tokens, temperature=temperature)
        n_api += 1
        u = resp.usage
        if u:                                            # track input/output/cache-read tokens
            usage["prompt_tokens"] += getattr(u, "prompt_tokens", 0) or 0
            usage["completion_tokens"] += getattr(u, "completion_tokens", 0) or 0
            ptd = getattr(u, "prompt_tokens_details", None)
            if ptd:
                usage["cached_tokens"] += getattr(ptd, "cached_tokens", 0) or 0
        msg = resp.choices[0].message
        tcs = msg.tool_calls or []
        assistant = {"role": "assistant", "content": msg.content or ""}
        if tcs:
            assistant["tool_calls"] = [
                {"id": tc.id, "type": "function",
                 "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in tcs]
        messages.append(assistant)
        if not tcs:                                          # no escape hatch: force a tool call
            messages.append({"role": "user", "content": _NUDGE})
            continue
        for tc in tcs:
            name = tc.function.name
            calls.append(name)
            try:
                args = json.loads(tc.function.arguments or "{}")
            except Exception:
                args = {}
            try:
                if name == "forecast_series":
                    out = _tool_forecast_series(project, k, **args)
                elif name == "evm_metrics":
                    out = _tool_evm(project, k)
                elif name == "finish":
                    result, out = args, {"ok": True}
                else:
                    out = {"error": "unknown tool"}
            except Exception as e:
                out = {"error": f"{type(e).__name__}: {e}"}
            messages.append({"role": "tool", "tool_call_id": tc.id, "content": json.dumps(out)})
        if result is not None:
            break
    ret = {"forecast": result, "tool_calls": calls, "n_api_calls": n_api, "usage": usage}
    if return_trace:
        ret["messages"] = messages
    return ret