Spaces:
Paused
Paused
| import pytest | |
| from tests.utils import wrap_test_forked, get_llama | |
| from src.enums import DocumentSubset | |
| def test_cli(monkeypatch): | |
| query = "What is the Earth?" | |
| monkeypatch.setattr('builtins.input', lambda _: query) | |
| from src.gen import main | |
| all_generations = main(base_model='gptj', cli=True, cli_loop=False, score_model='None') | |
| assert len(all_generations) == 1 | |
| assert "The Earth is a planet in our solar system" in all_generations[0] | |
| def test_cli_langchain(base_model, monkeypatch): | |
| from tests.utils import make_user_path_test | |
| user_path = make_user_path_test() | |
| query = "What is the cat doing?" | |
| monkeypatch.setattr('builtins.input', lambda _: query) | |
| from src.gen import main | |
| all_generations = main(base_model=base_model, cli=True, cli_loop=False, score_model='None', | |
| langchain_mode='UserData', | |
| user_path=user_path, | |
| langchain_modes=['UserData', 'MyData'], | |
| document_subset=DocumentSubset.Relevant.name, | |
| verbose=True) | |
| print(all_generations) | |
| assert len(all_generations) == 1 | |
| assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0] | |
| assert "looking out the window" in all_generations[0] or \ | |
| "staring out the window at the city skyline" in all_generations[0] or \ | |
| "what the cat is doing" in all_generations[0] or \ | |
| "question about a cat" in all_generations[0] or \ | |
| "The prompt asks for an answer to a question" in all_generations[0] or \ | |
| "The prompt asks what the cat in the scenario is doing" in all_generations[0] or \ | |
| "The prompt asks why H2O.ai" in all_generations[0] | |
| def test_cli_langchain_llamacpp(monkeypatch): | |
| prompt_type, full_path = get_llama() | |
| from tests.utils import make_user_path_test | |
| user_path = make_user_path_test() | |
| query = "What is the cat doing?" | |
| monkeypatch.setattr('builtins.input', lambda _: query) | |
| from src.gen import main | |
| all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None', | |
| langchain_mode='UserData', | |
| model_path_llama=full_path, | |
| prompt_type=prompt_type, | |
| user_path=user_path, | |
| langchain_modes=['UserData', 'MyData'], | |
| document_subset=DocumentSubset.Relevant.name, | |
| verbose=True) | |
| print(all_generations) | |
| assert len(all_generations) == 1 | |
| assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0] | |
| assert "The cat is sitting on a window seat and looking out the window" in all_generations[0] or \ | |
| "staring out the window at the city skyline" in all_generations[0] or \ | |
| "The cat is likely relaxing and enjoying" in all_generations[0] or \ | |
| "The cat is sitting on a window seat and looking out" in all_generations[0] or \ | |
| "cat in the image is" in all_generations[0] or \ | |
| "The cat is sitting on a window" in all_generations[0] or \ | |
| "The cat is sitting and looking out the window at the view of the city outside." in all_generations[0] | |
| def test_cli_llamacpp(monkeypatch): | |
| prompt_type, full_path = get_llama() | |
| query = "Who are you?" | |
| monkeypatch.setattr('builtins.input', lambda _: query) | |
| from src.gen import main | |
| langchain_mode = 'Disabled' | |
| all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None', | |
| langchain_mode=langchain_mode, | |
| prompt_type=prompt_type, | |
| model_path_llama=full_path, | |
| user_path=None, | |
| langchain_modes=[langchain_mode], | |
| document_subset=DocumentSubset.Relevant.name, | |
| verbose=True) | |
| print(all_generations) | |
| assert len(all_generations) == 1 | |
| assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \ | |
| "how can I assist" in all_generations[0] or \ | |
| "am a virtual assistant" in all_generations[0] or \ | |
| "My name is John." in all_generations[0] | |
| def test_cli_h2ogpt(monkeypatch): | |
| query = "What is the Earth?" | |
| monkeypatch.setattr('builtins.input', lambda _: query) | |
| from src.gen import main | |
| all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False, score_model='None') | |
| assert len(all_generations) == 1 | |
| assert "The Earth is a planet in the Solar System" in all_generations[0] or \ | |
| "The Earth is the third planet" in all_generations[0] | |
| def test_cli_langchain_h2ogpt(monkeypatch): | |
| from tests.utils import make_user_path_test | |
| user_path = make_user_path_test() | |
| query = "What is the cat doing?" | |
| monkeypatch.setattr('builtins.input', lambda _: query) | |
| from src.gen import main | |
| all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', | |
| cli=True, cli_loop=False, score_model='None', | |
| langchain_mode='UserData', | |
| user_path=user_path, | |
| langchain_modes=['UserData', 'MyData'], | |
| document_subset=DocumentSubset.Relevant.name, | |
| verbose=True) | |
| print(all_generations) | |
| assert len(all_generations) == 1 | |
| assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0] | |
| assert "looking out the window" in all_generations[0] or "staring out the window at the city skyline" in \ | |
| all_generations[0] | |