|
|
from unittest.mock import MagicMock, patch |
|
|
from urllib.parse import urljoin |
|
|
|
|
|
import pytest |
|
|
from langchain_ollama import ChatOllama |
|
|
from langflow.components.models import ChatOllamaComponent |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def component(): |
|
|
return ChatOllamaComponent() |
|
|
|
|
|
|
|
|
@patch("httpx.Client.get") |
|
|
def test_get_model_success(mock_get, component): |
|
|
mock_response = MagicMock() |
|
|
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} |
|
|
mock_response.raise_for_status.return_value = None |
|
|
mock_get.return_value = mock_response |
|
|
|
|
|
base_url = "http://localhost:11434" |
|
|
|
|
|
model_names = component.get_model(base_url) |
|
|
|
|
|
expected_url = urljoin(base_url, "/api/tags") |
|
|
|
|
|
mock_get.assert_called_once_with(expected_url) |
|
|
|
|
|
assert model_names == ["model1", "model2"] |
|
|
|
|
|
|
|
|
@patch("httpx.Client.get") |
|
|
def test_get_model_failure(mock_get, component): |
|
|
|
|
|
mock_get.side_effect = Exception("HTTP request failed") |
|
|
|
|
|
url = "http://localhost:11434/api/tags" |
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="Could not retrieve models"): |
|
|
component.get_model(url) |
|
|
|
|
|
|
|
|
def test_update_build_config_mirostat_disabled(component): |
|
|
build_config = { |
|
|
"mirostat_eta": {"advanced": False, "value": 0.1}, |
|
|
"mirostat_tau": {"advanced": False, "value": 5}, |
|
|
} |
|
|
field_value = "Disabled" |
|
|
field_name = "mirostat" |
|
|
|
|
|
updated_config = component.update_build_config(build_config, field_value, field_name) |
|
|
|
|
|
assert updated_config["mirostat_eta"]["advanced"] is True |
|
|
assert updated_config["mirostat_tau"]["advanced"] is True |
|
|
assert updated_config["mirostat_eta"]["value"] is None |
|
|
assert updated_config["mirostat_tau"]["value"] is None |
|
|
|
|
|
|
|
|
def test_update_build_config_mirostat_enabled(component): |
|
|
build_config = { |
|
|
"mirostat_eta": {"advanced": False, "value": None}, |
|
|
"mirostat_tau": {"advanced": False, "value": None}, |
|
|
} |
|
|
field_value = "Mirostat 2.0" |
|
|
field_name = "mirostat" |
|
|
|
|
|
updated_config = component.update_build_config(build_config, field_value, field_name) |
|
|
|
|
|
assert updated_config["mirostat_eta"]["advanced"] is False |
|
|
assert updated_config["mirostat_tau"]["advanced"] is False |
|
|
assert updated_config["mirostat_eta"]["value"] == 0.2 |
|
|
assert updated_config["mirostat_tau"]["value"] == 10 |
|
|
|
|
|
|
|
|
@patch("httpx.Client.get") |
|
|
def test_update_build_config_model_name(mock_get, component): |
|
|
|
|
|
mock_response = MagicMock() |
|
|
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} |
|
|
mock_response.raise_for_status.return_value = None |
|
|
mock_get.return_value = mock_response |
|
|
|
|
|
build_config = { |
|
|
"base_url": {"load_from_db": False, "value": None}, |
|
|
"model_name": {"options": []}, |
|
|
} |
|
|
field_value = None |
|
|
field_name = "model_name" |
|
|
|
|
|
updated_config = component.update_build_config(build_config, field_value, field_name) |
|
|
|
|
|
assert updated_config["model_name"]["options"] == ["model1", "model2"] |
|
|
|
|
|
|
|
|
def test_update_build_config_keep_alive(component): |
|
|
build_config = {"keep_alive": {"value": None, "advanced": False}} |
|
|
field_value = "Keep" |
|
|
field_name = "keep_alive_flag" |
|
|
|
|
|
updated_config = component.update_build_config(build_config, field_value, field_name) |
|
|
assert updated_config["keep_alive"]["value"] == "-1" |
|
|
assert updated_config["keep_alive"]["advanced"] is True |
|
|
|
|
|
field_value = "Immediately" |
|
|
updated_config = component.update_build_config(build_config, field_value, field_name) |
|
|
assert updated_config["keep_alive"]["value"] == "0" |
|
|
assert updated_config["keep_alive"]["advanced"] is True |
|
|
|
|
|
|
|
|
@patch( |
|
|
"langchain_community.chat_models.ChatOllama", |
|
|
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"), |
|
|
) |
|
|
def test_build_model(_mock_chat_ollama, component): |
|
|
component.base_url = "http://localhost:11434" |
|
|
component.model_name = "llama3.1" |
|
|
component.mirostat = "Mirostat 2.0" |
|
|
component.mirostat_eta = 0.2 |
|
|
component.mirostat_tau = 10.0 |
|
|
component.temperature = 0.2 |
|
|
component.verbose = True |
|
|
model = component.build_model() |
|
|
assert isinstance(model, ChatOllama) |
|
|
assert model.base_url == "http://localhost:11434" |
|
|
assert model.model == "llama3.1" |
|
|
|