FiberGate / tools /debug_extractor.py
AzizMiladi's picture
chore: git mv scripts, UI, dev tools, docs into folders
70c46cc
Raw
History Blame
2.4 kB
"""
Debug script to check if the extractor model is predicting entities or just "O" labels.
"""
import torch
from pathlib import Path
from PIL import Image
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
EXTRACTOR_MODEL = "models/extractor_v3"
MAX_LENGTH = 512
def resolve_model_path(model_dir):
model_path = Path(model_dir)
if (model_path / "config.json").exists() or (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists():
return model_path
checkpoints = [p for p in model_path.glob("checkpoint-*") if p.is_dir()]
if checkpoints:
return max(checkpoints, key=lambda p: int(p.name.split("-")[-1]))
raise FileNotFoundError(f"No saved model found in {model_path}")
# Load model
print("Loading extractor model...")
model_path = resolve_model_path(EXTRACTOR_MODEL)
print(f" Using checkpoint: {model_path}")
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
model.eval()
# Create dummy data
print("\nTesting with dummy data...")
image = Image.new("RGB", (1000, 1000), color=(255, 255, 255))
words = ["Reference_Urbanisme", "12345", "DLPI", "Code12"]
boxes = [[100, 100, 200, 200], [250, 100, 350, 200], [400, 100, 500, 200], [550, 100, 650, 200]]
encoding = processor(
image, words, boxes=boxes,
max_length=MAX_LENGTH, padding="max_length",
truncation=True, return_tensors="pt"
)
# Run inference
with torch.no_grad():
outputs = model(**encoding)
pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
word_ids = encoding.word_ids(batch_index=0)
id2label = model.config.id2label
print(f"\nPredicted IDs: {pred_ids[:20]}") # First 20
print(f"\nWord IDs: {word_ids[:20]}")
print("\nPredictions by word:")
prev_word = None
for pos, word_idx in enumerate(word_ids[:20]):
if word_idx is None or word_idx == prev_word:
continue
label = id2label.get(str(pred_ids[pos]), "O")
print(f" Word {word_idx}: pred_id={pred_ids[pos]}, label='{label}'")
prev_word = word_idx
# Count label distribution
from collections import Counter
label_counts = Counter(id2label.get(str(pid), "O") for pid in pred_ids)
print(f"\nLabel distribution in {len(pred_ids)} predictions:")
for label, count in label_counts.most_common():
print(f" {label}: {count}")