Spaces:
Sleeping
Sleeping
| """ | |
| STEP 5 β Evaluate both models on the test set | |
| Output: outputs/evaluation_report.json + printed classification report | |
| """ | |
| import json | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| Image.MAX_IMAGE_PIXELS = None | |
| from transformers import ( | |
| LayoutLMv3ForSequenceClassification, | |
| LayoutLMv3ForTokenClassification, | |
| LayoutLMv3Processor, | |
| ) | |
| from sklearn.metrics import classification_report | |
| # ββ CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TEST_JSON = "data_combined/combined_test_v2.json" | |
| MAPPINGS = "data2/label_mappings.json" | |
| CLASSIFIER_MODEL = "models/classifier" | |
| EXTRACTOR_MODEL = "models/extractor_v3" | |
| MAX_LENGTH = 512 | |
| MAX_IMAGE_SIDE = 2048 | |
| MAX_WORDS = 354 | |
| MIN_CONF = 30 | |
| def resolve_model_path(model_dir): | |
| model_path = Path(model_dir) | |
| if (model_path / "config.json").exists() or (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists(): | |
| return model_path | |
| checkpoints = [p for p in model_path.glob("checkpoint-*") if p.is_dir()] | |
| if checkpoints: | |
| return max(checkpoints, key=lambda p: int(p.name.split("-")[-1])) | |
| raise FileNotFoundError( | |
| f"No saved model found in {model_path}. Expected model.safetensors, pytorch_model.bin, or a checkpoint-* directory." | |
| ) | |
| def encode(processor, image, words, boxes): | |
| return processor( | |
| image, words, boxes=boxes, | |
| max_length=MAX_LENGTH, padding="max_length", | |
| truncation=True, return_tensors="pt" | |
| ) | |
| def load_image(image_path): | |
| if not image_path or not Path(image_path).exists(): | |
| return Image.new("RGB", (1654, 2339), (255, 255, 255)) | |
| image = Image.open(image_path).convert("RGB") | |
| if max(image.size) > MAX_IMAGE_SIDE: | |
| image.thumbnail((MAX_IMAGE_SIDE, MAX_IMAGE_SIDE)) | |
| return image | |
| def vertical_boxes_norm(words_count, img_h): | |
| if words_count <= 0: | |
| return [] | |
| word_h = max(img_h // words_count, 1) | |
| return [ | |
| [0, int(i * word_h / img_h * 1000), 1000, int((i + 1) * word_h / img_h * 1000)] | |
| for i in range(words_count) | |
| ] | |
| def vertical_boxes_px(words_count, img_w, img_h): | |
| if words_count <= 0: | |
| return [] | |
| word_h = max(img_h // words_count, 1) | |
| return [[0, i * word_h, img_w, (i + 1) * word_h] for i in range(words_count)] | |
| def load_ocr_json(rec): | |
| p = rec.get("ocr_path") or rec.get("ocr_json_path") | |
| if not p: | |
| return None | |
| pp = Path(p) | |
| if not pp.exists(): | |
| return None | |
| try: | |
| with open(pp, encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return None | |
| def build_words_boxes(rec): | |
| img_w = rec.get("image_width", 1654) | |
| img_h = rec.get("image_height", 2339) | |
| ocr = load_ocr_json(rec) | |
| if ocr and ocr.get("words") and ocr.get("bboxes_norm"): | |
| words_raw = ocr.get("words", [])[:MAX_WORDS] | |
| bnorm_raw = ocr.get("bboxes_norm", [])[:MAX_WORDS] | |
| bpx_raw = ocr.get("bboxes", [])[:MAX_WORDS] | |
| confs_raw = ocr.get("confs", [])[:MAX_WORDS] | |
| words, bnorm, bpx = [], [], [] | |
| for i, (w, bn) in enumerate(zip(words_raw, bnorm_raw)): | |
| conf = confs_raw[i] if i < len(confs_raw) else 100 | |
| try: | |
| conf_val = float(conf) | |
| except Exception: | |
| conf_val = 100 | |
| if conf_val < MIN_CONF: | |
| continue | |
| words.append(w) | |
| bnorm.append(bn) | |
| if i < len(bpx_raw): | |
| bpx.append(bpx_raw[i]) | |
| else: | |
| bpx.append([ | |
| int(bn[0] / 1000 * img_w), | |
| int(bn[1] / 1000 * img_h), | |
| int(bn[2] / 1000 * img_w), | |
| int(bn[3] / 1000 * img_h), | |
| ]) | |
| if words: | |
| return words, bnorm, bpx | |
| words = (rec.get("ocr_text", "") or "").split()[:MAX_WORDS] or ["[PAD]"] | |
| return words, vertical_boxes_norm(len(words), img_h), vertical_boxes_px(len(words), img_w, img_h) | |
| def main(): | |
| with open(MAPPINGS) as f: | |
| mappings = json.load(f) | |
| with open(TEST_JSON, encoding="utf-8") as f: | |
| test_data = json.load(f) | |
| doc_classes = mappings["doc_classes"] | |
| field_labels = mappings["field_labels"] | |
| field_label2id = {label: index for index, label in enumerate(field_labels)} | |
| processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) | |
| classifier = LayoutLMv3ForSequenceClassification.from_pretrained(resolve_model_path(CLASSIFIER_MODEL)) | |
| extractor = LayoutLMv3ForTokenClassification.from_pretrained(resolve_model_path(EXTRACTOR_MODEL)) | |
| classifier.eval() | |
| extractor.eval() | |
| print(f"Evaluating on {len(test_data)} test samples...\n") | |
| # ββ Classification evaluation ββββββββββββββββββββββββββββββββββββββββββββ | |
| true_classes = [] | |
| pred_classes = [] | |
| for rec in test_data: | |
| img_path = rec.get("image_path") | |
| image = load_image(img_path) | |
| words, boxes, _ = build_words_boxes(rec) | |
| encoding = encode(processor, image, words, boxes) | |
| with torch.no_grad(): | |
| logits = classifier(**encoding).logits | |
| pred_id = logits.argmax(-1).item() | |
| true_classes.append(rec["doc_class_id"]) | |
| pred_classes.append(pred_id) | |
| print("=" * 60) | |
| print("CLASSIFICATION REPORT") | |
| print("=" * 60) | |
| print(classification_report( | |
| true_classes, pred_classes, | |
| target_names=doc_classes, | |
| zero_division=0 | |
| )) | |
| clf_accuracy = (np.array(true_classes) == np.array(pred_classes)).mean() | |
| # ββ Extraction evaluation ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| all_true_tokens = [] | |
| all_pred_tokens = [] | |
| extractor_id2label = extractor.config.id2label | |
| for rec in test_data: | |
| if not rec.get("boxes"): | |
| continue | |
| img_path = rec.get("image_path") | |
| image = load_image(img_path) | |
| words, word_boxes, word_boxes_px = build_words_boxes(rec) | |
| encoding = encode(processor, image, words, word_boxes) | |
| word_ids = encoding.word_ids(batch_index=0) | |
| # Build true labels per token | |
| anno_boxes = rec.get("boxes", []) | |
| anno_labels = rec.get("box_label_ids", []) | |
| word_labels = [0] * len(words) | |
| for i, bbox_px in enumerate(word_boxes_px): | |
| wcx = (bbox_px[0] + bbox_px[2]) / 2 | |
| wcy = (bbox_px[1] + bbox_px[3]) / 2 | |
| for abox, lid in zip(anno_boxes, anno_labels): | |
| if abox[0] <= wcx <= abox[2] and abox[1] <= wcy <= abox[3]: | |
| word_labels[i] = lid | |
| break | |
| true_tok, pred_tok = [], [] | |
| prev = None | |
| with torch.no_grad(): | |
| preds = extractor(**encoding).logits.argmax(-1).squeeze().tolist() | |
| for pos, wi in enumerate(word_ids): | |
| if wi is None or wi == prev: | |
| prev = wi | |
| continue | |
| lbl = word_labels[wi] if wi < len(word_labels) else 0 | |
| # Ensure true label is within known field range | |
| if not isinstance(lbl, int) or lbl < 0 or lbl >= len(field_labels): | |
| lbl = 0 | |
| pred_label = extractor_id2label.get(preds[pos], extractor_id2label.get(str(preds[pos]), "O")) | |
| if pred_label.startswith("B-") or pred_label.startswith("I-"): | |
| pred_label = pred_label[2:] | |
| pred_id = field_label2id.get(pred_label, 0) | |
| true_tok.append(lbl) | |
| pred_tok.append(pred_id) | |
| prev = wi | |
| all_true_tokens.extend(true_tok) | |
| all_pred_tokens.extend(pred_tok) | |
| print("=" * 60) | |
| print("FIELD EXTRACTION REPORT") | |
| print("=" * 60) | |
| print(classification_report( | |
| all_true_tokens, all_pred_tokens, | |
| labels=list(range(len(field_labels))), | |
| target_names=field_labels, | |
| zero_division=0 | |
| )) | |
| ext_accuracy = (np.array(all_true_tokens) == np.array(all_pred_tokens)).mean() | |
| # ββ Save report ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Path("outputs").mkdir(exist_ok=True) | |
| report = { | |
| "classification_accuracy": round(float(clf_accuracy), 4), | |
| "extraction_accuracy": round(float(ext_accuracy), 4), | |
| "test_samples": len(test_data), | |
| } | |
| with open("outputs/evaluation_report.json", "w") as f: | |
| json.dump(report, f, indent=2) | |
| print(f"\nβ Classification accuracy : {clf_accuracy:.1%}") | |
| print(f"β Extraction accuracy : {ext_accuracy:.1%}") | |
| print("Report saved to: outputs/evaluation_report.json") | |
| if __name__ == "__main__": | |
| main() | |