# import gradio as gr # import torch # import torch.nn as nn # import pdfplumber # import json # import os # import re # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model # from TorchCRF import CRF # # --------------------------------------------------------- # # 1. CONFIGURATION # # --------------------------------------------------------- # # Ensure this filename matches exactly what you uploaded to the Space # MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth" # BASE_MODEL_ID = "microsoft/layoutlmv3-base" # # Define your labels exactly as they were during training # LABELS = [ # "O", # "B-QUESTION", "I-QUESTION", # "B-OPTION", "I-OPTION", # "B-ANSWER", "I-ANSWER", # "B-SECTION_HEADING", "I-SECTION_HEADING", # "B-PASSAGE", "I-PASSAGE" # ] # LABEL2ID = {l: i for i, l in enumerate(LABELS)} # ID2LABEL = {i: l for l, i in LABEL2ID.items()} # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID) # # --------------------------------------------------------- # # 2. MODEL ARCHITECTURE # # --------------------------------------------------------- # # ⚠️ ACTION REQUIRED: # # Replace this class with the exact class definition of your # # NEW HYBRID MODEL. The class name and structure must match # # what was used when you saved 'layoutlmv3_nonlinear_scratch.pth'. # # --------------------------------------------------------- # # --------------------------------------------------------- # # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF) # # --------------------------------------------------------- # class HybridModel(nn.Module): # def __init__(self, num_labels): # super().__init__() # self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID) # # Config for BiLSTM # hidden_size = self.layoutlm.config.hidden_size # Usually 768 # lstm_hidden_size = hidden_size // 2 # 384, so bidirectional output is 768 # # BiLSTM Layer # # input_size=768, hidden=384, bidir=True -> output_dim = 384 * 2 = 768 # self.lstm = nn.LSTM( # input_size=hidden_size, # hidden_size=lstm_hidden_size, # num_layers=1, # batch_first=True, # bidirectional=True # ) # # Dropout (Optional, check if you used this in training) # self.dropout = nn.Dropout(0.1) # # Classifier: Maps BiLSTM output (768) to Label count # self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels) # # CRF Layer # self.crf = CRF(num_labels) # def forward(self, input_ids, bbox, attention_mask, labels=None): # # 1. LayoutLMv3 Base # outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) # sequence_output = outputs.last_hidden_state # [Batch, Seq, 768] # # 2. BiLSTM # # LSTM returns (output, (h_n, c_n)). We only need output. # lstm_output, _ = self.lstm(sequence_output) # [Batch, Seq, 768] # # 3. Dropout & Classifier # lstm_output = self.dropout(lstm_output) # emissions = self.classifier(lstm_output) # [Batch, Seq, Num_Labels] # # 4. CRF # if labels is not None: # # Training/Eval (Loss) # log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool()) # return -log_likelihood.mean() # else: # # Inference (Prediction Tags) # return self.crf.viterbi_decode(emissions, mask=attention_mask.bool()) # # --------------------------------------------------------- # # 3. MODEL LOADING LOGIC # # --------------------------------------------------------- # model = None # def load_model(): # global model # if model is None: # print(f"🔄 Loading model from {MODEL_FILENAME}...") # if not os.path.exists(MODEL_FILENAME): # raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found. Please upload it to the Files tab of your Space.") # # Initialize the model structure # model = HybridModel(num_labels=len(LABELS)) # # Load weights # try: # state_dict = torch.load(MODEL_FILENAME, map_location=device) # model.load_state_dict(state_dict) # except RuntimeError as e: # raise RuntimeError(f"❌ State dictionary mismatch. Ensure the 'HybridModel' class structure in app.py matches the model you trained.\nDetails: {e}") # model.to(device) # model.eval() # print("✅ Model loaded successfully.") # return model # # --------------------------------------------------------- # # 4. JSON CONVERSION LOGIC (Your Custom Logic) # # --------------------------------------------------------- # def convert_bio_to_structured_json(predictions): # structured_data = [] # current_item = None # current_option_key = None # current_passage_buffer = [] # current_text_buffer = [] # first_question_started = False # last_entity_type = None # just_finished_i_option = False # is_in_new_passage = False # def finalize_passage_to_item(item, passage_buffer): # if passage_buffer: # passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() # if item.get('passage'): item['passage'] += ' ' + passage_text # else: item['passage'] = passage_text # passage_buffer.clear() # # Flatten predictions list if strictly page-separated # flat_predictions = [] # for page in predictions: # flat_predictions.extend(page['data']) # for idx, item in enumerate(flat_predictions): # word = item['word'] # label = item['predicted_label'] # entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None # current_text_buffer.append(word) # previous_entity_type = last_entity_type # is_passage_label = (entity_type == 'PASSAGE') # if not first_question_started: # if label != 'B-QUESTION' and not is_passage_label: # just_finished_i_option = False # is_in_new_passage = False # continue # if is_passage_label: # current_passage_buffer.append(word) # last_entity_type = 'PASSAGE' # just_finished_i_option = False # is_in_new_passage = False # continue # if label == 'B-QUESTION': # if not first_question_started: # header_text = ' '.join(current_text_buffer[:-1]).strip() # if header_text or current_passage_buffer: # metadata_item = {'type': 'METADATA', 'passage': ''} # finalize_passage_to_item(metadata_item, current_passage_buffer) # if header_text: metadata_item['text'] = header_text # structured_data.append(metadata_item) # first_question_started = True # current_text_buffer = [word] # if current_item is not None: # finalize_passage_to_item(current_item, current_passage_buffer) # current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() # structured_data.append(current_item) # current_text_buffer = [word] # current_item = { # 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' # } # current_option_key = None # last_entity_type = 'QUESTION' # just_finished_i_option = False # is_in_new_passage = False # continue # if current_item is not None: # if is_in_new_passage: # if 'new_passage' not in current_item: current_item['new_passage'] = word # else: current_item['new_passage'] += f' {word}' # if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): # is_in_new_passage = False # if label.startswith(('B-', 'I-')): last_entity_type = entity_type # continue # is_in_new_passage = False # if label.startswith('B-'): # if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: # finalize_passage_to_item(current_item, current_passage_buffer) # current_passage_buffer = [] # last_entity_type = entity_type # if entity_type == 'PASSAGE': # if previous_entity_type == 'OPTION' and just_finished_i_option: # current_item['new_passage'] = word # is_in_new_passage = True # else: current_passage_buffer.append(word) # elif entity_type == 'OPTION': # current_option_key = word # current_item['options'][current_option_key] = word # just_finished_i_option = False # elif entity_type == 'ANSWER': # current_item['answer'] = word # current_option_key = None # just_finished_i_option = False # elif entity_type == 'QUESTION': # current_item['question'] += f' {word}' # just_finished_i_option = False # elif label.startswith('I-'): # if entity_type == 'QUESTION': current_item['question'] += f' {word}' # elif entity_type == 'PASSAGE': # if previous_entity_type == 'OPTION' and just_finished_i_option: # current_item['new_passage'] = word # is_in_new_passage = True # else: # if not current_passage_buffer: last_entity_type = 'PASSAGE' # current_passage_buffer.append(word) # elif entity_type == 'OPTION' and current_option_key is not None: # current_item['options'][current_option_key] += f' {word}' # just_finished_i_option = True # elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' # just_finished_i_option = (entity_type == 'OPTION') # if current_item is not None: # finalize_passage_to_item(current_item, current_passage_buffer) # current_item['text'] = ' '.join(current_text_buffer).strip() # structured_data.append(current_item) # # Final Cleanup # for item in structured_data: # if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() # if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() # return structured_data # # --------------------------------------------------------- # # 5. INFERENCE PIPELINE # # --------------------------------------------------------- # def process_pdf(pdf_file): # if pdf_file is None: # return None, "⚠️ Please upload a PDF file." # try: # active_model = load_model() # # A. Extract Text and Boxes # extracted_pages = [] # with pdfplumber.open(pdf_file.name) as pdf: # for page_idx, page in enumerate(pdf.pages): # width, height = page.width, page.height # words_data = page.extract_words() # page_tokens = [] # page_bboxes = [] # for w in words_data: # text = w['text'] # # Normalize bbox to 0-1000 scale # x0 = int((w['x0'] / width) * 1000) # top = int((w['top'] / height) * 1000) # x1 = int((w['x1'] / width) * 1000) # bottom = int((w['bottom'] / height) * 1000) # # Safety clamp # box = [max(0, min(x0, 1000)), max(0, min(top, 1000)), # max(0, min(x1, 1000)), max(0, min(bottom, 1000))] # page_tokens.append(text) # page_bboxes.append(box) # extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes}) # # B. Run Inference # raw_predictions = [] # for page in extracted_pages: # tokens = page['tokens'] # bboxes = page['bboxes'] # if not tokens: continue # # Tokenize # encoding = tokenizer( # tokens, # boxes=bboxes, # return_tensors="pt", # padding="max_length", # truncation=True, # max_length=512, # return_offsets_mapping=True # ) # input_ids = encoding.input_ids.to(device) # bbox = encoding.bbox.to(device) # attention_mask = encoding.attention_mask.to(device) # # Predict # with torch.no_grad(): # # NOTE: If your hybrid model requires 'pixel_values', # # you will need to add image extraction logic above and pass it here. # preds = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) # # Check if preds returns a tuple (loss, tags) or just tags # # The CRF implementation usually returns a list of lists of tags in viterbi_decode # pred_tags = preds[0] if isinstance(preds, tuple) else preds[0] # # Note: Standard CRF.viterbi_decode returns List[List[int]], so [0] gets the first batch item # # Alignment # word_ids = encoding.word_ids() # aligned_data = [] # prev_word_idx = None # for i, word_idx in enumerate(word_ids): # if word_idx is None: continue # if word_idx != prev_word_idx: # # pred_tags is likely a list of ints. # # If pred_tags[i] fails, your max_length might be cutting off tags, # # or the model output shape differs from the token length. # if i < len(pred_tags): # label_id = pred_tags[i] # label_str = ID2LABEL.get(label_id, "O") # aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str}) # prev_word_idx = word_idx # raw_predictions.append({"data": aligned_data}) # # C. Convert to Structured JSON # final_json = convert_bio_to_structured_json(raw_predictions) # # Save output # output_filename = "structured_output.json" # with open(output_filename, "w", encoding="utf-8") as f: # json.dump(final_json, f, indent=2, ensure_ascii=False) # return output_filename, f"✅ Success! Processed {len(extracted_pages)} pages. Extracted {len(final_json)} items." # except Exception as e: # import traceback # return None, f"❌ Error:\n{str(e)}\n\nTraceback:\n{traceback.format_exc()}" # # --------------------------------------------------------- # # 6. GRADIO INTERFACE # # --------------------------------------------------------- # iface = gr.Interface( # fn=process_pdf, # inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), # outputs=[ # gr.File(label="Download JSON Output"), # gr.Textbox(label="Status Log", lines=10) # ], # title="Hybrid Model Inference: PDF to JSON", # description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.", # flagging_mode="never" # ) # if __name__ == "__main__": # iface.launch() import gradio as gr import torch import torch.nn as nn import pdfplumber import json import os import re from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model from TorchCRF import CRF # --------------------------------------------------------- # 1. CONFIGURATION # --------------------------------------------------------- MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth" BASE_MODEL_ID = "microsoft/layoutlmv3-base" # Labels: 11 Standard BIO tags + 2 Special tokens = 13 Total # NOTE: If your output labels look "scrambled" (e.g., Questions detected as Options), # try moving "UNK" and "PAD" to the BEGINNING of this list (indices 0 and 1). LABELS = [ "O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE", "UNK", "PAD" # Added to match the 13-label count in your weights ] LABEL2ID = {l: i for i, l in enumerate(LABELS)} ID2LABEL = {i: l for l, i in LABEL2ID.items()} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID) # --------------------------------------------------------- # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF) # --------------------------------------------------------- class HybridModel(nn.Module): def __init__(self, num_labels): super().__init__() self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID) # Structure derived from your error log: # Weight shape [1024, 768] implies hidden_size = 256 (1024/4) lstm_hidden_size = 256 self.lstm = nn.LSTM( input_size=768, # LayoutLMv3 output size hidden_size=lstm_hidden_size, num_layers=2, # Error log showed 'l1' weights, meaning 2 layers batch_first=True, bidirectional=True ) self.dropout = nn.Dropout(0.1) # Classifier input = lstm_hidden * 2 (bidirectional) = 256 * 2 = 512 # This matches your error log shape [13, 512] self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels) self.crf = CRF(num_labels) def forward(self, input_ids, bbox, attention_mask, labels=None): outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) sequence_output = outputs.last_hidden_state # BiLSTM lstm_output, _ = self.lstm(sequence_output) # Classifier lstm_output = self.dropout(lstm_output) emissions = self.classifier(lstm_output) if labels is not None: # Training/Eval loss log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool()) return -log_likelihood.mean() else: # Inference prediction return self.crf.viterbi_decode(emissions, mask=attention_mask.bool()) # --------------------------------------------------------- # 3. MODEL LOADING # --------------------------------------------------------- model = None def load_model(): global model if model is None: print(f"🔄 Loading model from {MODEL_FILENAME}...") if not os.path.exists(MODEL_FILENAME): raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found.") model = HybridModel(num_labels=len(LABELS)) # Load state dictionary state_dict = torch.load(MODEL_FILENAME, map_location=device) # Try loading. If labels are wrong, this will still throw a shape error. try: model.load_state_dict(state_dict) except RuntimeError as e: raise RuntimeError(f"❌ Weight mismatch! \nYour model has {len(LABELS)} labels defined in script.\nCheck if 'LABELS' list needs reordering or resizing.\nDetailed Error: {e}") model.to(device) model.eval() print("✅ Model loaded successfully.") return model # --------------------------------------------------------- # 4. JSON CONVERSION LOGIC # --------------------------------------------------------- def convert_bio_to_structured_json(predictions): structured_data = [] current_item = None current_option_key = None current_passage_buffer = [] current_text_buffer = [] first_question_started = False last_entity_type = None just_finished_i_option = False is_in_new_passage = False def finalize_passage_to_item(item, passage_buffer): if passage_buffer: passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() if item.get('passage'): item['passage'] += ' ' + passage_text else: item['passage'] = passage_text passage_buffer.clear() flat_predictions = [] for page in predictions: flat_predictions.extend(page['data']) for idx, item in enumerate(flat_predictions): word = item['word'] label = item['predicted_label'] # Clean label (remove B- / I-) entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None # Skip special tokens if they appear in prediction if label in ["UNK", "PAD", "O"]: current_text_buffer.append(word) continue current_text_buffer.append(word) previous_entity_type = last_entity_type is_passage_label = (entity_type == 'PASSAGE') if not first_question_started: if label != 'B-QUESTION' and not is_passage_label: just_finished_i_option = False is_in_new_passage = False continue if is_passage_label: current_passage_buffer.append(word) last_entity_type = 'PASSAGE' just_finished_i_option = False is_in_new_passage = False continue if label == 'B-QUESTION': if not first_question_started: header_text = ' '.join(current_text_buffer[:-1]).strip() if header_text or current_passage_buffer: metadata_item = {'type': 'METADATA', 'passage': ''} finalize_passage_to_item(metadata_item, current_passage_buffer) if header_text: metadata_item['text'] = header_text structured_data.append(metadata_item) first_question_started = True current_text_buffer = [word] if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() structured_data.append(current_item) current_text_buffer = [word] current_item = { 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' } current_option_key = None last_entity_type = 'QUESTION' just_finished_i_option = False is_in_new_passage = False continue if current_item is not None: if is_in_new_passage: if 'new_passage' not in current_item: current_item['new_passage'] = word else: current_item['new_passage'] += f' {word}' if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): is_in_new_passage = False if label.startswith(('B-', 'I-')): last_entity_type = entity_type continue is_in_new_passage = False if label.startswith('B-'): if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: finalize_passage_to_item(current_item, current_passage_buffer) current_passage_buffer = [] last_entity_type = entity_type if entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word is_in_new_passage = True else: current_passage_buffer.append(word) elif entity_type == 'OPTION': current_option_key = word current_item['options'][current_option_key] = word just_finished_i_option = False elif entity_type == 'ANSWER': current_item['answer'] = word current_option_key = None just_finished_i_option = False elif entity_type == 'QUESTION': current_item['question'] += f' {word}' just_finished_i_option = False elif label.startswith('I-'): if entity_type == 'QUESTION': current_item['question'] += f' {word}' elif entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word is_in_new_passage = True else: if not current_passage_buffer: last_entity_type = 'PASSAGE' current_passage_buffer.append(word) elif entity_type == 'OPTION' and current_option_key is not None: current_item['options'][current_option_key] += f' {word}' just_finished_i_option = True elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' just_finished_i_option = (entity_type == 'OPTION') if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer).strip() structured_data.append(current_item) for item in structured_data: if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() return structured_data # --------------------------------------------------------- # 5. PROCESSING PIPELINE # --------------------------------------------------------- def process_pdf(pdf_file): if pdf_file is None: return None, "⚠️ Please upload a PDF file." try: active_model = load_model() extracted_pages = [] with pdfplumber.open(pdf_file.name) as pdf: for page_idx, page in enumerate(pdf.pages): width, height = page.width, page.height words_data = page.extract_words() page_tokens = [] page_bboxes = [] for w in words_data: text = w['text'] x0 = int((w['x0'] / width) * 1000) top = int((w['top'] / height) * 1000) x1 = int((w['x1'] / width) * 1000) bottom = int((w['bottom'] / height) * 1000) box = [max(0, min(x0, 1000)), max(0, min(top, 1000)), max(0, min(x1, 1000)), max(0, min(bottom, 1000))] page_tokens.append(text) page_bboxes.append(box) extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes}) raw_predictions = [] for page in extracted_pages: tokens = page['tokens'] bboxes = page['bboxes'] if not tokens: continue encoding = tokenizer( tokens, boxes=bboxes, return_tensors="pt", padding="max_length", truncation=True, max_length=512, return_offsets_mapping=True ) input_ids = encoding.input_ids.to(device) bbox = encoding.bbox.to(device) attention_mask = encoding.attention_mask.to(device) with torch.no_grad(): # Get the tag indices from the CRF layer pred_tags = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) # If batch size is 1, pred_tags is a list of lists: [[tags...]] pred_tags = pred_tags[0] word_ids = encoding.word_ids() aligned_data = [] prev_word_idx = None for i, word_idx in enumerate(word_ids): if word_idx is None: continue if word_idx != prev_word_idx: if i < len(pred_tags): label_id = pred_tags[i] # Safe retrieval of label string label_str = ID2LABEL.get(label_id, "O") aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str}) prev_word_idx = word_idx raw_predictions.append({"data": aligned_data}) final_json = convert_bio_to_structured_json(raw_predictions) output_filename = "structured_output.json" with open(output_filename, "w", encoding="utf-8") as f: json.dump(final_json, f, indent=2, ensure_ascii=False) return output_filename, f"✅ Success! Processed {len(extracted_pages)} pages. Extracted {len(final_json)} items." except Exception as e: import traceback return None, f"❌ Error:\n{str(e)}\n\nTraceback:\n{traceback.format_exc()}" # --------------------------------------------------------- # 6. GRADIO INTERFACE # --------------------------------------------------------- iface = gr.Interface( fn=process_pdf, inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), outputs=[ gr.File(label="Download JSON Output"), gr.Textbox(label="Status Log", lines=10) ], title="LayoutLMv3 + BiLSTM Hybrid Model Inference", description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.", flagging_mode="never" ) if __name__ == "__main__": iface.launch()