fatimaxa's picture
Upload 112 files
00bd0c6 verified
import pytest
from unittest.mock import patch
import sys
import os
sys.path.insert(0, 'backend')
from semantic_search.retrieve import *
class TestRetrieveUnitTests():
def setup_method(self):
self.sample_dict_arr = [
{'word': 'a', 'symptoms_word_is_repeated_in': [('a1', 0.5), ('a2', 0.1)]},
{'word': 'b', 'symptoms_word_is_repeated_in': [('b1', 0.2), ('b2', 0.2), ('b3', 0.6)]},
{'word': 'c', 'symptoms_word_is_repeated_in': [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)]}
]
self.nad = [
("cough", 0.1, {"sources": "doc1", "risk": "low"}),
("fever", 0.3, {"sources": "doc2", "risk": "medium"}),
("headache", 0.2, {"sources": "doc3", "risk": "low"})
]
self.context = [{'sources': 'doc1', 'risk score': 3, 'rarity' : 5}, {'sources': 'doc2', 'risk score': 6, 'rarity' : 4}, {'sources': 'doc3', 'risk score': 5, 'rarity' : 7}]
def test_get_first_db_column(self):
column_data, context_data = get_first_db_column()
print(f"SEEING : {column_data[10]} " )
assert column_data[10] == "hoarse voice"
def test_create_id(self):
"""Test create id here"""
sample_col_data = ['a', 'b', 'c', 'd']
id_list = create_id(sample_col_data)
assert len(id_list) == len(sample_col_data)
assert id_list == ['id1', 'id2', 'id3', 'id4']
def test_get_compiled_list(self):
"""Test get_compiled_list here"""
sample_img = ['a', 'b']
sample_hist = ['c', 'd']
sample_uq = ['e', 'f']
compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq)
assert compiled_list == ['a', 'b', 'c', 'd', 'e', 'f']
sample_img = ['a', 'b']
sample_hist = []
sample_uq = ['e', 'f']
compiled_list = get_compiled_list(sample_img, sample_hist, sample_uq)
assert compiled_list == ['a', 'b', 'e', 'f']
def test_query_db(self):
fake_compiled_list = ['a', 'b', 'c']
with patch("semantic_search.retrieve.collection") as mock_collection:
query_db(fake_compiled_list)
mock_collection.query.assert_called_once()
def test_get_combined_output(self):
fake_input = {
'documents': [['cough', 'fever']],
'distances': [[0.2, 0.5]],
'metadatas': [[{'source': 'doc1'}, {'source': 'doc2'}]]
}
output = get_combined_output(fake_input)
print(output)
assert output == [('cough', 0.0, {'source': 'doc1'}), ('fever', 1.0, {'source': 'doc2'})]
def test_remove_stopwords(self):
"""Test remove stopwords"""
fake_symptom_list = [("eye of the tiger", 0.5), ("the pink pony club", 0.2), ("of or up symptoms when to the", 0.3)]
filtered_fake_symptom_list = remove_stopwords(fake_symptom_list)
assert filtered_fake_symptom_list == [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)]
def test_check_for_repeat_words(self):
"""Test check for repeat words"""
sample_input = [('eye of the tiger', ['eye', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)]
dict_arr = check_for_repeat_words_v2(sample_input)
assert dict_arr == []
sample_input = [('pink tiger', ['pink', 'tiger'], 0.5), ('the pink pony club', ['pink', 'pony', 'club'], 0.2), ('of or up symptoms when to the', [], 0.3)]
dict_arr = check_for_repeat_words_v2(sample_input)
assert dict_arr == [{'word': 'pink', 'symptoms_word_is_repeated_in': [('pink tiger', 0.5), ('the pink pony club', 0.2)]}]
def test_get_final_symptom_list(self):
"""Test get final symptom list method"""
max_dist = 0.2
fake_combined_input = ["irrelevant"]
with patch("semantic_search.retrieve.query_db") as mock_query_db, \
patch("semantic_search.retrieve.get_combined_output") as mock_get_combined, \
patch("semantic_search.retrieve.remove_repeated") as mock_remove_repeated:
mock_query_db.return_value = "fake_db_output"
mock_get_combined.return_value = self.nad
mock_remove_repeated.side_effect = lambda x : x
output = get_final_symptom_list(max_dist, fake_combined_input)
assert output == [('cough', 0.1, {'sources': 'doc1', 'risk': 'low'})]
def test_get_symptom_name_list(self):
"""Test get symptom name list."""
sample_list = symptom_name_dist_list = [
("cough", 0.12, {"source": "doc1"}),
("fever", 0.34, {"source": "doc2"}),
("headache", 0.56, {"source": "doc3"})
]
output = get_symptom_name_list(sample_list)
assert output == ['cough', 'fever', 'headache']
def test_get_risk_and_rarity(self):
"""Test get risk and clarity"""
with patch("semantic_search.retrieve.return_context") as mock_context:
mock_context.return_value = self.context
output = get_risk_and_rarity(self.nad)
assert output == [(3, 5), (6, 4), (5, 7)]
def test_get_sources(self):
"""Test get sources method"""
with patch("semantic_search.retrieve.return_context") as mock_context:
mock_context.return_value = self.context
output = get_sources(self.nad)
print(output)
assert output == {'doc3', 'doc1', 'doc2'}
def test_return_context(self):
"""Test return context method"""
output = return_context(self.nad)
print(output)
assert output == [{'sources': 'doc1', 'risk': 'low'}, {'sources': 'doc2', 'risk': 'medium'}, {'sources': 'doc3', 'risk': 'low'}]
def test_context_and_name(self):
with patch("semantic_search.retrieve.get_symptom_name_list") as mock_get_name_list:
mock_get_name_list.return_value = ['cough', 'fever', 'headache']
output = context_and_name(self.nad)
assert output == {'cough': {'sources': 'doc1', 'risk': 'low'}, 'fever': {'sources': 'doc2', 'risk': 'medium'}, 'headache': {'sources': 'doc3', 'risk': 'low'}}
def test_calculate_risk(self):
"""Test the calculate risk method on nad inputs"""
with patch("semantic_search.retrieve.get_risk_and_rarity") as mock_get_risk_and_rarity:
mock_get_risk_and_rarity.return_value = [(3, 5), (6, 4), (5, 7)]
output = calculate_risk(self.nad)
assert output == 6
def test_find_max(self):
sample_input = [('b', 0.2), ('c', 0.3), ('a', 0.1)]
best = find_max(sample_input)
assert best == 'a'
sample_input_2 = [('a', 0.1), ('b', 0.1)]
assert find_max(sample_input_2) == 'a'
def test_create_max_array(self):
test_max_arr = create_max_array(self.sample_dict_arr)
assert len(test_max_arr) == len(self.sample_dict_arr)
assert test_max_arr == ['a2', 'b1', 'c3']
def test_find_non_max(self):
"""Test find non max method"""
test_entry = [('c1', 0.9), ('c2', 0.5), ('c3', 0.4)]
find_non_max(test_entry)
def test_find_symptoms_to_remove(self):
"""Test find symptoms to remove method"""
output = find_symptoms_to_remove(self.sample_dict_arr)
print(output)
assert output == [('a1', 0.5), ('b2', 0.2), ('b3', 0.6), ('c1', 0.9), ('c2', 0.5)]
def test_remove_repeated_symptoms(self):
"""Test the remove repeated symptoms method."""
sample_fsl_1 = ['a', 'b', 'c', 'd', 'e']
sample_fsl_2 = ['a', 'a', 'a', 'a', 'a'] # testing a list where word is repeated more than once
sample_fsl_3 = ['a', 'a', 'b', 'a', 'b', 'c'] # testing a list with multiple repeated words
ffsl_1 = remove_repeated_symptoms(sample_fsl_1)
ffsl_2 = remove_repeated_symptoms(sample_fsl_2)
ffsl_3 = remove_repeated_symptoms(sample_fsl_3)
assert ffsl_1 == sample_fsl_1 # confirm nothing was removed from a list with no repeated words
assert len(ffsl_2) == 1 # confirm every single extra instance of 'a' was removed from list 2
assert ffsl_3 == ['a', 'b', 'c'] # confirm that both a and b repeated words were removed, and that the non repeated word c was kept
sample_symptoms_to_remove = ['a', 'e']
ffsl_1_post = remove_repeated_symptoms(sample_fsl_1, sample_symptoms_to_remove)
for letter in sample_symptoms_to_remove:
assert letter not in ffsl_1_post