| import pytest |
| from unittest.mock import patch |
| import sys |
| import os |
| sys.path.insert(0, 'backend') |
| from semantic_search.retrieve import * |
|
|
| class TestRetrieveUnitTests(): |
| def setup_method(self): |
| self.sample_dict_arr = [ |
| {'word': 'a', 'symptoms_word_is_repeated_in': [('a1', 0.5), ('a2', 0.1)]}, |
| {'word': 'b', 'symptoms_word_is_repeated_in': [('b1', 0.2), ('b2', 0.2), ('b3', 0.6)]}, |
| {'word': 'c', 'symptoms_word_is_repeated_in': [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)]} |
| ] |
| self.nad = [ |
| ("cough", 0.1, {"sources": "doc1", "risk": "low"}), |
| ("fever", 0.3, {"sources": "doc2", "risk": "medium"}), |
| ("headache", 0.2, {"sources": "doc3", "risk": "low"}) |
| ] |
|
|
| self.context = [{'sources': 'doc1', 'risk score': 3, 'rarity' : 5}, {'sources': 'doc2', 'risk score': 6, 'rarity' : 4}, {'sources': 'doc3', 'risk score': 5, 'rarity' : 7}] |
|
|
| def test_get_first_db_column(self): |
| column_data, context_data = get_first_db_column() |
| print(f"SEEING : {column_data[10]} " ) |
| assert column_data[10] == "hoarse voice" |
| |
| def test_create_id(self): |
| """Test create id here""" |
| sample_col_data = ['a', 'b', 'c', 'd'] |
| id_list = create_id(sample_col_data) |
| assert len(id_list) == len(sample_col_data) |
| assert id_list == ['id1', 'id2', 'id3', 'id4'] |
|
|
| |
| def test_get_compiled_list(self): |
| """Test get_compiled_list here""" |
| sample_img = ['a', 'b'] |
| sample_hist = ['c', 'd'] |
| sample_uq = ['e', 'f'] |
| compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq) |
|
|
| assert compiled_list == ['a', 'b', 'c', 'd', 'e', 'f'] |
| |
| sample_img = ['a', 'b'] |
| sample_hist = [] |
| sample_uq = ['e', 'f'] |
| compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq) |
|
|
| assert compiled_list == ['a', 'b', 'e', 'f'] |
| |
| |
| def test_query_db(self): |
| fake_compiled_list = ['a', 'b', 'c'] |
| with patch("semantic_search.retrieve.collection") as mock_collection: |
| query_db(fake_compiled_list) |
| mock_collection.query.assert_called_once() |
|
|
| |
| def test_get_combined_output(self): |
| fake_input = { |
| 'documents': [['cough', 'fever']], |
| 'distances': [[0.2, 0.5]], |
| 'metadatas': [[{'source': 'doc1'}, {'source': 'doc2'}]] |
| } |
| output = get_combined_output(fake_input) |
| print(output) |
| assert output == [('cough', 0.0, {'source': 'doc1'}), ('fever', 1.0, {'source': 'doc2'})] |
|
|
| |
| def test_remove_stopwords(self): |
| """Test remove stopwords""" |
| fake_symptom_list = [("eye of the tiger", 0.5), ("the pink pony club", 0.2), ("of or up symptoms when to the", 0.3)] |
| filtered_fake_symptom_list = remove_stopwords(fake_symptom_list) |
|
|
| assert filtered_fake_symptom_list == [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)] |
| |
| def test_check_for_repeat_words(self): |
| """Test check for repeat words""" |
| sample_input = [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)] |
| dict_arr = check_for_repeat_words_v2(sample_input) |
| assert dict_arr == [] |
|
|
| sample_input = [('pink tiger', ['pink', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)] |
| dict_arr = check_for_repeat_words_v2(sample_input) |
| assert dict_arr == [{'word': 'pink', 'symptoms_word_is_repeated_in': [('pink tiger', 0.5), ('the pink pony club', 0.2)]}] |
|
|
|
|
| def test_get_final_symptom_list(self): |
| """Test get final symptom list method""" |
| max_dist = 0.2 |
| fake_combined_input = ["irrelevant"] |
|
|
| with patch("semantic_search.retrieve.query_db") as mock_query_db, \ |
| patch("semantic_search.retrieve.get_combined_output") as mock_get_combined, \ |
| patch("semantic_search.retrieve.remove_repeated") as mock_remove_repeated: |
|
|
| mock_query_db.return_value = "fake_db_output" |
| mock_get_combined.return_value = self.nad |
| mock_remove_repeated.side_effect = lambda x : x |
| output = get_final_symptom_list(max_dist, fake_combined_input) |
| assert output == [('cough', 0.1, {'sources': 'doc1', 'risk': 'low'})] |
|
|
|
|
| def test_get_symptom_name_list(self): |
| """Test get symptom name list.""" |
| sample_list = symptom_name_dist_list = [ |
| ("cough", 0.12, {"source": "doc1"}), |
| ("fever", 0.34, {"source": "doc2"}), |
| ("headache", 0.56, {"source": "doc3"}) |
| ] |
| output = get_symptom_name_list(sample_list) |
| assert output == ['cough', 'fever', 'headache'] |
|
|
| def test_get_risk_and_rarity(self): |
| """Test get risk and clarity""" |
|
|
| with patch("semantic_search.retrieve.return_context") as mock_context: |
| mock_context.return_value = self.context |
| output = get_risk_and_rarity(self.nad) |
| assert output == [(3, 5), (6, 4), (5, 7)] |
|
|
| def test_get_sources(self): |
| """Test get sources method""" |
|
|
| with patch("semantic_search.retrieve.return_context") as mock_context: |
| mock_context.return_value = self.context |
| output = get_sources(self.nad) |
| print(output) |
| assert output == {'doc3', 'doc1', 'doc2'} |
|
|
| def test_return_context(self): |
| """Test return context method""" |
| output = return_context(self.nad) |
| print(output) |
| assert output == [{'sources': 'doc1', 'risk': 'low'}, {'sources': 'doc2', 'risk': 'medium'}, {'sources': 'doc3', 'risk': 'low'}] |
|
|
| def test_context_and_name(self): |
| with patch("semantic_search.retrieve.get_symptom_name_list") as mock_get_name_list: |
| mock_get_name_list.return_value = ['cough', 'fever', 'headache'] |
| output = context_and_name(self.nad) |
| assert output == {'cough': {'sources': 'doc1', 'risk': 'low'}, 'fever': {'sources': 'doc2', 'risk': 'medium'}, 'headache': {'sources': 'doc3', 'risk': 'low'}} |
|
|
|
|
| def test_calculate_risk(self): |
| """Test the calculate risk method on nad inputs""" |
|
|
| with patch("semantic_search.retrieve.get_risk_and_rarity") as mock_get_risk_and_rarity: |
| mock_get_risk_and_rarity.return_value = [(3, 5), (6, 4), (5, 7)] |
| output = calculate_risk(self.nad) |
| assert output == 6 |
|
|
|
|
| def test_find_max(self): |
| sample_input = [('b', 0.2), ('c', 0.3), ('a', 0.1)] |
| best = find_max(sample_input) |
|
|
| assert best == 'a' |
|
|
| sample_input_2 = [('a', 0.1), ('b', 0.1)] |
| assert find_max(sample_input_2) == 'a' |
|
|
| def test_create_max_array(self): |
| test_max_arr = create_max_array(self.sample_dict_arr) |
| |
| assert len(test_max_arr) == len(self.sample_dict_arr) |
| assert test_max_arr == ['a2', 'b1', 'c3'] |
|
|
| def test_find_non_max(self): |
| """Test find non max method""" |
| test_entry = [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)] |
| find_non_max(test_entry) |
| |
| def test_find_symptoms_to_remove(self): |
| """Test find symptoms to remove method""" |
| output = find_symptoms_to_remove(self.sample_dict_arr) |
| print(output) |
| assert output == [('a1', 0.5), ('b2', 0.2), ('b3', 0.6), ('c1', 0.9), ('c2', 0.5)] |
|
|
| def test_remove_repeated_symptoms(self): |
| """Test the remove repeated symptoms method.""" |
|
|
| sample_fsl_1 = ['a', 'b', 'c', 'd', 'e'] |
| sample_fsl_2 = ['a', 'a', 'a', 'a', 'a'] |
| sample_fsl_3 = ['a', 'a', 'b', 'a', 'b', 'c'] |
|
|
| ffsl_1 = remove_repeated_symptoms(sample_fsl_1) |
| ffsl_2 = remove_repeated_symptoms(sample_fsl_2) |
| ffsl_3 = remove_repeated_symptoms(sample_fsl_3) |
| |
| assert ffsl_1 == sample_fsl_1 |
| assert len(ffsl_2) == 1 |
| assert ffsl_3 == ['a', 'b', 'c'] |
|
|
| sample_symptoms_to_remove = ['a', 'e'] |
| ffsl_1_post = remove_repeated_symptoms(sample_fsl_1, sample_symptoms_to_remove) |
|
|
| for letter in sample_symptoms_to_remove: |
| assert letter not in ffsl_1_post |
|
|