File size: 3,915 Bytes
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""NVIDIA NIM settings (fixed values, no env config)."""

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator

from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS


class NimSettings(BaseModel):
    """Fixed NVIDIA NIM settings (not configurable via env)."""

    temperature: float = Field(
        1.0, ge=0.0, le=2.0, description="Sampling temperature, must be >=0 and <=2."
    )
    top_p: float = Field(
        1.0, ge=0.0, le=1.0, description="Nucleus sampling probability. [0,1]"
    )
    top_k: int = -1
    max_tokens: int = Field(
        ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
        ge=1,
        description="Maximum number of tokens in output.",
    )
    presence_penalty: float = Field(0.0, ge=-2.0, le=2.0)
    frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
    min_p: float = Field(
        0.0, ge=0.0, le=1.0, description="Minimum probability threshold [0,1]."
    )
    repetition_penalty: float = Field(
        1.0, ge=0.0, description="Penalty for repeated tokens. Must be >=0."
    )
    seed: int | None = None
    stop: str | None = None
    parallel_tool_calls: bool = True
    ignore_eos: bool = False
    min_tokens: int = Field(0, ge=0, description="Minimum tokens in the response.")
    chat_template: str | None = None
    request_id: str | None = None

    model_config = ConfigDict(extra="forbid")

    @field_validator("top_k", mode="before")
    @classmethod
    def validate_top_k(cls, v, info: ValidationInfo):
        if v is None or v == "":
            return -1
        int_v = int(v)
        if int_v < -1:
            raise ValueError(f"{info.field_name} must be -1 or >= 0")
        return int_v

    @field_validator(
        "temperature",
        "top_p",
        "min_p",
        "presence_penalty",
        "frequency_penalty",
        "repetition_penalty",
        mode="before",
    )
    @classmethod
    def validate_float_fields(cls, v, info: ValidationInfo):
        field_defaults = {
            "temperature": 1.0,
            "top_p": 1.0,
            "min_p": 0.0,
            "presence_penalty": 0.0,
            "frequency_penalty": 0.0,
            "repetition_penalty": 1.0,
        }
        if v is None or v == "":
            key = info.field_name or "temperature"
            return field_defaults.get(key, 1.0)
        try:
            val = float(v)
        except (TypeError, ValueError) as err:
            raise ValueError(
                f"{info.field_name} must be a float. Got {type(v).__name__}."
            ) from err
        return val

    @field_validator("max_tokens", "min_tokens", mode="before")
    @classmethod
    def validate_int_fields(cls, v, info: ValidationInfo):
        field_defaults = {
            "max_tokens": ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
            "min_tokens": 0,
        }
        if v is None or v == "":
            key = info.field_name or "max_tokens"
            return field_defaults.get(key, ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS)
        try:
            val = int(v)
        except (TypeError, ValueError) as err:
            raise ValueError(
                f"{info.field_name} must be an int. Got {type(v).__name__}."
            ) from err
        return val

    @field_validator("seed", mode="before")
    @classmethod
    def parse_optional_int(cls, v, info: ValidationInfo):
        if v == "" or v is None:
            return None
        try:
            return int(v)
        except (TypeError, ValueError) as err:
            raise ValueError(
                f"{info.field_name} must be an int or empty/None."
            ) from err

    @field_validator("stop", "chat_template", "request_id", mode="before")
    @classmethod
    def parse_optional_str(cls, v, info: ValidationInfo):
        if v == "":
            return None
        if v is not None and not isinstance(v, str):
            return str(v)
        return v