FiberGate / tools /debug_training.py
AzizMiladi's picture
chore: git mv scripts, UI, dev tools, docs into folders
70c46cc
Raw
History Blame
3.22 kB
"""
Debug script to test if model can learn on a single batch.
"""
import torch
import json
from pathlib import Path
from PIL import Image
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification, LayoutLMv3Config
from train_extractor_v3 import load_token_classifier_from_classifier_ckpt, build_bio_labels
# Setup
CLASSIFIER_CKPT = Path("models/classifier")
num_bio_labels = 25
# Create dummy model
config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base")
config.num_labels = num_bio_labels
model = LayoutLMv3ForTokenClassification(config)
# Try to load processor
try:
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
except:
print("Could not load processor")
processor = None
# Create dummy data
image = Image.new("RGB", (1000, 1000), color=(255, 255, 255))
words = ["Reference", "12345", "DLPI", "Code"]
boxes = [[100, 100, 200, 200], [250, 100, 350, 200], [400, 100, 500, 200], [550, 100, 650, 200]]
if processor:
encoding = processor(
image, words, boxes=boxes,
max_length=512, padding="max_length",
truncation=True, return_tensors="pt"
)
# Create dummy labels (some entity, some O)
labels = [-100] * 512
word_ids = encoding.word_ids(batch_index=0)
# Assign some labels: 0=O, 1=B-Reference_Urbanisme, 2=DLPI, etc
prev = None
for pos, wid in enumerate(word_ids):
if wid is None:
continue
elif wid != prev:
if wid == 0:
labels[pos] = 1 # B-Reference_Urbanisme
elif wid == 1:
labels[pos] = 0 # O
elif wid == 2:
labels[pos] = 3 # B-DLPI
else:
labels[pos] = 0 # O
prev = wid
labels = torch.tensor(labels, dtype=torch.long)
# Forward pass
with torch.no_grad():
outputs_before = model(**encoding)
pred_ids_before = outputs_before.logits.argmax(-1).squeeze().tolist()
print(f"Before training (first 20 pred_ids): {pred_ids_before[:20]}")
print(f"Expected labels (first 20): {labels[:20].tolist()}")
# Try a single training step
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for step in range(10):
optimizer.zero_grad()
outputs = model(**encoding, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
if step % 3 == 0:
print(f"Step {step}: loss={loss.item():.4f}")
# Check predictions after training
model.eval()
with torch.no_grad():
outputs_after = model(**encoding)
pred_ids_after = outputs_after.logits.argmax(-1).squeeze().tolist()
print(f"\nAfter training (first 20 pred_ids): {pred_ids_after[:20]}")
# Count non-O predictions
from collections import Counter
before_counts = Counter(pred_ids_before)
after_counts = Counter(pred_ids_after)
print(f"\nBefore - unique labels: {len(before_counts)}, label 0 (O) count: {before_counts.get(0, 0)}")
print(f"After - unique labels: {len(after_counts)}, label 0 (O) count: {after_counts.get(0, 0)}")