Update modelling_magi.py
Browse files- modelling_magi.py +15 -1
modelling_magi.py
CHANGED
|
@@ -242,12 +242,15 @@ class MagiModel(PreTrainedModel):
|
|
| 242 |
file.write(transript)
|
| 243 |
return transript
|
| 244 |
|
| 245 |
-
def
|
| 246 |
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
| 247 |
):
|
| 248 |
assert not self.config.disable_detections
|
| 249 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 250 |
|
|
|
|
|
|
|
|
|
|
| 251 |
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
| 252 |
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
| 253 |
processed_targets = inputs_to_detection_transformer.pop("labels")
|
|
@@ -255,6 +258,7 @@ class MagiModel(PreTrainedModel):
|
|
| 255 |
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
| 256 |
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
| 257 |
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
|
|
|
| 258 |
|
| 259 |
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
| 260 |
matching_dict = {
|
|
@@ -266,6 +270,7 @@ class MagiModel(PreTrainedModel):
|
|
| 266 |
matched_char_obj_tokens_for_batch = []
|
| 267 |
matched_text_obj_tokens_for_batch = []
|
| 268 |
t2c_tokens_for_batch = []
|
|
|
|
| 269 |
|
| 270 |
text_bboxes_for_batch = []
|
| 271 |
character_bboxes_for_batch = []
|
|
@@ -288,6 +293,7 @@ class MagiModel(PreTrainedModel):
|
|
| 288 |
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
| 289 |
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
| 290 |
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
|
|
|
| 291 |
|
| 292 |
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
| 293 |
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
|
@@ -296,8 +302,16 @@ class MagiModel(PreTrainedModel):
|
|
| 296 |
apply_sigmoid=apply_sigmoid,
|
| 297 |
)
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
return {
|
| 300 |
"text_character_affinity_matrices": text_character_affinity_matrices,
|
|
|
|
| 301 |
"text_bboxes_for_batch": text_bboxes_for_batch,
|
| 302 |
"character_bboxes_for_batch": character_bboxes_for_batch,
|
| 303 |
}
|
|
|
|
| 242 |
file.write(transript)
|
| 243 |
return transript
|
| 244 |
|
| 245 |
+
def get_affinity_matrices_given_annotations(
|
| 246 |
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
| 247 |
):
|
| 248 |
assert not self.config.disable_detections
|
| 249 |
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 250 |
|
| 251 |
+
character_bboxes_in_batch = [[bbox for bbox, label in zip(a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
|
| 252 |
+
crop_embeddings_for_batch = self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn)
|
| 253 |
+
|
| 254 |
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
| 255 |
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
| 256 |
processed_targets = inputs_to_detection_transformer.pop("labels")
|
|
|
|
| 258 |
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
| 259 |
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
| 260 |
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
| 261 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
|
| 262 |
|
| 263 |
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
| 264 |
matching_dict = {
|
|
|
|
| 270 |
matched_char_obj_tokens_for_batch = []
|
| 271 |
matched_text_obj_tokens_for_batch = []
|
| 272 |
t2c_tokens_for_batch = []
|
| 273 |
+
c2c_tokens_for_batch = []
|
| 274 |
|
| 275 |
text_bboxes_for_batch = []
|
| 276 |
character_bboxes_for_batch = []
|
|
|
|
| 293 |
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
| 294 |
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
| 295 |
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
| 296 |
+
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
|
| 297 |
|
| 298 |
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
| 299 |
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
|
|
|
| 302 |
apply_sigmoid=apply_sigmoid,
|
| 303 |
)
|
| 304 |
|
| 305 |
+
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
|
| 306 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
| 307 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
| 308 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
| 309 |
+
apply_sigmoid=apply_sigmoid,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
return {
|
| 313 |
"text_character_affinity_matrices": text_character_affinity_matrices,
|
| 314 |
+
"character_character_affinity_matrices": character_character_affinity_matrices,
|
| 315 |
"text_bboxes_for_batch": text_bboxes_for_batch,
|
| 316 |
"character_bboxes_for_batch": character_bboxes_for_batch,
|
| 317 |
}
|