File size: 3,593 Bytes
e8c7211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26630c7
e8c7211
26630c7
 
 
 
e8c7211
26630c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Local LLM agent via Ollama — for testing only.
Called from inference.py when USE_LOCAL=true.
"""

import json
import requests
import os
from dotenv import load_dotenv

load_dotenv()

from models import WhyDidItFailAction

# LOCAL_MODEL = "tinyllama"
LOCAL_MODEL = os.getenv("LOCAL_MODEL", "phi3")
LOCAL_URL   = os.getenv("LOCAL_URL", "http://127.0.0.1:11434/api/generate")


def _call(prompt: str) -> str:
    res = requests.post(
        LOCAL_URL,
        json={"model": LOCAL_MODEL, "prompt": prompt, "stream": False, "format": "json"},
    )
    return res.json()["response"]


def _extract_json(text: str) -> dict:
    """Extract the first JSON object from the response.

    Handles three common small-model output patterns:
      - Pure JSON object:        {"action_type": ...}
      - JSON wrapped in prose:   Sure! Here: {"action_type": ...}
      - JSON wrapped in a list:  [{"action_type": ...}]
    Uses bracket counting so nested objects don't break the extraction.
    """
    # Try parsing the whole text first (clean output case)
    try:
        parsed = json.loads(text)
        if isinstance(parsed, dict):
            return parsed
        if isinstance(parsed, list) and parsed and isinstance(parsed[0], dict):
            return parsed[0]
    except json.JSONDecodeError:
        pass

    # Fall back: find first { and walk balanced braces
    start = text.find("{")
    if start == -1:
        raise ValueError(f"No JSON object found in response: {text!r}")

    depth = 0
    for i, ch in enumerate(text[start:], start):
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return json.loads(text[start : i + 1])

    raise ValueError(f"Unbalanced braces in response: {text!r}")


_KNOWN_FIELDS = set(WhyDidItFailAction.model_fields)


def get_action(step: int, prompt: str) -> WhyDidItFailAction:
    """Call the local LLM and parse the response into a WhyDidItFailAction."""
    text = ""
    try:
        text = _call(prompt).strip()
        data = _extract_json(text)
        filtered = {k: v for k, v in data.items() if k in _KNOWN_FIELDS}

        # phi3 sometimes echoes the feedback signal instead of an action, e.g.:
        # {"feedback": "...", "source_to_investigate": "inspect_config"}
        # Recover action_type from source_to_investigate when possible.
        if "action_type" not in filtered or filtered.get("action_type") is None:
            _valid_actions = {"inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"}
            src = data.get("source_to_investigate", "")
            if isinstance(src, str) and src in _valid_actions:
                filtered["action_type"] = src
            else:
                raise ValueError(f"action_type missing in parsed output: {data}")

        return WhyDidItFailAction(**filtered)
    except Exception as exc:
        print(f"  [LOCAL] parse failed (step {step}): {exc} | raw: {text!r}", flush=True)
        # Step-based progression: avoid re-inspecting the same source.
        from typing import cast, Literal
        _fallback = ["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"]
        action_type = cast(
            Literal["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
            _fallback[min(step - 1, len(_fallback) - 1)],
        )
        diagnosis = "unknown" if action_type == "submit_diagnosis" else None
        return WhyDidItFailAction(action_type=action_type, diagnosis=diagnosis, suggested_fix=None, reasoning=None)