import pytest from unittest.mock import patch import sys import os sys.path.insert(0, 'backend') from semantic_search.retrieve import * class TestRetrieveUnitTests(): def setup_method(self): self.sample_dict_arr = [ {'word': 'a', 'symptoms_word_is_repeated_in': [('a1', 0.5), ('a2', 0.1)]}, {'word': 'b', 'symptoms_word_is_repeated_in': [('b1', 0.2), ('b2', 0.2), ('b3', 0.6)]}, {'word': 'c', 'symptoms_word_is_repeated_in': [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)]} ] self.nad = [ ("cough", 0.1, {"sources": "doc1", "risk": "low"}), ("fever", 0.3, {"sources": "doc2", "risk": "medium"}), ("headache", 0.2, {"sources": "doc3", "risk": "low"}) ] self.context = [{'sources': 'doc1', 'risk score': 3, 'rarity' : 5}, {'sources': 'doc2', 'risk score': 6, 'rarity' : 4}, {'sources': 'doc3', 'risk score': 5, 'rarity' : 7}] def test_get_first_db_column(self): column_data, context_data = get_first_db_column() print(f"SEEING : {column_data[10]} " ) assert column_data[10] == "hoarse voice" def test_create_id(self): """Test create id here""" sample_col_data = ['a', 'b', 'c', 'd'] id_list = create_id(sample_col_data) assert len(id_list) == len(sample_col_data) assert id_list == ['id1', 'id2', 'id3', 'id4'] def test_get_compiled_list(self): """Test get_compiled_list here""" sample_img = ['a', 'b'] sample_hist = ['c', 'd'] sample_uq = ['e', 'f'] compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq) assert compiled_list == ['a', 'b', 'c', 'd', 'e', 'f'] sample_img = ['a', 'b'] sample_hist = [] sample_uq = ['e', 'f'] compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq) assert compiled_list == ['a', 'b', 'e', 'f'] def test_query_db(self): fake_compiled_list = ['a', 'b', 'c'] with patch("semantic_search.retrieve.collection") as mock_collection: query_db(fake_compiled_list) mock_collection.query.assert_called_once() def test_get_combined_output(self): fake_input = { 'documents': [['cough', 'fever']], 'distances': [[0.2, 0.5]], 'metadatas': [[{'source': 'doc1'}, {'source': 'doc2'}]] } output = get_combined_output(fake_input) print(output) assert output == [('cough', 0.0, {'source': 'doc1'}), ('fever', 1.0, {'source': 'doc2'})] def test_remove_stopwords(self): """Test remove stopwords""" fake_symptom_list = [("eye of the tiger", 0.5), ("the pink pony club", 0.2), ("of or up symptoms when to the", 0.3)] filtered_fake_symptom_list = remove_stopwords(fake_symptom_list) assert filtered_fake_symptom_list == [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)] def test_check_for_repeat_words(self): """Test check for repeat words""" sample_input = [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)] dict_arr = check_for_repeat_words_v2(sample_input) assert dict_arr == [] sample_input = [('pink tiger', ['pink', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)] dict_arr = check_for_repeat_words_v2(sample_input) assert dict_arr == [{'word': 'pink', 'symptoms_word_is_repeated_in': [('pink tiger', 0.5), ('the pink pony club', 0.2)]}] def test_get_final_symptom_list(self): """Test get final symptom list method""" max_dist = 0.2 fake_combined_input = ["irrelevant"] with patch("semantic_search.retrieve.query_db") as mock_query_db, \ patch("semantic_search.retrieve.get_combined_output") as mock_get_combined, \ patch("semantic_search.retrieve.remove_repeated") as mock_remove_repeated: mock_query_db.return_value = "fake_db_output" mock_get_combined.return_value = self.nad mock_remove_repeated.side_effect = lambda x : x output = get_final_symptom_list(max_dist, fake_combined_input) assert output == [('cough', 0.1, {'sources': 'doc1', 'risk': 'low'})] def test_get_symptom_name_list(self): """Test get symptom name list.""" sample_list = symptom_name_dist_list = [ ("cough", 0.12, {"source": "doc1"}), ("fever", 0.34, {"source": "doc2"}), ("headache", 0.56, {"source": "doc3"}) ] output = get_symptom_name_list(sample_list) assert output == ['cough', 'fever', 'headache'] def test_get_risk_and_rarity(self): """Test get risk and clarity""" with patch("semantic_search.retrieve.return_context") as mock_context: mock_context.return_value = self.context output = get_risk_and_rarity(self.nad) assert output == [(3, 5), (6, 4), (5, 7)] def test_get_sources(self): """Test get sources method""" with patch("semantic_search.retrieve.return_context") as mock_context: mock_context.return_value = self.context output = get_sources(self.nad) print(output) assert output == {'doc3', 'doc1', 'doc2'} def test_return_context(self): """Test return context method""" output = return_context(self.nad) print(output) assert output == [{'sources': 'doc1', 'risk': 'low'}, {'sources': 'doc2', 'risk': 'medium'}, {'sources': 'doc3', 'risk': 'low'}] def test_context_and_name(self): with patch("semantic_search.retrieve.get_symptom_name_list") as mock_get_name_list: mock_get_name_list.return_value = ['cough', 'fever', 'headache'] output = context_and_name(self.nad) assert output == {'cough': {'sources': 'doc1', 'risk': 'low'}, 'fever': {'sources': 'doc2', 'risk': 'medium'}, 'headache': {'sources': 'doc3', 'risk': 'low'}} def test_calculate_risk(self): """Test the calculate risk method on nad inputs""" with patch("semantic_search.retrieve.get_risk_and_rarity") as mock_get_risk_and_rarity: mock_get_risk_and_rarity.return_value = [(3, 5), (6, 4), (5, 7)] output = calculate_risk(self.nad) assert output == 6 def test_find_max(self): sample_input = [('b', 0.2), ('c', 0.3), ('a', 0.1)] best = find_max(sample_input) assert best == 'a' sample_input_2 = [('a', 0.1), ('b', 0.1)] assert find_max(sample_input_2) == 'a' def test_create_max_array(self): test_max_arr = create_max_array(self.sample_dict_arr) assert len(test_max_arr) == len(self.sample_dict_arr) assert test_max_arr == ['a2', 'b1', 'c3'] def test_find_non_max(self): """Test find non max method""" test_entry = [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)] find_non_max(test_entry) def test_find_symptoms_to_remove(self): """Test find symptoms to remove method""" output = find_symptoms_to_remove(self.sample_dict_arr) print(output) assert output == [('a1', 0.5), ('b2', 0.2), ('b3', 0.6), ('c1', 0.9), ('c2', 0.5)] def test_remove_repeated_symptoms(self): """Test the remove repeated symptoms method.""" sample_fsl_1 = ['a', 'b', 'c', 'd', 'e'] sample_fsl_2 = ['a', 'a', 'a', 'a', 'a'] # testing a list where word is repeated more than once sample_fsl_3 = ['a', 'a', 'b', 'a', 'b', 'c'] # testing a list with multiple repeated words ffsl_1 = remove_repeated_symptoms(sample_fsl_1) ffsl_2 = remove_repeated_symptoms(sample_fsl_2) ffsl_3 = remove_repeated_symptoms(sample_fsl_3) assert ffsl_1 == sample_fsl_1 # confirm nothing was removed from a list with no repeated words assert len(ffsl_2) == 1 # confirm every single extra instance of 'a' was removed from list 2 assert ffsl_3 == ['a', 'b', 'c'] # confirm that both a and b repeated words were removed, and that the non repeated word c was kept sample_symptoms_to_remove = ['a', 'e'] ffsl_1_post = remove_repeated_symptoms(sample_fsl_1, sample_symptoms_to_remove) for letter in sample_symptoms_to_remove: assert letter not in ffsl_1_post