Spaces:
Runtime error
Runtime error
| """Test pipeline functionality.""" | |
| from typing import Dict, List | |
| import pytest | |
| from pydantic import BaseModel | |
| from langchain.chains.base import Chain | |
| from langchain.chains.sequential import SequentialChain, SimpleSequentialChain | |
| class FakeChain(Chain, BaseModel): | |
| """Fake Chain for testing purposes.""" | |
| input_variables: List[str] | |
| output_variables: List[str] | |
| def input_keys(self) -> List[str]: | |
| """Input keys this chain returns.""" | |
| return self.input_variables | |
| def output_keys(self) -> List[str]: | |
| """Input keys this chain returns.""" | |
| return self.output_variables | |
| def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |
| outputs = {} | |
| for var in self.output_variables: | |
| variables = [inputs[k] for k in self.input_variables] | |
| outputs[var] = f"{' '.join(variables)}foo" | |
| return outputs | |
| def test_sequential_usage_single_inputs() -> None: | |
| """Test sequential on single input chains.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) | |
| chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) | |
| output = chain({"foo": "123"}) | |
| expected_output = {"baz": "123foofoo", "foo": "123"} | |
| assert output == expected_output | |
| def test_sequential_usage_multiple_inputs() -> None: | |
| """Test sequential on multiple input chains.""" | |
| chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) | |
| chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) | |
| output = chain({"foo": "123", "test": "456"}) | |
| expected_output = { | |
| "baz": "123 456foo 123foo", | |
| "foo": "123", | |
| "test": "456", | |
| } | |
| assert output == expected_output | |
| def test_sequential_usage_multiple_outputs() -> None: | |
| """Test sequential usage on multiple output chains.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) | |
| chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) | |
| chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) | |
| output = chain({"foo": "123"}) | |
| expected_output = { | |
| "baz": "123foo 123foo", | |
| "foo": "123", | |
| } | |
| assert output == expected_output | |
| def test_sequential_missing_inputs() -> None: | |
| """Test error is raised when input variables are missing.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"]) | |
| with pytest.raises(ValueError): | |
| # Also needs "test" as an input | |
| SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) | |
| def test_sequential_bad_outputs() -> None: | |
| """Test error is raised when bad outputs are specified.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) | |
| with pytest.raises(ValueError): | |
| # "test" is not present as an output variable. | |
| SequentialChain( | |
| chains=[chain_1, chain_2], | |
| input_variables=["foo"], | |
| output_variables=["test"], | |
| ) | |
| def test_sequential_valid_outputs() -> None: | |
| """Test chain runs when valid outputs are specified.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) | |
| chain = SequentialChain( | |
| chains=[chain_1, chain_2], | |
| input_variables=["foo"], | |
| output_variables=["bar", "baz"], | |
| ) | |
| output = chain({"foo": "123"}, return_only_outputs=True) | |
| expected_output = {"baz": "123foofoo", "bar": "123foo"} | |
| assert output == expected_output | |
| def test_sequential_overlapping_inputs() -> None: | |
| """Test error is raised when input variables are overlapping.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) | |
| chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) | |
| with pytest.raises(ValueError): | |
| # "test" is specified as an input, but also is an output of one step | |
| SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) | |
| def test_simple_sequential_functionality() -> None: | |
| """Test simple sequential functionality.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) | |
| chain = SimpleSequentialChain(chains=[chain_1, chain_2]) | |
| output = chain({"input": "123"}) | |
| expected_output = {"output": "123foofoo", "input": "123"} | |
| assert output == expected_output | |
| def test_multi_input_errors() -> None: | |
| """Test simple sequential errors if multiple input variables are expected.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) | |
| chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) | |
| with pytest.raises(ValueError): | |
| SimpleSequentialChain(chains=[chain_1, chain_2]) | |
| def test_multi_output_errors() -> None: | |
| """Test simple sequential errors if multiple output variables are expected.""" | |
| chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"]) | |
| chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) | |
| with pytest.raises(ValueError): | |
| SimpleSequentialChain(chains=[chain_1, chain_2]) | |