Spaces:
Runtime error
Runtime error
Add doc strings
Browse files
utils.py
CHANGED
|
@@ -79,7 +79,50 @@ def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
|
|
| 79 |
|
| 80 |
|
| 81 |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
_result = []
|
| 84 |
for count in n_sentences:
|
| 85 |
_result.append(embeddings[s_idx:s_idx + count])
|
|
@@ -107,6 +150,37 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
|
|
| 107 |
|
| 108 |
|
| 109 |
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if depth == 0:
|
| 111 |
return isinstance(lst_obj, element_type)
|
| 112 |
elif depth > 0:
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
| 82 |
+
"""
|
| 83 |
+
Slice embeddings into segments based on the provided number of sentences per segment.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
- embeddings (np.ndarray): The array of embeddings to be sliced.
|
| 87 |
+
- num_sentences (Union[List[int], List[List[int]]]):
|
| 88 |
+
- If a list of integers: Specifies the number of embeddings to take in each slice.
|
| 89 |
+
- If a list of lists of integers: Specifies multiple nested levels of slicing.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
- List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
- TypeError: If `num_sentences` is not of type List[int] or List[List[int]].
|
| 96 |
+
|
| 97 |
+
Example Usage:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
embeddings = np.random.rand(10, 5)
|
| 101 |
+
num_sentences = [3, 2, 5]
|
| 102 |
+
result = slice_embeddings(embeddings, num_sentences)
|
| 103 |
+
# `result` will be a list of numpy arrays:
|
| 104 |
+
# [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
| 105 |
+
|
| 106 |
+
num_sentences_nested = [[2, 1], [3, 4]]
|
| 107 |
+
result_nested = slice_embeddings(embeddings, num_sentences_nested)
|
| 108 |
+
# `result_nested` will be a nested list of numpy arrays:
|
| 109 |
+
# [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
|
| 110 |
+
|
| 111 |
+
slice_embeddings(embeddings, "invalid") # Raises a TypeError
|
| 112 |
+
```
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
|
| 116 |
+
"""
|
| 117 |
+
Helper function to slice embeddings starting from index `s_idx`.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
- s_idx (int): Starting index for slicing.
|
| 121 |
+
- n_sentences (List[int]): List specifying number of sentences in each slice.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
- Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
|
| 125 |
+
"""
|
| 126 |
_result = []
|
| 127 |
for count in n_sentences:
|
| 128 |
_result.append(embeddings[s_idx:s_idx + count])
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
| 153 |
+
"""
|
| 154 |
+
Check if the given object is a nested list of a specific type up to a specified depth.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
- lst_obj: The object to check, expected to be a list or a single element.
|
| 158 |
+
- element_type: The type that each element in the nested list should match.
|
| 159 |
+
- depth (int): The depth of nesting to check. Must be non-negative.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
- bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
|
| 163 |
+
|
| 164 |
+
Raises:
|
| 165 |
+
- ValueError: If depth is negative.
|
| 166 |
+
|
| 167 |
+
Example:
|
| 168 |
+
```python
|
| 169 |
+
# Test cases
|
| 170 |
+
is_nested_list_of_type("test", str, 0) # Returns True
|
| 171 |
+
is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
|
| 172 |
+
is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
|
| 173 |
+
is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
|
| 174 |
+
is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
|
| 175 |
+
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
Explanation:
|
| 179 |
+
- The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
|
| 180 |
+
- If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
|
| 181 |
+
- If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
|
| 182 |
+
- Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
|
| 183 |
+
"""
|
| 184 |
if depth == 0:
|
| 185 |
return isinstance(lst_obj, element_type)
|
| 186 |
elif depth > 0:
|