""" Debug script to check if the extractor model is predicting entities or just "O" labels. """ import torch from pathlib import Path from PIL import Image from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor EXTRACTOR_MODEL = "models/extractor_v3" MAX_LENGTH = 512 def resolve_model_path(model_dir): model_path = Path(model_dir) if (model_path / "config.json").exists() or (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists(): return model_path checkpoints = [p for p in model_path.glob("checkpoint-*") if p.is_dir()] if checkpoints: return max(checkpoints, key=lambda p: int(p.name.split("-")[-1])) raise FileNotFoundError(f"No saved model found in {model_path}") # Load model print("Loading extractor model...") model_path = resolve_model_path(EXTRACTOR_MODEL) print(f" Using checkpoint: {model_path}") processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) model = LayoutLMv3ForTokenClassification.from_pretrained(model_path) model.eval() # Create dummy data print("\nTesting with dummy data...") image = Image.new("RGB", (1000, 1000), color=(255, 255, 255)) words = ["Reference_Urbanisme", "12345", "DLPI", "Code12"] boxes = [[100, 100, 200, 200], [250, 100, 350, 200], [400, 100, 500, 200], [550, 100, 650, 200]] encoding = processor( image, words, boxes=boxes, max_length=MAX_LENGTH, padding="max_length", truncation=True, return_tensors="pt" ) # Run inference with torch.no_grad(): outputs = model(**encoding) pred_ids = outputs.logits.argmax(-1).squeeze().tolist() word_ids = encoding.word_ids(batch_index=0) id2label = model.config.id2label print(f"\nPredicted IDs: {pred_ids[:20]}") # First 20 print(f"\nWord IDs: {word_ids[:20]}") print("\nPredictions by word:") prev_word = None for pos, word_idx in enumerate(word_ids[:20]): if word_idx is None or word_idx == prev_word: continue label = id2label.get(str(pred_ids[pos]), "O") print(f" Word {word_idx}: pred_id={pred_ids[pos]}, label='{label}'") prev_word = word_idx # Count label distribution from collections import Counter label_counts = Counter(id2label.get(str(pid), "O") for pid in pred_ids) print(f"\nLabel distribution in {len(pred_ids)} predictions:") for label, count in label_counts.most_common(): print(f" {label}: {count}")