Spaces:
Sleeping
Sleeping
| """ | |
| Debug script to check if the extractor model is predicting entities or just "O" labels. | |
| """ | |
| import torch | |
| from pathlib import Path | |
| from PIL import Image | |
| from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor | |
| EXTRACTOR_MODEL = "models/extractor_v3" | |
| MAX_LENGTH = 512 | |
| def resolve_model_path(model_dir): | |
| model_path = Path(model_dir) | |
| if (model_path / "config.json").exists() or (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists(): | |
| return model_path | |
| checkpoints = [p for p in model_path.glob("checkpoint-*") if p.is_dir()] | |
| if checkpoints: | |
| return max(checkpoints, key=lambda p: int(p.name.split("-")[-1])) | |
| raise FileNotFoundError(f"No saved model found in {model_path}") | |
| # Load model | |
| print("Loading extractor model...") | |
| model_path = resolve_model_path(EXTRACTOR_MODEL) | |
| print(f" Using checkpoint: {model_path}") | |
| processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) | |
| model = LayoutLMv3ForTokenClassification.from_pretrained(model_path) | |
| model.eval() | |
| # Create dummy data | |
| print("\nTesting with dummy data...") | |
| image = Image.new("RGB", (1000, 1000), color=(255, 255, 255)) | |
| words = ["Reference_Urbanisme", "12345", "DLPI", "Code12"] | |
| boxes = [[100, 100, 200, 200], [250, 100, 350, 200], [400, 100, 500, 200], [550, 100, 650, 200]] | |
| encoding = processor( | |
| image, words, boxes=boxes, | |
| max_length=MAX_LENGTH, padding="max_length", | |
| truncation=True, return_tensors="pt" | |
| ) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| pred_ids = outputs.logits.argmax(-1).squeeze().tolist() | |
| word_ids = encoding.word_ids(batch_index=0) | |
| id2label = model.config.id2label | |
| print(f"\nPredicted IDs: {pred_ids[:20]}") # First 20 | |
| print(f"\nWord IDs: {word_ids[:20]}") | |
| print("\nPredictions by word:") | |
| prev_word = None | |
| for pos, word_idx in enumerate(word_ids[:20]): | |
| if word_idx is None or word_idx == prev_word: | |
| continue | |
| label = id2label.get(str(pred_ids[pos]), "O") | |
| print(f" Word {word_idx}: pred_id={pred_ids[pos]}, label='{label}'") | |
| prev_word = word_idx | |
| # Count label distribution | |
| from collections import Counter | |
| label_counts = Counter(id2label.get(str(pid), "O") for pid in pred_ids) | |
| print(f"\nLabel distribution in {len(pred_ids)} predictions:") | |
| for label, count in label_counts.most_common(): | |
| print(f" {label}: {count}") | |