File size: 2,917 Bytes
7d2fea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import re
from pathlib import Path
from typing import TypeVar

from openai import OpenAI
from pydantic import BaseModel, ValidationError

from models.config import ModelSettings

T = TypeVar("T", bound=BaseModel)


class LLMClient:
    def __init__(self, settings: ModelSettings, debug: bool = False, debug_dir: Path | None = None):
        self.client = OpenAI(base_url=settings.base_url, api_key=settings.api_key)
        self.settings = settings
        self.debug = debug
        self.debug_dir = debug_dir

    def chat(self, messages: list[dict], schema: type[BaseModel] | None = None) -> str:
        extra_body = _extra_body(self.settings, schema)
        response = self.client.chat.completions.create(
            model=self.settings.name,
            messages=messages,
            temperature=self.settings.temperature,
            max_tokens=self.settings.max_tokens,
            extra_body=extra_body,
        )
        content = response.choices[0].message.content or ""
        if self.debug and self.debug_dir:
            self.debug_dir.mkdir(parents=True, exist_ok=True)
            count = len(list(self.debug_dir.glob("llm_response_*.txt")))
            (self.debug_dir / f"llm_response_{count + 1}.txt").write_text(content, encoding="utf-8")
        return content

    def chat_structured(self, messages: list[dict], schema: type[T]) -> T:
        last_error: Exception | None = None
        working_messages = list(messages)
        for _ in range(2):
            raw = self.chat(working_messages, schema=schema)
            cleaned = _strip_markdown_fences(raw)
            try:
                return schema.model_validate_json(cleaned)
            except ValidationError as exc:
                last_error = exc
                working_messages.append(
                    {
                        "role": "user",
                        "content": (
                            "The previous response did not match the required JSON schema. "
                            f"Validation error: {exc}. Return only valid JSON."
                        ),
                    }
                )
        raise last_error or ValueError("Structured LLM response could not be parsed")


def _extra_body(settings: ModelSettings, schema: type[BaseModel] | None = None) -> dict | None:
    extra: dict = {}
    if schema:
        extra["format"] = schema.model_json_schema()
    model_name = settings.name.lower()
    if model_name.startswith("qwen"):
        extra["top_k"] = 20
        extra["chat_template_kwargs"] = {"enable_thinking": False}
    elif "nemotron" in model_name:
        extra["chat_template_kwargs"] = {"enable_thinking": False}
    return extra or None


def _strip_markdown_fences(text: str) -> str:
    stripped = text.strip()
    match = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", stripped, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else stripped