from __future__ import annotations import json import re from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple from transformers import pipeline _JSON_RE = re.compile(r"\{.*\}", re.DOTALL) @dataclass(frozen=True) class LLMConfig: model_name: str = "google/flan-t5-base" max_new_tokens: int = 512 temperature: float = 0.2 top_p: float = 0.95 max_retries: int = 3 class LocalLLM: """ Local HF LLM wrapper. Uses text2text-generation (T5-style instruction models). Designed for stable structured JSON output. """ def __init__(self, config: Optional[LLMConfig] = None): self.config = config or LLMConfig() self.pipe = pipeline( "text2text-generation", model=self.config.model_name, ) def generate(self, prompt: str) -> str: out = self.pipe( prompt, max_new_tokens=self.config.max_new_tokens, do_sample=True, temperature=self.config.temperature, top_p=self.config.top_p, num_return_sequences=1, )[0]["generated_text"] return out.strip() def extract_json(text: str) -> Optional[str]: """ Extract the first JSON object from model output. Handles cases where the model wraps JSON in explanations. """ text = text.strip() if "```" in text: blocks = re.findall( r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE, ) if blocks: return blocks[0].strip() match = _JSON_RE.search(text) if not match: return None return match.group(0).strip() def safe_json_loads(s: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: try: return json.loads(s), None except Exception as exc: return None, str(exc) def generate_validated_json( llm: LocalLLM, prompt: str, validator_fn, max_retries: int = 3, ) -> Dict[str, Any]: """ Generate JSON with retries. validator_fn must raise on invalid JSON or invalid schema. """ last_error = None last_raw = None for attempt in range(1, max_retries + 1): raw = llm.generate(prompt) last_raw = raw js = extract_json(raw) if js is None: last_error = f"No JSON found in output (attempt {attempt}). Raw: {raw[:200]}..." continue data, err = safe_json_loads(js) if data is None: last_error = f"JSON parse error (attempt {attempt}): {err}" continue try: validator_fn(data) return data except Exception as exc: last_error = f"Schema validation error (attempt {attempt}): {exc}" continue raise RuntimeError( "Failed to generate valid JSON after retries.\n" f"Last error: {last_error}\n" f"Last raw output: {last_raw}" )