Theo Viel commited on
Commit
7bf8a91
·
1 Parent(s): af7c76e

minor fct simplification

Browse files
Files changed (2) hide show
  1. Demo.ipynb +2 -2
  2. post_processing/graphic_elt_pp.py +2 -6
Demo.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a2a102251661c90c8754034554147835394b71507dc496eb1f09a34665381eb6
3
- size 979765
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:777844f15f935804cd331020a2338c7db3670a96b0dc6d3430b2602cde666e68
3
+ size 989346
post_processing/graphic_elt_pp.py CHANGED
@@ -89,7 +89,6 @@ def expand_boxes(
89
  def retrieve_title(
90
  boxes: npt.NDArray[np.float64],
91
  labels: npt.NDArray[np.int_],
92
- scores: npt.NDArray[np.float64],
93
  classes: List[str],
94
  ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
95
  """
@@ -101,13 +100,10 @@ def retrieve_title(
101
  Args:
102
  boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
103
  labels (numpy.ndarray): Array of labels with shape (N,).
104
- scores (numpy.ndarray): Array of confidence scores with shape (N,).
105
  classes (list): List of class labels.
106
 
107
  Returns:
108
- numpy.ndarray [N x 4]: Array of bounding boxes (unchanged).
109
- numpy.ndarray [N]: Array of labels (potentially modified).
110
- numpy.ndarray [N]: Array of scores (unchanged).
111
  """
112
  if classes.index("chart_title") not in labels:
113
  widths = boxes[:, 2] - boxes[:, 0]
@@ -115,4 +111,4 @@ def retrieve_title(
115
  replaced = np.argmax(scores) if max(scores) > 0 else None
116
  if replaced is not None:
117
  labels[replaced] = classes.index("chart_title")
118
- return boxes, labels, scores
 
89
  def retrieve_title(
90
  boxes: npt.NDArray[np.float64],
91
  labels: npt.NDArray[np.int_],
 
92
  classes: List[str],
93
  ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
94
  """
 
100
  Args:
101
  boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
102
  labels (numpy.ndarray): Array of labels with shape (N,).
 
103
  classes (list): List of class labels.
104
 
105
  Returns:
106
+ numpy.ndarray [N]: Array of labels.
 
 
107
  """
108
  if classes.index("chart_title") not in labels:
109
  widths = boxes[:, 2] - boxes[:, 0]
 
111
  replaced = np.argmax(scores) if max(scores) > 0 else None
112
  if replaced is not None:
113
  labels[replaced] = classes.index("chart_title")
114
+ return labels