Spaces:
Runtime error
Runtime error
| """Test LLM chain.""" | |
| from tempfile import TemporaryDirectory | |
| from typing import Dict, List, Union | |
| from unittest.mock import patch | |
| import pytest | |
| from langchain.chains.llm import LLMChain | |
| from langchain.chains.loading import load_chain | |
| from langchain.prompts.base import BaseOutputParser | |
| from langchain.prompts.prompt import PromptTemplate | |
| from tests.unit_tests.llms.fake_llm import FakeLLM | |
| class FakeOutputParser(BaseOutputParser): | |
| """Fake output parser class for testing.""" | |
| def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: | |
| """Parse by splitting.""" | |
| return text.split() | |
| def fake_llm_chain() -> LLMChain: | |
| """Fake LLM chain for testing purposes.""" | |
| prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:") | |
| return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") | |
| def test_serialization(fake_llm_chain: LLMChain) -> None: | |
| """Test serialization.""" | |
| with TemporaryDirectory() as temp_dir: | |
| file = temp_dir + "/llm.json" | |
| fake_llm_chain.save(file) | |
| loaded_chain = load_chain(file) | |
| assert loaded_chain == fake_llm_chain | |
| def test_missing_inputs(fake_llm_chain: LLMChain) -> None: | |
| """Test error is raised if inputs are missing.""" | |
| with pytest.raises(ValueError): | |
| fake_llm_chain({"foo": "bar"}) | |
| def test_valid_call(fake_llm_chain: LLMChain) -> None: | |
| """Test valid call of LLM chain.""" | |
| output = fake_llm_chain({"bar": "baz"}) | |
| assert output == {"bar": "baz", "text1": "foo"} | |
| # Test with stop words. | |
| output = fake_llm_chain({"bar": "baz", "stop": ["foo"]}) | |
| # Response should be `bar` now. | |
| assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"} | |
| def test_predict_method(fake_llm_chain: LLMChain) -> None: | |
| """Test predict method works.""" | |
| output = fake_llm_chain.predict(bar="baz") | |
| assert output == "foo" | |
| def test_predict_and_parse() -> None: | |
| """Test parsing ability.""" | |
| prompt = PromptTemplate( | |
| input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser() | |
| ) | |
| llm = FakeLLM(queries={"foo": "foo bar"}) | |
| chain = LLMChain(prompt=prompt, llm=llm) | |
| output = chain.predict_and_parse(foo="foo") | |
| assert output == ["foo", "bar"] | |