""" Local LLM agent via Ollama — for testing only. Called from inference.py when USE_LOCAL=true. """ import json import requests import os from dotenv import load_dotenv load_dotenv() from models import WhyDidItFailAction # LOCAL_MODEL = "tinyllama" LOCAL_MODEL = os.getenv("LOCAL_MODEL", "phi3") LOCAL_URL = os.getenv("LOCAL_URL", "http://127.0.0.1:11434/api/generate") def _call(prompt: str) -> str: res = requests.post( LOCAL_URL, json={"model": LOCAL_MODEL, "prompt": prompt, "stream": False, "format": "json"}, ) return res.json()["response"] def _extract_json(text: str) -> dict: """Extract the first JSON object from the response. Handles three common small-model output patterns: - Pure JSON object: {"action_type": ...} - JSON wrapped in prose: Sure! Here: {"action_type": ...} - JSON wrapped in a list: [{"action_type": ...}] Uses bracket counting so nested objects don't break the extraction. """ # Try parsing the whole text first (clean output case) try: parsed = json.loads(text) if isinstance(parsed, dict): return parsed if isinstance(parsed, list) and parsed and isinstance(parsed[0], dict): return parsed[0] except json.JSONDecodeError: pass # Fall back: find first { and walk balanced braces start = text.find("{") if start == -1: raise ValueError(f"No JSON object found in response: {text!r}") depth = 0 for i, ch in enumerate(text[start:], start): if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: return json.loads(text[start : i + 1]) raise ValueError(f"Unbalanced braces in response: {text!r}") _KNOWN_FIELDS = set(WhyDidItFailAction.model_fields) def get_action(step: int, prompt: str) -> WhyDidItFailAction: """Call the local LLM and parse the response into a WhyDidItFailAction.""" text = "" try: text = _call(prompt).strip() data = _extract_json(text) filtered = {k: v for k, v in data.items() if k in _KNOWN_FIELDS} # phi3 sometimes echoes the feedback signal instead of an action, e.g.: # {"feedback": "...", "source_to_investigate": "inspect_config"} # Recover action_type from source_to_investigate when possible. if "action_type" not in filtered or filtered.get("action_type") is None: _valid_actions = {"inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"} src = data.get("source_to_investigate", "") if isinstance(src, str) and src in _valid_actions: filtered["action_type"] = src else: raise ValueError(f"action_type missing in parsed output: {data}") return WhyDidItFailAction(**filtered) except Exception as exc: print(f" [LOCAL] parse failed (step {step}): {exc} | raw: {text!r}", flush=True) # Step-based progression: avoid re-inspecting the same source. from typing import cast, Literal _fallback = ["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"] action_type = cast( Literal["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"], _fallback[min(step - 1, len(_fallback) - 1)], ) diagnosis = "unknown" if action_type == "submit_diagnosis" else None return WhyDidItFailAction(action_type=action_type, diagnosis=diagnosis, suggested_fix=None, reasoning=None)