Coverage for tinytroupe / extraction / normalizer.py: 0%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-28 17:48 +0000

1import pandas as pd 

2from typing import Union, List 

3 

4from tinytroupe.extraction import logger 

5 

6from tinytroupe import openai_utils 

7import tinytroupe.utils as utils 

8class Normalizer: 

9 """ 

10 A mechanism to normalize passages, concepts and other textual elements. 

11 """ 

12 

13 def __init__(self, elements:List[str], n:int, verbose:bool=False): 

14 """ 

15 Normalizes the specified elements. 

16 

17 Args: 

18 elements (list): The elements to normalize. 

19 n (int): The number of normalized elements to output. 

20 verbose (bool, optional): Whether to print debug messages. Defaults to False. 

21 """ 

22 # ensure elements are unique 

23 self.elements = list(set(elements)) 

24 

25 self.n = n 

26 self.verbose = verbose 

27 

28 # a JSON-based structure, where each output element is a key to a list of input elements that were merged into it 

29 self.normalized_elements = None 

30 # a dict that maps each input element to its normalized output. This will be used as cache later. 

31 self.normalizing_map = {} 

32 

33 rendering_configs = {"n": n, 

34 "elements": self.elements} 

35 

36 messages = utils.compose_initial_LLM_messages_with_templates("normalizer.system.mustache", "normalizer.user.mustache", 

37 base_module_folder="extraction", 

38 rendering_configs=rendering_configs) 

39 

40 next_message = openai_utils.client().send_message(messages, temperature=0.1) 

41 

42 debug_msg = f"Normalization result message: {next_message}" 

43 logger.debug(debug_msg) 

44 if self.verbose: 

45 print(debug_msg) 

46 

47 result = utils.extract_json(next_message["content"]) 

48 logger.debug(result) 

49 if self.verbose: 

50 print(result) 

51 

52 self.normalized_elements = result 

53 

54 

55 def normalize(self, element_or_elements:Union[str, List[str]]) -> Union[str, List[str]]: 

56 """ 

57 Normalizes the specified element or elements. 

58 

59 This method uses a caching mechanism to improve performance. If an element has been normalized before,  

60 its normalized form is stored in a cache (self.normalizing_map). When the same element needs to be  

61 normalized again, the method will first check the cache and use the stored normalized form if available,  

62 instead of normalizing the element again. 

63 

64 The order of elements in the output will be the same as in the input. This is ensured by processing  

65 the elements in the order they appear in the input and appending the normalized elements to the output  

66 list in the same order. 

67 

68 Args: 

69 element_or_elements (Union[str, List[str]]): The element or elements to normalize. 

70 

71 Returns: 

72 str: The normalized element if the input was a string. 

73 list: The normalized elements if the input was a list, preserving the order of elements in the input. 

74 """ 

75 if isinstance(element_or_elements, str): 

76 denormalized_elements = [element_or_elements] 

77 elif isinstance(element_or_elements, list): 

78 denormalized_elements = element_or_elements 

79 else: 

80 raise ValueError("The element_or_elements must be either a string or a list.") 

81 

82 normalized_elements = [] 

83 elements_to_normalize = [] 

84 for element in denormalized_elements: 

85 if element not in self.normalizing_map: 

86 elements_to_normalize.append(element) 

87 

88 if elements_to_normalize: 

89 rendering_configs = {"categories": self.normalized_elements, 

90 "elements": elements_to_normalize} 

91 

92 messages = utils.compose_initial_LLM_messages_with_templates("normalizer.applier.system.mustache", "normalizer.applier.user.mustache", 

93 base_module_folder="extraction", 

94 rendering_configs=rendering_configs) 

95 

96 next_message = openai_utils.client().send_message(messages, temperature=0.1) 

97 

98 debug_msg = f"Normalization result message: {next_message}" 

99 logger.debug(debug_msg) 

100 if self.verbose: 

101 print(debug_msg) 

102 

103 normalized_elements_from_llm = utils.extract_json(next_message["content"]) 

104 assert isinstance(normalized_elements_from_llm, list), "The normalized element must be a list." 

105 assert len(normalized_elements_from_llm) == len(elements_to_normalize), "The number of normalized elements must be equal to the number of elements to normalize." 

106 

107 for i, element in enumerate(elements_to_normalize): 

108 normalized_element = normalized_elements_from_llm[i] 

109 self.normalizing_map[element] = normalized_element 

110 

111 for element in denormalized_elements: 

112 normalized_elements.append(self.normalizing_map[element]) 

113 

114 return normalized_elements 

115