""" STEP 5 — Evaluate both models on the test set Output: outputs/evaluation_report.json + printed classification report """ import json import torch import numpy as np from pathlib import Path from PIL import Image Image.MAX_IMAGE_PIXELS = None from transformers import ( LayoutLMv3ForSequenceClassification, LayoutLMv3ForTokenClassification, LayoutLMv3Processor, ) from sklearn.metrics import classification_report # ── CONFIG ────────────────────────────────────────────────────────────────── TEST_JSON = "data_combined/combined_test_v2.json" MAPPINGS = "data2/label_mappings.json" CLASSIFIER_MODEL = "models/classifier" EXTRACTOR_MODEL = "models/extractor_v3" MAX_LENGTH = 512 MAX_IMAGE_SIDE = 2048 MAX_WORDS = 354 MIN_CONF = 30 def resolve_model_path(model_dir): model_path = Path(model_dir) if (model_path / "config.json").exists() or (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists(): return model_path checkpoints = [p for p in model_path.glob("checkpoint-*") if p.is_dir()] if checkpoints: return max(checkpoints, key=lambda p: int(p.name.split("-")[-1])) raise FileNotFoundError( f"No saved model found in {model_path}. Expected model.safetensors, pytorch_model.bin, or a checkpoint-* directory." ) def encode(processor, image, words, boxes): return processor( image, words, boxes=boxes, max_length=MAX_LENGTH, padding="max_length", truncation=True, return_tensors="pt" ) def load_image(image_path): if not image_path or not Path(image_path).exists(): return Image.new("RGB", (1654, 2339), (255, 255, 255)) image = Image.open(image_path).convert("RGB") if max(image.size) > MAX_IMAGE_SIDE: image.thumbnail((MAX_IMAGE_SIDE, MAX_IMAGE_SIDE)) return image def vertical_boxes_norm(words_count, img_h): if words_count <= 0: return [] word_h = max(img_h // words_count, 1) return [ [0, int(i * word_h / img_h * 1000), 1000, int((i + 1) * word_h / img_h * 1000)] for i in range(words_count) ] def vertical_boxes_px(words_count, img_w, img_h): if words_count <= 0: return [] word_h = max(img_h // words_count, 1) return [[0, i * word_h, img_w, (i + 1) * word_h] for i in range(words_count)] def load_ocr_json(rec): p = rec.get("ocr_path") or rec.get("ocr_json_path") if not p: return None pp = Path(p) if not pp.exists(): return None try: with open(pp, encoding="utf-8") as f: return json.load(f) except Exception: return None def build_words_boxes(rec): img_w = rec.get("image_width", 1654) img_h = rec.get("image_height", 2339) ocr = load_ocr_json(rec) if ocr and ocr.get("words") and ocr.get("bboxes_norm"): words_raw = ocr.get("words", [])[:MAX_WORDS] bnorm_raw = ocr.get("bboxes_norm", [])[:MAX_WORDS] bpx_raw = ocr.get("bboxes", [])[:MAX_WORDS] confs_raw = ocr.get("confs", [])[:MAX_WORDS] words, bnorm, bpx = [], [], [] for i, (w, bn) in enumerate(zip(words_raw, bnorm_raw)): conf = confs_raw[i] if i < len(confs_raw) else 100 try: conf_val = float(conf) except Exception: conf_val = 100 if conf_val < MIN_CONF: continue words.append(w) bnorm.append(bn) if i < len(bpx_raw): bpx.append(bpx_raw[i]) else: bpx.append([ int(bn[0] / 1000 * img_w), int(bn[1] / 1000 * img_h), int(bn[2] / 1000 * img_w), int(bn[3] / 1000 * img_h), ]) if words: return words, bnorm, bpx words = (rec.get("ocr_text", "") or "").split()[:MAX_WORDS] or ["[PAD]"] return words, vertical_boxes_norm(len(words), img_h), vertical_boxes_px(len(words), img_w, img_h) def main(): with open(MAPPINGS) as f: mappings = json.load(f) with open(TEST_JSON, encoding="utf-8") as f: test_data = json.load(f) doc_classes = mappings["doc_classes"] field_labels = mappings["field_labels"] field_label2id = {label: index for index, label in enumerate(field_labels)} processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) classifier = LayoutLMv3ForSequenceClassification.from_pretrained(resolve_model_path(CLASSIFIER_MODEL)) extractor = LayoutLMv3ForTokenClassification.from_pretrained(resolve_model_path(EXTRACTOR_MODEL)) classifier.eval() extractor.eval() print(f"Evaluating on {len(test_data)} test samples...\n") # ── Classification evaluation ──────────────────────────────────────────── true_classes = [] pred_classes = [] for rec in test_data: img_path = rec.get("image_path") image = load_image(img_path) words, boxes, _ = build_words_boxes(rec) encoding = encode(processor, image, words, boxes) with torch.no_grad(): logits = classifier(**encoding).logits pred_id = logits.argmax(-1).item() true_classes.append(rec["doc_class_id"]) pred_classes.append(pred_id) print("=" * 60) print("CLASSIFICATION REPORT") print("=" * 60) print(classification_report( true_classes, pred_classes, target_names=doc_classes, zero_division=0 )) clf_accuracy = (np.array(true_classes) == np.array(pred_classes)).mean() # ── Extraction evaluation ──────────────────────────────────────────────── all_true_tokens = [] all_pred_tokens = [] extractor_id2label = extractor.config.id2label for rec in test_data: if not rec.get("boxes"): continue img_path = rec.get("image_path") image = load_image(img_path) words, word_boxes, word_boxes_px = build_words_boxes(rec) encoding = encode(processor, image, words, word_boxes) word_ids = encoding.word_ids(batch_index=0) # Build true labels per token anno_boxes = rec.get("boxes", []) anno_labels = rec.get("box_label_ids", []) word_labels = [0] * len(words) for i, bbox_px in enumerate(word_boxes_px): wcx = (bbox_px[0] + bbox_px[2]) / 2 wcy = (bbox_px[1] + bbox_px[3]) / 2 for abox, lid in zip(anno_boxes, anno_labels): if abox[0] <= wcx <= abox[2] and abox[1] <= wcy <= abox[3]: word_labels[i] = lid break true_tok, pred_tok = [], [] prev = None with torch.no_grad(): preds = extractor(**encoding).logits.argmax(-1).squeeze().tolist() for pos, wi in enumerate(word_ids): if wi is None or wi == prev: prev = wi continue lbl = word_labels[wi] if wi < len(word_labels) else 0 # Ensure true label is within known field range if not isinstance(lbl, int) or lbl < 0 or lbl >= len(field_labels): lbl = 0 pred_label = extractor_id2label.get(preds[pos], extractor_id2label.get(str(preds[pos]), "O")) if pred_label.startswith("B-") or pred_label.startswith("I-"): pred_label = pred_label[2:] pred_id = field_label2id.get(pred_label, 0) true_tok.append(lbl) pred_tok.append(pred_id) prev = wi all_true_tokens.extend(true_tok) all_pred_tokens.extend(pred_tok) print("=" * 60) print("FIELD EXTRACTION REPORT") print("=" * 60) print(classification_report( all_true_tokens, all_pred_tokens, labels=list(range(len(field_labels))), target_names=field_labels, zero_division=0 )) ext_accuracy = (np.array(all_true_tokens) == np.array(all_pred_tokens)).mean() # ── Save report ────────────────────────────────────────────────────────── Path("outputs").mkdir(exist_ok=True) report = { "classification_accuracy": round(float(clf_accuracy), 4), "extraction_accuracy": round(float(ext_accuracy), 4), "test_samples": len(test_data), } with open("outputs/evaluation_report.json", "w") as f: json.dump(report, f, indent=2) print(f"\n✅ Classification accuracy : {clf_accuracy:.1%}") print(f"✅ Extraction accuracy : {ext_accuracy:.1%}") print("Report saved to: outputs/evaluation_report.json") if __name__ == "__main__": main()