Spaces:
Build error
Build error
Update PanopticQuality.py
Browse files- PanopticQuality.py +3 -4
PanopticQuality.py
CHANGED
|
@@ -113,6 +113,7 @@ class PQMetric(evaluate.Metric):
|
|
| 113 |
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
|
| 114 |
self.per_class = per_class
|
| 115 |
self.split_sq_rq = split_sq_rq
|
|
|
|
| 116 |
self.pq_metric = PanopticQuality(
|
| 117 |
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
|
| 118 |
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff]),
|
|
@@ -172,8 +173,6 @@ class PQMetric(evaluate.Metric):
|
|
| 172 |
fn = self.pq_metric.metric.false_negatives.clone()
|
| 173 |
iou = self.pq_metric.metric.iou_sum.clone()
|
| 174 |
|
| 175 |
-
things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
|
| 176 |
-
|
| 177 |
# compute scores
|
| 178 |
result = self.pq_metric.compute() # shape : (n_classes (sorted things + sorted stuffs), scores (pq, sq, rq))
|
| 179 |
|
|
@@ -183,10 +182,10 @@ class PQMetric(evaluate.Metric):
|
|
| 183 |
if not self.split_sq_rq:
|
| 184 |
result = result.T
|
| 185 |
result_dict["scores"] = {self.id2label[numeric_label]: result[i].tolist() \
|
| 186 |
-
for i, numeric_label in enumerate(things_stuffs)}
|
| 187 |
result_dict["scores"].update({"ALL": result.mean(axis=0).tolist()})
|
| 188 |
result_dict["numbers"] = {self.id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
|
| 189 |
-
for i, numeric_label in enumerate(things_stuffs)}
|
| 190 |
result_dict["numbers"].update({"ALL": [tp.sum().item(), fp.sum().item(), fn.sum().item(), iou.sum().item()]})
|
| 191 |
else:
|
| 192 |
result_dict["scores"] = {"ALL": result.tolist() if self.split_sq_rq else [result.tolist()]}
|
|
|
|
| 113 |
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
|
| 114 |
self.per_class = per_class
|
| 115 |
self.split_sq_rq = split_sq_rq
|
| 116 |
+
self.things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
|
| 117 |
self.pq_metric = PanopticQuality(
|
| 118 |
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
|
| 119 |
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff]),
|
|
|
|
| 173 |
fn = self.pq_metric.metric.false_negatives.clone()
|
| 174 |
iou = self.pq_metric.metric.iou_sum.clone()
|
| 175 |
|
|
|
|
|
|
|
| 176 |
# compute scores
|
| 177 |
result = self.pq_metric.compute() # shape : (n_classes (sorted things + sorted stuffs), scores (pq, sq, rq))
|
| 178 |
|
|
|
|
| 182 |
if not self.split_sq_rq:
|
| 183 |
result = result.T
|
| 184 |
result_dict["scores"] = {self.id2label[numeric_label]: result[i].tolist() \
|
| 185 |
+
for i, numeric_label in enumerate(self.things_stuffs)}
|
| 186 |
result_dict["scores"].update({"ALL": result.mean(axis=0).tolist()})
|
| 187 |
result_dict["numbers"] = {self.id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
|
| 188 |
+
for i, numeric_label in enumerate(self.things_stuffs)}
|
| 189 |
result_dict["numbers"].update({"ALL": [tp.sum().item(), fp.sum().item(), fn.sum().item(), iou.sum().item()]})
|
| 190 |
else:
|
| 191 |
result_dict["scores"] = {"ALL": result.tolist() if self.split_sq_rq else [result.tolist()]}
|