FiberGate / scripts /05_evaluate.py
AzizMiladi's picture
chore: git mv scripts, UI, dev tools, docs into folders
70c46cc
Raw
History Blame
9.15 kB
"""
STEP 5 β€” Evaluate both models on the test set
Output: outputs/evaluation_report.json + printed classification report
"""
import json
import torch
import numpy as np
from pathlib import Path
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from transformers import (
LayoutLMv3ForSequenceClassification,
LayoutLMv3ForTokenClassification,
LayoutLMv3Processor,
)
from sklearn.metrics import classification_report
# ── CONFIG ──────────────────────────────────────────────────────────────────
TEST_JSON = "data_combined/combined_test_v2.json"
MAPPINGS = "data2/label_mappings.json"
CLASSIFIER_MODEL = "models/classifier"
EXTRACTOR_MODEL = "models/extractor_v3"
MAX_LENGTH = 512
MAX_IMAGE_SIDE = 2048
MAX_WORDS = 354
MIN_CONF = 30
def resolve_model_path(model_dir):
model_path = Path(model_dir)
if (model_path / "config.json").exists() or (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists():
return model_path
checkpoints = [p for p in model_path.glob("checkpoint-*") if p.is_dir()]
if checkpoints:
return max(checkpoints, key=lambda p: int(p.name.split("-")[-1]))
raise FileNotFoundError(
f"No saved model found in {model_path}. Expected model.safetensors, pytorch_model.bin, or a checkpoint-* directory."
)
def encode(processor, image, words, boxes):
return processor(
image, words, boxes=boxes,
max_length=MAX_LENGTH, padding="max_length",
truncation=True, return_tensors="pt"
)
def load_image(image_path):
if not image_path or not Path(image_path).exists():
return Image.new("RGB", (1654, 2339), (255, 255, 255))
image = Image.open(image_path).convert("RGB")
if max(image.size) > MAX_IMAGE_SIDE:
image.thumbnail((MAX_IMAGE_SIDE, MAX_IMAGE_SIDE))
return image
def vertical_boxes_norm(words_count, img_h):
if words_count <= 0:
return []
word_h = max(img_h // words_count, 1)
return [
[0, int(i * word_h / img_h * 1000), 1000, int((i + 1) * word_h / img_h * 1000)]
for i in range(words_count)
]
def vertical_boxes_px(words_count, img_w, img_h):
if words_count <= 0:
return []
word_h = max(img_h // words_count, 1)
return [[0, i * word_h, img_w, (i + 1) * word_h] for i in range(words_count)]
def load_ocr_json(rec):
p = rec.get("ocr_path") or rec.get("ocr_json_path")
if not p:
return None
pp = Path(p)
if not pp.exists():
return None
try:
with open(pp, encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def build_words_boxes(rec):
img_w = rec.get("image_width", 1654)
img_h = rec.get("image_height", 2339)
ocr = load_ocr_json(rec)
if ocr and ocr.get("words") and ocr.get("bboxes_norm"):
words_raw = ocr.get("words", [])[:MAX_WORDS]
bnorm_raw = ocr.get("bboxes_norm", [])[:MAX_WORDS]
bpx_raw = ocr.get("bboxes", [])[:MAX_WORDS]
confs_raw = ocr.get("confs", [])[:MAX_WORDS]
words, bnorm, bpx = [], [], []
for i, (w, bn) in enumerate(zip(words_raw, bnorm_raw)):
conf = confs_raw[i] if i < len(confs_raw) else 100
try:
conf_val = float(conf)
except Exception:
conf_val = 100
if conf_val < MIN_CONF:
continue
words.append(w)
bnorm.append(bn)
if i < len(bpx_raw):
bpx.append(bpx_raw[i])
else:
bpx.append([
int(bn[0] / 1000 * img_w),
int(bn[1] / 1000 * img_h),
int(bn[2] / 1000 * img_w),
int(bn[3] / 1000 * img_h),
])
if words:
return words, bnorm, bpx
words = (rec.get("ocr_text", "") or "").split()[:MAX_WORDS] or ["[PAD]"]
return words, vertical_boxes_norm(len(words), img_h), vertical_boxes_px(len(words), img_w, img_h)
def main():
with open(MAPPINGS) as f:
mappings = json.load(f)
with open(TEST_JSON, encoding="utf-8") as f:
test_data = json.load(f)
doc_classes = mappings["doc_classes"]
field_labels = mappings["field_labels"]
field_label2id = {label: index for index, label in enumerate(field_labels)}
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
classifier = LayoutLMv3ForSequenceClassification.from_pretrained(resolve_model_path(CLASSIFIER_MODEL))
extractor = LayoutLMv3ForTokenClassification.from_pretrained(resolve_model_path(EXTRACTOR_MODEL))
classifier.eval()
extractor.eval()
print(f"Evaluating on {len(test_data)} test samples...\n")
# ── Classification evaluation ────────────────────────────────────────────
true_classes = []
pred_classes = []
for rec in test_data:
img_path = rec.get("image_path")
image = load_image(img_path)
words, boxes, _ = build_words_boxes(rec)
encoding = encode(processor, image, words, boxes)
with torch.no_grad():
logits = classifier(**encoding).logits
pred_id = logits.argmax(-1).item()
true_classes.append(rec["doc_class_id"])
pred_classes.append(pred_id)
print("=" * 60)
print("CLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(
true_classes, pred_classes,
target_names=doc_classes,
zero_division=0
))
clf_accuracy = (np.array(true_classes) == np.array(pred_classes)).mean()
# ── Extraction evaluation ────────────────────────────────────────────────
all_true_tokens = []
all_pred_tokens = []
extractor_id2label = extractor.config.id2label
for rec in test_data:
if not rec.get("boxes"):
continue
img_path = rec.get("image_path")
image = load_image(img_path)
words, word_boxes, word_boxes_px = build_words_boxes(rec)
encoding = encode(processor, image, words, word_boxes)
word_ids = encoding.word_ids(batch_index=0)
# Build true labels per token
anno_boxes = rec.get("boxes", [])
anno_labels = rec.get("box_label_ids", [])
word_labels = [0] * len(words)
for i, bbox_px in enumerate(word_boxes_px):
wcx = (bbox_px[0] + bbox_px[2]) / 2
wcy = (bbox_px[1] + bbox_px[3]) / 2
for abox, lid in zip(anno_boxes, anno_labels):
if abox[0] <= wcx <= abox[2] and abox[1] <= wcy <= abox[3]:
word_labels[i] = lid
break
true_tok, pred_tok = [], []
prev = None
with torch.no_grad():
preds = extractor(**encoding).logits.argmax(-1).squeeze().tolist()
for pos, wi in enumerate(word_ids):
if wi is None or wi == prev:
prev = wi
continue
lbl = word_labels[wi] if wi < len(word_labels) else 0
# Ensure true label is within known field range
if not isinstance(lbl, int) or lbl < 0 or lbl >= len(field_labels):
lbl = 0
pred_label = extractor_id2label.get(preds[pos], extractor_id2label.get(str(preds[pos]), "O"))
if pred_label.startswith("B-") or pred_label.startswith("I-"):
pred_label = pred_label[2:]
pred_id = field_label2id.get(pred_label, 0)
true_tok.append(lbl)
pred_tok.append(pred_id)
prev = wi
all_true_tokens.extend(true_tok)
all_pred_tokens.extend(pred_tok)
print("=" * 60)
print("FIELD EXTRACTION REPORT")
print("=" * 60)
print(classification_report(
all_true_tokens, all_pred_tokens,
labels=list(range(len(field_labels))),
target_names=field_labels,
zero_division=0
))
ext_accuracy = (np.array(all_true_tokens) == np.array(all_pred_tokens)).mean()
# ── Save report ──────────────────────────────────────────────────────────
Path("outputs").mkdir(exist_ok=True)
report = {
"classification_accuracy": round(float(clf_accuracy), 4),
"extraction_accuracy": round(float(ext_accuracy), 4),
"test_samples": len(test_data),
}
with open("outputs/evaluation_report.json", "w") as f:
json.dump(report, f, indent=2)
print(f"\nβœ… Classification accuracy : {clf_accuracy:.1%}")
print(f"βœ… Extraction accuracy : {ext_accuracy:.1%}")
print("Report saved to: outputs/evaluation_report.json")
if __name__ == "__main__":
main()