Update modeling_falcon_ocr.py

#7
Files changed (1) hide show
  1. modeling_falcon_ocr.py +22 -7
modeling_falcon_ocr.py CHANGED
@@ -777,6 +777,12 @@ class FalconOCRForCausalLM(PreTrainedModel):
777
  crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
778
 
779
  for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
 
 
 
 
 
 
780
  img_w, img_h = pil_img.size
781
  for det_idx, det in enumerate(dets):
782
  cat_key = det["category"].strip().lower()
@@ -819,12 +825,21 @@ class FalconOCRForCausalLM(PreTrainedModel):
819
  # --- Reassemble per-image results ---
820
  results: list[list[dict]] = [[] for _ in range(len(pil_images))]
821
  for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
822
- det = all_layout_dets[img_idx][det_idx]
823
- results[img_idx].append({
824
- "category": det["category"],
825
- "bbox": det["bbox"],
826
- "score": det["score"],
827
- "text": text,
828
- })
 
 
 
 
 
 
 
 
 
829
 
830
  return results
 
777
  crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
778
 
779
  for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
780
+ if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"):
781
+ prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>"
782
+ flat_crops.append((pil_img, prompt))
783
+ crop_origins.append((img_idx, -1))
784
+ continue
785
+
786
  img_w, img_h = pil_img.size
787
  for det_idx, det in enumerate(dets):
788
  cat_key = det["category"].strip().lower()
 
825
  # --- Reassemble per-image results ---
826
  results: list[list[dict]] = [[] for _ in range(len(pil_images))]
827
  for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
828
+ if det_idx == -1:
829
+ img_w, img_h = pil_images[img_idx].size
830
+ results[img_idx].append({
831
+ "category": "plain",
832
+ "bbox": [0, 0, img_w, img_h],
833
+ "score": 1.0,
834
+ "text": text,
835
+ })
836
+ else:
837
+ det = all_layout_dets[img_idx][det_idx]
838
+ results[img_idx].append({
839
+ "category": det["category"],
840
+ "bbox": det["bbox"],
841
+ "score": det["score"],
842
+ "text": text,
843
+ })
844
 
845
  return results