Update modeling_falcon_ocr.py
#7
by griffintaur - opened
- modeling_falcon_ocr.py +22 -7
modeling_falcon_ocr.py
CHANGED
|
@@ -777,6 +777,12 @@ class FalconOCRForCausalLM(PreTrainedModel):
|
|
| 777 |
crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
|
| 778 |
|
| 779 |
for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
img_w, img_h = pil_img.size
|
| 781 |
for det_idx, det in enumerate(dets):
|
| 782 |
cat_key = det["category"].strip().lower()
|
|
@@ -819,12 +825,21 @@ class FalconOCRForCausalLM(PreTrainedModel):
|
|
| 819 |
# --- Reassemble per-image results ---
|
| 820 |
results: list[list[dict]] = [[] for _ in range(len(pil_images))]
|
| 821 |
for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
return results
|
|
|
|
| 777 |
crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
|
| 778 |
|
| 779 |
for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
|
| 780 |
+
if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"):
|
| 781 |
+
prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>"
|
| 782 |
+
flat_crops.append((pil_img, prompt))
|
| 783 |
+
crop_origins.append((img_idx, -1))
|
| 784 |
+
continue
|
| 785 |
+
|
| 786 |
img_w, img_h = pil_img.size
|
| 787 |
for det_idx, det in enumerate(dets):
|
| 788 |
cat_key = det["category"].strip().lower()
|
|
|
|
| 825 |
# --- Reassemble per-image results ---
|
| 826 |
results: list[list[dict]] = [[] for _ in range(len(pil_images))]
|
| 827 |
for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
|
| 828 |
+
if det_idx == -1:
|
| 829 |
+
img_w, img_h = pil_images[img_idx].size
|
| 830 |
+
results[img_idx].append({
|
| 831 |
+
"category": "plain",
|
| 832 |
+
"bbox": [0, 0, img_w, img_h],
|
| 833 |
+
"score": 1.0,
|
| 834 |
+
"text": text,
|
| 835 |
+
})
|
| 836 |
+
else:
|
| 837 |
+
det = all_layout_dets[img_idx][det_idx]
|
| 838 |
+
results[img_idx].append({
|
| 839 |
+
"category": det["category"],
|
| 840 |
+
"bbox": det["bbox"],
|
| 841 |
+
"score": det["score"],
|
| 842 |
+
"text": text,
|
| 843 |
+
})
|
| 844 |
|
| 845 |
return results
|