File size: 8,697 Bytes
00bd0c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | import pytest
from unittest.mock import patch
import sys
import os
sys.path.insert(0, 'backend')
from semantic_search.retrieve import *
class TestRetrieveUnitTests():
def setup_method(self):
self.sample_dict_arr = [
{'word': 'a', 'symptoms_word_is_repeated_in': [('a1', 0.5), ('a2', 0.1)]},
{'word': 'b', 'symptoms_word_is_repeated_in': [('b1', 0.2), ('b2', 0.2), ('b3', 0.6)]},
{'word': 'c', 'symptoms_word_is_repeated_in': [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)]}
]
self.nad = [
("cough", 0.1, {"sources": "doc1", "risk": "low"}),
("fever", 0.3, {"sources": "doc2", "risk": "medium"}),
("headache", 0.2, {"sources": "doc3", "risk": "low"})
]
self.context = [{'sources': 'doc1', 'risk score': 3, 'rarity' : 5}, {'sources': 'doc2', 'risk score': 6, 'rarity' : 4}, {'sources': 'doc3', 'risk score': 5, 'rarity' : 7}]
def test_get_first_db_column(self):
column_data, context_data = get_first_db_column()
print(f"SEEING : {column_data[10]} " )
assert column_data[10] == "hoarse voice"
def test_create_id(self):
"""Test create id here"""
sample_col_data = ['a', 'b', 'c', 'd']
id_list = create_id(sample_col_data)
assert len(id_list) == len(sample_col_data)
assert id_list == ['id1', 'id2', 'id3', 'id4']
def test_get_compiled_list(self):
"""Test get_compiled_list here"""
sample_img = ['a', 'b']
sample_hist = ['c', 'd']
sample_uq = ['e', 'f']
compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq)
assert compiled_list == ['a', 'b', 'c', 'd', 'e', 'f']
sample_img = ['a', 'b']
sample_hist = []
sample_uq = ['e', 'f']
compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq)
assert compiled_list == ['a', 'b', 'e', 'f']
def test_query_db(self):
fake_compiled_list = ['a', 'b', 'c']
with patch("semantic_search.retrieve.collection") as mock_collection:
query_db(fake_compiled_list)
mock_collection.query.assert_called_once()
def test_get_combined_output(self):
fake_input = {
'documents': [['cough', 'fever']],
'distances': [[0.2, 0.5]],
'metadatas': [[{'source': 'doc1'}, {'source': 'doc2'}]]
}
output = get_combined_output(fake_input)
print(output)
assert output == [('cough', 0.0, {'source': 'doc1'}), ('fever', 1.0, {'source': 'doc2'})]
def test_remove_stopwords(self):
"""Test remove stopwords"""
fake_symptom_list = [("eye of the tiger", 0.5), ("the pink pony club", 0.2), ("of or up symptoms when to the", 0.3)]
filtered_fake_symptom_list = remove_stopwords(fake_symptom_list)
assert filtered_fake_symptom_list == [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)]
def test_check_for_repeat_words(self):
"""Test check for repeat words"""
sample_input = [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)]
dict_arr = check_for_repeat_words_v2(sample_input)
assert dict_arr == []
sample_input = [('pink tiger', ['pink', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)]
dict_arr = check_for_repeat_words_v2(sample_input)
assert dict_arr == [{'word': 'pink', 'symptoms_word_is_repeated_in': [('pink tiger', 0.5), ('the pink pony club', 0.2)]}]
def test_get_final_symptom_list(self):
"""Test get final symptom list method"""
max_dist = 0.2
fake_combined_input = ["irrelevant"]
with patch("semantic_search.retrieve.query_db") as mock_query_db, \
patch("semantic_search.retrieve.get_combined_output") as mock_get_combined, \
patch("semantic_search.retrieve.remove_repeated") as mock_remove_repeated:
mock_query_db.return_value = "fake_db_output"
mock_get_combined.return_value = self.nad
mock_remove_repeated.side_effect = lambda x : x
output = get_final_symptom_list(max_dist, fake_combined_input)
assert output == [('cough', 0.1, {'sources': 'doc1', 'risk': 'low'})]
def test_get_symptom_name_list(self):
"""Test get symptom name list."""
sample_list = symptom_name_dist_list = [
("cough", 0.12, {"source": "doc1"}),
("fever", 0.34, {"source": "doc2"}),
("headache", 0.56, {"source": "doc3"})
]
output = get_symptom_name_list(sample_list)
assert output == ['cough', 'fever', 'headache']
def test_get_risk_and_rarity(self):
"""Test get risk and clarity"""
with patch("semantic_search.retrieve.return_context") as mock_context:
mock_context.return_value = self.context
output = get_risk_and_rarity(self.nad)
assert output == [(3, 5), (6, 4), (5, 7)]
def test_get_sources(self):
"""Test get sources method"""
with patch("semantic_search.retrieve.return_context") as mock_context:
mock_context.return_value = self.context
output = get_sources(self.nad)
print(output)
assert output == {'doc3', 'doc1', 'doc2'}
def test_return_context(self):
"""Test return context method"""
output = return_context(self.nad)
print(output)
assert output == [{'sources': 'doc1', 'risk': 'low'}, {'sources': 'doc2', 'risk': 'medium'}, {'sources': 'doc3', 'risk': 'low'}]
def test_context_and_name(self):
with patch("semantic_search.retrieve.get_symptom_name_list") as mock_get_name_list:
mock_get_name_list.return_value = ['cough', 'fever', 'headache']
output = context_and_name(self.nad)
assert output == {'cough': {'sources': 'doc1', 'risk': 'low'}, 'fever': {'sources': 'doc2', 'risk': 'medium'}, 'headache': {'sources': 'doc3', 'risk': 'low'}}
def test_calculate_risk(self):
"""Test the calculate risk method on nad inputs"""
with patch("semantic_search.retrieve.get_risk_and_rarity") as mock_get_risk_and_rarity:
mock_get_risk_and_rarity.return_value = [(3, 5), (6, 4), (5, 7)]
output = calculate_risk(self.nad)
assert output == 6
def test_find_max(self):
sample_input = [('b', 0.2), ('c', 0.3), ('a', 0.1)]
best = find_max(sample_input)
assert best == 'a'
sample_input_2 = [('a', 0.1), ('b', 0.1)]
assert find_max(sample_input_2) == 'a'
def test_create_max_array(self):
test_max_arr = create_max_array(self.sample_dict_arr)
assert len(test_max_arr) == len(self.sample_dict_arr)
assert test_max_arr == ['a2', 'b1', 'c3']
def test_find_non_max(self):
"""Test find non max method"""
test_entry = [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)]
find_non_max(test_entry)
def test_find_symptoms_to_remove(self):
"""Test find symptoms to remove method"""
output = find_symptoms_to_remove(self.sample_dict_arr)
print(output)
assert output == [('a1', 0.5), ('b2', 0.2), ('b3', 0.6), ('c1', 0.9), ('c2', 0.5)]
def test_remove_repeated_symptoms(self):
"""Test the remove repeated symptoms method."""
sample_fsl_1 = ['a', 'b', 'c', 'd', 'e']
sample_fsl_2 = ['a', 'a', 'a', 'a', 'a'] # testing a list where word is repeated more than once
sample_fsl_3 = ['a', 'a', 'b', 'a', 'b', 'c'] # testing a list with multiple repeated words
ffsl_1 = remove_repeated_symptoms(sample_fsl_1)
ffsl_2 = remove_repeated_symptoms(sample_fsl_2)
ffsl_3 = remove_repeated_symptoms(sample_fsl_3)
assert ffsl_1 == sample_fsl_1 # confirm nothing was removed from a list with no repeated words
assert len(ffsl_2) == 1 # confirm every single extra instance of 'a' was removed from list 2
assert ffsl_3 == ['a', 'b', 'c'] # confirm that both a and b repeated words were removed, and that the non repeated word c was kept
sample_symptoms_to_remove = ['a', 'e']
ffsl_1_post = remove_repeated_symptoms(sample_fsl_1, sample_symptoms_to_remove)
for letter in sample_symptoms_to_remove:
assert letter not in ffsl_1_post
|