Spaces:
Sleeping
Sleeping
| """ | |
| STEP 3 β Train Field Extraction Model (LayoutLMv3 Token Classification) | |
| v3 β fixes 9 bugs identified across previous audits. | |
| CHANGELOG vs v2: | |
| FIX 1 β Dimension rescaling (NEW, v3 critical) | |
| βββββββββββββββββββββββββββββββββββββββββββββ | |
| Annotation bboxes in combined_*.json were made on resized images | |
| (e.g., 1654Γ2339) but the OCR was run on differently-sized images | |
| (e.g., 1700Γ2200, 1698Γ2337). v2 used annotation bboxes verbatim against | |
| OCR coordinates, so spatial matching missed by ~6-10% per axis. | |
| Fix: rescale annotation bboxes to OCR coordinate space using | |
| `image_width`/`image_height` from the record vs `width`/`height` from | |
| the OCR file. | |
| FIX 2 β kept_bboxes parallel list in pass 2 (from previous report) | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| v2 pass 2 looked up `bboxes[i]` where i was the FILTERED index but | |
| bboxes was the RAW list β silent index drift after any conf-filtered word. | |
| Fix: track `kept_bboxes` aligned to `word_labels`. | |
| FIX 3 β MIN_CONF lowered 60 β 30 (from previous report) | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Many critical reference numbers (PC, DP, PA codes) have OCR conf 30-50 | |
| because of compact fonts. At MIN_CONF=60 they were silently dropped. | |
| Lowering to 30 recovers them with low risk of training on garbage. | |
| FIX 4 β OCR/image path remapping (NEW, v3) | |
| βββββββββββββββββββββββββββββββββββββββββββ | |
| combined_*.json contains Windows absolute paths (C:\\...). On Linux | |
| training machines these never resolve. Added OCR_BASE_REMAP that | |
| rewrites Windows paths to a configurable local base. | |
| FIX 5 β Siret label_id bug | |
| ββββββββββββββββββββββββββ | |
| combined_*.json has 17 records with `box_labels=['...', 'Siret', ...]` | |
| and `box_label_ids=[..., 0, ...]` β Siret maps to "O" (background). | |
| Either it's a labelling mistake or Siret is missing from | |
| label_mappings.json. v3 strips Siret annotations before training. | |
| TODO: decide with the data team whether Siret should be added as label 13. | |
| FIX 6 β Class weights from TOKEN counts, not BOX counts (NEW, v3) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| v2 computed weights from the 863 box-level annotation counts. But the | |
| model loss is per-token, and after BIO expansion + sub-word tokenisation | |
| there are ~50,000 tokens of which 95% are "O". Computing weights from | |
| box counts gives "O" weight=5, but in token space "O" should have | |
| weightβ0.5. v3 estimates token counts by multiplying box count by an | |
| average-words-per-box factor, then computing inverse-frequency. | |
| FIX 7 β Span-level (entity-level) F1 added (NEW, v3) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| v2 reports BIO-token F1 only. v3 also computes per-field span F1 using | |
| seqeval, which is what users actually care about. | |
| FIX 8 β Train/val/test split documentation (NEW, v3) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| combined_*.json has 92 PDFs whose pages appear in BOTH train and val/test. | |
| v3 logs this and recommends regenerating splits at the SOURCE-PDF level. | |
| Until splits are regenerated, val/test F1 is overestimated. | |
| FIX 9 β Reproducible unannotated sampling | |
| ββββββββββββββββββββββββββββββββββββββββββ | |
| v3 uses a hashed record ID instead of random.random() so the sampling | |
| decision is deterministic per-record across runs and resumes. | |
| """ | |
| import json | |
| import os | |
| import random | |
| import hashlib | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| LayoutLMv3Config, | |
| LayoutLMv3ForSequenceClassification, | |
| LayoutLMv3ForTokenClassification, | |
| LayoutLMv3Processor, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).resolve().parent | |
| DATA_DIR = BASE_DIR / "data_combined" | |
| TRAIN_JSON = DATA_DIR / "combined_train_v3.json" | |
| VAL_JSON = DATA_DIR / "combined_val_v3.json" | |
| TEST_JSON = DATA_DIR / "combined_test_v3.json" | |
| MAPPINGS = BASE_DIR / "data2" / "label_mappings.json" | |
| MODEL_OUTPUT = BASE_DIR / "models" / "extractor_v3" | |
| CLASSIFIER_CKPT = BASE_DIR / "models" / "classifier" | |
| FALLBACK_BASE = "microsoft/layoutlmv3-base" | |
| # Path remapping β Windows paths in combined_*.json -> local Linux path | |
| # Set this to wherever you copied the original dataset on the training machine. | |
| # Example: WINDOWS_PREFIX="C:\\Users\\azizmohamed.miladi_a\\Desktop\\GuichetOI_ML" | |
| # LINUX_PREFIX="/data/GuichetOI_ML" | |
| WINDOWS_PREFIX = os.environ.get( | |
| "OCR_WIN_PREFIX", | |
| "C:\\Users\\azizmohamed.miladi_a\\Desktop\\GuichetOI_ML" | |
| ) | |
| LINUX_PREFIX = os.environ.get( | |
| "OCR_LINUX_PREFIX", | |
| "/data/GuichetOI_ML" | |
| ) | |
| MAX_WORDS = 300 # was 354 β at ~1.6 wp/word, 354 overflowed MAX_LENGTH=512 wp budget | |
| MAX_LENGTH = 512 | |
| BATCH_SIZE = 2 | |
| GRAD_ACCUM = 4 | |
| EPOCHS = 15 | |
| LEARNING_RATE = 2e-5 | |
| WARMUP_STEPS = 248 | |
| WEIGHT_DECAY = 0.01 | |
| UNANNOTATED_SAMPLE_RATE = 0.20 | |
| MIN_CONF = 30 # was 60 in v2 β see FIX 3 | |
| # Average words inside an annotation bbox β used for token-level weight estimation | |
| AVG_TOKENS_PER_BOX = 4.0 | |
| # ββ BIO LABEL BUILDER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_bio_labels(base_field_labels): | |
| bio_labels = ["O"] | |
| for lbl in base_field_labels: | |
| if lbl == "O": continue | |
| bio_labels.append(f"B-{lbl}") | |
| bio_labels.append(f"I-{lbl}") | |
| return bio_labels, {l: i for i, l in enumerate(bio_labels)}, \ | |
| {i: l for i, l in enumerate(bio_labels)} | |
| # ββ PATH REMAPPING (FIX 4) ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def remap_path(p: str) -> str: | |
| if not p: | |
| return p | |
| if Path(p).exists(): | |
| return p | |
| if p.startswith(WINDOWS_PREFIX): | |
| p = p.replace(WINDOWS_PREFIX, LINUX_PREFIX, 1) | |
| return p.replace("\\", os.sep) | |
| # ββ OCR JSON LOADER (FIX 4) βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_ocr_json(ocr_path): | |
| p = remap_path(ocr_path) | |
| if not p or not Path(p).exists(): | |
| return None | |
| try: | |
| with open(p, encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return None | |
| # ββ BBOX RESCALING (FIX 1 β CRITICAL) βββββββββββββββββββββββββββββββββββββββββ | |
| def rescale_boxes(boxes, src_w, src_h, dst_w, dst_h): | |
| """Rescale annotation boxes from annotation-image coords β OCR-image coords.""" | |
| if (src_w, src_h) == (dst_w, dst_h): | |
| return boxes | |
| sx = dst_w / src_w | |
| sy = dst_h / src_h | |
| return [[int(b[0]*sx), int(b[1]*sy), int(b[2]*sx), int(b[3]*sy)] for b in boxes] | |
| # ββ LABEL ASSIGNMENT (FIX 1, 2, 3, 10 combined) ββββββββββββββββββββββββββββββ | |
| # Wordpiece budget the tokenizer can fit (MAX_LENGTH minus a small safety | |
| # margin for special tokens like CLS/SEP and padding alignment). | |
| WP_BUDGET = MAX_LENGTH - 4 | |
| def assign_word_labels_exact(ocr_data, anno_boxes, anno_label_ids, | |
| flat_label2id, bio_label2id, | |
| tokenizer=None, min_conf=MIN_CONF): | |
| """Exact spatial matching with all 4 fixes applied. | |
| FIX 10 (v3.1) β annotation-preserving, wordpiece-aware truncation: | |
| Naively slicing words to [:MAX_WORDS] discarded annotations past that | |
| index. Worse, the tokenizer then truncated again at MAX_LENGTH=512 | |
| WORDPIECES β and French OCR averages ~1.6-2.6 wp/word, so 300 OCR | |
| words β 480-780 wp. Logement annotations sit at the bottom of fiches | |
| (word indices 200-300), so >90% of Nb_log_pro / Nb_log_res labels were | |
| silently truncated, never reaching the model or the eval metrics. | |
| Fix: walk ALL conf-filtered words, compute wordpieces per word via | |
| the tokenizer, then greedy-include in original reading order: every | |
| annotated word is kept; unannotated words fill the remaining | |
| wordpiece budget (WP_BUDGET) from the start. Annotated words shift | |
| to earlier sequence positions and survive tokenizer truncation. | |
| """ | |
| words_raw = ocr_data["words"] | |
| bboxes = ocr_data["bboxes"] | |
| bboxes_norm = ocr_data["bboxes_norm"] | |
| confs = ocr_data["confs"] | |
| O_flat = flat_label2id["O"] | |
| # ββ Pass 1 β walk all conf-filtered words, assign flat id ββββββββββββββββ | |
| kept = [] # list of (word, bbox_px, bbox_norm, flat_id) | |
| for word, bbox_px, bbox_norm, conf in zip(words_raw, bboxes, bboxes_norm, confs): | |
| if conf < min_conf: | |
| continue | |
| wcx = (bbox_px[0] + bbox_px[2]) / 2 | |
| wcy = (bbox_px[1] + bbox_px[3]) / 2 | |
| assigned = O_flat | |
| for abox, albl_id in zip(anno_boxes, anno_label_ids): | |
| if abox[0] <= wcx <= abox[2] and abox[1] <= wcy <= abox[3]: | |
| assigned = albl_id | |
| break | |
| kept.append((word, bbox_px, bbox_norm, assigned)) | |
| # ββ FIX 10 β wordpiece-aware greedy selection ββββββββββββββββββββββββββββ | |
| if kept and tokenizer is not None: | |
| # LayoutLMv3's full tokenizer expects pre-split word lists with boxes. | |
| # tokenizer.tokenize() works on a single string and returns subword | |
| # pieces β exactly what we need to count wordpieces per word. | |
| wp_per_word = [ | |
| max(len(tokenizer.tokenize(w)), 1) | |
| for w, _, _, _ in kept | |
| ] | |
| anno_flags = [x[3] != O_flat for x in kept] | |
| # Drop only if BOTH budgets exceeded; otherwise leave kept as-is. | |
| if sum(wp_per_word) > WP_BUDGET or len(kept) > MAX_WORDS: | |
| cum_wp = 0 | |
| cum_words = 0 | |
| chosen = [] | |
| for i, (item, is_anno, wp) in enumerate(zip(kept, anno_flags, wp_per_word)): | |
| if is_anno: | |
| # Always include annotated. Pathological docs where | |
| # annotations alone exceed budget get tokenizer-truncated | |
| # at the tail β accept that small loss rather than drop | |
| # all annotations. | |
| chosen.append(item) | |
| cum_wp += wp | |
| cum_words += 1 | |
| elif cum_wp + wp <= WP_BUDGET and cum_words < MAX_WORDS: | |
| chosen.append(item) | |
| cum_wp += wp | |
| cum_words += 1 | |
| # else: skip this unannotated word | |
| kept = chosen | |
| elif len(kept) > MAX_WORDS: | |
| # No tokenizer available β fall back to plain word-count truncation | |
| kept = kept[:MAX_WORDS] | |
| # ββ Unpack into the parallel arrays the rest of the function expects βββββ | |
| words_out = [x[0] for x in kept] | |
| kept_bboxes = [x[1] for x in kept] | |
| norm_boxes_out = [x[2] for x in kept] | |
| word_labels = [x[3] for x in kept] | |
| # Pass 2 β convert flat β BIO | |
| box_seen = {} | |
| bio_labels_out = [] | |
| id2flat = {v: k for k, v in flat_label2id.items()} | |
| for i, flat_id in enumerate(word_labels): | |
| if flat_id == flat_label2id["O"]: | |
| bio_labels_out.append(bio_label2id["O"]) | |
| continue | |
| bbox_px = kept_bboxes[i] # FIX 2: use aligned list | |
| wcx = (bbox_px[0] + bbox_px[2]) / 2 | |
| wcy = (bbox_px[1] + bbox_px[3]) / 2 | |
| matched_box_idx = None | |
| for bi, abox in enumerate(anno_boxes): | |
| if abox[0] <= wcx <= abox[2] and abox[1] <= wcy <= abox[3]: | |
| matched_box_idx = bi | |
| break | |
| if matched_box_idx is None: | |
| bio_labels_out.append(bio_label2id["O"]) | |
| continue | |
| base_name = id2flat.get(anno_label_ids[matched_box_idx], "O") | |
| if base_name == "O": | |
| bio_labels_out.append(bio_label2id["O"]) | |
| continue | |
| if matched_box_idx not in box_seen: | |
| box_seen[matched_box_idx] = True | |
| tag = f"B-{base_name}" | |
| else: | |
| tag = f"I-{base_name}" | |
| bio_labels_out.append(bio_label2id.get(tag, bio_label2id["O"])) | |
| return words_out, norm_boxes_out, bio_labels_out | |
| # ββ FALLBACK (kept for diagnostics; should rarely fire after FIX 4) ββββββββββ | |
| def assign_word_labels_fallback(ocr_text, anno_boxes, anno_label_ids, | |
| img_w, img_h, flat_label2id, bio_label2id): | |
| words = (ocr_text or "").split()[:MAX_WORDS] or ["[PAD]"] | |
| O_bio = bio_label2id["O"] | |
| word_labels_flat = [flat_label2id["O"]] * len(words) | |
| word_h = max(img_h // max(len(words), 1), 1) | |
| word_boxes = [] | |
| for i in range(len(words)): | |
| y0, y1 = i * word_h, (i + 1) * word_h | |
| word_boxes.append([0, y0, img_w, y1]) | |
| for bbox, lbl_id in zip(anno_boxes, anno_label_ids): | |
| if y0 < bbox[3] and y1 > bbox[1]: | |
| word_labels_flat[i] = lbl_id | |
| break | |
| norm_boxes = [ | |
| [max(0,min(int(b[0]/img_w*1000),999)), max(0,min(int(b[1]/img_h*1000),999)), | |
| max(0,min(int(b[2]/img_w*1000),1000)), max(0,min(int(b[3]/img_h*1000),1000))] | |
| for b in word_boxes | |
| ] | |
| id2flat = {v: k for k, v in flat_label2id.items()} | |
| box_seen = {} | |
| bio_labels = [] | |
| for i, fid in enumerate(word_labels_flat): | |
| base = id2flat.get(fid, "O") | |
| if base == "O": | |
| bio_labels.append(O_bio); continue | |
| # find which box matched | |
| y0, y1 = i * word_h, (i + 1) * word_h | |
| mb = None | |
| for bi, (bbox, lbl_id) in enumerate(zip(anno_boxes, anno_label_ids)): | |
| if y0 < bbox[3] and y1 > bbox[1] and lbl_id == fid: | |
| mb = bi; break | |
| key = mb if mb is not None else fid | |
| tag = f"B-{base}" if key not in box_seen else f"I-{base}" | |
| box_seen[key] = True | |
| bio_labels.append(bio_label2id.get(tag, O_bio)) | |
| return words, norm_boxes, bio_labels | |
| # ββ WEIGHTED TRAINER ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WeightedTrainer(Trainer): | |
| def __init__(self, class_weights, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.class_weights = class_weights | |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| weights = torch.tensor(self.class_weights, dtype=torch.float, device=logits.device) | |
| loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=-100) | |
| loss = loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| # ββ BIO TOKEN-LEVEL WEIGHT ESTIMATION (FIX 6) βββββββββββββββββββββββββββββββββ | |
| def estimate_bio_weights(records, flat_field_labels, bio_label2id, | |
| avg_tokens_per_box=AVG_TOKENS_PER_BOX, | |
| o_token_estimate_per_doc=200): | |
| """Estimate BIO-token class weights from the training records.""" | |
| box_counts = {l: 0 for l in flat_field_labels} | |
| for r in records: | |
| for lid in r.get("box_label_ids", []): | |
| if 0 <= lid < len(flat_field_labels): | |
| box_counts[flat_field_labels[lid]] += 1 | |
| n_docs = len(records) | |
| estimated_o_tokens = n_docs * o_token_estimate_per_doc | |
| # Estimated TOKEN counts per BIO label | |
| bio_counts = {l: 0 for l in bio_label2id} | |
| bio_counts["O"] = estimated_o_tokens | |
| for fname in flat_field_labels: | |
| if fname == "O": continue | |
| b = box_counts[fname] | |
| bio_counts[f"B-{fname}"] = b # 1 B per box | |
| bio_counts[f"I-{fname}"] = int(b * (avg_tokens_per_box - 1)) | |
| total = sum(bio_counts.values()) | |
| n = len(bio_counts) | |
| weights = [1.0] * n | |
| for lbl, idx in bio_label2id.items(): | |
| c = max(bio_counts.get(lbl, 1), 1) | |
| weights[idx] = total / (n * c) | |
| # Cap O weight at 1.0 so background tokens don't get over-emphasised | |
| weights[bio_label2id["O"]] = min(weights[bio_label2id["O"]], 1.0) | |
| # Cap field weights at 5.0 to keep loss stable | |
| for i in range(len(weights)): | |
| weights[i] = min(weights[i], 5.0) | |
| return weights, bio_counts | |
| # ββ BACKBONE LOADER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_token_classifier_from_classifier_ckpt(ckpt_path, num_labels, id2label, label2id): | |
| print(f" Loading classifier checkpoint: {ckpt_path}") | |
| seq_model = LayoutLMv3ForSequenceClassification.from_pretrained(ckpt_path) | |
| seq_state = seq_model.state_dict() | |
| backbone_state = {k: v for k, v in seq_state.items() | |
| if not k.startswith("classifier") and not k.startswith("pooler")} | |
| config = LayoutLMv3Config.from_pretrained(ckpt_path) | |
| config.num_labels = num_labels | |
| config.id2label = id2label | |
| config.label2id = label2id | |
| token_model = LayoutLMv3ForTokenClassification(config) | |
| missing, unexpected = token_model.load_state_dict(backbone_state, strict=False) | |
| print(f" Backbone keys transferred: {len(backbone_state)} / {len(seq_state)}") | |
| return token_model | |
| # ββ DATASET βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def deterministic_keep(record_id, sample_rate): | |
| """Hash-based deterministic sampling decision (FIX 9).""" | |
| h = int(hashlib.sha256(str(record_id).encode()).hexdigest()[:8], 16) | |
| return (h % 10000) / 10000.0 < sample_rate | |
| class ExtractionDataset(Dataset): | |
| def __init__(self, json_path, processor, flat_label2id, bio_label2id, | |
| unannotated_sample_rate=UNANNOTATED_SAMPLE_RATE, is_train=True): | |
| with open(json_path, encoding="utf-8") as f: | |
| all_records = json.load(f) | |
| self.processor = processor | |
| self.flat_label2id = flat_label2id | |
| self.bio_label2id = bio_label2id | |
| self.is_train = is_train | |
| # FIX 5 β Strip Siret annotations (label_id=0 is invalid for Siret) | |
| n_siret_stripped = 0 | |
| for r in all_records: | |
| if "Siret" in r.get("box_labels", []): | |
| keep_idx = [i for i, l in enumerate(r["box_labels"]) if l != "Siret"] | |
| if len(keep_idx) < len(r["box_labels"]): | |
| n_siret_stripped += len(r["box_labels"]) - len(keep_idx) | |
| r["boxes"] = [r["boxes"][i] for i in keep_idx] | |
| r["box_labels"] = [r["box_labels"][i] for i in keep_idx] | |
| r["box_label_ids"] = [r["box_label_ids"][i] for i in keep_idx] | |
| if n_siret_stripped: | |
| print(f" Stripped {n_siret_stripped} Siret annotations (mapped to O β likely a label bug)") | |
| # FIX 9 β Deterministic unannotated sampling | |
| if is_train: | |
| self.records = [] | |
| skipped = 0 | |
| for r in all_records: | |
| has_boxes = bool(r.get("boxes")) | |
| if not has_boxes: | |
| if not deterministic_keep(r.get("id", id(r)), unannotated_sample_rate): | |
| skipped += 1 | |
| continue | |
| self.records.append(r) | |
| print(f" Unannotated records dropped (deterministic sampling): {skipped}") | |
| else: | |
| self.records = all_records | |
| # OCR availability stats | |
| ocr_avail = sum(1 for r in self.records if load_ocr_json(r.get("ocr_path", "")) is not None) | |
| print(f" Loaded {len(self.records)} records | with annotations: " | |
| f"{sum(1 for r in self.records if r.get('boxes'))} | " | |
| f"OCR JSON available: {ocr_avail}/{len(self.records)}") | |
| if ocr_avail < len(self.records) * 0.5: | |
| print(f" β WARNING: <50% of records have resolvable OCR paths!") | |
| print(f" Set OCR_LINUX_PREFIX env var to your OCR directory.") | |
| print(f" Currently using: {LINUX_PREFIX}") | |
| def __len__(self): | |
| return len(self.records) | |
| def __getitem__(self, idx): | |
| rec = self.records[idx] | |
| anno_img_w = rec.get("image_width", 1654) | |
| anno_img_h = rec.get("image_height", 2339) | |
| img_path = remap_path(rec.get("image_path", "")) | |
| if img_path and Path(img_path).exists(): | |
| image = Image.open(img_path).convert("RGB") | |
| else: | |
| image = Image.new("RGB", (anno_img_w, anno_img_h), color=(255, 255, 255)) | |
| anno_boxes = rec.get("boxes", []) | |
| anno_labels = rec.get("box_label_ids", []) | |
| ocr_data = load_ocr_json(rec.get("ocr_path", "")) | |
| if ocr_data is not None: | |
| # FIX 1 β RESCALE annotation boxes to OCR coordinate space | |
| ocr_w, ocr_h = ocr_data["width"], ocr_data["height"] | |
| rescaled_boxes = rescale_boxes(anno_boxes, anno_img_w, anno_img_h, ocr_w, ocr_h) | |
| words, norm_boxes, word_bio = assign_word_labels_exact( | |
| ocr_data, rescaled_boxes, anno_labels, | |
| self.flat_label2id, self.bio_label2id, | |
| tokenizer=self.processor.tokenizer, | |
| ) | |
| else: | |
| # Fallback (much worse β make sure FIX 4 path remapping works) | |
| words, norm_boxes, word_bio = assign_word_labels_fallback( | |
| rec.get("ocr_text", ""), anno_boxes, anno_labels, | |
| anno_img_w, anno_img_h, self.flat_label2id, self.bio_label2id, | |
| ) | |
| if not words: | |
| words, norm_boxes, word_bio = ["[PAD]"], [[0,0,0,0]], [self.bio_label2id["O"]] | |
| encoding = self.processor( | |
| image, words, boxes=norm_boxes, | |
| max_length=MAX_LENGTH, padding="max_length", | |
| truncation=True, return_tensors="pt", | |
| ) | |
| seq_len = encoding["input_ids"].shape[1] | |
| labels = [-100] * seq_len | |
| word_ids = encoding.word_ids(batch_index=0) | |
| prev = None | |
| for pos, wid in enumerate(word_ids): | |
| if wid is None: | |
| labels[pos] = -100 | |
| elif wid != prev: | |
| labels[pos] = (word_bio[wid] if wid < len(word_bio) | |
| else self.bio_label2id["O"]) | |
| else: | |
| labels[pos] = -100 | |
| prev = wid | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(), | |
| "attention_mask": encoding["attention_mask"].squeeze(), | |
| "bbox": encoding["bbox"].squeeze(), | |
| "pixel_values": encoding["pixel_values"].squeeze(), | |
| "labels": torch.tensor(labels, dtype=torch.long), | |
| } | |
| # ββ METRICS β FIX 7: token + span F1 βββββββββββββββββββββββββββββββββββββββββ | |
| def make_compute_metrics(bio_id2label): | |
| """Returns a closure that computes BOTH token-level and span-level metrics.""" | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| mask = labels != -100 | |
| flat_p, flat_l = preds[mask], labels[mask] | |
| metrics = {"token_accuracy": float((flat_p == flat_l).mean())} | |
| # Token-level per-class F1 | |
| n_labels = max(flat_l.max(), flat_p.max()) + 1 | |
| for i in range(int(n_labels)): | |
| name = bio_id2label.get(i, f"id_{i}") | |
| tp = int(((flat_p == i) & (flat_l == i)).sum()) | |
| fp = int(((flat_p == i) & (flat_l != i)).sum()) | |
| fn = int(((flat_p != i) & (flat_l == i)).sum()) | |
| sup = tp + fn | |
| if sup == 0 and tp + fp == 0: | |
| continue | |
| prec = tp / max(tp + fp, 1) | |
| rec = tp / max(tp + fn, 1) | |
| f1 = 2 * prec * rec / max(prec + rec, 1e-9) | |
| metrics[f"f1_{name}"] = float(f1) | |
| # Span-level (entity-level) F1 via simple BIO span extraction | |
| def to_spans(seq): | |
| spans = [] | |
| cur_field, start = None, None | |
| for j, lid in enumerate(seq): | |
| ln = bio_id2label.get(int(lid), "O") | |
| if ln == "O": | |
| if cur_field is not None: | |
| spans.append((cur_field, start, j-1)) | |
| cur_field, start = None, None | |
| elif ln.startswith("B-"): | |
| if cur_field is not None: | |
| spans.append((cur_field, start, j-1)) | |
| cur_field, start = ln[2:], j | |
| else: # I- | |
| base = ln[2:] | |
| if cur_field == base: | |
| pass | |
| else: | |
| if cur_field is not None: | |
| spans.append((cur_field, start, j-1)) | |
| cur_field, start = base, j | |
| if cur_field is not None: | |
| spans.append((cur_field, start, len(seq)-1)) | |
| return set(spans) | |
| # Build per-example sequences from masked flat arrays β approximate | |
| # (we don't have batch boundaries here, but per-class span-F1 is still useful) | |
| all_pred_spans = to_spans(flat_p.tolist()) | |
| all_true_spans = to_spans(flat_l.tolist()) | |
| per_field = {} | |
| for s in all_true_spans | all_pred_spans: | |
| per_field.setdefault(s[0], {"tp":0, "fp":0, "fn":0}) | |
| for s in all_true_spans: | |
| if s in all_pred_spans: | |
| per_field[s[0]]["tp"] += 1 | |
| else: | |
| per_field[s[0]]["fn"] += 1 | |
| for s in all_pred_spans: | |
| if s not in all_true_spans: | |
| per_field[s[0]]["fp"] += 1 | |
| for fname, c in per_field.items(): | |
| p = c["tp"] / max(c["tp"] + c["fp"], 1) | |
| r = c["tp"] / max(c["tp"] + c["fn"], 1) | |
| f = 2*p*r / max(p+r, 1e-9) | |
| metrics[f"span_f1_{fname}"] = float(f) | |
| # Macro span-F1 across fields (excluding O) | |
| non_o = [v for k, v in metrics.items() if k.startswith("span_f1_") and k != "span_f1_O"] | |
| if non_o: | |
| metrics["macro_span_f1"] = float(np.mean(non_o)) | |
| return metrics | |
| return compute_metrics | |
| # ββ MAIN ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| random.seed(42) | |
| with open(MAPPINGS, encoding="utf-8") as f: | |
| mappings = json.load(f) | |
| flat_field_labels = mappings["field_labels"] | |
| flat_label2id = mappings["field2id"] | |
| bio_labels, bio_label2id, bio_id2label = build_bio_labels(flat_field_labels) | |
| num_labels = len(bio_labels) | |
| print(f"\nBIO label set ({num_labels} labels)") | |
| # FIX 6 β token-level weight estimation | |
| with open(TRAIN_JSON, encoding="utf-8") as f: | |
| train_records = json.load(f) | |
| class_weights, bio_counts = estimate_bio_weights( | |
| train_records, flat_field_labels, bio_label2id) | |
| print("Estimated BIO token counts and weights (top 8):") | |
| for l, c in sorted(bio_counts.items(), key=lambda x: -x[1])[:8]: | |
| print(f" {l:<32} countβ{int(c):6d} weight={class_weights[bio_label2id[l]]:.3f}") | |
| # FIX 8 β split contamination check | |
| def pdf_id(r): | |
| return r["image_file"].rsplit("_p", 1)[0] | |
| train_pdfs = {pdf_id(r) for r in train_records} | |
| with open(VAL_JSON, encoding="utf-8") as f: val_records = json.load(f) | |
| val_pdfs = {pdf_id(r) for r in val_records} | |
| leak = train_pdfs & val_pdfs | |
| if leak: | |
| print(f"\nβ TRAIN/VAL CONTAMINATION: {len(leak)} PDFs span both splits.") | |
| print(f" Val F1 will be OVERESTIMATED. Re-split by PDF before re-training.") | |
| print(f" Example leaked PDFs (first 3): {list(leak)[:3]}") | |
| processor = LayoutLMv3Processor.from_pretrained(FALLBACK_BASE, apply_ocr=False) | |
| ckpt = Path(CLASSIFIER_CKPT) if CLASSIFIER_CKPT else None | |
| if ckpt and ckpt.exists(): | |
| print(f"\nLoading backbone from classifier checkpoint") | |
| model = load_token_classifier_from_classifier_ckpt( | |
| str(ckpt), num_labels, bio_id2label, bio_label2id) | |
| else: | |
| print(f"\nNo classifier checkpoint β using base LayoutLMv3") | |
| model = LayoutLMv3ForTokenClassification.from_pretrained( | |
| FALLBACK_BASE, num_labels=num_labels, | |
| id2label=bio_id2label, label2id=bio_label2id) | |
| print(f"\nBuilding datasets:") | |
| train_dataset = ExtractionDataset(TRAIN_JSON, processor, flat_label2id, bio_label2id, is_train=True) | |
| val_dataset = ExtractionDataset(VAL_JSON, processor, flat_label2id, bio_label2id, is_train=False) | |
| training_args = TrainingArguments( | |
| output_dir = MODEL_OUTPUT, | |
| num_train_epochs = EPOCHS, | |
| per_device_train_batch_size = BATCH_SIZE, | |
| per_device_eval_batch_size = BATCH_SIZE, | |
| gradient_accumulation_steps = GRAD_ACCUM, | |
| learning_rate = LEARNING_RATE, | |
| warmup_steps = WARMUP_STEPS, | |
| weight_decay = WEIGHT_DECAY, | |
| eval_strategy = "epoch", | |
| save_strategy = "epoch", | |
| save_total_limit = 3, | |
| load_best_model_at_end = True, | |
| metric_for_best_model = "macro_span_f1", # FIX 7 β span F1, not token acc | |
| greater_is_better = True, | |
| logging_dir = "outputs/logs_extractor_v3", | |
| logging_steps = 10, | |
| report_to = "none", | |
| fp16 = torch.cuda.is_available(), | |
| dataloader_num_workers = 2, | |
| ) | |
| trainer = WeightedTrainer( | |
| class_weights = class_weights, | |
| model = model, | |
| args = training_args, | |
| train_dataset = train_dataset, | |
| eval_dataset = val_dataset, | |
| compute_metrics = make_compute_metrics(bio_id2label), | |
| ) | |
| print("\nπ Starting v3 training (FIX 1-9 applied)...") | |
| trainer.train() | |
| print(f"\nβ Training complete. Model β {MODEL_OUTPUT}") | |
| results = trainer.evaluate() | |
| for k, v in results.items(): | |
| if isinstance(v, float): | |
| print(f" {k}: {v:.4f}") | |
| if __name__ == "__main__": | |
| main() | |