sky2 / tests /config /test_llm_params.py
JustinTX's picture
Add files using upload-large-folder tool
af83196 verified
"""Tests for LLM config: optional temperature/top_p and api_base routing."""
from dataclasses import fields
from unittest.mock import AsyncMock, patch
import pytest
from skydiscover.config import LLMConfig, LLMModelConfig
_OPENAI_DEFAULT_API_BASE: str = next(
f.default for f in fields(LLMConfig) if f.name == "api_base"
)
class TestLLMConfigDefaults:
def test_default_temperature(self):
cfg = LLMConfig(name="test-model")
assert cfg.temperature == 0.7
def test_default_top_p_is_none(self):
cfg = LLMConfig(name="test-model")
assert cfg.top_p is None
def test_explicit_none_temperature(self):
cfg = LLMConfig(name="test-model", temperature=None)
assert cfg.temperature is None
def test_explicit_none_top_p(self):
cfg = LLMConfig(name="test-model", top_p=None)
assert cfg.top_p is None
def test_both_none(self):
cfg = LLMConfig(name="test-model", temperature=None, top_p=None)
assert cfg.temperature is None
assert cfg.top_p is None
class TestApiBaseRouting:
def test_unknown_model_preserves_local_api_base(self):
local = "http://localhost:11434/v1"
cfg = LLMConfig(
name="my-custom-local-model",
api_base=local,
models=[LLMModelConfig(name="my-custom-local-model")],
)
assert cfg.models[0].api_base == local
def test_unknown_model_gets_openai_default(self):
cfg = LLMConfig(
name="my-custom-local-model",
models=[LLMModelConfig(name="my-custom-local-model")],
)
assert cfg.models[0].api_base == _OPENAI_DEFAULT_API_BASE
def test_mixed_providers_with_local_api_base(self):
cfg = LLMConfig(
api_base="http://localhost:11434/v1",
models=[
LLMModelConfig(name="anthropic/claude-3-sonnet"),
LLMModelConfig(name="my-local-model"),
],
)
assert cfg.models[0].api_base == "https://api.anthropic.com/v1/"
assert cfg.models[1].api_base == "http://localhost:11434/v1"
class TestOpenAILLMParams:
def _make_llm(self, temperature=0.7, top_p=0.95):
from skydiscover.llm.openai import OpenAILLM
cfg = LLMModelConfig(
name="test-model",
temperature=temperature,
top_p=top_p,
api_base="http://localhost:1234/v1",
api_key="fake",
timeout=10,
retries=0,
retry_delay=0,
)
with patch("skydiscover.llm.openai.openai.OpenAI"):
llm = OpenAILLM(cfg)
return llm
@pytest.mark.asyncio
async def test_params_include_temperature_and_top_p(self):
llm = self._make_llm(temperature=0.5, top_p=0.9)
llm._call_api = AsyncMock(return_value="response")
await llm.generate(
system_message="sys",
messages=[{"role": "user", "content": "user"}],
temperature=0.5,
top_p=0.9,
)
params = llm._call_api.call_args[0][0]
assert params["temperature"] == 0.5
assert params["top_p"] == 0.9
@pytest.mark.asyncio
async def test_params_exclude_none_top_p(self):
llm = self._make_llm(top_p=None)
llm._call_api = AsyncMock(return_value="response")
await llm.generate(system_message="sys", messages=[{"role": "user", "content": "user"}])
params = llm._call_api.call_args[0][0]
assert "top_p" not in params
assert "temperature" in params
@pytest.mark.asyncio
async def test_params_exclude_none_temperature(self):
llm = self._make_llm(temperature=None)
llm._call_api = AsyncMock(return_value="response")
await llm.generate(system_message="sys", messages=[{"role": "user", "content": "user"}])
params = llm._call_api.call_args[0][0]
assert "temperature" not in params
assert "top_p" in params
@pytest.mark.asyncio
async def test_params_exclude_both_none(self):
llm = self._make_llm(temperature=None, top_p=None)
llm._call_api = AsyncMock(return_value="response")
await llm.generate(system_message="sys", messages=[{"role": "user", "content": "user"}])
params = llm._call_api.call_args[0][0]
assert "temperature" not in params
assert "top_p" not in params